security revision

main^2
MrWaradana 1 week ago
parent 589d5f099f
commit 8fc47edc1b

@ -34,5 +34,14 @@ class ManpowerCostPagination(Pagination):
class QueryParams(CommonParams):
items_per_page: Optional[int] = Field(5)
search: Optional[str] = Field(None)
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -46,8 +46,13 @@ class QueryParams(CommonParams):
description="Simulation identifier",
)
items_per_page: Optional[int] = Field(
5,
default=5,
ge=1,
description="Items per page"
le=1000,
description="Items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)
search: Optional[str] = Field(None)

@ -14,6 +14,37 @@ ALLOWED_MULTI_PARAMS = {
"exclude[]",
}
# Whitelist of ALL allowed query parameter names across the application.
# Any param NOT in this set will be rejected.
ALLOWED_QUERY_PARAMS = {
# CommonParameters (from database/service.py common_parameters)
"currentUser",
"page",
"itemsPerPage",
"q",
"filter",
"sortBy[]",
"descending[]",
"all",
# ListQueryParams / QueryParams used across routers
"items_per_page",
"search",
# equipment_master specific
"parent_id",
# masterdata_simulations / plant_transaction_data_simulations specific
"simulation_id",
# exclude
"exclude[]",
}
# Query params that are ONLY allowed for "write" operations (read operations use ALLOWED_QUERY_PARAMS).
# For GET/POST/PUT/etc, whitelisting still applies.
WRITE_METHOD_ALLOWED_PARAMS = {
# Only auth/session params are allowed in query for write methods.
# Data values (like simulation_id) must be in the JSON body for these methods.
"currentUser",
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
@ -62,31 +93,31 @@ def has_control_chars(value: str) -> bool:
def inspect_value(value: str, source: str):
if XSS_PATTERN.search(value):
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Potential XSS payload detected in {source}",
)
if SQLI_PATTERN.search(value):
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Potential SQL injection payload detected in {source}",
)
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Potential Path Traversal payload detected in {source}",
)
if has_control_chars(value):
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Invalid control characters detected in {source}",
)
@ -96,7 +127,7 @@ def inspect_json(obj, path="body"):
for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Forbidden JSON key detected: {path}.{key}",
)
inspect_json(value, f"{path}.{key}")
@ -126,12 +157,28 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(
status_code=400,
status_code=422,
detail="Too many query parameters",
)
# -------------------------
# 2. Duplicate parameters
# 2. Query param whitelist
# -------------------------
# For GET, we allow data parameters like page, search, etc.
# For POST, PUT, DELETE, PATCH, we ONLY allow auth/session params.
active_whitelist = ALLOWED_QUERY_PARAMS if request.method == "GET" else WRITE_METHOD_ALLOWED_PARAMS
unknown_params = [
key for key, _ in params if key not in active_whitelist
]
if unknown_params:
raise HTTPException(
status_code=422,
detail=f"Unknown query parameters are not allowed for {request.method} request: {unknown_params}",
)
# -------------------------
# 3. Duplicate parameters
# -------------------------
counter = Counter(key for key, _ in params)
duplicates = [
@ -141,12 +188,40 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
if duplicates:
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
# -------------------------
# 3. Query param inspection
# 4. Single source enforcement
# Ensuring data comes from ONLY one source (Query OR Body).
# -------------------------
content_type = request.headers.get("content-type", "")
has_json_body = content_type.startswith("application/json")
# Check for data parameters in query (anything whitelisted as 'data' but not 'session/auth')
data_params_in_query = [
key for key, _ in params
if key in ALLOWED_QUERY_PARAMS and key not in WRITE_METHOD_ALLOWED_PARAMS
]
if has_json_body:
# If sending JSON body, we forbid any data in query string (one source only)
if data_params_in_query:
raise HTTPException(
status_code=422,
detail=f"Single source enforcement: Data received from both JSON body and query string ({data_params_in_query}). Use only one source.",
)
# Special case: GET with body is discouraged/forbidden in many strict security contexts
if request.method == "GET":
raise HTTPException(
status_code=422,
detail="GET requests must use query parameters, not JSON body.",
)
# -------------------------
# 5. Query param inspection
# -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
for key, value in params:
@ -159,24 +234,23 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError:
raise HTTPException(
status_code=400,
status_code=422,
detail=f"Pagination size '{key}' must be an integer",
)
# -------------------------
# 4. Content-Type sanity
# 6. Content-Type sanity
# -------------------------
content_type = request.headers.get("content-type", "")
if content_type and not any(
content_type.startswith(t)
for t in (
@ -191,7 +265,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
)
# -------------------------
# 5. JSON body inspection
# 7. JSON body inspection
# -------------------------
if content_type.startswith("application/json"):
body = await request.body()
@ -207,7 +281,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
status_code=422,
detail="Invalid JSON body",
)

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Annotated, List, Optional
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query, status
@ -16,6 +16,7 @@ from .schema import (
PlantFSTransactionDataRead,
PlantFSTransactionDataUpdate,
PlantFSChartData,
ListQueryParams,
)
from .service import create, delete, get, get_all, update, update_fs_charts_from_matrix, get_charts
@ -28,15 +29,14 @@ router = APIRouter()
async def list_fs_transactions(
db_session: DbSession,
common: CommonParameters,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Return paginated financial statement transaction data."""
records = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
)
@ -166,8 +166,3 @@ async def delete_fs_transaction(
await delete(db_session=db_session, fs_transaction_id=str(fs_transaction_id))
return StandardResponse(data=record, message="Data deleted successfully")

@ -4,7 +4,7 @@ from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
class PlantFSTransactionDataBase(DefaultBase):
@ -100,3 +100,18 @@ class PlantFSChartData(DefaultBase):
bep_year: Optional[int] = Field(None, ge=0, le=9999)
bep_total_lcc: Optional[float] = Field(None, ge=0, le=1_000_000_000_000_000)
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,4 +1,4 @@
from typing import Optional
from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, status, Query
from .model import PlantMasterData
@ -7,6 +7,7 @@ from .schema import (
PlantMasterDataRead,
PlantMasterDataCreate,
PlantMasterDataUpdate,
ListQueryParams,
)
from .service import get, get_all, create, update, delete
@ -22,15 +23,14 @@ router = APIRouter()
async def get_masterdatas(
db_session: DbSession,
common: CommonParameters,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Get all documents."""
# return
master_datas = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
)
return StandardResponse(

@ -3,7 +3,7 @@ from typing import List, Optional
from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
from src.auth.service import CurrentUser
@ -85,3 +85,18 @@ class PlantMasterDataRead(PlantMasterdataBase):
class PlantMasterDataPagination(Pagination):
items: List[PlantMasterDataRead] = []
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Annotated, List, Optional
from fastapi import APIRouter, HTTPException, status, Query
from .model import PlantTransactionData
@ -10,6 +10,7 @@ from .schema import (
PlantTransactionDataCreate,
PlantTransactionDataUpdate,
PlantTransactionFSImport,
ListQueryParams,
)
from .service import (
get,
@ -33,14 +34,13 @@ router = APIRouter()
async def get_transaction_datas(
db_session: DbSession,
common: CommonParameters,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Get all transaction_data pagination."""
plant_transaction_data = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
)
# return

@ -3,7 +3,7 @@ from typing import Any, List, Optional
from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
class PlantTransactionDataBase(DefaultBase):
@ -117,3 +117,18 @@ class PlantTransactionDataRead(PlantTransactionDataBase):
class PlantTransactionDataPagination(Pagination):
items: List[PlantTransactionDataRead] = []
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Annotated, List, Optional
from uuid import UUID
from fastapi import APIRouter, HTTPException, status, Query
@ -11,6 +11,7 @@ from src.plant_transaction_data_simulations.schema import (
PlantTransactionDataSimulationsCreate,
PlantTransactionDataSimulationsUpdate,
PlantTransactionFSImportSimulations,
ListQueryParams,
)
from src.plant_transaction_data_simulations.service import (
get,
@ -34,17 +35,15 @@ router = APIRouter()
async def get_transaction_datas(
db_session: DbSession,
common: CommonParameters,
simulation_id: UUID = Query(..., description="Simulation identifier"),
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Get all transaction_data pagination."""
plant_transaction_data = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
simulation_id=simulation_id,
simulation_id=params.simulation_id,
)
# return
return StandardResponse(

@ -3,7 +3,7 @@ from typing import Any, List, Optional
from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
class PlantTransactionDataSimulationsBase(DefaultBase):
@ -140,3 +140,22 @@ class PlantTransactionDataSimulationsRead(PlantTransactionDataSimulationsBase):
class PlantTransactionDataSimulationsPagination(Pagination):
items: List[PlantTransactionDataSimulationsRead] = []
class ListQueryParams(CommonParams):
simulation_id: UUID = Field(
...,
description="Simulation identifier",
)
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,4 +1,4 @@
from typing import Optional
from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, Query, status
@ -13,6 +13,7 @@ from src.simulations.schema import (
SimulationRead,
SimulationRunPayload,
SimulationUpdate,
ListQueryParams,
)
from src.simulations.service import create, delete, get, get_all, run_simulation, update
@ -24,13 +25,12 @@ async def get_simulations(
db_session: DbSession,
common: CommonParameters,
current_user: CurrentUser,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
simulations = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
owner=current_user.name,
)

@ -4,7 +4,7 @@ from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
from src.masterdata_simulations.schema import MasterDataSimulationRead
from src.plant_transaction_data_simulations.schema import (
PlantTransactionDataSimulationsRead,
@ -51,3 +51,18 @@ class MasterDataOverride(DefaultBase):
class SimulationRunPayload(DefaultBase):
label: Optional[str] = Field(None)
overrides: List[MasterDataOverride] = Field(default_factory=list)
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,8 +1,8 @@
from typing import Optional
from typing import Annotated, Optional
from fastapi import APIRouter, Form, HTTPException, status, Query, UploadFile, File
from .model import UploadedFileData
from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination
from src.uploaded_file.schema import UploadedFileDataCreate, UploadedFileDataUpdate, UploadedFileDataRead, UploadedFileDataPagination, ListQueryParams
from src.uploaded_file.service import get, get_all, create, update, delete
from src.database.service import CommonParameters, search_filter_sort_paginate
@ -20,14 +20,13 @@ router = APIRouter()
async def get_uploaded_files(
db_session: DbSession,
common: CommonParameters,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Get all uploaded files pagination."""
uploaded_files = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
)
# return

@ -3,7 +3,7 @@ from typing import List, Optional
from uuid import UUID
from pydantic import Field
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
class UploadedFileDataBase(DefaultBase):
filename: str = Field(...)
@ -28,3 +28,17 @@ class UploadedFileDataRead(UploadedFileDataBase):
class UploadedFileDataPagination(Pagination):
items: List[UploadedFileDataRead] = []
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

@ -1,8 +1,8 @@
from typing import Optional
from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, status, Query
from .model import Yeardata
from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate
from .schema import YeardataPagination, YeardataRead, YeardataCreate, YeardataUpdate, ListQueryParams
from .service import get, get_all, create, update, delete
from src.database.service import CommonParameters, search_filter_sort_paginate
@ -17,14 +17,13 @@ router = APIRouter()
async def get_yeardatas(
db_session: DbSession,
common: CommonParameters,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
params: Annotated[ListQueryParams, Query()],
):
"""Get all yeardata pagination."""
year_data = await get_all(
db_session=db_session,
items_per_page=items_per_page,
search=search,
items_per_page=params.items_per_page,
search=params.search,
common=common,
)
# return

@ -3,7 +3,7 @@ from typing import List, Optional
from uuid import UUID
from pydantic import Field, field_validator
from src.models import DefaultBase, Pagination
from src.models import CommonParams, DefaultBase, Pagination
class YeardataBase(DefaultBase):
@ -61,3 +61,18 @@ class YeardataRead(YeardataBase):
class YeardataPagination(Pagination):
items: List[YeardataRead] = []
class ListQueryParams(CommonParams):
items_per_page: Optional[int] = Field(
default=5,
ge=1,
le=1000,
description="Number of items per page",
alias="itemsPerPage",
)
search: Optional[str] = Field(
default=None,
description="Search keyword",
)

Loading…
Cancel
Save