""" Tests for database security (SQL injection prevention). """ import pytest from unittest.mock import Mock, MagicMock, patch import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from shared.data.db import DocumentDB class TestSQLInjectionPrevention: """Test that SQL injection attacks are prevented.""" @pytest.fixture def mock_db(self): """Create a mock database connection.""" db = DocumentDB() db.conn = MagicMock() return db def test_check_document_status_uses_parameterized_query(self, mock_db): """Test that check_document_status uses parameterized query.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchone.return_value = (True,) # Try SQL injection malicious_id = "doc123' OR '1'='1" mock_db.check_document_status(malicious_id) # Verify parameterized query was used cursor_mock.execute.assert_called_once() call_args = cursor_mock.execute.call_args query = call_args[0][0] params = call_args[0][1] # Should use %s placeholder and pass value as parameter assert "%s" in query assert malicious_id in params assert "OR" not in query # Injection attempt should not be in query string def test_delete_document_uses_parameterized_query(self, mock_db): """Test that delete_document uses parameterized query.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock # Try SQL injection malicious_id = "doc123'; DROP TABLE documents; --" mock_db.delete_document(malicious_id) # Verify parameterized query was used cursor_mock.execute.assert_called_once() call_args = cursor_mock.execute.call_args query = call_args[0][0] params = call_args[0][1] # Should use %s placeholder assert "%s" in query assert "DROP TABLE" not in query # Injection attempt should not be in query def test_get_document_uses_parameterized_query(self, mock_db): """Test that get_document uses parameterized query.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchone.return_value = None # No document found # Try SQL injection malicious_id = "doc123' UNION SELECT * FROM users --" mock_db.get_document(malicious_id) # Verify both queries use parameterized approach assert cursor_mock.execute.call_count >= 1 for call in cursor_mock.execute.call_args_list: query = call[0][0] # Should use %s placeholder assert "%s" in query assert "UNION" not in query # Injection should not be in query def test_get_all_documents_summary_limit_is_safe(self, mock_db): """Test that get_all_documents_summary uses parameterized LIMIT.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchall.return_value = [] # Try SQL injection via limit parameter malicious_limit = "10; DROP TABLE documents; --" # This should raise error or be safely handled # Since limit is expected to be int, passing string should either: # 1. Fail type validation # 2. Be safely parameterized try: mock_db.get_all_documents_summary(limit=malicious_limit) except Exception: # Expected - type validation should catch this pass # Test with valid integer limit mock_db.get_all_documents_summary(limit=10) # Verify parameterized query was used call_args = cursor_mock.execute.call_args query = call_args[0][0] # Should use %s placeholder for LIMIT assert "LIMIT %s" in query or "LIMIT" not in query def test_get_failed_matches_uses_parameterized_limit(self, mock_db): """Test that get_failed_matches uses parameterized LIMIT.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchall.return_value = [] # Call with normal parameters mock_db.get_failed_matches(field_name="amount", limit=50) # Verify parameterized query call_args = cursor_mock.execute.call_args query = call_args[0][0] params = call_args[0][1] # Should use %s placeholder for both field_name and limit assert query.count("%s") == 2 # Two parameters assert "amount" in params assert 50 in params def test_check_documents_status_batch_uses_any_array(self, mock_db): """Test that batch status check uses ANY(%s) safely.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchall.return_value = [] # Try with potentially malicious IDs malicious_ids = [ "doc1", "doc2' OR '1'='1", "doc3'; DROP TABLE documents; --" ] mock_db.check_documents_status_batch(malicious_ids) # Verify ANY(%s) pattern is used call_args = cursor_mock.execute.call_args query = call_args[0][0] params = call_args[0][1] assert "ANY(%s)" in query assert isinstance(params[0], list) # Malicious strings should be passed as parameters, not in query assert "DROP TABLE" not in query def test_get_documents_batch_uses_any_array(self, mock_db): """Test that get_documents_batch uses ANY(%s) safely.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchall.return_value = [] # Try with potentially malicious IDs malicious_ids = ["doc1", "doc2' UNION SELECT * FROM users --"] mock_db.get_documents_batch(malicious_ids) # Verify both queries use ANY(%s) pattern for call in cursor_mock.execute.call_args_list: query = call[0][0] assert "ANY(%s)" in query assert "UNION" not in query class TestInputValidation: """Test input validation and type safety.""" @pytest.fixture def mock_db(self): """Create a mock database connection.""" db = DocumentDB() db.conn = MagicMock() return db def test_limit_parameter_type_validation(self, mock_db): """Test that limit parameter expects integer.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock cursor_mock.fetchall.return_value = [] # Valid integer should work mock_db.get_all_documents_summary(limit=10) assert cursor_mock.execute.called # String should either raise error or be safely handled # (Type hints suggest int, runtime may vary) cursor_mock.reset_mock() try: result = mock_db.get_all_documents_summary(limit="malicious") # If it doesn't raise, verify it was parameterized call_args = cursor_mock.execute.call_args if call_args: query = call_args[0][0] assert "%s" in query or "LIMIT" not in query except (TypeError, ValueError): # Expected - type validation pass def test_doc_id_list_validation(self, mock_db): """Test that document ID lists are properly validated.""" cursor_mock = MagicMock() mock_db.conn.cursor.return_value.__enter__.return_value = cursor_mock # Empty list should be handled gracefully result = mock_db.get_documents_batch([]) assert result == {} assert not cursor_mock.execute.called # Valid list should work cursor_mock.fetchall.return_value = [] mock_db.get_documents_batch(["doc1", "doc2"]) assert cursor_mock.execute.called