diff --git a/Jenkinsfile b/Jenkinsfile index 2559cb0..2d43bd4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4,7 +4,6 @@ pipeline { environment { DOCKER_HUB_USERNAME = 'aimodocker' // This creates DOCKER_AUTH_USR and DOCKER_AUTH_PSW - DOCKER_AUTH = credentials('aimodocker') IMAGE_NAME = 'lcca-service' SERVICE_NAME = 'ahm-app' @@ -55,13 +54,6 @@ pipeline { // } // } - stage('Docker Login') { - steps { - // Fixed variable names based on the 'DOCKER_AUTH' environment key - sh "echo ${DOCKER_AUTH_PSW} | docker login -u ${DOCKER_AUTH_USR} --password-stdin" - } - } - stage('Build & Tag') { steps { script { @@ -75,14 +67,19 @@ pipeline { } } - stage('Push to Docker Hub') { + stage('Docker Login & Push') { steps { script { def fullImageName = "${DOCKER_HUB_USERNAME}/${IMAGE_NAME}" - sh "docker push ${fullImageName}:${IMAGE_TAG}" - - if (SECONDARY_TAG) { - sh "docker push ${fullImageName}:${SECONDARY_TAG}" + withCredentials([usernamePassword(credentialsId: 'aimodocker', passwordVariable: 'DOCKER_PSW', usernameVariable: 'DOCKER_USR')]) { + // Use single quotes to prevent Groovy from interpolating the secret in logs + sh 'echo $DOCKER_PSW | docker login -u $DOCKER_USR --password-stdin' + + sh "docker push ${fullImageName}:${IMAGE_TAG}" + + if (SECONDARY_TAG) { + sh "docker push ${fullImageName}:${SECONDARY_TAG}" + } } } } diff --git a/src/acquisition_cost/router.py b/src/acquisition_cost/router.py index 194de28..7fd49a7 100644 --- a/src/acquisition_cost/router.py +++ b/src/acquisition_cost/router.py @@ -33,6 +33,23 @@ async def get_yeardatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[AcquisitionCostDataPagination]) +async def get_yeardatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all acquisition_cost_data for export.""" + common["all"] = True + get_acquisition_cost_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=get_acquisition_cost_data, + message="All Acquisition Cost Data retrieved successfully", + ) + @router.get("/{acquisition_cost_data_id}", response_model=StandardResponse[AcquisitionCostDataRead]) async def get_acquisition_cost_data(db_session: DbSession, acquisition_cost_data_id: str): diff --git a/src/acquisition_cost/schema.py b/src/acquisition_cost/schema.py index 23dbdd8..c455e26 100644 --- a/src/acquisition_cost/schema.py +++ b/src/acquisition_cost/schema.py @@ -34,13 +34,4 @@ class AcquisitionCostDataPagination(Pagination): class ListQueryParams(CommonParams): - items_per_page: Optional[int] = Field( - default=5, - ge=1, - le=1000, - description="Number of items per page" - ) - search: Optional[str] = Field( - default=None, - description="Search keyword" - ) \ No newline at end of file + pass \ No newline at end of file diff --git a/src/database/service.py b/src/database/service.py index baa61e7..62e29df 100644 --- a/src/database/service.py +++ b/src/database/service.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, List +from typing import Annotated, List, Optional from sqlalchemy import desc, func, or_, Select from sqlalchemy_filters import apply_pagination @@ -18,9 +18,11 @@ QueryStr = constr(pattern=r"^[ -~]+$", min_length=1) def common_parameters( db_session: DbSession, # type: ignore - current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore + current_user: Optional[str] = Query(None, alias="currentUser"), # type: ignore + current_user_snake: Optional[str] = Query(None, alias="current_user"), # type: ignore page: int = Query(1, gt=0, lt=2147483647), - items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647), + items_per_page: Optional[int] = Query(None, alias="items_per_page", gt=-2, lt=2147483647), + items_per_page_camel: Optional[int] = Query(None, alias="itemsPerPage", gt=-2, lt=2147483647), query_str: QueryStr = Query(None, alias="q"), # type: ignore filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore sort_by: List[str] = Query([], alias="sortBy[]"), @@ -28,15 +30,23 @@ def common_parameters( all: int = Query(0), # role: QueryStr = Depends(get_current_role), ): + # Support both snake_case and camelCase for pagination size + final_items_per_page = items_per_page_camel if items_per_page_camel is not None else ( + items_per_page if items_per_page is not None else 5 + ) + + # Support both snake_case and camelCase for current user + final_current_user = current_user or current_user_snake + return { "db_session": db_session, "page": page, - "items_per_page": items_per_page, + "items_per_page": final_items_per_page, "query_str": query_str, "filter_spec": filter_spec, "sort_by": sort_by, "descending": descending, - "current_user": current_user, + "current_user": final_current_user, # "role": role, "all": bool(all), } diff --git a/src/equipment/router.py b/src/equipment/router.py index 4a49db6..6927e5c 100644 --- a/src/equipment/router.py +++ b/src/equipment/router.py @@ -62,6 +62,23 @@ async def get_equipments( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[EquipmentPagination]) +async def get_equipments_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all equipment for export.""" + common["all"] = True + equipment_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=equipment_data, + message="All Equipment Data retrieved successfully", + ) + @router.get("/maximo/{assetnum}", response_model=StandardResponse[List[dict]]) @@ -211,6 +228,18 @@ async def get_calculated_top_10_replacement_priorities(db_session: DbSession, co message="Top 10 Replacement Priorities Data retrieved successfully", ) +@router.get( + "/top-10-replacement-priorities-export-all", + response_model=StandardResponse[EquipmentTop10Pagination], +) +async def get_calculated_top_10_replacement_priorities_all(db_session: DbSession, common: CommonParameters): + common["all"] = True + equipment_data = await get_top_10_replacement_priorities(db_session=db_session, common=common) + return StandardResponse( + data=equipment_data, + message="All Replacement Priorities Data retrieved successfully", + ) + @router.get( "/top-10-economic-life", response_model=StandardResponse[EquipmentTop10Pagination], @@ -224,6 +253,18 @@ async def get_calculated_top_10_economic_life(db_session: DbSession, common: Com message="Top 10 Economic Life Data retrieved successfully", ) +@router.get( + "/top-10-economic-life-export-all", + response_model=StandardResponse[EquipmentTop10Pagination], +) +async def get_calculated_top_10_economic_life_all(db_session: DbSession, common: CommonParameters): + common["all"] = True + equipment_data = await get_top_10_economic_life(db_session=db_session, common=common) + return StandardResponse( + data=equipment_data, + message="All Economic Life Data retrieved successfully", + ) + @router.get("/tree", response_model=StandardResponse[EquipmentRead]) async def get_equipment_tree(): diff --git a/src/equipment/schema.py b/src/equipment/schema.py index 50fc188..0d6b449 100644 --- a/src/equipment/schema.py +++ b/src/equipment/schema.py @@ -34,9 +34,17 @@ class EquipmentBase(DefaultBase): updated_by: Optional[str] = Field(None) class EquipmentMasterBase(DefaultBase): - location_tag: Optional[str] = Field(None) - assetnum: Optional[str] = Field(None) + id: Optional[UUID] = Field(None) name: Optional[str] = Field(None) + parent_id: Optional[UUID] = Field(None) + equipment_tree_id: Optional[UUID] = Field(None) + category_id: Optional[UUID] = Field(None) + system_tag: Optional[str] = Field(None) + assetnum: Optional[str] = Field(None) + location_tag: Optional[str] = Field(None) + image_name: Optional[str] = Field(None) + description: Optional[str] = Field(None) + class MasterBase(DefaultBase): assetnum: Optional[str] = Field(None) @@ -162,15 +170,5 @@ class CountRemainingLifeResponse(DefaultBase): critical: int class ListQueryParams(CommonParams): - items_per_page: Optional[int] = Field( - default=5, - ge=1, - le=1000, - description="Number of items per page", - alias="itemsPerPage" - ) - search: Optional[str] = Field( - default=None, - description="Search keyword" - ) + pass diff --git a/src/equipment_master/model.py b/src/equipment_master/model.py index f9d64ba..c86e64b 100644 --- a/src/equipment_master/model.py +++ b/src/equipment_master/model.py @@ -31,7 +31,8 @@ class EquipmentMaster(Base, DefaultMixin): system_tag = Column(String, nullable=True) assetnum = Column(String, nullable=True) location_tag = Column(String, nullable=True) - + image_name = Column(String, nullable=True) + description = Column(String, nullable=True) # Relationship definitions # Define both sides of the relationship # parent = relationship( diff --git a/src/equipment_master/router.py b/src/equipment_master/router.py index a201d76..e843dbb 100644 --- a/src/equipment_master/router.py +++ b/src/equipment_master/router.py @@ -28,6 +28,21 @@ async def get_all_equipment_master_tree( data=equipment_masters, message="Data retrieved successfully" ) +@router.get("/export-all", response_model=StandardResponse[EquipmentMasterPaginated]) +async def get_all_equipment_master_tree_export_all( + db_session: DbSession, + common: CommonParameters, +): + common["all"] = True + equipment_masters = await get_all_master( + db_session=db_session, + common=common, + ) + + return StandardResponse( + data=equipment_masters, message="All Equipment Master Data retrieved successfully" + ) + @router.get( "/{equipment_master_id}", response_model=StandardResponse[EquipmentMasterRead] diff --git a/src/equipment_master/schema.py b/src/equipment_master/schema.py index ae94d76..54eed2a 100644 --- a/src/equipment_master/schema.py +++ b/src/equipment_master/schema.py @@ -46,5 +46,4 @@ class EquipmentMasterPaginated(Pagination): class EquipmentMasterQuery(CommonParams): parent_id : Optional[str] = None - items_per_page : Optional[int] = 5 search : Optional[str] = None \ No newline at end of file diff --git a/src/manpower_cost/router.py b/src/manpower_cost/router.py index d34a809..6eaa8da 100644 --- a/src/manpower_cost/router.py +++ b/src/manpower_cost/router.py @@ -32,6 +32,23 @@ async def get_yeardatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[ManpowerCostPagination]) +async def get_yeardatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all manpower_cost_data for export.""" + common["all"] = True + get_acquisition_cost_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=get_acquisition_cost_data, + message="All Manpower Cost Data retrieved successfully", + ) + @router.get("/{acquisition_cost_data_id}", response_model=StandardResponse[ManpowerCostRead]) async def get_acquisition_cost_data(db_session: DbSession, acquisition_cost_data_id: str): diff --git a/src/manpower_cost/schema.py b/src/manpower_cost/schema.py index 91eda64..cf420a5 100644 --- a/src/manpower_cost/schema.py +++ b/src/manpower_cost/schema.py @@ -34,5 +34,4 @@ class ManpowerCostPagination(Pagination): class QueryParams(CommonParams): - items_per_page: Optional[int] = Field(5) - search: Optional[str] = Field(None) \ No newline at end of file + pass diff --git a/src/manpower_master/router.py b/src/manpower_master/router.py index d34a809..69b4929 100644 --- a/src/manpower_master/router.py +++ b/src/manpower_master/router.py @@ -32,6 +32,23 @@ async def get_yeardatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[ManpowerCostPagination]) +async def get_yeardatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all manpower_master_data for export.""" + common["all"] = True + get_acquisition_cost_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=get_acquisition_cost_data, + message="All Manpower Master Data retrieved successfully", + ) + @router.get("/{acquisition_cost_data_id}", response_model=StandardResponse[ManpowerCostRead]) async def get_acquisition_cost_data(db_session: DbSession, acquisition_cost_data_id: str): diff --git a/src/manpower_master/schema.py b/src/manpower_master/schema.py index c945b86..26f010d 100644 --- a/src/manpower_master/schema.py +++ b/src/manpower_master/schema.py @@ -33,5 +33,4 @@ class ManpowerCostPagination(Pagination): items: List[ManpowerCostRead] = [] class QueryParams(CommonParams): - items_per_page: Optional[int] = Field(5) search: Optional[str] = Field(None) \ No newline at end of file diff --git a/src/masterdata/router.py b/src/masterdata/router.py index b351ee3..29bb2b0 100644 --- a/src/masterdata/router.py +++ b/src/masterdata/router.py @@ -1,9 +1,9 @@ from typing import Annotated, Optional, List -from fastapi import APIRouter, HTTPException, status, Query +from fastapi import APIRouter, HTTPException, status, Query, Depends from sqlalchemy import Select -from src.manpower_cost.schema import QueryParams +from .schema import QueryParams from .model import MasterData from .schema import ( MasterDataPagination, @@ -25,7 +25,7 @@ router = APIRouter() async def get_masterdatas( db_session: DbSession, common: CommonParameters, - params: Annotated[QueryParams, Query()], + params: Annotated[QueryParams, Depends()], ): """Get all documents.""" # return @@ -40,6 +40,23 @@ async def get_masterdatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[MasterDataPagination]) +async def get_masterdatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all documents for export.""" + common["all"] = True + master_datas = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=master_datas, + message="All Master Data retrieved successfully", + ) + @router.get("/{masterdata_id}", response_model=StandardResponse[MasterDataRead]) async def get_masterdata(db_session: DbSession, masterdata_id: str): diff --git a/src/masterdata/schema.py b/src/masterdata/schema.py index b2fe898..43dea94 100644 --- a/src/masterdata/schema.py +++ b/src/masterdata/schema.py @@ -2,8 +2,8 @@ from datetime import datetime from typing import List, Optional from uuid import UUID -from pydantic import BaseModel, Field -from src.models import DefaultBase, Pagination +from pydantic import BaseModel, Field, model_validator +from src.models import CommonParams, DefaultBase, Pagination from src.auth.service import CurrentUser @@ -52,13 +52,5 @@ class MasterDataPagination(Pagination): items: List[MasterDataRead] = [] -class QueryParams(BaseModel): - items_per_page: Optional[int] = Field( - 5, - ge=1, - description="Items per page" - ) - search: Optional[str] = Field( - None, - description="Search keyword" - ) +class QueryParams(CommonParams): + pass diff --git a/src/masterdata_simulations/schema.py b/src/masterdata_simulations/schema.py index b5726e0..561b7f2 100644 --- a/src/masterdata_simulations/schema.py +++ b/src/masterdata_simulations/schema.py @@ -45,9 +45,3 @@ class QueryParams(CommonParams): ..., description="Simulation identifier", ) - items_per_page: Optional[int] = Field( - 5, - ge=1, - description="Items per page" - ) - search: Optional[str] = Field(None) \ No newline at end of file diff --git a/src/middleware.py b/src/middleware.py index 5599a59..65c9aa3 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -14,6 +14,37 @@ ALLOWED_MULTI_PARAMS = { "exclude[]", } +# Whitelist of ALL allowed query parameter names across the application. +# Any param NOT in this set will be rejected. +ALLOWED_QUERY_PARAMS = { + # CommonParameters (from database/service.py common_parameters) + "currentUser", + "page", + "itemsPerPage", + "q", + "filter", + "sortBy[]", + "descending[]", + "all", + # ListQueryParams / QueryParams used across routers + "items_per_page", + "search", + # equipment_master specific + "parent_id", + # masterdata_simulations / plant_transaction_data_simulations specific + "simulation_id", + # exclude + "exclude[]", +} + +# Query params that are ONLY allowed for "write" operations (read operations use ALLOWED_QUERY_PARAMS). +# For GET/POST/PUT/etc, whitelisting still applies. +WRITE_METHOD_ALLOWED_PARAMS = { + # Only auth/session params are allowed in query for write methods. + # Data values (like simulation_id) must be in the JSON body for these methods. + "currentUser", +} + MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB @@ -62,31 +93,31 @@ def has_control_chars(value: str) -> bool: def inspect_value(value: str, source: str): if XSS_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential XSS payload detected in {source}", ) if SQLI_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential SQL injection payload detected in {source}", ) if RCE_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential RCE payload detected in {source}", ) if TRAVERSAL_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential Path Traversal payload detected in {source}", ) if has_control_chars(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Invalid control characters detected in {source}", ) @@ -96,7 +127,7 @@ def inspect_json(obj, path="body"): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: raise HTTPException( - status_code=400, + status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}", ) inspect_json(value, f"{path}.{key}") @@ -126,12 +157,28 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if len(params) > MAX_QUERY_PARAMS: raise HTTPException( - status_code=400, + status_code=422, detail="Too many query parameters", ) # ------------------------- - # 2. Duplicate parameters + # 2. Query param whitelist + # ------------------------- + # For GET, we allow data parameters like page, search, etc. + # For POST, PUT, DELETE, PATCH, we ONLY allow auth/session params. + active_whitelist = ALLOWED_QUERY_PARAMS if request.method == "GET" else WRITE_METHOD_ALLOWED_PARAMS + + unknown_params = [ + key for key, _ in params if key not in active_whitelist + ] + if unknown_params: + raise HTTPException( + status_code=422, + detail=f"Unknown query parameters are not allowed for {request.method} request: {unknown_params}", + ) + + # ------------------------- + # 3. Duplicate parameters # ------------------------- counter = Counter(key for key, _ in params) duplicates = [ @@ -141,33 +188,77 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if duplicates: raise HTTPException( - status_code=400, + status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}", ) # ------------------------- - # 3. Query param inspection + # 4. JSON body inspection & Single source enforcement + # Ensuring data comes from ONLY one source (Query OR Body). + # ------------------------- + content_type = request.headers.get("content-type", "") + has_json_header = content_type.startswith("application/json") + + # Read body now so we can check if it's actually empty + body = b"" + if has_json_header: + body = await request.body() + + # We consider it a "JSON body" source ONLY if it's not empty and not just "{}" + has_actual_json_body = has_json_header and body and body.strip() != b"{}" + + # Check for data parameters in query (anything whitelisted as 'data' but not 'session/auth') + data_params_in_query = [ + key for key, _ in params + if key in ALLOWED_QUERY_PARAMS and key not in WRITE_METHOD_ALLOWED_PARAMS + ] + + if has_actual_json_body: + # If sending actual JSON body, we forbid any data in query string (one source only) + if data_params_in_query: + raise HTTPException( + status_code=422, + detail=f"Single source enforcement: Data received from both JSON body and query string ({data_params_in_query}). Use only one source.", + ) + + # Special case: GET with actual body is discouraged/forbidden + if request.method == "GET": + raise HTTPException( + status_code=422, + detail="GET requests must use query parameters, not JSON body.", + ) + + # ------------------------- + # 5. Query param inspection # ------------------------- pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} for key, value in params: if value: inspect_value(value, f"query param '{key}'") - + # Pagination constraint: multiples of 5, max 50 if key in pagination_size_keys and value: try: size_val = int(value) if size_val > 50: - raise HTTPException(status_code=400, detail=f"Pagination size '{key}' cannot exceed 50") + raise HTTPException( + status_code=422, + detail=f"Pagination size '{key}' cannot exceed 50", + ) if size_val % 5 != 0: - raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be a multiple of 5") + raise HTTPException( + status_code=422, + detail=f"Pagination size '{key}' must be a multiple of 5", + ) except ValueError: - raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be an integer") + raise HTTPException( + status_code=422, + detail=f"Pagination size '{key}' must be an integer", + ) # ------------------------- - # 4. Content-Type sanity + # 6. Content-Type sanity # ------------------------- - content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) for t in ( @@ -182,32 +273,22 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ) # ------------------------- - # 5. JSON body inspection + # 7. JSON body inspection & Re-injection # ------------------------- - if content_type.startswith("application/json"): - body = await request.body() - - #if len(body) > MAX_JSON_BODY_SIZE: - # raise HTTPException( - # status_code=413, - # detail="JSON body too large", - # ) - + if has_json_header: if body: try: payload = json.loads(body) except json.JSONDecodeError: raise HTTPException( - status_code=400, + status_code=422, detail="Invalid JSON body", ) - inspect_json(payload) # Re-inject body for downstream handlers async def receive(): return {"type": "http.request", "body": body} - request._receive = receive # noqa: protected-access return await call_next(request) diff --git a/src/models.py b/src/models.py index 9fdaee8..456374b 100644 --- a/src/models.py +++ b/src/models.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Generic, List, Optional, TypeVar import uuid -from pydantic import BaseModel, Field, SecretStr, ConfigDict +from pydantic import BaseModel, Field, SecretStr, ConfigDict, model_validator from sqlalchemy import Column, DateTime, String, func, event from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -100,16 +100,29 @@ class StandardResponse(BaseModel, Generic[T]): class CommonParams(DefaultBase): # This ensures no extra query params are allowed - current_user: Optional[str] = Field(None, alias="currentUser") + current_user: Optional[str] = Field(None, alias="current_user") + currentUser: Optional[str] = Field(None, description="Alias for current_user") page: int = Field(1, gt=0, lt=2147483647) - items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage") + items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="items_per_page") + itemsPerPage: Optional[int] = Field(None, description="Alias for items_per_page") query_str: Optional[str] = Field(None, alias="q") + search: Optional[str] = Field(None, description="Search keyword") filter_spec: Optional[str] = Field(None, alias="filter") - sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") - descending: List[bool] = Field(default_factory=list, alias="descending[]") - exclude: List[str] = Field(default_factory=list, alias="exclude[]") + sort_by: List[str] = Field(default=[], alias="sortBy[]") + descending: List[bool] = Field(default=[], alias="descending[]") + exclude: List[str] = Field(default=[], alias="exclude[]") all_params: int = Field(0, alias="all") + @model_validator(mode="before") + @classmethod + def resolve_aliases(cls, data: any) -> any: + if isinstance(data, dict): + if "itemsPerPage" in data and data["itemsPerPage"] is not None: + data.setdefault("items_per_page", data["itemsPerPage"]) + if "currentUser" in data and data["currentUser"] is not None: + data.setdefault("current_user", data["currentUser"]) + return data + # Property to mirror your original return dict's bool conversion @property def is_all(self) -> bool: diff --git a/src/plant_fs_transaction_data/router.py b/src/plant_fs_transaction_data/router.py index 80a126b..93e7bd3 100644 --- a/src/plant_fs_transaction_data/router.py +++ b/src/plant_fs_transaction_data/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from fastapi import APIRouter, HTTPException, Query, status @@ -16,6 +16,7 @@ from .schema import ( PlantFSTransactionDataRead, PlantFSTransactionDataUpdate, PlantFSChartData, + ListQueryParams, ) from .service import create, delete, get, get_all, update, update_fs_charts_from_matrix, get_charts @@ -28,15 +29,14 @@ router = APIRouter() async def list_fs_transactions( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Return paginated financial statement transaction data.""" records = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) @@ -166,8 +166,3 @@ async def delete_fs_transaction( await delete(db_session=db_session, fs_transaction_id=str(fs_transaction_id)) return StandardResponse(data=record, message="Data deleted successfully") - - - - - diff --git a/src/plant_fs_transaction_data/schema.py b/src/plant_fs_transaction_data/schema.py index 8af5592..2341a18 100644 --- a/src/plant_fs_transaction_data/schema.py +++ b/src/plant_fs_transaction_data/schema.py @@ -4,7 +4,7 @@ from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantFSTransactionDataBase(DefaultBase): @@ -100,3 +100,11 @@ class PlantFSChartData(DefaultBase): bep_year: Optional[int] = Field(None, ge=0, le=9999) bep_total_lcc: Optional[float] = Field(None, ge=0, le=1_000_000_000_000_000) + +class ListQueryParams(CommonParams): + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + + diff --git a/src/plant_masterdata/router.py b/src/plant_masterdata/router.py index f696acd..f12f7c7 100644 --- a/src/plant_masterdata/router.py +++ b/src/plant_masterdata/router.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import PlantMasterData @@ -7,6 +7,7 @@ from .schema import ( PlantMasterDataRead, PlantMasterDataCreate, PlantMasterDataUpdate, + ListQueryParams, ) from .service import get, get_all, create, update, delete @@ -22,15 +23,14 @@ router = APIRouter() async def get_masterdatas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all documents.""" # return master_datas = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) return StandardResponse( @@ -38,6 +38,23 @@ async def get_masterdatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[PlantMasterDataPagination]) +async def get_masterdatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all documents for export.""" + common["all"] = True + master_datas = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=master_datas, + message="All Plant Master Data retrieved successfully", + ) + @router.get("/{masterdata_id}", response_model=StandardResponse[PlantMasterDataRead]) async def get_masterdata(db_session: DbSession, masterdata_id: str): diff --git a/src/plant_masterdata/schema.py b/src/plant_masterdata/schema.py index 016003d..58c22d4 100644 --- a/src/plant_masterdata/schema.py +++ b/src/plant_masterdata/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination from src.auth.service import CurrentUser @@ -85,3 +85,11 @@ class PlantMasterDataRead(PlantMasterdataBase): class PlantMasterDataPagination(Pagination): items: List[PlantMasterDataRead] = [] + + +class ListQueryParams(CommonParams): + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/plant_transaction_data/router.py b/src/plant_transaction_data/router.py index bd4f02c..7baeee0 100644 --- a/src/plant_transaction_data/router.py +++ b/src/plant_transaction_data/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import PlantTransactionData @@ -10,6 +10,7 @@ from .schema import ( PlantTransactionDataCreate, PlantTransactionDataUpdate, PlantTransactionFSImport, + ListQueryParams, ) from .service import ( get, @@ -33,14 +34,13 @@ router = APIRouter() async def get_transaction_datas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all transaction_data pagination.""" plant_transaction_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return @@ -49,6 +49,23 @@ async def get_transaction_datas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[PlantTransactionDataPagination]) +async def get_transaction_datas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all transaction_data for export.""" + common["all"] = True + plant_transaction_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=plant_transaction_data, + message="All Plant Transaction Data retrieved successfully", + ) + @router.get("/charts", response_model=StandardResponse[PlantChartData]) async def get_chart_data(db_session: DbSession, common: CommonParameters): chart_data, bep_year, bep_total_lcc = await get_charts( diff --git a/src/plant_transaction_data/schema.py b/src/plant_transaction_data/schema.py index a9db71f..b31e11c 100644 --- a/src/plant_transaction_data/schema.py +++ b/src/plant_transaction_data/schema.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantTransactionDataBase(DefaultBase): @@ -117,3 +117,8 @@ class PlantTransactionDataRead(PlantTransactionDataBase): class PlantTransactionDataPagination(Pagination): items: List[PlantTransactionDataRead] = [] + + +class ListQueryParams(CommonParams): + pass + diff --git a/src/plant_transaction_data_simulations/router.py b/src/plant_transaction_data_simulations/router.py index bd6a16c..b0b845c 100644 --- a/src/plant_transaction_data_simulations/router.py +++ b/src/plant_transaction_data_simulations/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from fastapi import APIRouter, HTTPException, status, Query @@ -11,6 +11,7 @@ from src.plant_transaction_data_simulations.schema import ( PlantTransactionDataSimulationsCreate, PlantTransactionDataSimulationsUpdate, PlantTransactionFSImportSimulations, + ListQueryParams, ) from src.plant_transaction_data_simulations.service import ( get, @@ -34,17 +35,15 @@ router = APIRouter() async def get_transaction_datas( db_session: DbSession, common: CommonParameters, - simulation_id: UUID = Query(..., description="Simulation identifier"), - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all transaction_data pagination.""" plant_transaction_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, - simulation_id=simulation_id, + simulation_id=params.simulation_id, ) # return return StandardResponse( @@ -52,6 +51,25 @@ async def get_transaction_datas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[PlantTransactionDataSimulationsPagination]) +async def get_transaction_datas_export_all( + db_session: DbSession, + common: CommonParameters, + simulation_id: UUID = Query(..., description="Simulation identifier"), +): + """Get all transaction_data for export.""" + common["all"] = True + plant_transaction_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + simulation_id=simulation_id, + ) + return StandardResponse( + data=plant_transaction_data, + message="All Plant Transaction Data Simulations retrieved successfully", + ) + @router.get("/charts", response_model=StandardResponse[PlantChartDataSimulations]) async def get_chart_data( db_session: DbSession, diff --git a/src/plant_transaction_data_simulations/schema.py b/src/plant_transaction_data_simulations/schema.py index 0668854..0d8f13e 100644 --- a/src/plant_transaction_data_simulations/schema.py +++ b/src/plant_transaction_data_simulations/schema.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantTransactionDataSimulationsBase(DefaultBase): @@ -140,3 +140,11 @@ class PlantTransactionDataSimulationsRead(PlantTransactionDataSimulationsBase): class PlantTransactionDataSimulationsPagination(Pagination): items: List[PlantTransactionDataSimulationsRead] = [] + + +class ListQueryParams(CommonParams): + simulation_id: UUID = Field( + ..., + description="Simulation identifier", + ) + diff --git a/src/simulations/router.py b/src/simulations/router.py index 50b1d3b..0f4c7ad 100644 --- a/src/simulations/router.py +++ b/src/simulations/router.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, Query, status @@ -13,6 +13,7 @@ from src.simulations.schema import ( SimulationRead, SimulationRunPayload, SimulationUpdate, + ListQueryParams, ) from src.simulations.service import create, delete, get, get_all, run_simulation, update @@ -24,18 +25,33 @@ async def get_simulations( db_session: DbSession, common: CommonParameters, current_user: CurrentUser, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): simulations = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, owner=current_user.name, ) return StandardResponse(data=simulations, message="Data retrieved successfully") +@router.get("/export-all", response_model=StandardResponse[SimulationPagination]) +async def get_simulations_export_all( + db_session: DbSession, + common: CommonParameters, + current_user: CurrentUser, +): + """Get all simulations for export.""" + common["all"] = True + simulations = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + owner=current_user.name, + ) + return StandardResponse(data=simulations, message="All Simulations Data retrieved successfully") + @router.get("/{simulation_id}", response_model=StandardResponse[SimulationRead]) async def get_simulation( diff --git a/src/simulations/schema.py b/src/simulations/schema.py index 7b5a7cd..fe09cd0 100644 --- a/src/simulations/schema.py +++ b/src/simulations/schema.py @@ -4,7 +4,7 @@ from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination from src.masterdata_simulations.schema import MasterDataSimulationRead from src.plant_transaction_data_simulations.schema import ( PlantTransactionDataSimulationsRead, @@ -51,3 +51,11 @@ class MasterDataOverride(DefaultBase): class SimulationRunPayload(DefaultBase): label: Optional[str] = Field(None) overrides: List[MasterDataOverride] = Field(default_factory=list) + + +class ListQueryParams(CommonParams): + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/uploaded_file/router.py b/src/uploaded_file/router.py index d03b1b0..142af32 100644 --- a/src/uploaded_file/router.py +++ b/src/uploaded_file/router.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, Form, HTTPException, status, Query, UploadFile, File from .model import UploadedFileData -from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination +from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination, ListQueryParams from src.uploaded_file.service import get, get_all, create, update, delete from src.database.service import CommonParameters, search_filter_sort_paginate @@ -20,14 +20,13 @@ router = APIRouter() async def get_uploaded_files( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all uploaded files pagination.""" uploaded_files = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return @@ -36,6 +35,23 @@ async def get_uploaded_files( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[UploadedFileDataPagination]) +async def get_uploaded_files_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all uploaded files for export.""" + common["all"] = True + uploaded_files = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=uploaded_files, + message="All Uploaded Files Data retrieved successfully", + ) + @router.get("/{uploaded_file_id}", response_model=StandardResponse[UploadedFileDataRead]) async def get_uploaded_file(db_session: DbSession, uploaded_file_id: str): diff --git a/src/uploaded_file/schema.py b/src/uploaded_file/schema.py index baf5a44..ba32598 100644 --- a/src/uploaded_file/schema.py +++ b/src/uploaded_file/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class UploadedFileDataBase(DefaultBase): filename: str = Field(...) @@ -28,3 +28,7 @@ class UploadedFileDataRead(UploadedFileDataBase): class UploadedFileDataPagination(Pagination): items: List[UploadedFileDataRead] = [] + + +class ListQueryParams(CommonParams): + pass diff --git a/src/yeardata/router.py b/src/yeardata/router.py index b92d443..4b6efa5 100644 --- a/src/yeardata/router.py +++ b/src/yeardata/router.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import Yeardata -from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate +from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate, ListQueryParams from .service import get, get_all, create, update, delete from src.database.service import CommonParameters, search_filter_sort_paginate @@ -17,14 +17,13 @@ router = APIRouter() async def get_yeardatas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all yeardata pagination.""" year_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return @@ -33,6 +32,23 @@ async def get_yeardatas( message="Data retrieved successfully", ) +@router.get("/export-all", response_model=StandardResponse[YeardataPagination]) +async def get_yeardatas_export_all( + db_session: DbSession, + common: CommonParameters, +): + """Get all yeardata for export.""" + common["all"] = True + year_data = await get_all( + db_session=db_session, + items_per_page=-1, + common=common, + ) + return StandardResponse( + data=year_data, + message="All Year Data retrieved successfully", + ) + @router.get("/{yeardata_id}", response_model=StandardResponse[YeardataRead]) async def get_yeardata(db_session: DbSession, yeardata_id: str): diff --git a/src/yeardata/schema.py b/src/yeardata/schema.py index 039a4e7..cf2abc8 100644 --- a/src/yeardata/schema.py +++ b/src/yeardata/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field, field_validator -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class YeardataBase(DefaultBase): @@ -61,3 +61,8 @@ class YeardataRead(YeardataBase): class YeardataPagination(Pagination): items: List[YeardataRead] = [] + + +class ListQueryParams(CommonParams): + pass +