You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
import pytest
|
|
from fastapi import HTTPException
|
|
from src.middleware import (
|
|
inspect_value,
|
|
inspect_json,
|
|
has_control_chars,
|
|
XSS_PATTERN,
|
|
SQLI_PATTERN,
|
|
RCE_PATTERN,
|
|
TRAVERSAL_PATTERN
|
|
)
|
|
|
|
def test_xss_patterns():
|
|
# Test common XSS payloads
|
|
payloads = [
|
|
"<script>alert(1)</script>",
|
|
"<img src=x onerror=alert(1)>",
|
|
"javascript:alert(1)",
|
|
"<iframe src='javascript:alert(1)'>",
|
|
"onclick=alert(1)",
|
|
]
|
|
for payload in payloads:
|
|
assert XSS_PATTERN.search(payload) is not None
|
|
|
|
def test_sqli_patterns():
|
|
# Test common SQLi payloads
|
|
payloads = [
|
|
"UNION SELECT",
|
|
"OR '1'='1'",
|
|
"DROP TABLE users",
|
|
"';--",
|
|
"WAITFOR DELAY '0:0:5'",
|
|
"INFORMATION_SCHEMA.TABLES",
|
|
]
|
|
for payload in payloads:
|
|
assert SQLI_PATTERN.search(payload) is not None
|
|
|
|
def test_rce_patterns():
|
|
# Test common RCE payloads
|
|
payloads = [
|
|
"$(whoami)",
|
|
"`id`",
|
|
"; cat /etc/passwd",
|
|
"| ls -la",
|
|
"/etc/shadow",
|
|
"C:\\Windows\\System32",
|
|
]
|
|
for payload in payloads:
|
|
assert RCE_PATTERN.search(payload) is not None
|
|
|
|
def test_traversal_patterns():
|
|
# Test path traversal payloads
|
|
payloads = [
|
|
"../../etc/passwd",
|
|
"..\\windows",
|
|
"%2e%2e%2f",
|
|
]
|
|
for payload in payloads:
|
|
assert TRAVERSAL_PATTERN.search(payload) is not None
|
|
|
|
def test_inspect_value_raises():
|
|
# Test that inspect_value raises HTTPException for malicious input
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
inspect_value("<script>", "source")
|
|
assert excinfo.value.status_code == 400
|
|
assert "Potential XSS payload" in excinfo.value.detail
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
inspect_value("UNION SELECT", "source")
|
|
assert excinfo.value.status_code == 400
|
|
assert "Potential SQL injection" in excinfo.value.detail
|
|
|
|
def test_inspect_json_raises():
|
|
# Test forbidden keys and malicious values in JSON
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
inspect_json({"__proto__": "polluted"})
|
|
assert excinfo.value.status_code == 400
|
|
assert "Forbidden JSON key" in excinfo.value.detail
|
|
|
|
with pytest.raises(HTTPException) as excinfo:
|
|
inspect_json({"data": {"nested": "<script>"}})
|
|
assert excinfo.value.status_code == 400
|
|
assert "Potential XSS payload" in excinfo.value.detail
|
|
|
|
def test_has_control_chars():
|
|
assert has_control_chars("normal string") is False
|
|
assert has_control_chars("string with \x00 null") is True
|
|
# Newlines, tabs, and carriage returns are specifically allowed in has_control_chars
|
|
assert has_control_chars("string with \n newline") is False
|