add validation

main
Cizz22 7 hours ago
parent c27cef35eb
commit afcf5ab780

@ -1,9 +1,10 @@
from typing import Dict, List, Optional
from typing import Annotated, Dict, List, Optional
from fastapi import APIRouter, HTTPException, status
from fastapi.params import Query
from src.auth.service import Token
from src.calculation_budget_constrains.schema import BudgetContraintQuery
from src.calculation_target_reliability.service import get_simulation_results
from src.config import TC_RBD_ID
from src.database.core import CollectorDbSession, DbSession
@ -20,10 +21,10 @@ async def get_target_reliability(
token: Token,
session_id: str,
collector_db: CollectorDbSession,
cost_threshold: float = Query(100),
params: Annotated[BudgetContraintQuery, Query()],
):
"""Get all scope pagination."""
cost_threshold = params.cost_threshold
results = await get_simulation_results(
simulation_id = TC_RBD_ID,
token=token

@ -32,6 +32,9 @@ class OverhaulRead(OverhaulBase):
systemComponents: Dict[str, Any]
class BudgetContraintQuery(DefultBase):
cost_threshold: float = 100
# {
# "overview": {
# "totalEquipment": 30,

@ -1,5 +1,6 @@
import asyncio
from typing import Dict, List, Optional
from typing_extensions import Annotated
from temporalio.client import Client
from fastapi import APIRouter, HTTPException, status
from fastapi.params import Query
@ -11,7 +12,7 @@ from src.auth.service import Token
from src.models import StandardResponse
from .service import run_rbd_simulation, get_simulation_results, identify_worst_eaf_contributors
from .schema import OptimizationResult
from .schema import OptimizationResult, TargetReliabiltiyQuery
router = APIRouter()
@ -37,13 +38,20 @@ async def get_target_reliability(
db_session: DbSession,
token: Token,
collector_db: CollectorDbSession,
oh_session_id: Optional[str] = Query(None),
eaf_input: float = Query(99.8),
duration: int = Query(17520),
simulation_id: Optional[str] = Query(None),
cut_hours = Query(0)
params: Annotated[TargetReliabiltiyQuery, Query()],
# oh_session_id: Optional[str] = Query(None),
# eaf_input: float = Query(99.8),
# duration: int = Query(17520),
# simulation_id: Optional[str] = Query(None),
# cut_hours = Query(0)
):
"""Get all scope pagination."""
oh_session_id = params.oh_session_id,
eaf_input = params.eaf_input,
duration = params.duration,
simulation_id = params.simulation_id,
cut_hours = params.cut_hours
if not oh_session_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

@ -64,6 +64,12 @@ class OptimizationResult(OverhaulBase):
simulation_id: Optional[str] = None
class TargetReliabiltiyQuery(DefultBase):
oh_session_id: Optional[str] = Field(None)
eaf_input: float = Field(99.8)
duration: int = Field(17520)
simulation_id: Optional[str] = Field(None)
cut_hours:int = Field(0)
# {
# "overview": {

@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import Annotated, List, Optional, Union
from fastapi import APIRouter
from fastapi.params import Query
@ -18,7 +18,7 @@ from .schema import (CalculationResultsRead,
CalculationTimeConstrainsParametersCreate,
CalculationTimeConstrainsParametersRead,
CalculationTimeConstrainsParametersRetrive,
CalculationTimeConstrainsRead, EquipmentResult)
CalculationTimeConstrainsRead, CreateCalculationQuery, EquipmentResult)
from .service import (bulk_update_equipment, get_calculation_result,
get_calculation_result_by_day, get_calculation_by_assetnum)
from src.database.core import CollectorDbSession
@ -36,11 +36,15 @@ async def create_calculation_time_constrains(
collector_db_session: CollectorDbSession,
current_user: CurrentUser,
calculation_time_constrains_in: CalculationTimeConstrainsParametersCreate,
scope_calculation_id: Optional[str] = Query(None),
with_results: Optional[int] = Query(0),
simulation_id = Query(None)
params: Annotated[CreateCalculationQuery, Query()],
# scope_calculation_id: Optional[str] = Query(None),
# with_results: Optional[int] = Query(0),
# simulation_id = Query(None)
):
"""Save calculation time constrains Here"""
scope_calculation_id = params.scope_calculation_id
with_results = params.with_results
simulation_id = params.simulation_id
if scope_calculation_id:
results = await get_or_create_scope_equipment_calculation(
@ -96,9 +100,9 @@ async def get_calculation_results(db_session: DbSession, calculation_id, token:I
db_session=db_session, calculation_id=calculation_id, token=token, include_risk_cost=include_risk_cost
)
requests.post(f"{config.AUTH_SERVICE_API}/sign-out", headers={
"Authorization": f"Bearer {token}"
})
# requests.post(f"{config.AUTH_SERVICE_API}/sign-out", headers={
# "Authorization": f"Bearer {token}"
# })
return StandardResponse(
data=results,

@ -128,3 +128,9 @@ class CalculationTimeConstrainsSimulationRead(CalculationTimeConstrainsBase):
class CalculationSelectedEquipmentUpdate(CalculationTimeConstrainsBase):
is_included: bool
location_tag: str
class CreateCalculationQuery(DefultBase):
scope_calculation_id: Optional[str] = Field(None)
with_results: Optional[int] = Field(0)
simulation_id: Optional[UUID] = Field(None)

@ -0,0 +1,22 @@
from typing import List, Optional
from pydantic import Field
from src.models import DefultBase
class CommonParams(DefultBase):
# This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser")
page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647)
query_str: Optional[str] = Field(None, alias="q")
filter_spec: Optional[str] = Field(None, alias="filter")
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")
descending: List[bool] = Field(default_factory=list, alias="descending[]")
exclude: List[str] = Field(default_factory=list, alias="exclude[]")
all_params: int = Field(0, alias="all")
# Property to mirror your original return dict's bool conversion
@property
def is_all(self) -> bool:
return bool(self.all_params)

@ -1,5 +1,5 @@
import logging
from typing import Annotated, List
from typing import Annotated, List, Type, TypeVar
from fastapi import Depends, Query
from pydantic.types import Json, constr
@ -7,6 +7,8 @@ from sqlalchemy import Select, desc, func, or_
from sqlalchemy.exc import ProgrammingError
from sqlalchemy_filters import apply_pagination
from src.database.schema import CommonParams
from .core import DbSession
log = logging.getLogger(__name__)
@ -47,6 +49,21 @@ CommonParameters = Annotated[
Depends(common_parameters),
]
T = TypeVar("T", bound=CommonParams)
def get_params_factory(model_type: Type[T]):
async def wrapper(
db_session: DbSession,
params: Annotated[model_type, Query()] # type: ignore
):
res = params.model_dump()
return {
"db_session": db_session,
"all": params.is_all,
**res
}
return wrapper
def search(*, query_str: str, query: Query, model, sort=False):
"""Perform a search based on the query."""

@ -69,6 +69,8 @@ class DefultBase(BaseModel):
validate_assignment = True
arbitrary_types_allowed = True
str_strip_whitespace = True
populate_by_name = True
extra="forbid"
json_encoders = {
# custom output conversion for datetime

Loading…
Cancel
Save