diff --git a/src/manpower_cost/schema.py b/src/manpower_cost/schema.py index 91eda64..99ca8c3 100644 --- a/src/manpower_cost/schema.py +++ b/src/manpower_cost/schema.py @@ -34,5 +34,14 @@ class ManpowerCostPagination(Pagination): class QueryParams(CommonParams): - items_per_page: Optional[int] = Field(5) - search: Optional[str] = Field(None) \ No newline at end of file + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) diff --git a/src/masterdata_simulations/schema.py b/src/masterdata_simulations/schema.py index b5726e0..8c0f048 100644 --- a/src/masterdata_simulations/schema.py +++ b/src/masterdata_simulations/schema.py @@ -46,8 +46,13 @@ class QueryParams(CommonParams): description="Simulation identifier", ) items_per_page: Optional[int] = Field( - 5, + default=5, ge=1, - description="Items per page" + le=1000, + description="Items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", ) - search: Optional[str] = Field(None) \ No newline at end of file diff --git a/src/middleware.py b/src/middleware.py index b226813..f428ef5 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -14,6 +14,37 @@ ALLOWED_MULTI_PARAMS = { "exclude[]", } +# Whitelist of ALL allowed query parameter names across the application. +# Any param NOT in this set will be rejected. +ALLOWED_QUERY_PARAMS = { + # CommonParameters (from database/service.py common_parameters) + "currentUser", + "page", + "itemsPerPage", + "q", + "filter", + "sortBy[]", + "descending[]", + "all", + # ListQueryParams / QueryParams used across routers + "items_per_page", + "search", + # equipment_master specific + "parent_id", + # masterdata_simulations / plant_transaction_data_simulations specific + "simulation_id", + # exclude + "exclude[]", +} + +# Query params that are ONLY allowed for "write" operations (read operations use ALLOWED_QUERY_PARAMS). +# For GET/POST/PUT/etc, whitelisting still applies. +WRITE_METHOD_ALLOWED_PARAMS = { + # Only auth/session params are allowed in query for write methods. + # Data values (like simulation_id) must be in the JSON body for these methods. + "currentUser", +} + MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB @@ -62,31 +93,31 @@ def has_control_chars(value: str) -> bool: def inspect_value(value: str, source: str): if XSS_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential XSS payload detected in {source}", ) if SQLI_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential SQL injection payload detected in {source}", ) if RCE_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential RCE payload detected in {source}", ) if TRAVERSAL_PATTERN.search(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Potential Path Traversal payload detected in {source}", ) if has_control_chars(value): raise HTTPException( - status_code=400, + status_code=422, detail=f"Invalid control characters detected in {source}", ) @@ -96,7 +127,7 @@ def inspect_json(obj, path="body"): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: raise HTTPException( - status_code=400, + status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}", ) inspect_json(value, f"{path}.{key}") @@ -126,12 +157,28 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if len(params) > MAX_QUERY_PARAMS: raise HTTPException( - status_code=400, + status_code=422, detail="Too many query parameters", ) # ------------------------- - # 2. Duplicate parameters + # 2. Query param whitelist + # ------------------------- + # For GET, we allow data parameters like page, search, etc. + # For POST, PUT, DELETE, PATCH, we ONLY allow auth/session params. + active_whitelist = ALLOWED_QUERY_PARAMS if request.method == "GET" else WRITE_METHOD_ALLOWED_PARAMS + + unknown_params = [ + key for key, _ in params if key not in active_whitelist + ] + if unknown_params: + raise HTTPException( + status_code=422, + detail=f"Unknown query parameters are not allowed for {request.method} request: {unknown_params}", + ) + + # ------------------------- + # 3. Duplicate parameters # ------------------------- counter = Counter(key for key, _ in params) duplicates = [ @@ -141,12 +188,40 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if duplicates: raise HTTPException( - status_code=400, + status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}", ) # ------------------------- - # 3. Query param inspection + # 4. Single source enforcement + # Ensuring data comes from ONLY one source (Query OR Body). + # ------------------------- + content_type = request.headers.get("content-type", "") + has_json_body = content_type.startswith("application/json") + + # Check for data parameters in query (anything whitelisted as 'data' but not 'session/auth') + data_params_in_query = [ + key for key, _ in params + if key in ALLOWED_QUERY_PARAMS and key not in WRITE_METHOD_ALLOWED_PARAMS + ] + + if has_json_body: + # If sending JSON body, we forbid any data in query string (one source only) + if data_params_in_query: + raise HTTPException( + status_code=422, + detail=f"Single source enforcement: Data received from both JSON body and query string ({data_params_in_query}). Use only one source.", + ) + + # Special case: GET with body is discouraged/forbidden in many strict security contexts + if request.method == "GET": + raise HTTPException( + status_code=422, + detail="GET requests must use query parameters, not JSON body.", + ) + + # ------------------------- + # 5. Query param inspection # ------------------------- pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} for key, value in params: @@ -159,24 +234,23 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): size_val = int(value) if size_val > 50: raise HTTPException( - status_code=400, + status_code=422, detail=f"Pagination size '{key}' cannot exceed 50", ) if size_val % 5 != 0: raise HTTPException( - status_code=400, + status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5", ) except ValueError: raise HTTPException( - status_code=400, + status_code=422, detail=f"Pagination size '{key}' must be an integer", ) # ------------------------- - # 4. Content-Type sanity + # 6. Content-Type sanity # ------------------------- - content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) for t in ( @@ -191,7 +265,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ) # ------------------------- - # 5. JSON body inspection + # 7. JSON body inspection # ------------------------- if content_type.startswith("application/json"): body = await request.body() @@ -207,7 +281,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): payload = json.loads(body) except json.JSONDecodeError: raise HTTPException( - status_code=400, + status_code=422, detail="Invalid JSON body", ) diff --git a/src/plant_fs_transaction_data/router.py b/src/plant_fs_transaction_data/router.py index 80a126b..93e7bd3 100644 --- a/src/plant_fs_transaction_data/router.py +++ b/src/plant_fs_transaction_data/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from fastapi import APIRouter, HTTPException, Query, status @@ -16,6 +16,7 @@ from .schema import ( PlantFSTransactionDataRead, PlantFSTransactionDataUpdate, PlantFSChartData, + ListQueryParams, ) from .service import create, delete, get, get_all, update, update_fs_charts_from_matrix, get_charts @@ -28,15 +29,14 @@ router = APIRouter() async def list_fs_transactions( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Return paginated financial statement transaction data.""" records = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) @@ -166,8 +166,3 @@ async def delete_fs_transaction( await delete(db_session=db_session, fs_transaction_id=str(fs_transaction_id)) return StandardResponse(data=record, message="Data deleted successfully") - - - - - diff --git a/src/plant_fs_transaction_data/schema.py b/src/plant_fs_transaction_data/schema.py index 8af5592..d2497e4 100644 --- a/src/plant_fs_transaction_data/schema.py +++ b/src/plant_fs_transaction_data/schema.py @@ -4,7 +4,7 @@ from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantFSTransactionDataBase(DefaultBase): @@ -100,3 +100,18 @@ class PlantFSChartData(DefaultBase): bep_year: Optional[int] = Field(None, ge=0, le=9999) bep_total_lcc: Optional[float] = Field(None, ge=0, le=1_000_000_000_000_000) + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + + diff --git a/src/plant_masterdata/router.py b/src/plant_masterdata/router.py index 6064ce2..f12f7c7 100644 --- a/src/plant_masterdata/router.py +++ b/src/plant_masterdata/router.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import PlantMasterData @@ -7,6 +7,7 @@ from .schema import ( PlantMasterDataRead, PlantMasterDataCreate, PlantMasterDataUpdate, + ListQueryParams, ) from .service import get, get_all, create, update, delete @@ -22,15 +23,14 @@ router = APIRouter() async def get_masterdatas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all documents.""" # return master_datas = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) return StandardResponse( diff --git a/src/plant_masterdata/schema.py b/src/plant_masterdata/schema.py index 016003d..c938a6b 100644 --- a/src/plant_masterdata/schema.py +++ b/src/plant_masterdata/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination from src.auth.service import CurrentUser @@ -85,3 +85,18 @@ class PlantMasterDataRead(PlantMasterdataBase): class PlantMasterDataPagination(Pagination): items: List[PlantMasterDataRead] = [] + + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/plant_transaction_data/router.py b/src/plant_transaction_data/router.py index bac841f..7baeee0 100644 --- a/src/plant_transaction_data/router.py +++ b/src/plant_transaction_data/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import PlantTransactionData @@ -10,6 +10,7 @@ from .schema import ( PlantTransactionDataCreate, PlantTransactionDataUpdate, PlantTransactionFSImport, + ListQueryParams, ) from .service import ( get, @@ -33,14 +34,13 @@ router = APIRouter() async def get_transaction_datas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all transaction_data pagination.""" plant_transaction_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return diff --git a/src/plant_transaction_data/schema.py b/src/plant_transaction_data/schema.py index a9db71f..9813a80 100644 --- a/src/plant_transaction_data/schema.py +++ b/src/plant_transaction_data/schema.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantTransactionDataBase(DefaultBase): @@ -117,3 +117,18 @@ class PlantTransactionDataRead(PlantTransactionDataBase): class PlantTransactionDataPagination(Pagination): items: List[PlantTransactionDataRead] = [] + + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/plant_transaction_data_simulations/router.py b/src/plant_transaction_data_simulations/router.py index afe05cf..b0b845c 100644 --- a/src/plant_transaction_data_simulations/router.py +++ b/src/plant_transaction_data_simulations/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from fastapi import APIRouter, HTTPException, status, Query @@ -11,6 +11,7 @@ from src.plant_transaction_data_simulations.schema import ( PlantTransactionDataSimulationsCreate, PlantTransactionDataSimulationsUpdate, PlantTransactionFSImportSimulations, + ListQueryParams, ) from src.plant_transaction_data_simulations.service import ( get, @@ -34,17 +35,15 @@ router = APIRouter() async def get_transaction_datas( db_session: DbSession, common: CommonParameters, - simulation_id: UUID = Query(..., description="Simulation identifier"), - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all transaction_data pagination.""" plant_transaction_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, - simulation_id=simulation_id, + simulation_id=params.simulation_id, ) # return return StandardResponse( diff --git a/src/plant_transaction_data_simulations/schema.py b/src/plant_transaction_data_simulations/schema.py index 0668854..939ad68 100644 --- a/src/plant_transaction_data_simulations/schema.py +++ b/src/plant_transaction_data_simulations/schema.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class PlantTransactionDataSimulationsBase(DefaultBase): @@ -140,3 +140,22 @@ class PlantTransactionDataSimulationsRead(PlantTransactionDataSimulationsBase): class PlantTransactionDataSimulationsPagination(Pagination): items: List[PlantTransactionDataSimulationsRead] = [] + + +class ListQueryParams(CommonParams): + simulation_id: UUID = Field( + ..., + description="Simulation identifier", + ) + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/simulations/router.py b/src/simulations/router.py index 26c3faf..0f4c7ad 100644 --- a/src/simulations/router.py +++ b/src/simulations/router.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, Query, status @@ -13,6 +13,7 @@ from src.simulations.schema import ( SimulationRead, SimulationRunPayload, SimulationUpdate, + ListQueryParams, ) from src.simulations.service import create, delete, get, get_all, run_simulation, update @@ -24,13 +25,12 @@ async def get_simulations( db_session: DbSession, common: CommonParameters, current_user: CurrentUser, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): simulations = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, owner=current_user.name, ) diff --git a/src/simulations/schema.py b/src/simulations/schema.py index 7b5a7cd..c4780bf 100644 --- a/src/simulations/schema.py +++ b/src/simulations/schema.py @@ -4,7 +4,7 @@ from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination from src.masterdata_simulations.schema import MasterDataSimulationRead from src.plant_transaction_data_simulations.schema import ( PlantTransactionDataSimulationsRead, @@ -51,3 +51,18 @@ class MasterDataOverride(DefaultBase): class SimulationRunPayload(DefaultBase): label: Optional[str] = Field(None) overrides: List[MasterDataOverride] = Field(default_factory=list) + + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) + diff --git a/src/uploaded_file/router.py b/src/uploaded_file/router.py index e3232dc..142af32 100644 --- a/src/uploaded_file/router.py +++ b/src/uploaded_file/router.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, Form, HTTPException, status, Query, UploadFile, File from .model import UploadedFileData -from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination +from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination, ListQueryParams from src.uploaded_file.service import get, get_all, create, update, delete from src.database.service import CommonParameters, search_filter_sort_paginate @@ -20,14 +20,13 @@ router = APIRouter() async def get_uploaded_files( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all uploaded files pagination.""" uploaded_files = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return diff --git a/src/uploaded_file/schema.py b/src/uploaded_file/schema.py index baf5a44..1750ea7 100644 --- a/src/uploaded_file/schema.py +++ b/src/uploaded_file/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class UploadedFileDataBase(DefaultBase): filename: str = Field(...) @@ -28,3 +28,17 @@ class UploadedFileDataRead(UploadedFileDataBase): class UploadedFileDataPagination(Pagination): items: List[UploadedFileDataRead] = [] + + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) diff --git a/src/yeardata/router.py b/src/yeardata/router.py index 8a536aa..4b6efa5 100644 --- a/src/yeardata/router.py +++ b/src/yeardata/router.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, HTTPException, status, Query from .model import Yeardata -from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate +from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate, ListQueryParams from .service import get, get_all, create, update, delete from src.database.service import CommonParameters, search_filter_sort_paginate @@ -17,14 +17,13 @@ router = APIRouter() async def get_yeardatas( db_session: DbSession, common: CommonParameters, - items_per_page: Optional[int] = Query(5), - search: Optional[str] = Query(None), + params: Annotated[ListQueryParams, Query()], ): """Get all yeardata pagination.""" year_data = await get_all( db_session=db_session, - items_per_page=items_per_page, - search=search, + items_per_page=params.items_per_page, + search=params.search, common=common, ) # return diff --git a/src/yeardata/schema.py b/src/yeardata/schema.py index 039a4e7..778888d 100644 --- a/src/yeardata/schema.py +++ b/src/yeardata/schema.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from pydantic import Field, field_validator -from src.models import DefaultBase, Pagination +from src.models import CommonParams, DefaultBase, Pagination class YeardataBase(DefaultBase): @@ -61,3 +61,18 @@ class YeardataRead(YeardataBase): class YeardataPagination(Pagination): items: List[YeardataRead] = [] + + +class ListQueryParams(CommonParams): + items_per_page: Optional[int] = Field( + default=5, + ge=1, + le=1000, + description="Number of items per page", + alias="itemsPerPage", + ) + search: Optional[str] = Field( + default=None, + description="Search keyword", + ) +