222 lines
8.0 KiB
Python
222 lines
8.0 KiB
Python
"""
|
|
Tests for database security (SQL injection prevention).
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, MagicMock, patch
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from src.data.db import DocumentDB
|
|
|
|
|
|
class TestSQLInjectionPrevention:
|
|
"""Test that SQL injection attacks are prevented."""
|
|
|
|
@pytest.fixture
|
|
def mock_db(self):
|
|
"""Create a mock database connection."""
|
|
db = DocumentDB()
|
|
db.conn = MagicMock()
|
|
return db
|
|
|
|
def test_check_document_status_uses_parameterized_query(self, mock_db):
|
|
"""Test that check_document_status uses parameterized query."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchone.return_value = (True,)
|
|
|
|
# Try SQL injection
|
|
malicious_id = "doc123' OR '1'='1"
|
|
mock_db.check_document_status(malicious_id)
|
|
|
|
# Verify parameterized query was used
|
|
cursor_mock.execute.assert_called_once()
|
|
call_args = cursor_mock.execute.call_args
|
|
query = call_args[0][0]
|
|
params = call_args[0][1]
|
|
|
|
# Should use %s placeholder and pass value as parameter
|
|
assert "%s" in query
|
|
assert malicious_id in params
|
|
assert "OR" not in query # Injection attempt should not be in query string
|
|
|
|
def test_delete_document_uses_parameterized_query(self, mock_db):
|
|
"""Test that delete_document uses parameterized query."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
|
|
# Try SQL injection
|
|
malicious_id = "doc123'; DROP TABLE documents; --"
|
|
mock_db.delete_document(malicious_id)
|
|
|
|
# Verify parameterized query was used
|
|
cursor_mock.execute.assert_called_once()
|
|
call_args = cursor_mock.execute.call_args
|
|
query = call_args[0][0]
|
|
params = call_args[0][1]
|
|
|
|
# Should use %s placeholder
|
|
assert "%s" in query
|
|
assert "DROP TABLE" not in query # Injection attempt should not be in query
|
|
|
|
def test_get_document_uses_parameterized_query(self, mock_db):
|
|
"""Test that get_document uses parameterized query."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchone.return_value = None # No document found
|
|
|
|
# Try SQL injection
|
|
malicious_id = "doc123' UNION SELECT * FROM users --"
|
|
mock_db.get_document(malicious_id)
|
|
|
|
# Verify both queries use parameterized approach
|
|
assert cursor_mock.execute.call_count >= 1
|
|
for call in cursor_mock.execute.call_args_list:
|
|
query = call[0][0]
|
|
# Should use %s placeholder
|
|
assert "%s" in query
|
|
assert "UNION" not in query # Injection should not be in query
|
|
|
|
def test_get_all_documents_summary_limit_is_safe(self, mock_db):
|
|
"""Test that get_all_documents_summary uses parameterized LIMIT."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchall.return_value = []
|
|
|
|
# Try SQL injection via limit parameter
|
|
malicious_limit = "10; DROP TABLE documents; --"
|
|
|
|
# This should raise error or be safely handled
|
|
# Since limit is expected to be int, passing string should either:
|
|
# 1. Fail type validation
|
|
# 2. Be safely parameterized
|
|
try:
|
|
mock_db.get_all_documents_summary(limit=malicious_limit)
|
|
except Exception:
|
|
# Expected - type validation should catch this
|
|
pass
|
|
|
|
# Test with valid integer limit
|
|
mock_db.get_all_documents_summary(limit=10)
|
|
|
|
# Verify parameterized query was used
|
|
call_args = cursor_mock.execute.call_args
|
|
query = call_args[0][0]
|
|
|
|
# Should use %s placeholder for LIMIT
|
|
assert "LIMIT %s" in query or "LIMIT" not in query
|
|
|
|
def test_get_failed_matches_uses_parameterized_limit(self, mock_db):
|
|
"""Test that get_failed_matches uses parameterized LIMIT."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchall.return_value = []
|
|
|
|
# Call with normal parameters
|
|
mock_db.get_failed_matches(field_name="amount", limit=50)
|
|
|
|
# Verify parameterized query
|
|
call_args = cursor_mock.execute.call_args
|
|
query = call_args[0][0]
|
|
params = call_args[0][1]
|
|
|
|
# Should use %s placeholder for both field_name and limit
|
|
assert query.count("%s") == 2 # Two parameters
|
|
assert "amount" in params
|
|
assert 50 in params
|
|
|
|
def test_check_documents_status_batch_uses_any_array(self, mock_db):
|
|
"""Test that batch status check uses ANY(%s) safely."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchall.return_value = []
|
|
|
|
# Try with potentially malicious IDs
|
|
malicious_ids = [
|
|
"doc1",
|
|
"doc2' OR '1'='1",
|
|
"doc3'; DROP TABLE documents; --"
|
|
]
|
|
mock_db.check_documents_status_batch(malicious_ids)
|
|
|
|
# Verify ANY(%s) pattern is used
|
|
call_args = cursor_mock.execute.call_args
|
|
query = call_args[0][0]
|
|
params = call_args[0][1]
|
|
|
|
assert "ANY(%s)" in query
|
|
assert isinstance(params[0], list)
|
|
# Malicious strings should be passed as parameters, not in query
|
|
assert "DROP TABLE" not in query
|
|
|
|
def test_get_documents_batch_uses_any_array(self, mock_db):
|
|
"""Test that get_documents_batch uses ANY(%s) safely."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchall.return_value = []
|
|
|
|
# Try with potentially malicious IDs
|
|
malicious_ids = ["doc1", "doc2' UNION SELECT * FROM users --"]
|
|
mock_db.get_documents_batch(malicious_ids)
|
|
|
|
# Verify both queries use ANY(%s) pattern
|
|
for call in cursor_mock.execute.call_args_list:
|
|
query = call[0][0]
|
|
assert "ANY(%s)" in query
|
|
assert "UNION" not in query
|
|
|
|
|
|
class TestInputValidation:
|
|
"""Test input validation and type safety."""
|
|
|
|
@pytest.fixture
|
|
def mock_db(self):
|
|
"""Create a mock database connection."""
|
|
db = DocumentDB()
|
|
db.conn = MagicMock()
|
|
return db
|
|
|
|
def test_limit_parameter_type_validation(self, mock_db):
|
|
"""Test that limit parameter expects integer."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
cursor_mock.fetchall.return_value = []
|
|
|
|
# Valid integer should work
|
|
mock_db.get_all_documents_summary(limit=10)
|
|
assert cursor_mock.execute.called
|
|
|
|
# String should either raise error or be safely handled
|
|
# (Type hints suggest int, runtime may vary)
|
|
cursor_mock.reset_mock()
|
|
try:
|
|
result = mock_db.get_all_documents_summary(limit="malicious")
|
|
# If it doesn't raise, verify it was parameterized
|
|
call_args = cursor_mock.execute.call_args
|
|
if call_args:
|
|
query = call_args[0][0]
|
|
assert "%s" in query or "LIMIT" not in query
|
|
except (TypeError, ValueError):
|
|
# Expected - type validation
|
|
pass
|
|
|
|
def test_doc_id_list_validation(self, mock_db):
|
|
"""Test that document ID lists are properly validated."""
|
|
cursor_mock = MagicMock()
|
|
mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock
|
|
|
|
# Empty list should be handled gracefully
|
|
result = mock_db.get_documents_batch([])
|
|
assert result == {}
|
|
assert not cursor_mock.execute.called
|
|
|
|
# Valid list should work
|
|
cursor_mock.fetchall.return_value = []
|
|
mock_db.get_documents_batch(["doc1", "doc2"])
|
|
assert cursor_mock.execute.called
|