add validation

main
Cizz22 15 hours ago
parent c27cef35eb
commit afcf5ab780

@ -1,9 +1,10 @@
from typing import Dict, List, Optional from typing import Annotated, Dict, List, Optional
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from fastapi.params import Query from fastapi.params import Query
from src.auth.service import Token from src.auth.service import Token
from src.calculation_budget_constrains.schema import BudgetContraintQuery
from src.calculation_target_reliability.service import get_simulation_results from src.calculation_target_reliability.service import get_simulation_results
from src.config import TC_RBD_ID from src.config import TC_RBD_ID
from src.database.core import CollectorDbSession, DbSession from src.database.core import CollectorDbSession, DbSession
@ -20,10 +21,10 @@ async def get_target_reliability(
token: Token, token: Token,
session_id: str, session_id: str,
collector_db: CollectorDbSession, collector_db: CollectorDbSession,
cost_threshold: float = Query(100), params: Annotated[BudgetContraintQuery, Query()],
): ):
"""Get all scope pagination.""" """Get all scope pagination."""
cost_threshold = params.cost_threshold
results = await get_simulation_results( results = await get_simulation_results(
simulation_id = TC_RBD_ID, simulation_id = TC_RBD_ID,
token=token token=token

@ -32,6 +32,9 @@ class OverhaulRead(OverhaulBase):
systemComponents: Dict[str, Any] systemComponents: Dict[str, Any]
class BudgetContraintQuery(DefultBase):
cost_threshold: float = 100
# { # {
# "overview": { # "overview": {
# "totalEquipment": 30, # "totalEquipment": 30,

@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import Dict, List, Optional from typing import Dict, List, Optional
from typing_extensions import Annotated
from temporalio.client import Client from temporalio.client import Client
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from fastapi.params import Query from fastapi.params import Query
@ -11,7 +12,7 @@ from src.auth.service import Token
from src.models import StandardResponse from src.models import StandardResponse
from .service import run_rbd_simulation, get_simulation_results, identify_worst_eaf_contributors from .service import run_rbd_simulation, get_simulation_results, identify_worst_eaf_contributors
from .schema import OptimizationResult from .schema import OptimizationResult, TargetReliabiltiyQuery
router = APIRouter() router = APIRouter()
@ -37,13 +38,20 @@ async def get_target_reliability(
db_session: DbSession, db_session: DbSession,
token: Token, token: Token,
collector_db: CollectorDbSession, collector_db: CollectorDbSession,
oh_session_id: Optional[str] = Query(None), params: Annotated[TargetReliabiltiyQuery, Query()],
eaf_input: float = Query(99.8), # oh_session_id: Optional[str] = Query(None),
duration: int = Query(17520), # eaf_input: float = Query(99.8),
simulation_id: Optional[str] = Query(None), # duration: int = Query(17520),
cut_hours = Query(0) # simulation_id: Optional[str] = Query(None),
# cut_hours = Query(0)
): ):
"""Get all scope pagination.""" """Get all scope pagination."""
oh_session_id = params.oh_session_id,
eaf_input = params.eaf_input,
duration = params.duration,
simulation_id = params.simulation_id,
cut_hours = params.cut_hours
if not oh_session_id: if not oh_session_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

@ -64,6 +64,12 @@ class OptimizationResult(OverhaulBase):
simulation_id: Optional[str] = None simulation_id: Optional[str] = None
class TargetReliabiltiyQuery(DefultBase):
oh_session_id: Optional[str] = Field(None)
eaf_input: float = Field(99.8)
duration: int = Field(17520)
simulation_id: Optional[str] = Field(None)
cut_hours:int = Field(0)
# { # {
# "overview": { # "overview": {

@ -1,4 +1,4 @@
from typing import List, Optional, Union from typing import Annotated, List, Optional, Union
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.params import Query from fastapi.params import Query
@ -18,7 +18,7 @@ from .schema import (CalculationResultsRead,
CalculationTimeConstrainsParametersCreate, CalculationTimeConstrainsParametersCreate,
CalculationTimeConstrainsParametersRead, CalculationTimeConstrainsParametersRead,
CalculationTimeConstrainsParametersRetrive, CalculationTimeConstrainsParametersRetrive,
CalculationTimeConstrainsRead, EquipmentResult) CalculationTimeConstrainsRead, CreateCalculationQuery, EquipmentResult)
from .service import (bulk_update_equipment, get_calculation_result, from .service import (bulk_update_equipment, get_calculation_result,
get_calculation_result_by_day, get_calculation_by_assetnum) get_calculation_result_by_day, get_calculation_by_assetnum)
from src.database.core import CollectorDbSession from src.database.core import CollectorDbSession
@ -36,11 +36,15 @@ async def create_calculation_time_constrains(
collector_db_session: CollectorDbSession, collector_db_session: CollectorDbSession,
current_user: CurrentUser, current_user: CurrentUser,
calculation_time_constrains_in: CalculationTimeConstrainsParametersCreate, calculation_time_constrains_in: CalculationTimeConstrainsParametersCreate,
scope_calculation_id: Optional[str] = Query(None), params: Annotated[CreateCalculationQuery, Query()],
with_results: Optional[int] = Query(0), # scope_calculation_id: Optional[str] = Query(None),
simulation_id = Query(None) # with_results: Optional[int] = Query(0),
# simulation_id = Query(None)
): ):
"""Save calculation time constrains Here""" """Save calculation time constrains Here"""
scope_calculation_id = params.scope_calculation_id
with_results = params.with_results
simulation_id = params.simulation_id
if scope_calculation_id: if scope_calculation_id:
results = await get_or_create_scope_equipment_calculation( results = await get_or_create_scope_equipment_calculation(
@ -96,9 +100,9 @@ async def get_calculation_results(db_session: DbSession, calculation_id, token:I
db_session=db_session, calculation_id=calculation_id, token=token, include_risk_cost=include_risk_cost db_session=db_session, calculation_id=calculation_id, token=token, include_risk_cost=include_risk_cost
) )
requests.post(f"{config.AUTH_SERVICE_API}/sign-out", headers={ # requests.post(f"{config.AUTH_SERVICE_API}/sign-out", headers={
"Authorization": f"Bearer {token}" # "Authorization": f"Bearer {token}"
}) # })
return StandardResponse( return StandardResponse(
data=results, data=results,

@ -128,3 +128,9 @@ class CalculationTimeConstrainsSimulationRead(CalculationTimeConstrainsBase):
class CalculationSelectedEquipmentUpdate(CalculationTimeConstrainsBase): class CalculationSelectedEquipmentUpdate(CalculationTimeConstrainsBase):
is_included: bool is_included: bool
location_tag: str location_tag: str
class CreateCalculationQuery(DefultBase):
scope_calculation_id: Optional[str] = Field(None)
with_results: Optional[int] = Field(0)
simulation_id: Optional[UUID] = Field(None)

@ -0,0 +1,22 @@
from typing import List, Optional
from pydantic import Field
from src.models import DefultBase
class CommonParams(DefultBase):
# This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser")
page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647)
query_str: Optional[str] = Field(None, alias="q")
filter_spec: Optional[str] = Field(None, alias="filter")
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")
descending: List[bool] = Field(default_factory=list, alias="descending[]")
exclude: List[str] = Field(default_factory=list, alias="exclude[]")
all_params: int = Field(0, alias="all")
# Property to mirror your original return dict's bool conversion
@property
def is_all(self) -> bool:
return bool(self.all_params)

@ -1,5 +1,5 @@
import logging import logging
from typing import Annotated, List from typing import Annotated, List, Type, TypeVar
from fastapi import Depends, Query from fastapi import Depends, Query
from pydantic.types import Json, constr from pydantic.types import Json, constr
@ -7,6 +7,8 @@ from sqlalchemy import Select, desc, func, or_
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from sqlalchemy_filters import apply_pagination from sqlalchemy_filters import apply_pagination
from src.database.schema import CommonParams
from .core import DbSession from .core import DbSession
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -47,6 +49,21 @@ CommonParameters = Annotated[
Depends(common_parameters), Depends(common_parameters),
] ]
T = TypeVar("T", bound=CommonParams)
def get_params_factory(model_type: Type[T]):
async def wrapper(
db_session: DbSession,
params: Annotated[model_type, Query()] # type: ignore
):
res = params.model_dump()
return {
"db_session": db_session,
"all": params.is_all,
**res
}
return wrapper
def search(*, query_str: str, query: Query, model, sort=False): def search(*, query_str: str, query: Query, model, sort=False):
"""Perform a search based on the query.""" """Perform a search based on the query."""

@ -69,6 +69,8 @@ class DefultBase(BaseModel):
validate_assignment = True validate_assignment = True
arbitrary_types_allowed = True arbitrary_types_allowed = True
str_strip_whitespace = True str_strip_whitespace = True
populate_by_name = True
extra="forbid"
json_encoders = { json_encoders = {
# custom output conversion for datetime # custom output conversion for datetime

Loading…
Cancel
Save