WIP
This commit is contained in:
1
tests/shared/fields/__init__.py
Normal file
1
tests/shared/fields/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for shared.fields module."""
|
||||
200
tests/shared/fields/test_field_config.py
Normal file
200
tests/shared/fields/test_field_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for field configuration - Single Source of Truth.
|
||||
|
||||
These tests ensure consistency across all field definitions and prevent
|
||||
accidental changes that could break model inference.
|
||||
|
||||
CRITICAL: These tests verify that field definitions match the trained YOLO model.
|
||||
If these tests fail, it likely means someone modified field IDs incorrectly.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.fields import (
|
||||
FIELD_DEFINITIONS,
|
||||
CLASS_NAMES,
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
NUM_CLASSES,
|
||||
FieldDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestFieldDefinitionsIntegrity:
|
||||
"""Tests to ensure field definitions are complete and consistent."""
|
||||
|
||||
def test_exactly_10_field_definitions(self):
|
||||
"""Verify we have exactly 10 field classes (matching trained model)."""
|
||||
assert len(FIELD_DEFINITIONS) == 10
|
||||
assert NUM_CLASSES == 10
|
||||
|
||||
def test_class_ids_are_sequential(self):
|
||||
"""Verify class IDs are 0-9 without gaps."""
|
||||
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
|
||||
assert class_ids == set(range(10))
|
||||
|
||||
def test_class_ids_are_unique(self):
|
||||
"""Verify no duplicate class IDs."""
|
||||
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_ids) == len(set(class_ids))
|
||||
|
||||
def test_class_names_are_unique(self):
|
||||
"""Verify no duplicate class names."""
|
||||
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_names) == len(set(class_names))
|
||||
|
||||
def test_field_definition_is_immutable(self):
|
||||
"""Verify FieldDefinition is frozen (immutable)."""
|
||||
fd = FIELD_DEFINITIONS[0]
|
||||
with pytest.raises(AttributeError):
|
||||
fd.class_id = 99 # type: ignore
|
||||
|
||||
|
||||
class TestModelCompatibility:
|
||||
"""Tests to verify field definitions match the trained YOLO model.
|
||||
|
||||
These exact values are read from runs/train/invoice_fields/weights/best.pt
|
||||
and MUST NOT be changed without retraining the model.
|
||||
"""
|
||||
|
||||
# Expected model.names from best.pt - DO NOT CHANGE
|
||||
EXPECTED_MODEL_NAMES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_org_number",
|
||||
8: "customer_number",
|
||||
9: "payment_line",
|
||||
}
|
||||
|
||||
def test_field_classes_match_model(self):
|
||||
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
|
||||
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
|
||||
|
||||
def test_class_names_order_matches_model(self):
|
||||
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
|
||||
expected_order = [
|
||||
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
|
||||
]
|
||||
assert CLASS_NAMES == expected_order
|
||||
|
||||
def test_customer_number_is_class_8(self):
|
||||
"""CRITICAL: customer_number must be class 8 (not 9)."""
|
||||
assert FIELD_CLASS_IDS["customer_number"] == 8
|
||||
assert FIELD_CLASSES[8] == "customer_number"
|
||||
|
||||
def test_payment_line_is_class_9(self):
|
||||
"""CRITICAL: payment_line must be class 9 (not 8)."""
|
||||
assert FIELD_CLASS_IDS["payment_line"] == 9
|
||||
assert FIELD_CLASSES[9] == "payment_line"
|
||||
|
||||
|
||||
class TestMappingConsistency:
|
||||
"""Tests to verify all mappings are consistent with each other."""
|
||||
|
||||
def test_field_classes_and_field_class_ids_are_inverses(self):
|
||||
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
|
||||
for class_id, class_name in FIELD_CLASSES.items():
|
||||
assert FIELD_CLASS_IDS[class_name] == class_id
|
||||
|
||||
for class_name, class_id in FIELD_CLASS_IDS.items():
|
||||
assert FIELD_CLASSES[class_id] == class_name
|
||||
|
||||
def test_class_names_matches_field_classes_values(self):
|
||||
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
|
||||
for i, class_name in enumerate(CLASS_NAMES):
|
||||
assert FIELD_CLASSES[i] == class_name
|
||||
|
||||
def test_class_to_field_has_all_classes(self):
|
||||
"""Verify CLASS_TO_FIELD has mapping for all class names."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in CLASS_TO_FIELD
|
||||
|
||||
def test_csv_mapping_excludes_derived_fields(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
|
||||
# payment_line is derived, should not be in CSV mapping
|
||||
assert "payment_line" not in CSV_TO_CLASS_MAPPING
|
||||
|
||||
# All non-derived fields should be in CSV mapping
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if not fd.is_derived:
|
||||
assert fd.field_name in CSV_TO_CLASS_MAPPING
|
||||
|
||||
def test_training_field_classes_includes_all(self):
|
||||
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||
|
||||
|
||||
class TestSpecificFieldDefinitions:
|
||||
"""Tests for specific field definitions to catch common mistakes."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_id,expected_class_name",
|
||||
[
|
||||
(0, "invoice_number"),
|
||||
(1, "invoice_date"),
|
||||
(2, "invoice_due_date"),
|
||||
(3, "ocr_number"),
|
||||
(4, "bankgiro"),
|
||||
(5, "plusgiro"),
|
||||
(6, "amount"),
|
||||
(7, "supplier_org_number"),
|
||||
(8, "customer_number"),
|
||||
(9, "payment_line"),
|
||||
],
|
||||
)
|
||||
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
|
||||
"""Verify each class ID maps to the correct class name."""
|
||||
assert FIELD_CLASSES[class_id] == expected_class_name
|
||||
|
||||
def test_payment_line_is_derived(self):
|
||||
"""Verify payment_line is marked as derived."""
|
||||
payment_line_def = next(
|
||||
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
|
||||
)
|
||||
assert payment_line_def.is_derived is True
|
||||
|
||||
def test_other_fields_are_not_derived(self):
|
||||
"""Verify all fields except payment_line are not derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if fd.class_name != "payment_line":
|
||||
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests to ensure backward compatibility with existing code."""
|
||||
|
||||
def test_csv_to_class_mapping_field_names(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
|
||||
# These are the field names used in CSV files
|
||||
expected_fields = {
|
||||
"InvoiceNumber": 0,
|
||||
"InvoiceDate": 1,
|
||||
"InvoiceDueDate": 2,
|
||||
"OCR": 3,
|
||||
"Bankgiro": 4,
|
||||
"Plusgiro": 5,
|
||||
"Amount": 6,
|
||||
"supplier_organisation_number": 7,
|
||||
"customer_number": 8,
|
||||
# payment_line (9) is derived, not in CSV
|
||||
}
|
||||
assert CSV_TO_CLASS_MAPPING == expected_fields
|
||||
|
||||
def test_class_to_field_returns_field_names(self):
|
||||
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
|
||||
# Sample checks for key fields
|
||||
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
|
||||
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
|
||||
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
|
||||
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
|
||||
assert CLASS_TO_FIELD["payment_line"] == "payment_line"
|
||||
1
tests/shared/storage/__init__.py
Normal file
1
tests/shared/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for storage module
|
||||
718
tests/shared/storage/test_azure.py
Normal file
718
tests/shared/storage/test_azure.py
Normal file
@@ -0,0 +1,718 @@
|
||||
"""
|
||||
Tests for AzureBlobStorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
Uses mocking to avoid requiring actual Azure credentials.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blob_service_client() -> MagicMock:
|
||||
"""Create a mock BlobServiceClient."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
|
||||
"""Create a mock ContainerClient."""
|
||||
container_client = MagicMock()
|
||||
mock_blob_service_client.get_container_client.return_value = container_client
|
||||
return container_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
|
||||
"""Create a mock BlobClient."""
|
||||
blob_client = MagicMock()
|
||||
mock_container_client.get_blob_client.return_value = blob_client
|
||||
return blob_client
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendCreation:
|
||||
"""Tests for AzureBlobStorageBackend instantiation."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_with_connection_string(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test creating backend with connection string."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string=connection_string,
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
mock_service_class.from_connection_string.assert_called_once_with(
|
||||
connection_string
|
||||
)
|
||||
assert backend.container_name == "training-images"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_creates_container_if_not_exists(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that container is created if it doesn't exist."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="new-container",
|
||||
create_container=True,
|
||||
)
|
||||
|
||||
mock_container.create_container.assert_called_once()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_does_not_create_container_by_default(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that container is not created by default."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="existing-container",
|
||||
)
|
||||
|
||||
mock_container.create_container.assert_not_called()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_is_storage_backend_subclass(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that AzureBlobStorageBackend is a StorageBackend."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendUpload:
|
||||
"""Tests for AzureBlobStorageBackend.upload method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_file(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading a file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"Hello, World!")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
result = backend.upload(temp_path, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_fails_if_blob_exists_without_overwrite(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that upload fails if blob exists and overwrite is False."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"content")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(temp_path, "existing.txt", overwrite=False)
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"content")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
result = backend.upload(temp_path, "existing.txt", overwrite=True)
|
||||
|
||||
assert result == "existing.txt"
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
# Check overwrite=True was passed
|
||||
call_kwargs = mock_blob.upload_blob.call_args[1]
|
||||
assert call_kwargs.get("overwrite") is True
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_nonexistent_file_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDownload:
|
||||
"""Tests for AzureBlobStorageBackend.download method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_file(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
# Mock download_blob to return stream
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"Hello, World!"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir) / "downloaded.txt"
|
||||
result = backend.download("remote/sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
assert local_path.exists()
|
||||
assert local_path.read_bytes() == b"Hello, World!"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_creates_parent_directories(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"content"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.exists()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_nonexistent_blob_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that downloading nonexistent blob fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendExists:
|
||||
"""Tests for AzureBlobStorageBackend.exists method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_exists_returns_true_for_existing_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.exists("existing.txt") is True
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_exists_returns_false_for_nonexistent_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendListFiles:
|
||||
"""Tests for AzureBlobStorageBackend.list_files method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_empty_container(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files in empty container."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.list_blobs.return_value = []
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.list_files("") == []
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_returns_all_blobs(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing all blobs."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
# Create mock blob items
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "file1.txt"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "file2.txt"
|
||||
mock_blob3 = MagicMock()
|
||||
mock_blob3.name = "subdir/file3.txt"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_with_prefix(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files with prefix filter."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "images/a.png"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "images/b.png"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
files = backend.list_files("images/")
|
||||
|
||||
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
|
||||
assert len(files) == 2
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDelete:
|
||||
"""Tests for AzureBlobStorageBackend.delete method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_delete_existing_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting an existing blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
mock_blob.delete_blob.assert_called_once()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_delete_nonexistent_blob_returns_false(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting nonexistent blob returns False."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
mock_blob.delete_blob.assert_not_called()
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendGetUrl:
|
||||
"""Tests for AzureBlobStorageBackend.get_url method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_url_returns_blob_url(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url returns blob URL."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
assert url == "https://account.blob.core.windows.net/container/sample.txt"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_url_nonexistent_blob_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url for nonexistent blob fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendUploadBytes:
|
||||
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
data = b"Binary content here"
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDownloadBytes:
|
||||
"""Tests for AzureBlobStorageBackend.download_bytes method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading blob as bytes."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"Hello, World!"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_bytes_nonexistent(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading nonexistent blob as bytes."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendBatchOperations:
|
||||
"""Tests for batch operations in AzureBlobStorageBackend."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading an entire directory."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
(temp_path / "file1.txt").write_text("content1")
|
||||
(temp_path / "subdir").mkdir()
|
||||
(temp_path / "subdir" / "file2.txt").write_text("content2")
|
||||
|
||||
results = backend.upload_directory(temp_path, "uploads/")
|
||||
|
||||
assert len(results) == 2
|
||||
assert "uploads/file1.txt" in results
|
||||
assert "uploads/subdir/file2.txt" in results
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_directory(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading blobs matching a prefix."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
# Mock blob listing
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "images/a.png"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "images/b.png"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||
|
||||
# Mock blob clients
|
||||
mock_blob_client = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"image content"
|
||||
mock_blob_client.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir)
|
||||
results = backend.download_directory("images/", local_path)
|
||||
|
||||
assert len(results) == 2
|
||||
# Files should be created relative to prefix
|
||||
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()
|
||||
301
tests/shared/storage/test_base.py
Normal file
301
tests/shared/storage/test_base.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Tests for storage base module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageBackendInterface:
|
||||
"""Tests for StorageBackend abstract base class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""Test that StorageBackend cannot be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
StorageBackend() # type: ignore
|
||||
|
||||
def test_is_abstract_base_class(self) -> None:
|
||||
"""Test that StorageBackend is an ABC."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
assert issubclass(StorageBackend, ABC)
|
||||
|
||||
def test_subclass_must_implement_upload(self) -> None:
|
||||
"""Test that subclass must implement upload method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_download(self) -> None:
|
||||
"""Test that subclass must implement download method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_exists(self) -> None:
|
||||
"""Test that subclass must implement exists method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_list_files(self) -> None:
|
||||
"""Test that subclass must implement list_files method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_delete(self) -> None:
|
||||
"""Test that subclass must implement delete method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_get_url(self) -> None:
|
||||
"""Test that subclass must implement get_url method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_valid_subclass_can_be_instantiated(self) -> None:
|
||||
"""Test that a complete subclass can be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class CompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def get_presigned_url(
|
||||
self, remote_path: str, expires_in_seconds: int = 3600
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
backend = CompleteBackend()
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestStorageError:
|
||||
"""Tests for StorageError exception."""
|
||||
|
||||
def test_storage_error_is_exception(self) -> None:
|
||||
"""Test that StorageError is an Exception."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
assert issubclass(StorageError, Exception)
|
||||
|
||||
def test_storage_error_with_message(self) -> None:
|
||||
"""Test StorageError with message."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
error = StorageError("Upload failed")
|
||||
assert str(error) == "Upload failed"
|
||||
|
||||
def test_storage_error_can_be_raised(self) -> None:
|
||||
"""Test that StorageError can be raised and caught."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
with pytest.raises(StorageError, match="test error"):
|
||||
raise StorageError("test error")
|
||||
|
||||
|
||||
class TestFileNotFoundError:
|
||||
"""Tests for FileNotFoundStorageError exception."""
|
||||
|
||||
def test_file_not_found_is_storage_error(self) -> None:
|
||||
"""Test that FileNotFoundStorageError is a StorageError."""
|
||||
from shared.storage.base import FileNotFoundStorageError, StorageError
|
||||
|
||||
assert issubclass(FileNotFoundStorageError, StorageError)
|
||||
|
||||
def test_file_not_found_with_path(self) -> None:
|
||||
"""Test FileNotFoundStorageError with path."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
error = FileNotFoundStorageError("images/test.png")
|
||||
assert "images/test.png" in str(error)
|
||||
|
||||
|
||||
class TestStorageConfig:
|
||||
"""Tests for StorageConfig dataclass."""
|
||||
|
||||
def test_storage_config_creation(self) -> None:
|
||||
"""Test creating StorageConfig."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="DefaultEndpointsProtocol=https;...",
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||
assert config.container_name == "training-images"
|
||||
|
||||
def test_storage_config_defaults(self) -> None:
|
||||
"""Test StorageConfig with defaults."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.connection_string is None
|
||||
assert config.container_name is None
|
||||
assert config.base_path is None
|
||||
|
||||
def test_storage_config_with_base_path(self) -> None:
|
||||
"""Test StorageConfig with base_path for local backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=Path("/data/images"),
|
||||
)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.base_path == Path("/data/images")
|
||||
|
||||
def test_storage_config_immutable(self) -> None:
|
||||
"""Test that StorageConfig is immutable (frozen)."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.backend_type = "azure_blob" # type: ignore
|
||||
348
tests/shared/storage/test_config_loader.py
Normal file
348
tests/shared/storage/test_config_loader.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Tests for storage configuration file loader.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Path:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class TestEnvVarSubstitution:
|
||||
"""Tests for environment variable substitution in config values."""
|
||||
|
||||
def test_substitute_simple_env_var(self) -> None:
|
||||
"""Test substituting a simple environment variable."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
|
||||
result = substitute_env_vars("${MY_VAR}")
|
||||
assert result == "my_value"
|
||||
|
||||
def test_substitute_env_var_with_default(self) -> None:
|
||||
"""Test substituting env var with default when var is not set."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
# Ensure var is not set
|
||||
os.environ.pop("UNSET_VAR", None)
|
||||
|
||||
result = substitute_env_vars("${UNSET_VAR:-default_value}")
|
||||
assert result == "default_value"
|
||||
|
||||
def test_substitute_env_var_ignores_default_when_set(self) -> None:
|
||||
"""Test that default is ignored when env var is set."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
|
||||
result = substitute_env_vars("${SET_VAR:-default_value}")
|
||||
assert result == "actual_value"
|
||||
|
||||
def test_substitute_multiple_env_vars(self) -> None:
|
||||
"""Test substituting multiple env vars in one string."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
|
||||
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
|
||||
assert result == "postgres://localhost:5432/db"
|
||||
|
||||
def test_substitute_preserves_non_env_text(self) -> None:
|
||||
"""Test that non-env-var text is preserved."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"VAR": "value"}):
|
||||
result = substitute_env_vars("prefix_${VAR}_suffix")
|
||||
assert result == "prefix_value_suffix"
|
||||
|
||||
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
|
||||
"""Test that empty string is returned when var not set and no default."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
os.environ.pop("MISSING_VAR", None)
|
||||
|
||||
result = substitute_env_vars("${MISSING_VAR}")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestLoadStorageConfigYaml:
|
||||
"""Tests for loading storage configuration from YAML files."""
|
||||
|
||||
def test_load_local_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for local backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: local
|
||||
presigned_url_expiry: 3600
|
||||
|
||||
local:
|
||||
base_path: ./data/storage
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.presigned_url_expiry == 3600
|
||||
assert config.local is not None
|
||||
assert config.local.base_path == Path("./data/storage")
|
||||
|
||||
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for Azure backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: azure_blob
|
||||
presigned_url_expiry: 7200
|
||||
|
||||
azure:
|
||||
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||
container_name: documents
|
||||
create_container: true
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.presigned_url_expiry == 7200
|
||||
assert config.azure is not None
|
||||
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
|
||||
assert config.azure.container_name == "documents"
|
||||
assert config.azure.create_container is True
|
||||
|
||||
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for S3 backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: s3
|
||||
presigned_url_expiry: 1800
|
||||
|
||||
s3:
|
||||
bucket_name: my-bucket
|
||||
region_name: us-west-2
|
||||
endpoint_url: http://localhost:9000
|
||||
create_bucket: false
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "s3"
|
||||
assert config.presigned_url_expiry == 1800
|
||||
assert config.s3 is not None
|
||||
assert config.s3.bucket_name == "my-bucket"
|
||||
assert config.s3.region_name == "us-west-2"
|
||||
assert config.s3.endpoint_url == "http://localhost:9000"
|
||||
assert config.s3.create_bucket is False
|
||||
|
||||
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
|
||||
"""Test that environment variables are substituted in config."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
local:
|
||||
base_path: ${STORAGE_PATH:-./default/path}
|
||||
""")
|
||||
|
||||
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.local is not None
|
||||
assert config.local.base_path == Path("/custom/path")
|
||||
|
||||
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that FileNotFoundError is raised for missing config file."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_storage_config(temp_dir / "nonexistent.yaml")
|
||||
|
||||
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that ValueError is raised for invalid YAML."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("invalid: yaml: content: [")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
load_storage_config(config_path)
|
||||
|
||||
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that ValueError is raised when backend is missing."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
local:
|
||||
base_path: ./data
|
||||
""")
|
||||
|
||||
with pytest.raises(ValueError, match="backend"):
|
||||
load_storage_config(config_path)
|
||||
|
||||
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
|
||||
"""Test default presigned_url_expiry when not specified."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: ./data
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.presigned_url_expiry == 3600 # Default value
|
||||
|
||||
|
||||
class TestStorageFileConfig:
|
||||
"""Tests for StorageFileConfig dataclass."""
|
||||
|
||||
def test_storage_file_config_is_immutable(self) -> None:
|
||||
"""Test that StorageFileConfig is frozen (immutable)."""
|
||||
from shared.storage.config_loader import StorageFileConfig
|
||||
|
||||
config = StorageFileConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.backend_type = "azure_blob" # type: ignore
|
||||
|
||||
def test_storage_file_config_defaults(self) -> None:
|
||||
"""Test StorageFileConfig default values."""
|
||||
from shared.storage.config_loader import StorageFileConfig
|
||||
|
||||
config = StorageFileConfig(backend_type="local")
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.local is None
|
||||
assert config.azure is None
|
||||
assert config.s3 is None
|
||||
assert config.presigned_url_expiry == 3600
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
"""Tests for LocalConfig dataclass."""
|
||||
|
||||
def test_local_config_creation(self) -> None:
|
||||
"""Test creating LocalConfig."""
|
||||
from shared.storage.config_loader import LocalConfig
|
||||
|
||||
config = LocalConfig(base_path=Path("/data/storage"))
|
||||
|
||||
assert config.base_path == Path("/data/storage")
|
||||
|
||||
def test_local_config_is_immutable(self) -> None:
|
||||
"""Test that LocalConfig is frozen."""
|
||||
from shared.storage.config_loader import LocalConfig
|
||||
|
||||
config = LocalConfig(base_path=Path("/data"))
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.base_path = Path("/other") # type: ignore
|
||||
|
||||
|
||||
class TestAzureConfig:
|
||||
"""Tests for AzureConfig dataclass."""
|
||||
|
||||
def test_azure_config_creation(self) -> None:
|
||||
"""Test creating AzureConfig."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="test_connection",
|
||||
container_name="test_container",
|
||||
create_container=True,
|
||||
)
|
||||
|
||||
assert config.connection_string == "test_connection"
|
||||
assert config.container_name == "test_container"
|
||||
assert config.create_container is True
|
||||
|
||||
def test_azure_config_defaults(self) -> None:
|
||||
"""Test AzureConfig default values."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="conn",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert config.create_container is False
|
||||
|
||||
def test_azure_config_is_immutable(self) -> None:
|
||||
"""Test that AzureConfig is frozen."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="conn",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.container_name = "other" # type: ignore
|
||||
|
||||
|
||||
class TestS3Config:
|
||||
"""Tests for S3Config dataclass."""
|
||||
|
||||
def test_s3_config_creation(self) -> None:
|
||||
"""Test creating S3Config."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(
|
||||
bucket_name="my-bucket",
|
||||
region_name="us-east-1",
|
||||
access_key_id="AKIAIOSFODNN7EXAMPLE",
|
||||
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
endpoint_url="http://localhost:9000",
|
||||
create_bucket=True,
|
||||
)
|
||||
|
||||
assert config.bucket_name == "my-bucket"
|
||||
assert config.region_name == "us-east-1"
|
||||
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
|
||||
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
assert config.endpoint_url == "http://localhost:9000"
|
||||
assert config.create_bucket is True
|
||||
|
||||
def test_s3_config_minimal(self) -> None:
|
||||
"""Test S3Config with only required fields."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(bucket_name="bucket")
|
||||
|
||||
assert config.bucket_name == "bucket"
|
||||
assert config.region_name is None
|
||||
assert config.access_key_id is None
|
||||
assert config.secret_access_key is None
|
||||
assert config.endpoint_url is None
|
||||
assert config.create_bucket is False
|
||||
|
||||
def test_s3_config_is_immutable(self) -> None:
|
||||
"""Test that S3Config is frozen."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(bucket_name="bucket")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.bucket_name = "other" # type: ignore
|
||||
423
tests/shared/storage/test_factory.py
Normal file
423
tests/shared/storage/test_factory.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Tests for storage factory.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageFactory:
|
||||
"""Tests for create_storage_backend factory function."""
|
||||
|
||||
def test_create_local_backend(self) -> None:
|
||||
"""Test creating local storage backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config = StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=Path(temp_dir),
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
assert backend.base_path == Path(temp_dir)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test creating Azure blob storage backend."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="DefaultEndpointsProtocol=https;...",
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
def test_create_unknown_backend_raises(self) -> None:
|
||||
"""Test that unknown backend type raises ValueError."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(backend_type="unknown_backend")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown storage backend"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_local_requires_base_path(self) -> None:
|
||||
"""Test that local backend requires base_path."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(ValueError, match="base_path"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_azure_requires_connection_string(self) -> None:
|
||||
"""Test that Azure backend requires connection_string."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection_string"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_azure_requires_container_name(self) -> None:
|
||||
"""Test that Azure backend requires container_name."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="connection_string",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="container_name"):
|
||||
create_storage_backend(config)
|
||||
|
||||
|
||||
class TestStorageFactoryFromEnv:
|
||||
"""Tests for create_storage_backend_from_env factory function."""
|
||||
|
||||
def test_create_from_env_local(self) -> None:
|
||||
"""Test creating local backend from environment variables."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test creating Azure backend from environment variables."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
def test_create_from_env_defaults_to_local(self) -> None:
|
||||
"""Test that factory defaults to local backend."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
# Remove STORAGE_BACKEND if present
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
if "STORAGE_BACKEND" in os.environ:
|
||||
del os.environ["STORAGE_BACKEND"]
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_create_from_env_missing_azure_vars(self) -> None:
|
||||
"""Test error when Azure env vars are missing."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
# Missing AZURE_STORAGE_CONNECTION_STRING
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# Remove the connection string if present
|
||||
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
|
||||
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
|
||||
|
||||
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
|
||||
create_storage_backend_from_env()
|
||||
|
||||
|
||||
class TestGetDefaultStorageConfig:
|
||||
"""Tests for get_default_storage_config function."""
|
||||
|
||||
def test_get_default_config_local(self) -> None:
|
||||
"""Test getting default local config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.base_path == Path(temp_dir)
|
||||
|
||||
def test_get_default_config_azure(self) -> None:
|
||||
"""Test getting default Azure config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||
assert config.container_name == "training-images"
|
||||
|
||||
|
||||
class TestStorageFactoryS3:
|
||||
"""Tests for S3 backend support in factory."""
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating S3 storage backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="s3",
|
||||
bucket_name="test-bucket",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_s3_requires_bucket_name(self) -> None:
|
||||
"""Test that S3 backend requires bucket_name."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="s3",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="bucket_name"):
|
||||
create_storage_backend(config)
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating S3 backend from environment variables."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
"AWS_S3_BUCKET": "test-bucket",
|
||||
"AWS_REGION": "us-east-1",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_from_env_s3_missing_bucket(self) -> None:
|
||||
"""Test error when S3 bucket env var is missing."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
# Missing AWS_S3_BUCKET
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
if "AWS_S3_BUCKET" in os.environ:
|
||||
del os.environ["AWS_S3_BUCKET"]
|
||||
|
||||
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
|
||||
create_storage_backend_from_env()
|
||||
|
||||
def test_get_default_config_s3(self) -> None:
|
||||
"""Test getting default S3 config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
"AWS_S3_BUCKET": "test-bucket",
|
||||
"AWS_REGION": "us-west-2",
|
||||
"AWS_ENDPOINT_URL": "http://localhost:9000",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "s3"
|
||||
assert config.bucket_name == "test-bucket"
|
||||
assert config.region_name == "us-west-2"
|
||||
assert config.endpoint_url == "http://localhost:9000"
|
||||
|
||||
|
||||
class TestStorageFactoryFromFile:
|
||||
"""Tests for create_storage_backend_from_file factory function."""
|
||||
|
||||
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
|
||||
"""Test creating local backend from YAML config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_from_yaml_file_azure(
|
||||
self, mock_service_class: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test creating Azure backend from YAML config file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
config_file.write_text("""
|
||||
backend: azure_blob
|
||||
|
||||
azure:
|
||||
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||
container_name: documents
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_from_yaml_file_s3(
|
||||
self, mock_boto3_client: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test creating S3 backend from YAML config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
config_file.write_text("""
|
||||
backend: s3
|
||||
|
||||
s3:
|
||||
bucket_name: my-bucket
|
||||
region_name: us-east-1
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
|
||||
"""Test that env vars are substituted in config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text("""
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
local:
|
||||
base_path: ${CUSTOM_STORAGE_PATH}
|
||||
""")
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
|
||||
):
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
|
||||
"""Test that FileNotFoundError is raised for missing file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
|
||||
|
||||
|
||||
class TestGetStorageBackend:
|
||||
"""Tests for get_storage_backend convenience function."""
|
||||
|
||||
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
|
||||
"""Test getting backend from explicit config file."""
|
||||
from shared.storage.factory import get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = get_storage_backend(config_path=config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||
"""Test that get_storage_backend falls back to env vars."""
|
||||
from shared.storage.factory import get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(storage_path),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# No config file provided, should use env vars
|
||||
backend = get_storage_backend(config_path=None)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
712
tests/shared/storage/test_local.py
Normal file
712
tests/shared/storage/test_local.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Tests for LocalStorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir() -> Path:
|
||||
"""Create a temporary directory for storage tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_storage_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample PNG file for testing."""
|
||||
file_path = temp_storage_dir / "sample.png"
|
||||
# Minimal valid PNG (1x1 transparent pixel)
|
||||
png_data = bytes(
|
||||
[
|
||||
0x89,
|
||||
0x50,
|
||||
0x4E,
|
||||
0x47,
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x1A,
|
||||
0x0A, # PNG signature
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0D, # IHDR length
|
||||
0x49,
|
||||
0x48,
|
||||
0x44,
|
||||
0x52, # IHDR
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x01, # width: 1
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x01, # height: 1
|
||||
0x08,
|
||||
0x06,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, # 8-bit RGBA
|
||||
0x1F,
|
||||
0x15,
|
||||
0xC4,
|
||||
0x89, # CRC
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0A, # IDAT length
|
||||
0x49,
|
||||
0x44,
|
||||
0x41,
|
||||
0x54, # IDAT
|
||||
0x78,
|
||||
0x9C,
|
||||
0x63,
|
||||
0x00,
|
||||
0x01,
|
||||
0x00,
|
||||
0x00,
|
||||
0x05,
|
||||
0x00,
|
||||
0x01, # compressed data
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x2D,
|
||||
0xB4, # CRC
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, # IEND length
|
||||
0x49,
|
||||
0x45,
|
||||
0x4E,
|
||||
0x44, # IEND
|
||||
0xAE,
|
||||
0x42,
|
||||
0x60,
|
||||
0x82, # CRC
|
||||
]
|
||||
)
|
||||
file_path.write_bytes(png_data)
|
||||
return file_path
|
||||
|
||||
|
||||
class TestLocalStorageBackendCreation:
|
||||
"""Tests for LocalStorageBackend instantiation."""
|
||||
|
||||
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
|
||||
"""Test creating backend with base path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.base_path == temp_storage_dir
|
||||
|
||||
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
|
||||
"""Test creating backend with string path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
|
||||
|
||||
assert backend.base_path == temp_storage_dir
|
||||
|
||||
def test_create_creates_directory_if_not_exists(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that base directory is created if it doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
new_dir = temp_storage_dir / "new_storage"
|
||||
assert not new_dir.exists()
|
||||
|
||||
backend = LocalStorageBackend(base_path=new_dir)
|
||||
|
||||
assert new_dir.exists()
|
||||
assert backend.base_path == new_dir
|
||||
|
||||
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that LocalStorageBackend is a StorageBackend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestLocalStorageBackendUpload:
|
||||
"""Tests for LocalStorageBackend.upload method."""
|
||||
|
||||
def test_upload_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test uploading a file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
assert (storage_dir / "uploads" / "sample.txt").exists()
|
||||
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
|
||||
|
||||
def test_upload_creates_subdirectories(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload creates necessary subdirectories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
|
||||
|
||||
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
|
||||
|
||||
def test_upload_fails_if_file_exists_without_overwrite(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload fails if file exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload succeeds
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Second upload should fail
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Modify original file
|
||||
sample_file.write_text("Modified content")
|
||||
|
||||
# Second upload with overwrite
|
||||
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||
|
||||
assert result == "sample.txt"
|
||||
assert (storage_dir / "sample.txt").read_text() == "Modified content"
|
||||
|
||||
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||
|
||||
def test_upload_binary_file(
|
||||
self, temp_storage_dir: Path, sample_image: Path
|
||||
) -> None:
|
||||
"""Test uploading a binary file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_image, "images/sample.png")
|
||||
|
||||
assert result == "images/sample.png"
|
||||
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
|
||||
assert uploaded_content == sample_image.read_bytes()
|
||||
|
||||
|
||||
class TestLocalStorageBackendDownload:
|
||||
"""Tests for LocalStorageBackend.download method."""
|
||||
|
||||
def test_download_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
download_dir = temp_storage_dir / "downloads"
|
||||
download_dir.mkdir()
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Then download
|
||||
local_path = download_dir / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
assert local_path.exists()
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
def test_download_creates_parent_directories(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.exists()
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that downloading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||
|
||||
def test_download_nested_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading a file from nested path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||
|
||||
local_path = temp_storage_dir / "downloaded.txt"
|
||||
result = backend.download("a/b/c/sample.txt", local_path)
|
||||
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
|
||||
class TestLocalStorageBackendExists:
|
||||
"""Tests for LocalStorageBackend.exists method."""
|
||||
|
||||
def test_exists_returns_true_for_existing_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test exists returns True for existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
assert backend.exists("sample.txt") is True
|
||||
|
||||
def test_exists_returns_false_for_nonexistent_file(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
def test_exists_with_nested_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test exists with nested path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/sample.txt")
|
||||
|
||||
assert backend.exists("a/b/sample.txt") is True
|
||||
assert backend.exists("a/b/other.txt") is False
|
||||
|
||||
|
||||
class TestLocalStorageBackendListFiles:
|
||||
"""Tests for LocalStorageBackend.list_files method."""
|
||||
|
||||
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
|
||||
"""Test listing files in empty storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.list_files("") == []
|
||||
|
||||
def test_list_files_returns_all_files(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test listing all files."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# Upload multiple files
|
||||
backend.upload(sample_file, "file1.txt")
|
||||
backend.upload(sample_file, "file2.txt")
|
||||
backend.upload(sample_file, "subdir/file3.txt")
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
def test_list_files_with_prefix(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test listing files with prefix filter."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
backend.upload(sample_file, "images/a.png")
|
||||
backend.upload(sample_file, "images/b.png")
|
||||
backend.upload(sample_file, "labels/a.txt")
|
||||
|
||||
files = backend.list_files("images/")
|
||||
|
||||
assert len(files) == 2
|
||||
assert "images/a.png" in files
|
||||
assert "images/b.png" in files
|
||||
assert "labels/a.txt" not in files
|
||||
|
||||
def test_list_files_returns_sorted(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that list_files returns sorted list."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
backend.upload(sample_file, "c.txt")
|
||||
backend.upload(sample_file, "a.txt")
|
||||
backend.upload(sample_file, "b.txt")
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert files == ["a.txt", "b.txt", "c.txt"]
|
||||
|
||||
|
||||
class TestLocalStorageBackendDelete:
|
||||
"""Tests for LocalStorageBackend.delete method."""
|
||||
|
||||
def test_delete_existing_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test deleting an existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
assert not (storage_dir / "sample.txt").exists()
|
||||
|
||||
def test_delete_nonexistent_file_returns_false(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test deleting nonexistent file returns False."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_delete_nested_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test deleting a nested file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/sample.txt")
|
||||
|
||||
result = backend.delete("a/b/sample.txt")
|
||||
|
||||
assert result is True
|
||||
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
|
||||
|
||||
|
||||
class TestLocalStorageBackendGetUrl:
|
||||
"""Tests for LocalStorageBackend.get_url method."""
|
||||
|
||||
def test_get_url_returns_file_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_url returns file:// URL."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
# Should return file:// URL or absolute path
|
||||
assert "sample.txt" in url
|
||||
# URL should be usable to locate the file
|
||||
expected_path = storage_dir / "sample.txt"
|
||||
assert str(expected_path) in url or expected_path.as_uri() == url
|
||||
|
||||
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
|
||||
"""Test get_url for nonexistent file."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestLocalStorageBackendUploadBytes:
|
||||
"""Tests for LocalStorageBackend.upload_bytes method."""
|
||||
|
||||
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
data = b"Binary content here"
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
assert (storage_dir / "binary.dat").read_bytes() == data
|
||||
|
||||
def test_upload_bytes_creates_subdirectories(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that upload_bytes creates subdirectories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
data = b"content"
|
||||
backend.upload_bytes(data, "a/b/c/file.dat")
|
||||
|
||||
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
|
||||
|
||||
|
||||
class TestLocalStorageBackendDownloadBytes:
|
||||
"""Tests for LocalStorageBackend.download_bytes method."""
|
||||
|
||||
def test_download_bytes(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading file as bytes."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
|
||||
"""Test downloading nonexistent file as bytes."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestLocalStorageBackendSecurity:
|
||||
"""Security tests for LocalStorageBackend - path traversal prevention."""
|
||||
|
||||
def test_path_traversal_with_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that path traversal using ../ is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "../escape.txt")
|
||||
|
||||
def test_path_traversal_with_nested_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that nested path traversal is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "subdir/../../escape.txt")
|
||||
|
||||
def test_path_traversal_with_many_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that deeply nested path traversal is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
|
||||
|
||||
def test_absolute_path_unix_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that absolute Unix paths are blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||
backend.upload(sample_file, "/etc/passwd")
|
||||
|
||||
def test_absolute_path_windows_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that absolute Windows paths are blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||
backend.upload(sample_file, "C:\\Windows\\System32\\config")
|
||||
|
||||
def test_download_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in download is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.download("../escape.txt", Path("/tmp/file.txt"))
|
||||
|
||||
def test_exists_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in exists is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.exists("../escape.txt")
|
||||
|
||||
def test_delete_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in delete is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.delete("../escape.txt")
|
||||
|
||||
def test_get_url_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in get_url is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.get_url("../escape.txt")
|
||||
|
||||
def test_upload_bytes_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in upload_bytes is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload_bytes(b"content", "../escape.txt")
|
||||
|
||||
def test_download_bytes_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in download_bytes is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.download_bytes("../escape.txt")
|
||||
|
||||
def test_valid_nested_path_still_works(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that valid nested paths still work after security fix."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# Valid nested paths should still work
|
||||
result = backend.upload(sample_file, "a/b/c/d/file.txt")
|
||||
|
||||
assert result == "a/b/c/d/file.txt"
|
||||
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()
|
||||
158
tests/shared/storage/test_prefixes.py
Normal file
158
tests/shared/storage/test_prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for storage prefixes module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||
|
||||
|
||||
class TestStoragePrefixes:
|
||||
"""Tests for StoragePrefixes class."""
|
||||
|
||||
def test_prefixes_are_strings(self) -> None:
|
||||
"""All prefix constants should be strings."""
|
||||
assert isinstance(PREFIXES.DOCUMENTS, str)
|
||||
assert isinstance(PREFIXES.IMAGES, str)
|
||||
assert isinstance(PREFIXES.UPLOADS, str)
|
||||
assert isinstance(PREFIXES.RESULTS, str)
|
||||
assert isinstance(PREFIXES.EXPORTS, str)
|
||||
assert isinstance(PREFIXES.DATASETS, str)
|
||||
assert isinstance(PREFIXES.MODELS, str)
|
||||
assert isinstance(PREFIXES.RAW_PDFS, str)
|
||||
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
|
||||
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
|
||||
|
||||
def test_prefixes_are_non_empty(self) -> None:
|
||||
"""All prefix constants should be non-empty."""
|
||||
assert PREFIXES.DOCUMENTS
|
||||
assert PREFIXES.IMAGES
|
||||
assert PREFIXES.UPLOADS
|
||||
assert PREFIXES.RESULTS
|
||||
assert PREFIXES.EXPORTS
|
||||
assert PREFIXES.DATASETS
|
||||
assert PREFIXES.MODELS
|
||||
assert PREFIXES.RAW_PDFS
|
||||
assert PREFIXES.STRUCTURED_DATA
|
||||
assert PREFIXES.ADMIN_IMAGES
|
||||
|
||||
def test_prefixes_have_no_leading_slash(self) -> None:
|
||||
"""Prefixes should not start with a slash for portability."""
|
||||
assert not PREFIXES.DOCUMENTS.startswith("/")
|
||||
assert not PREFIXES.IMAGES.startswith("/")
|
||||
assert not PREFIXES.UPLOADS.startswith("/")
|
||||
assert not PREFIXES.RESULTS.startswith("/")
|
||||
|
||||
def test_prefixes_have_no_trailing_slash(self) -> None:
|
||||
"""Prefixes should not end with a slash."""
|
||||
assert not PREFIXES.DOCUMENTS.endswith("/")
|
||||
assert not PREFIXES.IMAGES.endswith("/")
|
||||
assert not PREFIXES.UPLOADS.endswith("/")
|
||||
assert not PREFIXES.RESULTS.endswith("/")
|
||||
|
||||
def test_frozen_dataclass(self) -> None:
|
||||
"""StoragePrefixes should be immutable."""
|
||||
with pytest.raises(Exception): # FrozenInstanceError
|
||||
PREFIXES.DOCUMENTS = "new_value" # type: ignore
|
||||
|
||||
|
||||
class TestDocumentPath:
|
||||
"""Tests for document_path helper."""
|
||||
|
||||
def test_document_path_with_extension(self) -> None:
|
||||
"""Should generate correct document path with extension."""
|
||||
path = PREFIXES.document_path("abc123", ".pdf")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
def test_document_path_without_leading_dot(self) -> None:
|
||||
"""Should handle extension without leading dot."""
|
||||
path = PREFIXES.document_path("abc123", "pdf")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
def test_document_path_default_extension(self) -> None:
|
||||
"""Should use .pdf as default extension."""
|
||||
path = PREFIXES.document_path("abc123")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
|
||||
class TestImagePath:
|
||||
"""Tests for image_path helper."""
|
||||
|
||||
def test_image_path_basic(self) -> None:
|
||||
"""Should generate correct image path."""
|
||||
path = PREFIXES.image_path("doc123", 1)
|
||||
assert path == "images/doc123/page_1.png"
|
||||
|
||||
def test_image_path_page_number(self) -> None:
|
||||
"""Should include page number in path."""
|
||||
path = PREFIXES.image_path("doc123", 5)
|
||||
assert path == "images/doc123/page_5.png"
|
||||
|
||||
def test_image_path_custom_extension(self) -> None:
|
||||
"""Should support custom extension."""
|
||||
path = PREFIXES.image_path("doc123", 1, ".jpg")
|
||||
assert path == "images/doc123/page_1.jpg"
|
||||
|
||||
|
||||
class TestUploadPath:
|
||||
"""Tests for upload_path helper."""
|
||||
|
||||
def test_upload_path_basic(self) -> None:
|
||||
"""Should generate correct upload path."""
|
||||
path = PREFIXES.upload_path("invoice.pdf")
|
||||
assert path == "uploads/invoice.pdf"
|
||||
|
||||
def test_upload_path_with_subfolder(self) -> None:
|
||||
"""Should include subfolder when provided."""
|
||||
path = PREFIXES.upload_path("invoice.pdf", "async")
|
||||
assert path == "uploads/async/invoice.pdf"
|
||||
|
||||
|
||||
class TestResultPath:
|
||||
"""Tests for result_path helper."""
|
||||
|
||||
def test_result_path_basic(self) -> None:
|
||||
"""Should generate correct result path."""
|
||||
path = PREFIXES.result_path("output.json")
|
||||
assert path == "results/output.json"
|
||||
|
||||
|
||||
class TestExportPath:
|
||||
"""Tests for export_path helper."""
|
||||
|
||||
def test_export_path_basic(self) -> None:
|
||||
"""Should generate correct export path."""
|
||||
path = PREFIXES.export_path("exp123", "dataset.zip")
|
||||
assert path == "exports/exp123/dataset.zip"
|
||||
|
||||
|
||||
class TestDatasetPath:
|
||||
"""Tests for dataset_path helper."""
|
||||
|
||||
def test_dataset_path_basic(self) -> None:
|
||||
"""Should generate correct dataset path."""
|
||||
path = PREFIXES.dataset_path("ds123", "data.yaml")
|
||||
assert path == "datasets/ds123/data.yaml"
|
||||
|
||||
|
||||
class TestModelPath:
|
||||
"""Tests for model_path helper."""
|
||||
|
||||
def test_model_path_basic(self) -> None:
|
||||
"""Should generate correct model path."""
|
||||
path = PREFIXES.model_path("v1.0.0", "best.pt")
|
||||
assert path == "models/v1.0.0/best.pt"
|
||||
|
||||
|
||||
class TestExportsFromInit:
|
||||
"""Tests for exports from storage __init__.py."""
|
||||
|
||||
def test_prefixes_exported(self) -> None:
|
||||
"""PREFIXES should be exported from storage module."""
|
||||
from shared.storage import PREFIXES as exported_prefixes
|
||||
|
||||
assert exported_prefixes is PREFIXES
|
||||
|
||||
def test_storage_prefixes_exported(self) -> None:
|
||||
"""StoragePrefixes should be exported from storage module."""
|
||||
from shared.storage import StoragePrefixes as exported_class
|
||||
|
||||
assert exported_class is StoragePrefixes
|
||||
264
tests/shared/storage/test_presigned_urls.py
Normal file
264
tests/shared/storage/test_presigned_urls.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Tests for pre-signed URL functionality across all storage backends.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir() -> Path:
|
||||
"""Create a temporary directory for storage tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_storage_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
class TestStorageBackendInterfacePresignedUrl:
|
||||
"""Tests for get_presigned_url in StorageBackend interface."""
|
||||
|
||||
def test_subclass_must_implement_get_presigned_url(self) -> None:
|
||||
"""Test that subclass must implement get_presigned_url method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
|
||||
"""Test that a complete subclass with get_presigned_url can be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class CompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def get_presigned_url(
|
||||
self, remote_path: str, expires_in_seconds: int = 3600
|
||||
) -> str:
|
||||
return f"https://example.com/{remote_path}?token=abc"
|
||||
|
||||
backend = CompleteBackend()
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestLocalStorageBackendPresignedUrl:
|
||||
"""Tests for LocalStorageBackend.get_presigned_url method."""
|
||||
|
||||
def test_get_presigned_url_returns_file_uri(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url returns file:// URI for existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
url = backend.get_presigned_url("sample.txt")
|
||||
|
||||
assert url.startswith("file://")
|
||||
assert "sample.txt" in url
|
||||
|
||||
def test_get_presigned_url_with_custom_expiry(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url accepts expires_in_seconds parameter."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# For local storage, expiry is ignored but should not raise error
|
||||
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||
|
||||
assert url.startswith("file://")
|
||||
|
||||
def test_get_presigned_url_nonexistent_file_raises(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
|
||||
def test_get_presigned_url_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in get_presigned_url is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.get_presigned_url("../escape.txt")
|
||||
|
||||
def test_get_presigned_url_nested_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url works with nested paths."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||
|
||||
url = backend.get_presigned_url("a/b/c/sample.txt")
|
||||
|
||||
assert url.startswith("file://")
|
||||
assert "sample.txt" in url
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendPresignedUrl:
|
||||
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_generates_sas_url(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url generates URL with SAS token."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
# Setup mocks
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.account_name = "testaccount"
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||
|
||||
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
|
||||
|
||||
assert "https://testaccount.blob.core.windows.net" in url
|
||||
assert "sv=2021-06-08" in url or "test.txt" in url
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_nonexistent_blob_raises(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises for nonexistent blob."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = False
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_uses_custom_expiry(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url uses custom expiry time."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.account_name = "testaccount"
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||
|
||||
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
|
||||
|
||||
# Verify generate_blob_sas was called (expiry is part of the call)
|
||||
mock_generate_sas.assert_called_once()
|
||||
520
tests/shared/storage/test_s3.py
Normal file
520
tests/shared/storage/test_s3.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Tests for S3StorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Path:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_boto3_client():
|
||||
"""Create a mock boto3 S3 client."""
|
||||
with patch("boto3.client") as mock_client_func:
|
||||
mock_client = MagicMock()
|
||||
mock_client_func.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
class TestS3StorageBackendCreation:
|
||||
"""Tests for S3StorageBackend instantiation."""
|
||||
|
||||
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with bucket name."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.bucket_name == "test-bucket"
|
||||
|
||||
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with region."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("region_name") == "us-west-2"
|
||||
|
||||
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with explicit credentials."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
access_key_id="AKIATEST",
|
||||
secret_access_key="secret123",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
|
||||
assert call_kwargs.get("aws_secret_access_key") == "secret123"
|
||||
|
||||
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with custom endpoint (for S3-compatible services)."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
endpoint_url="http://localhost:9000",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
|
||||
|
||||
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test that bucket is created when create_bucket=True."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_bucket.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadBucket"
|
||||
)
|
||||
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
create_bucket=True,
|
||||
)
|
||||
|
||||
mock_boto3_client.create_bucket.assert_called_once()
|
||||
|
||||
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test that S3StorageBackend is a StorageBackend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestS3StorageBackendUpload:
|
||||
"""Tests for S3StorageBackend.upload method."""
|
||||
|
||||
def test_upload_file(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test uploading a file."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
# Object does not exist
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
mock_boto3_client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_fails_if_exists_without_overwrite(
|
||||
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload fails if object exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||
|
||||
assert result == "sample.txt"
|
||||
mock_boto3_client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_nonexistent_file_fails(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendDownload:
|
||||
"""Tests for S3StorageBackend.download method."""
|
||||
|
||||
def test_download_file(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
local_path = temp_dir / "downloaded.txt"
|
||||
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
mock_boto3_client.download_file.assert_called_once()
|
||||
|
||||
def test_download_creates_parent_directories(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
|
||||
|
||||
backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.parent.exists()
|
||||
|
||||
def test_download_nonexistent_object_fails(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that downloading nonexistent object fails."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download("nonexistent.txt", temp_dir / "file.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendExists:
|
||||
"""Tests for S3StorageBackend.exists method."""
|
||||
|
||||
def test_exists_returns_true_for_existing_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing object."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.exists("sample.txt") is True
|
||||
|
||||
def test_exists_returns_false_for_nonexistent_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
|
||||
class TestS3StorageBackendListFiles:
|
||||
"""Tests for S3StorageBackend.list_files method."""
|
||||
|
||||
def test_list_files_returns_objects(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing objects."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{"Key": "file1.txt"},
|
||||
{"Key": "file2.txt"},
|
||||
{"Key": "subdir/file3.txt"},
|
||||
]
|
||||
}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
def test_list_files_with_prefix(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing objects with prefix filter."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{"Key": "images/a.png"},
|
||||
{"Key": "images/b.png"},
|
||||
]
|
||||
}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("images/")
|
||||
|
||||
mock_boto3_client.list_objects_v2.assert_called_with(
|
||||
Bucket="test-bucket", Prefix="images/"
|
||||
)
|
||||
|
||||
def test_list_files_empty_bucket(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files in empty bucket."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("")
|
||||
|
||||
assert files == []
|
||||
|
||||
|
||||
class TestS3StorageBackendDelete:
|
||||
"""Tests for S3StorageBackend.delete method."""
|
||||
|
||||
def test_delete_existing_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting an existing object."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
mock_boto3_client.delete_object.assert_called_once()
|
||||
|
||||
def test_delete_nonexistent_object_returns_false(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting nonexistent object returns False."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestS3StorageBackendGetUrl:
|
||||
"""Tests for S3StorageBackend.get_url method."""
|
||||
|
||||
def test_get_url_returns_s3_url(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url returns S3 URL."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = (
|
||||
"https://test-bucket.s3.amazonaws.com/sample.txt"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
assert "sample.txt" in url
|
||||
|
||||
def test_get_url_nonexistent_object_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url raises for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendUploadBytes:
|
||||
"""Tests for S3StorageBackend.upload_bytes method."""
|
||||
|
||||
def test_upload_bytes(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
data = b"Binary content here"
|
||||
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
mock_boto3_client.put_object.assert_called_once()
|
||||
|
||||
def test_upload_bytes_fails_if_exists_without_overwrite(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test upload_bytes fails if object exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
|
||||
|
||||
|
||||
class TestS3StorageBackendDownloadBytes:
|
||||
"""Tests for S3StorageBackend.download_bytes method."""
|
||||
|
||||
def test_download_bytes(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading object as bytes."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"Hello, World!"
|
||||
mock_boto3_client.get_object.return_value = {"Body": mock_response}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
def test_download_bytes_nonexistent_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading nonexistent object as bytes."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.get_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendPresignedUrl:
|
||||
"""Tests for S3StorageBackend.get_presigned_url method."""
|
||||
|
||||
def test_get_presigned_url_generates_url(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url generates presigned URL."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = (
|
||||
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
url = backend.get_presigned_url("sample.txt")
|
||||
|
||||
assert "X-Amz-Algorithm" in url or "sample.txt" in url
|
||||
mock_boto3_client.generate_presigned_url.assert_called_once()
|
||||
|
||||
def test_get_presigned_url_with_custom_expiry(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url uses custom expiry."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = "https://..."
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||
|
||||
call_args = mock_boto3_client.generate_presigned_url.call_args
|
||||
assert call_args[1].get("ExpiresIn") == 7200
|
||||
|
||||
def test_get_presigned_url_nonexistent_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
Reference in New Issue
Block a user