WIP
This commit is contained in:
@@ -17,9 +17,8 @@ from inference.data.admin_models import (
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
TrainingTask,
|
||||
FIELD_CLASSES,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
)
|
||||
from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING
|
||||
|
||||
|
||||
class TestBatchUpload:
|
||||
@@ -507,7 +506,10 @@ class TestCSVToClassMapping:
|
||||
assert len(CSV_TO_CLASS_MAPPING) > 0
|
||||
|
||||
def test_csv_mapping_values(self):
|
||||
"""Test specific CSV column mappings."""
|
||||
"""Test specific CSV column mappings.
|
||||
|
||||
Note: customer_number is class 8 (verified from trained model best.pt).
|
||||
"""
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
|
||||
@@ -516,7 +518,7 @@ class TestCSVToClassMapping:
|
||||
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
|
||||
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
|
||||
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
|
||||
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9
|
||||
assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8
|
||||
|
||||
def test_csv_mapping_matches_field_classes(self):
|
||||
"""Test that CSV mapping is consistent with FIELD_CLASSES."""
|
||||
|
||||
1
tests/shared/fields/__init__.py
Normal file
1
tests/shared/fields/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for shared.fields module."""
|
||||
200
tests/shared/fields/test_field_config.py
Normal file
200
tests/shared/fields/test_field_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for field configuration - Single Source of Truth.
|
||||
|
||||
These tests ensure consistency across all field definitions and prevent
|
||||
accidental changes that could break model inference.
|
||||
|
||||
CRITICAL: These tests verify that field definitions match the trained YOLO model.
|
||||
If these tests fail, it likely means someone modified field IDs incorrectly.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.fields import (
|
||||
FIELD_DEFINITIONS,
|
||||
CLASS_NAMES,
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
NUM_CLASSES,
|
||||
FieldDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestFieldDefinitionsIntegrity:
|
||||
"""Tests to ensure field definitions are complete and consistent."""
|
||||
|
||||
def test_exactly_10_field_definitions(self):
|
||||
"""Verify we have exactly 10 field classes (matching trained model)."""
|
||||
assert len(FIELD_DEFINITIONS) == 10
|
||||
assert NUM_CLASSES == 10
|
||||
|
||||
def test_class_ids_are_sequential(self):
|
||||
"""Verify class IDs are 0-9 without gaps."""
|
||||
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
|
||||
assert class_ids == set(range(10))
|
||||
|
||||
def test_class_ids_are_unique(self):
|
||||
"""Verify no duplicate class IDs."""
|
||||
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_ids) == len(set(class_ids))
|
||||
|
||||
def test_class_names_are_unique(self):
|
||||
"""Verify no duplicate class names."""
|
||||
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_names) == len(set(class_names))
|
||||
|
||||
def test_field_definition_is_immutable(self):
|
||||
"""Verify FieldDefinition is frozen (immutable)."""
|
||||
fd = FIELD_DEFINITIONS[0]
|
||||
with pytest.raises(AttributeError):
|
||||
fd.class_id = 99 # type: ignore
|
||||
|
||||
|
||||
class TestModelCompatibility:
|
||||
"""Tests to verify field definitions match the trained YOLO model.
|
||||
|
||||
These exact values are read from runs/train/invoice_fields/weights/best.pt
|
||||
and MUST NOT be changed without retraining the model.
|
||||
"""
|
||||
|
||||
# Expected model.names from best.pt - DO NOT CHANGE
|
||||
EXPECTED_MODEL_NAMES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_org_number",
|
||||
8: "customer_number",
|
||||
9: "payment_line",
|
||||
}
|
||||
|
||||
def test_field_classes_match_model(self):
|
||||
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
|
||||
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
|
||||
|
||||
def test_class_names_order_matches_model(self):
|
||||
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
|
||||
expected_order = [
|
||||
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
|
||||
]
|
||||
assert CLASS_NAMES == expected_order
|
||||
|
||||
def test_customer_number_is_class_8(self):
|
||||
"""CRITICAL: customer_number must be class 8 (not 9)."""
|
||||
assert FIELD_CLASS_IDS["customer_number"] == 8
|
||||
assert FIELD_CLASSES[8] == "customer_number"
|
||||
|
||||
def test_payment_line_is_class_9(self):
|
||||
"""CRITICAL: payment_line must be class 9 (not 8)."""
|
||||
assert FIELD_CLASS_IDS["payment_line"] == 9
|
||||
assert FIELD_CLASSES[9] == "payment_line"
|
||||
|
||||
|
||||
class TestMappingConsistency:
|
||||
"""Tests to verify all mappings are consistent with each other."""
|
||||
|
||||
def test_field_classes_and_field_class_ids_are_inverses(self):
|
||||
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
|
||||
for class_id, class_name in FIELD_CLASSES.items():
|
||||
assert FIELD_CLASS_IDS[class_name] == class_id
|
||||
|
||||
for class_name, class_id in FIELD_CLASS_IDS.items():
|
||||
assert FIELD_CLASSES[class_id] == class_name
|
||||
|
||||
def test_class_names_matches_field_classes_values(self):
|
||||
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
|
||||
for i, class_name in enumerate(CLASS_NAMES):
|
||||
assert FIELD_CLASSES[i] == class_name
|
||||
|
||||
def test_class_to_field_has_all_classes(self):
|
||||
"""Verify CLASS_TO_FIELD has mapping for all class names."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in CLASS_TO_FIELD
|
||||
|
||||
def test_csv_mapping_excludes_derived_fields(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
|
||||
# payment_line is derived, should not be in CSV mapping
|
||||
assert "payment_line" not in CSV_TO_CLASS_MAPPING
|
||||
|
||||
# All non-derived fields should be in CSV mapping
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if not fd.is_derived:
|
||||
assert fd.field_name in CSV_TO_CLASS_MAPPING
|
||||
|
||||
def test_training_field_classes_includes_all(self):
|
||||
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||
|
||||
|
||||
class TestSpecificFieldDefinitions:
|
||||
"""Tests for specific field definitions to catch common mistakes."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_id,expected_class_name",
|
||||
[
|
||||
(0, "invoice_number"),
|
||||
(1, "invoice_date"),
|
||||
(2, "invoice_due_date"),
|
||||
(3, "ocr_number"),
|
||||
(4, "bankgiro"),
|
||||
(5, "plusgiro"),
|
||||
(6, "amount"),
|
||||
(7, "supplier_org_number"),
|
||||
(8, "customer_number"),
|
||||
(9, "payment_line"),
|
||||
],
|
||||
)
|
||||
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
|
||||
"""Verify each class ID maps to the correct class name."""
|
||||
assert FIELD_CLASSES[class_id] == expected_class_name
|
||||
|
||||
def test_payment_line_is_derived(self):
|
||||
"""Verify payment_line is marked as derived."""
|
||||
payment_line_def = next(
|
||||
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
|
||||
)
|
||||
assert payment_line_def.is_derived is True
|
||||
|
||||
def test_other_fields_are_not_derived(self):
|
||||
"""Verify all fields except payment_line are not derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if fd.class_name != "payment_line":
|
||||
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests to ensure backward compatibility with existing code."""
|
||||
|
||||
def test_csv_to_class_mapping_field_names(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
|
||||
# These are the field names used in CSV files
|
||||
expected_fields = {
|
||||
"InvoiceNumber": 0,
|
||||
"InvoiceDate": 1,
|
||||
"InvoiceDueDate": 2,
|
||||
"OCR": 3,
|
||||
"Bankgiro": 4,
|
||||
"Plusgiro": 5,
|
||||
"Amount": 6,
|
||||
"supplier_organisation_number": 7,
|
||||
"customer_number": 8,
|
||||
# payment_line (9) is derived, not in CSV
|
||||
}
|
||||
assert CSV_TO_CLASS_MAPPING == expected_fields
|
||||
|
||||
def test_class_to_field_returns_field_names(self):
|
||||
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
|
||||
# Sample checks for key fields
|
||||
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
|
||||
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
|
||||
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
|
||||
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
|
||||
assert CLASS_TO_FIELD["payment_line"] == "payment_line"
|
||||
1
tests/shared/storage/__init__.py
Normal file
1
tests/shared/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for storage module
|
||||
718
tests/shared/storage/test_azure.py
Normal file
718
tests/shared/storage/test_azure.py
Normal file
@@ -0,0 +1,718 @@
|
||||
"""
|
||||
Tests for AzureBlobStorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
Uses mocking to avoid requiring actual Azure credentials.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blob_service_client() -> MagicMock:
|
||||
"""Create a mock BlobServiceClient."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
|
||||
"""Create a mock ContainerClient."""
|
||||
container_client = MagicMock()
|
||||
mock_blob_service_client.get_container_client.return_value = container_client
|
||||
return container_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
|
||||
"""Create a mock BlobClient."""
|
||||
blob_client = MagicMock()
|
||||
mock_container_client.get_blob_client.return_value = blob_client
|
||||
return blob_client
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendCreation:
|
||||
"""Tests for AzureBlobStorageBackend instantiation."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_with_connection_string(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test creating backend with connection string."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string=connection_string,
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
mock_service_class.from_connection_string.assert_called_once_with(
|
||||
connection_string
|
||||
)
|
||||
assert backend.container_name == "training-images"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_creates_container_if_not_exists(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that container is created if it doesn't exist."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="new-container",
|
||||
create_container=True,
|
||||
)
|
||||
|
||||
mock_container.create_container.assert_called_once()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_does_not_create_container_by_default(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that container is not created by default."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="existing-container",
|
||||
)
|
||||
|
||||
mock_container.create_container.assert_not_called()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_is_storage_backend_subclass(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that AzureBlobStorageBackend is a StorageBackend."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendUpload:
|
||||
"""Tests for AzureBlobStorageBackend.upload method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_file(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading a file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"Hello, World!")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
result = backend.upload(temp_path, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_fails_if_blob_exists_without_overwrite(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that upload fails if blob exists and overwrite is False."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"content")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(temp_path, "existing.txt", overwrite=False)
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||
f.write(b"content")
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
result = backend.upload(temp_path, "existing.txt", overwrite=True)
|
||||
|
||||
assert result == "existing.txt"
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
# Check overwrite=True was passed
|
||||
call_kwargs = mock_blob.upload_blob.call_args[1]
|
||||
assert call_kwargs.get("overwrite") is True
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_nonexistent_file_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDownload:
|
||||
"""Tests for AzureBlobStorageBackend.download method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_file(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
# Mock download_blob to return stream
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"Hello, World!"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir) / "downloaded.txt"
|
||||
result = backend.download("remote/sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
assert local_path.exists()
|
||||
assert local_path.read_bytes() == b"Hello, World!"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_creates_parent_directories(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"content"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.exists()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_nonexistent_blob_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test that downloading nonexistent blob fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendExists:
|
||||
"""Tests for AzureBlobStorageBackend.exists method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_exists_returns_true_for_existing_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.exists("existing.txt") is True
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_exists_returns_false_for_nonexistent_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendListFiles:
|
||||
"""Tests for AzureBlobStorageBackend.list_files method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_empty_container(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files in empty container."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_container.list_blobs.return_value = []
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert backend.list_files("") == []
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_returns_all_blobs(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing all blobs."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
# Create mock blob items
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "file1.txt"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "file2.txt"
|
||||
mock_blob3 = MagicMock()
|
||||
mock_blob3.name = "subdir/file3.txt"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_list_files_with_prefix(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files with prefix filter."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "images/a.png"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "images/b.png"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
files = backend.list_files("images/")
|
||||
|
||||
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
|
||||
assert len(files) == 2
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDelete:
|
||||
"""Tests for AzureBlobStorageBackend.delete method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_delete_existing_blob(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting an existing blob."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
mock_blob.delete_blob.assert_called_once()
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_delete_nonexistent_blob_returns_false(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting nonexistent blob returns False."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
mock_blob.delete_blob.assert_not_called()
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendGetUrl:
|
||||
"""Tests for AzureBlobStorageBackend.get_url method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_url_returns_blob_url(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url returns blob URL."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
assert url == "https://account.blob.core.windows.net/container/sample.txt"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_url_nonexistent_blob_fails(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url for nonexistent blob fails."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendUploadBytes:
|
||||
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
data = b"Binary content here"
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
mock_blob.upload_blob.assert_called_once()
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendDownloadBytes:
|
||||
"""Tests for AzureBlobStorageBackend.download_bytes method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading blob as bytes."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = True
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"Hello, World!"
|
||||
mock_blob.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_bytes_nonexistent(
|
||||
self, mock_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading nonexistent blob as bytes."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendBatchOperations:
|
||||
"""Tests for batch operations in AzureBlobStorageBackend."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test uploading an entire directory."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
mock_blob = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob
|
||||
mock_blob.exists.return_value = False
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
(temp_path / "file1.txt").write_text("content1")
|
||||
(temp_path / "subdir").mkdir()
|
||||
(temp_path / "subdir" / "file2.txt").write_text("content2")
|
||||
|
||||
results = backend.upload_directory(temp_path, "uploads/")
|
||||
|
||||
assert len(results) == 2
|
||||
assert "uploads/file1.txt" in results
|
||||
assert "uploads/subdir/file2.txt" in results
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_download_directory(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test downloading blobs matching a prefix."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service_class.from_connection_string.return_value = mock_service
|
||||
mock_container = MagicMock()
|
||||
mock_service.get_container_client.return_value = mock_container
|
||||
|
||||
# Mock blob listing
|
||||
mock_blob1 = MagicMock()
|
||||
mock_blob1.name = "images/a.png"
|
||||
mock_blob2 = MagicMock()
|
||||
mock_blob2.name = "images/b.png"
|
||||
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||
|
||||
# Mock blob clients
|
||||
mock_blob_client = MagicMock()
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.readall.return_value = b"image content"
|
||||
mock_blob_client.download_blob.return_value = mock_stream
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="connection_string",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir)
|
||||
results = backend.download_directory("images/", local_path)
|
||||
|
||||
assert len(results) == 2
|
||||
# Files should be created relative to prefix
|
||||
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()
|
||||
301
tests/shared/storage/test_base.py
Normal file
301
tests/shared/storage/test_base.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Tests for storage base module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageBackendInterface:
|
||||
"""Tests for StorageBackend abstract base class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""Test that StorageBackend cannot be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
StorageBackend() # type: ignore
|
||||
|
||||
def test_is_abstract_base_class(self) -> None:
|
||||
"""Test that StorageBackend is an ABC."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
assert issubclass(StorageBackend, ABC)
|
||||
|
||||
def test_subclass_must_implement_upload(self) -> None:
|
||||
"""Test that subclass must implement upload method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_download(self) -> None:
|
||||
"""Test that subclass must implement download method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_exists(self) -> None:
|
||||
"""Test that subclass must implement exists method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_list_files(self) -> None:
|
||||
"""Test that subclass must implement list_files method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_delete(self) -> None:
|
||||
"""Test that subclass must implement delete method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_subclass_must_implement_get_url(self) -> None:
|
||||
"""Test that subclass must implement get_url method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_valid_subclass_can_be_instantiated(self) -> None:
|
||||
"""Test that a complete subclass can be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class CompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def get_presigned_url(
|
||||
self, remote_path: str, expires_in_seconds: int = 3600
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
backend = CompleteBackend()
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestStorageError:
|
||||
"""Tests for StorageError exception."""
|
||||
|
||||
def test_storage_error_is_exception(self) -> None:
|
||||
"""Test that StorageError is an Exception."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
assert issubclass(StorageError, Exception)
|
||||
|
||||
def test_storage_error_with_message(self) -> None:
|
||||
"""Test StorageError with message."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
error = StorageError("Upload failed")
|
||||
assert str(error) == "Upload failed"
|
||||
|
||||
def test_storage_error_can_be_raised(self) -> None:
|
||||
"""Test that StorageError can be raised and caught."""
|
||||
from shared.storage.base import StorageError
|
||||
|
||||
with pytest.raises(StorageError, match="test error"):
|
||||
raise StorageError("test error")
|
||||
|
||||
|
||||
class TestFileNotFoundError:
|
||||
"""Tests for FileNotFoundStorageError exception."""
|
||||
|
||||
def test_file_not_found_is_storage_error(self) -> None:
|
||||
"""Test that FileNotFoundStorageError is a StorageError."""
|
||||
from shared.storage.base import FileNotFoundStorageError, StorageError
|
||||
|
||||
assert issubclass(FileNotFoundStorageError, StorageError)
|
||||
|
||||
def test_file_not_found_with_path(self) -> None:
|
||||
"""Test FileNotFoundStorageError with path."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
|
||||
error = FileNotFoundStorageError("images/test.png")
|
||||
assert "images/test.png" in str(error)
|
||||
|
||||
|
||||
class TestStorageConfig:
|
||||
"""Tests for StorageConfig dataclass."""
|
||||
|
||||
def test_storage_config_creation(self) -> None:
|
||||
"""Test creating StorageConfig."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="DefaultEndpointsProtocol=https;...",
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||
assert config.container_name == "training-images"
|
||||
|
||||
def test_storage_config_defaults(self) -> None:
|
||||
"""Test StorageConfig with defaults."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.connection_string is None
|
||||
assert config.container_name is None
|
||||
assert config.base_path is None
|
||||
|
||||
def test_storage_config_with_base_path(self) -> None:
|
||||
"""Test StorageConfig with base_path for local backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=Path("/data/images"),
|
||||
)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.base_path == Path("/data/images")
|
||||
|
||||
def test_storage_config_immutable(self) -> None:
|
||||
"""Test that StorageConfig is immutable (frozen)."""
|
||||
from shared.storage.base import StorageConfig
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.backend_type = "azure_blob" # type: ignore
|
||||
348
tests/shared/storage/test_config_loader.py
Normal file
348
tests/shared/storage/test_config_loader.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Tests for storage configuration file loader.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Path:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class TestEnvVarSubstitution:
|
||||
"""Tests for environment variable substitution in config values."""
|
||||
|
||||
def test_substitute_simple_env_var(self) -> None:
|
||||
"""Test substituting a simple environment variable."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
|
||||
result = substitute_env_vars("${MY_VAR}")
|
||||
assert result == "my_value"
|
||||
|
||||
def test_substitute_env_var_with_default(self) -> None:
|
||||
"""Test substituting env var with default when var is not set."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
# Ensure var is not set
|
||||
os.environ.pop("UNSET_VAR", None)
|
||||
|
||||
result = substitute_env_vars("${UNSET_VAR:-default_value}")
|
||||
assert result == "default_value"
|
||||
|
||||
def test_substitute_env_var_ignores_default_when_set(self) -> None:
|
||||
"""Test that default is ignored when env var is set."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
|
||||
result = substitute_env_vars("${SET_VAR:-default_value}")
|
||||
assert result == "actual_value"
|
||||
|
||||
def test_substitute_multiple_env_vars(self) -> None:
|
||||
"""Test substituting multiple env vars in one string."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
|
||||
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
|
||||
assert result == "postgres://localhost:5432/db"
|
||||
|
||||
def test_substitute_preserves_non_env_text(self) -> None:
|
||||
"""Test that non-env-var text is preserved."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
with patch.dict(os.environ, {"VAR": "value"}):
|
||||
result = substitute_env_vars("prefix_${VAR}_suffix")
|
||||
assert result == "prefix_value_suffix"
|
||||
|
||||
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
|
||||
"""Test that empty string is returned when var not set and no default."""
|
||||
from shared.storage.config_loader import substitute_env_vars
|
||||
|
||||
os.environ.pop("MISSING_VAR", None)
|
||||
|
||||
result = substitute_env_vars("${MISSING_VAR}")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestLoadStorageConfigYaml:
|
||||
"""Tests for loading storage configuration from YAML files."""
|
||||
|
||||
def test_load_local_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for local backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: local
|
||||
presigned_url_expiry: 3600
|
||||
|
||||
local:
|
||||
base_path: ./data/storage
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.presigned_url_expiry == 3600
|
||||
assert config.local is not None
|
||||
assert config.local.base_path == Path("./data/storage")
|
||||
|
||||
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for Azure backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: azure_blob
|
||||
presigned_url_expiry: 7200
|
||||
|
||||
azure:
|
||||
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||
container_name: documents
|
||||
create_container: true
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.presigned_url_expiry == 7200
|
||||
assert config.azure is not None
|
||||
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
|
||||
assert config.azure.container_name == "documents"
|
||||
assert config.azure.create_container is True
|
||||
|
||||
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
|
||||
"""Test loading configuration for S3 backend."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: s3
|
||||
presigned_url_expiry: 1800
|
||||
|
||||
s3:
|
||||
bucket_name: my-bucket
|
||||
region_name: us-west-2
|
||||
endpoint_url: http://localhost:9000
|
||||
create_bucket: false
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "s3"
|
||||
assert config.presigned_url_expiry == 1800
|
||||
assert config.s3 is not None
|
||||
assert config.s3.bucket_name == "my-bucket"
|
||||
assert config.s3.region_name == "us-west-2"
|
||||
assert config.s3.endpoint_url == "http://localhost:9000"
|
||||
assert config.s3.create_bucket is False
|
||||
|
||||
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
|
||||
"""Test that environment variables are substituted in config."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
local:
|
||||
base_path: ${STORAGE_PATH:-./default/path}
|
||||
""")
|
||||
|
||||
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.local is not None
|
||||
assert config.local.base_path == Path("/custom/path")
|
||||
|
||||
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that FileNotFoundError is raised for missing config file."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_storage_config(temp_dir / "nonexistent.yaml")
|
||||
|
||||
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that ValueError is raised for invalid YAML."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("invalid: yaml: content: [")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
load_storage_config(config_path)
|
||||
|
||||
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
|
||||
"""Test that ValueError is raised when backend is missing."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
local:
|
||||
base_path: ./data
|
||||
""")
|
||||
|
||||
with pytest.raises(ValueError, match="backend"):
|
||||
load_storage_config(config_path)
|
||||
|
||||
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
|
||||
"""Test default presigned_url_expiry when not specified."""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
config_path = temp_dir / "storage.yaml"
|
||||
config_path.write_text("""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: ./data
|
||||
""")
|
||||
|
||||
config = load_storage_config(config_path)
|
||||
|
||||
assert config.presigned_url_expiry == 3600 # Default value
|
||||
|
||||
|
||||
class TestStorageFileConfig:
|
||||
"""Tests for StorageFileConfig dataclass."""
|
||||
|
||||
def test_storage_file_config_is_immutable(self) -> None:
|
||||
"""Test that StorageFileConfig is frozen (immutable)."""
|
||||
from shared.storage.config_loader import StorageFileConfig
|
||||
|
||||
config = StorageFileConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.backend_type = "azure_blob" # type: ignore
|
||||
|
||||
def test_storage_file_config_defaults(self) -> None:
|
||||
"""Test StorageFileConfig default values."""
|
||||
from shared.storage.config_loader import StorageFileConfig
|
||||
|
||||
config = StorageFileConfig(backend_type="local")
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.local is None
|
||||
assert config.azure is None
|
||||
assert config.s3 is None
|
||||
assert config.presigned_url_expiry == 3600
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
"""Tests for LocalConfig dataclass."""
|
||||
|
||||
def test_local_config_creation(self) -> None:
|
||||
"""Test creating LocalConfig."""
|
||||
from shared.storage.config_loader import LocalConfig
|
||||
|
||||
config = LocalConfig(base_path=Path("/data/storage"))
|
||||
|
||||
assert config.base_path == Path("/data/storage")
|
||||
|
||||
def test_local_config_is_immutable(self) -> None:
|
||||
"""Test that LocalConfig is frozen."""
|
||||
from shared.storage.config_loader import LocalConfig
|
||||
|
||||
config = LocalConfig(base_path=Path("/data"))
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.base_path = Path("/other") # type: ignore
|
||||
|
||||
|
||||
class TestAzureConfig:
|
||||
"""Tests for AzureConfig dataclass."""
|
||||
|
||||
def test_azure_config_creation(self) -> None:
|
||||
"""Test creating AzureConfig."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="test_connection",
|
||||
container_name="test_container",
|
||||
create_container=True,
|
||||
)
|
||||
|
||||
assert config.connection_string == "test_connection"
|
||||
assert config.container_name == "test_container"
|
||||
assert config.create_container is True
|
||||
|
||||
def test_azure_config_defaults(self) -> None:
|
||||
"""Test AzureConfig default values."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="conn",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
assert config.create_container is False
|
||||
|
||||
def test_azure_config_is_immutable(self) -> None:
|
||||
"""Test that AzureConfig is frozen."""
|
||||
from shared.storage.config_loader import AzureConfig
|
||||
|
||||
config = AzureConfig(
|
||||
connection_string="conn",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.container_name = "other" # type: ignore
|
||||
|
||||
|
||||
class TestS3Config:
|
||||
"""Tests for S3Config dataclass."""
|
||||
|
||||
def test_s3_config_creation(self) -> None:
|
||||
"""Test creating S3Config."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(
|
||||
bucket_name="my-bucket",
|
||||
region_name="us-east-1",
|
||||
access_key_id="AKIAIOSFODNN7EXAMPLE",
|
||||
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
endpoint_url="http://localhost:9000",
|
||||
create_bucket=True,
|
||||
)
|
||||
|
||||
assert config.bucket_name == "my-bucket"
|
||||
assert config.region_name == "us-east-1"
|
||||
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
|
||||
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
assert config.endpoint_url == "http://localhost:9000"
|
||||
assert config.create_bucket is True
|
||||
|
||||
def test_s3_config_minimal(self) -> None:
|
||||
"""Test S3Config with only required fields."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(bucket_name="bucket")
|
||||
|
||||
assert config.bucket_name == "bucket"
|
||||
assert config.region_name is None
|
||||
assert config.access_key_id is None
|
||||
assert config.secret_access_key is None
|
||||
assert config.endpoint_url is None
|
||||
assert config.create_bucket is False
|
||||
|
||||
def test_s3_config_is_immutable(self) -> None:
|
||||
"""Test that S3Config is frozen."""
|
||||
from shared.storage.config_loader import S3Config
|
||||
|
||||
config = S3Config(bucket_name="bucket")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
config.bucket_name = "other" # type: ignore
|
||||
423
tests/shared/storage/test_factory.py
Normal file
423
tests/shared/storage/test_factory.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Tests for storage factory.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageFactory:
|
||||
"""Tests for create_storage_backend factory function."""
|
||||
|
||||
def test_create_local_backend(self) -> None:
|
||||
"""Test creating local storage backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config = StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=Path(temp_dir),
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
assert backend.base_path == Path(temp_dir)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test creating Azure blob storage backend."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="DefaultEndpointsProtocol=https;...",
|
||||
container_name="training-images",
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
def test_create_unknown_backend_raises(self) -> None:
|
||||
"""Test that unknown backend type raises ValueError."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(backend_type="unknown_backend")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown storage backend"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_local_requires_base_path(self) -> None:
|
||||
"""Test that local backend requires base_path."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(backend_type="local")
|
||||
|
||||
with pytest.raises(ValueError, match="base_path"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_azure_requires_connection_string(self) -> None:
|
||||
"""Test that Azure backend requires connection_string."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection_string"):
|
||||
create_storage_backend(config)
|
||||
|
||||
def test_create_azure_requires_container_name(self) -> None:
|
||||
"""Test that Azure backend requires container_name."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string="connection_string",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="container_name"):
|
||||
create_storage_backend(config)
|
||||
|
||||
|
||||
class TestStorageFactoryFromEnv:
|
||||
"""Tests for create_storage_backend_from_env factory function."""
|
||||
|
||||
def test_create_from_env_local(self) -> None:
|
||||
"""Test creating local backend from environment variables."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
|
||||
"""Test creating Azure backend from environment variables."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
def test_create_from_env_defaults_to_local(self) -> None:
|
||||
"""Test that factory defaults to local backend."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
# Remove STORAGE_BACKEND if present
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
if "STORAGE_BACKEND" in os.environ:
|
||||
del os.environ["STORAGE_BACKEND"]
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_create_from_env_missing_azure_vars(self) -> None:
|
||||
"""Test error when Azure env vars are missing."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
# Missing AZURE_STORAGE_CONNECTION_STRING
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# Remove the connection string if present
|
||||
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
|
||||
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
|
||||
|
||||
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
|
||||
create_storage_backend_from_env()
|
||||
|
||||
|
||||
class TestGetDefaultStorageConfig:
|
||||
"""Tests for get_default_storage_config function."""
|
||||
|
||||
def test_get_default_config_local(self) -> None:
|
||||
"""Test getting default local config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": temp_dir,
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "local"
|
||||
assert config.base_path == Path(temp_dir)
|
||||
|
||||
def test_get_default_config_azure(self) -> None:
|
||||
"""Test getting default Azure config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "azure_blob",
|
||||
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "azure_blob"
|
||||
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||
assert config.container_name == "training-images"
|
||||
|
||||
|
||||
class TestStorageFactoryS3:
|
||||
"""Tests for S3 backend support in factory."""
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating S3 storage backend."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="s3",
|
||||
bucket_name="test-bucket",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
backend = create_storage_backend(config)
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_s3_requires_bucket_name(self) -> None:
|
||||
"""Test that S3 backend requires bucket_name."""
|
||||
from shared.storage.base import StorageConfig
|
||||
from shared.storage.factory import create_storage_backend
|
||||
|
||||
config = StorageConfig(
|
||||
backend_type="s3",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="bucket_name"):
|
||||
create_storage_backend(config)
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating S3 backend from environment variables."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
"AWS_S3_BUCKET": "test-bucket",
|
||||
"AWS_REGION": "us-east-1",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = create_storage_backend_from_env()
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_from_env_s3_missing_bucket(self) -> None:
|
||||
"""Test error when S3 bucket env var is missing."""
|
||||
from shared.storage.factory import create_storage_backend_from_env
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
# Missing AWS_S3_BUCKET
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
if "AWS_S3_BUCKET" in os.environ:
|
||||
del os.environ["AWS_S3_BUCKET"]
|
||||
|
||||
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
|
||||
create_storage_backend_from_env()
|
||||
|
||||
def test_get_default_config_s3(self) -> None:
|
||||
"""Test getting default S3 config."""
|
||||
from shared.storage.factory import get_default_storage_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "s3",
|
||||
"AWS_S3_BUCKET": "test-bucket",
|
||||
"AWS_REGION": "us-west-2",
|
||||
"AWS_ENDPOINT_URL": "http://localhost:9000",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = get_default_storage_config()
|
||||
|
||||
assert config.backend_type == "s3"
|
||||
assert config.bucket_name == "test-bucket"
|
||||
assert config.region_name == "us-west-2"
|
||||
assert config.endpoint_url == "http://localhost:9000"
|
||||
|
||||
|
||||
class TestStorageFactoryFromFile:
|
||||
"""Tests for create_storage_backend_from_file factory function."""
|
||||
|
||||
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
|
||||
"""Test creating local backend from YAML config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_create_from_yaml_file_azure(
|
||||
self, mock_service_class: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test creating Azure backend from YAML config file."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
config_file.write_text("""
|
||||
backend: azure_blob
|
||||
|
||||
azure:
|
||||
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||
container_name: documents
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, AzureBlobStorageBackend)
|
||||
|
||||
@patch("boto3.client")
|
||||
def test_create_from_yaml_file_s3(
|
||||
self, mock_boto3_client: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test creating S3 backend from YAML config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
config_file.write_text("""
|
||||
backend: s3
|
||||
|
||||
s3:
|
||||
bucket_name: my-bucket
|
||||
region_name: us-east-1
|
||||
""")
|
||||
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, S3StorageBackend)
|
||||
|
||||
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
|
||||
"""Test that env vars are substituted in config file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text("""
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
local:
|
||||
base_path: ${CUSTOM_STORAGE_PATH}
|
||||
""")
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
|
||||
):
|
||||
backend = create_storage_backend_from_file(config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
|
||||
"""Test that FileNotFoundError is raised for missing file."""
|
||||
from shared.storage.factory import create_storage_backend_from_file
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
|
||||
|
||||
|
||||
class TestGetStorageBackend:
|
||||
"""Tests for get_storage_backend convenience function."""
|
||||
|
||||
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
|
||||
"""Test getting backend from explicit config file."""
|
||||
from shared.storage.factory import get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = get_storage_backend(config_path=config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||
"""Test that get_storage_backend falls back to env vars."""
|
||||
from shared.storage.factory import get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(storage_path),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# No config file provided, should use env vars
|
||||
backend = get_storage_backend(config_path=None)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
712
tests/shared/storage/test_local.py
Normal file
712
tests/shared/storage/test_local.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Tests for LocalStorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir() -> Path:
|
||||
"""Create a temporary directory for storage tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_storage_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample PNG file for testing."""
|
||||
file_path = temp_storage_dir / "sample.png"
|
||||
# Minimal valid PNG (1x1 transparent pixel)
|
||||
png_data = bytes(
|
||||
[
|
||||
0x89,
|
||||
0x50,
|
||||
0x4E,
|
||||
0x47,
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x1A,
|
||||
0x0A, # PNG signature
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0D, # IHDR length
|
||||
0x49,
|
||||
0x48,
|
||||
0x44,
|
||||
0x52, # IHDR
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x01, # width: 1
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x01, # height: 1
|
||||
0x08,
|
||||
0x06,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, # 8-bit RGBA
|
||||
0x1F,
|
||||
0x15,
|
||||
0xC4,
|
||||
0x89, # CRC
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0A, # IDAT length
|
||||
0x49,
|
||||
0x44,
|
||||
0x41,
|
||||
0x54, # IDAT
|
||||
0x78,
|
||||
0x9C,
|
||||
0x63,
|
||||
0x00,
|
||||
0x01,
|
||||
0x00,
|
||||
0x00,
|
||||
0x05,
|
||||
0x00,
|
||||
0x01, # compressed data
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x2D,
|
||||
0xB4, # CRC
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, # IEND length
|
||||
0x49,
|
||||
0x45,
|
||||
0x4E,
|
||||
0x44, # IEND
|
||||
0xAE,
|
||||
0x42,
|
||||
0x60,
|
||||
0x82, # CRC
|
||||
]
|
||||
)
|
||||
file_path.write_bytes(png_data)
|
||||
return file_path
|
||||
|
||||
|
||||
class TestLocalStorageBackendCreation:
|
||||
"""Tests for LocalStorageBackend instantiation."""
|
||||
|
||||
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
|
||||
"""Test creating backend with base path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.base_path == temp_storage_dir
|
||||
|
||||
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
|
||||
"""Test creating backend with string path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
|
||||
|
||||
assert backend.base_path == temp_storage_dir
|
||||
|
||||
def test_create_creates_directory_if_not_exists(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that base directory is created if it doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
new_dir = temp_storage_dir / "new_storage"
|
||||
assert not new_dir.exists()
|
||||
|
||||
backend = LocalStorageBackend(base_path=new_dir)
|
||||
|
||||
assert new_dir.exists()
|
||||
assert backend.base_path == new_dir
|
||||
|
||||
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that LocalStorageBackend is a StorageBackend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestLocalStorageBackendUpload:
|
||||
"""Tests for LocalStorageBackend.upload method."""
|
||||
|
||||
def test_upload_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test uploading a file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
assert (storage_dir / "uploads" / "sample.txt").exists()
|
||||
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
|
||||
|
||||
def test_upload_creates_subdirectories(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload creates necessary subdirectories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
|
||||
|
||||
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
|
||||
|
||||
def test_upload_fails_if_file_exists_without_overwrite(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload fails if file exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload succeeds
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Second upload should fail
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Modify original file
|
||||
sample_file.write_text("Modified content")
|
||||
|
||||
# Second upload with overwrite
|
||||
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||
|
||||
assert result == "sample.txt"
|
||||
assert (storage_dir / "sample.txt").read_text() == "Modified content"
|
||||
|
||||
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||
|
||||
def test_upload_binary_file(
|
||||
self, temp_storage_dir: Path, sample_image: Path
|
||||
) -> None:
|
||||
"""Test uploading a binary file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
result = backend.upload(sample_image, "images/sample.png")
|
||||
|
||||
assert result == "images/sample.png"
|
||||
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
|
||||
assert uploaded_content == sample_image.read_bytes()
|
||||
|
||||
|
||||
class TestLocalStorageBackendDownload:
|
||||
"""Tests for LocalStorageBackend.download method."""
|
||||
|
||||
def test_download_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
download_dir = temp_storage_dir / "downloads"
|
||||
download_dir.mkdir()
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# First upload
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# Then download
|
||||
local_path = download_dir / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
assert local_path.exists()
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
def test_download_creates_parent_directories(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.exists()
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||
"""Test that downloading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||
|
||||
def test_download_nested_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading a file from nested path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||
|
||||
local_path = temp_storage_dir / "downloaded.txt"
|
||||
result = backend.download("a/b/c/sample.txt", local_path)
|
||||
|
||||
assert local_path.read_text() == "Hello, World!"
|
||||
|
||||
|
||||
class TestLocalStorageBackendExists:
|
||||
"""Tests for LocalStorageBackend.exists method."""
|
||||
|
||||
def test_exists_returns_true_for_existing_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test exists returns True for existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
assert backend.exists("sample.txt") is True
|
||||
|
||||
def test_exists_returns_false_for_nonexistent_file(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
def test_exists_with_nested_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test exists with nested path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/sample.txt")
|
||||
|
||||
assert backend.exists("a/b/sample.txt") is True
|
||||
assert backend.exists("a/b/other.txt") is False
|
||||
|
||||
|
||||
class TestLocalStorageBackendListFiles:
|
||||
"""Tests for LocalStorageBackend.list_files method."""
|
||||
|
||||
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
|
||||
"""Test listing files in empty storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
assert backend.list_files("") == []
|
||||
|
||||
def test_list_files_returns_all_files(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test listing all files."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# Upload multiple files
|
||||
backend.upload(sample_file, "file1.txt")
|
||||
backend.upload(sample_file, "file2.txt")
|
||||
backend.upload(sample_file, "subdir/file3.txt")
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
def test_list_files_with_prefix(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test listing files with prefix filter."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
backend.upload(sample_file, "images/a.png")
|
||||
backend.upload(sample_file, "images/b.png")
|
||||
backend.upload(sample_file, "labels/a.txt")
|
||||
|
||||
files = backend.list_files("images/")
|
||||
|
||||
assert len(files) == 2
|
||||
assert "images/a.png" in files
|
||||
assert "images/b.png" in files
|
||||
assert "labels/a.txt" not in files
|
||||
|
||||
def test_list_files_returns_sorted(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that list_files returns sorted list."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
backend.upload(sample_file, "c.txt")
|
||||
backend.upload(sample_file, "a.txt")
|
||||
backend.upload(sample_file, "b.txt")
|
||||
|
||||
files = backend.list_files("")
|
||||
|
||||
assert files == ["a.txt", "b.txt", "c.txt"]
|
||||
|
||||
|
||||
class TestLocalStorageBackendDelete:
|
||||
"""Tests for LocalStorageBackend.delete method."""
|
||||
|
||||
def test_delete_existing_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test deleting an existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
assert not (storage_dir / "sample.txt").exists()
|
||||
|
||||
def test_delete_nonexistent_file_returns_false(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test deleting nonexistent file returns False."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_delete_nested_file(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test deleting a nested file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/sample.txt")
|
||||
|
||||
result = backend.delete("a/b/sample.txt")
|
||||
|
||||
assert result is True
|
||||
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
|
||||
|
||||
|
||||
class TestLocalStorageBackendGetUrl:
|
||||
"""Tests for LocalStorageBackend.get_url method."""
|
||||
|
||||
def test_get_url_returns_file_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_url returns file:// URL."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
# Should return file:// URL or absolute path
|
||||
assert "sample.txt" in url
|
||||
# URL should be usable to locate the file
|
||||
expected_path = storage_dir / "sample.txt"
|
||||
assert str(expected_path) in url or expected_path.as_uri() == url
|
||||
|
||||
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
|
||||
"""Test get_url for nonexistent file."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestLocalStorageBackendUploadBytes:
|
||||
"""Tests for LocalStorageBackend.upload_bytes method."""
|
||||
|
||||
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
data = b"Binary content here"
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
assert (storage_dir / "binary.dat").read_bytes() == data
|
||||
|
||||
def test_upload_bytes_creates_subdirectories(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that upload_bytes creates subdirectories."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
data = b"content"
|
||||
backend.upload_bytes(data, "a/b/c/file.dat")
|
||||
|
||||
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
|
||||
|
||||
|
||||
class TestLocalStorageBackendDownloadBytes:
|
||||
"""Tests for LocalStorageBackend.download_bytes method."""
|
||||
|
||||
def test_download_bytes(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test downloading file as bytes."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
|
||||
"""Test downloading nonexistent file as bytes."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestLocalStorageBackendSecurity:
|
||||
"""Security tests for LocalStorageBackend - path traversal prevention."""
|
||||
|
||||
def test_path_traversal_with_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that path traversal using ../ is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "../escape.txt")
|
||||
|
||||
def test_path_traversal_with_nested_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that nested path traversal is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "subdir/../../escape.txt")
|
||||
|
||||
def test_path_traversal_with_many_dotdot_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that deeply nested path traversal is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
|
||||
|
||||
def test_absolute_path_unix_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that absolute Unix paths are blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||
backend.upload(sample_file, "/etc/passwd")
|
||||
|
||||
def test_absolute_path_windows_blocked(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that absolute Windows paths are blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||
backend.upload(sample_file, "C:\\Windows\\System32\\config")
|
||||
|
||||
def test_download_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in download is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.download("../escape.txt", Path("/tmp/file.txt"))
|
||||
|
||||
def test_exists_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in exists is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.exists("../escape.txt")
|
||||
|
||||
def test_delete_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in delete is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.delete("../escape.txt")
|
||||
|
||||
def test_get_url_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in get_url is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.get_url("../escape.txt")
|
||||
|
||||
def test_upload_bytes_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in upload_bytes is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.upload_bytes(b"content", "../escape.txt")
|
||||
|
||||
def test_download_bytes_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in download_bytes is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.download_bytes("../escape.txt")
|
||||
|
||||
def test_valid_nested_path_still_works(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that valid nested paths still work after security fix."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
|
||||
# Valid nested paths should still work
|
||||
result = backend.upload(sample_file, "a/b/c/d/file.txt")
|
||||
|
||||
assert result == "a/b/c/d/file.txt"
|
||||
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()
|
||||
158
tests/shared/storage/test_prefixes.py
Normal file
158
tests/shared/storage/test_prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for storage prefixes module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||
|
||||
|
||||
class TestStoragePrefixes:
|
||||
"""Tests for StoragePrefixes class."""
|
||||
|
||||
def test_prefixes_are_strings(self) -> None:
|
||||
"""All prefix constants should be strings."""
|
||||
assert isinstance(PREFIXES.DOCUMENTS, str)
|
||||
assert isinstance(PREFIXES.IMAGES, str)
|
||||
assert isinstance(PREFIXES.UPLOADS, str)
|
||||
assert isinstance(PREFIXES.RESULTS, str)
|
||||
assert isinstance(PREFIXES.EXPORTS, str)
|
||||
assert isinstance(PREFIXES.DATASETS, str)
|
||||
assert isinstance(PREFIXES.MODELS, str)
|
||||
assert isinstance(PREFIXES.RAW_PDFS, str)
|
||||
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
|
||||
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
|
||||
|
||||
def test_prefixes_are_non_empty(self) -> None:
|
||||
"""All prefix constants should be non-empty."""
|
||||
assert PREFIXES.DOCUMENTS
|
||||
assert PREFIXES.IMAGES
|
||||
assert PREFIXES.UPLOADS
|
||||
assert PREFIXES.RESULTS
|
||||
assert PREFIXES.EXPORTS
|
||||
assert PREFIXES.DATASETS
|
||||
assert PREFIXES.MODELS
|
||||
assert PREFIXES.RAW_PDFS
|
||||
assert PREFIXES.STRUCTURED_DATA
|
||||
assert PREFIXES.ADMIN_IMAGES
|
||||
|
||||
def test_prefixes_have_no_leading_slash(self) -> None:
|
||||
"""Prefixes should not start with a slash for portability."""
|
||||
assert not PREFIXES.DOCUMENTS.startswith("/")
|
||||
assert not PREFIXES.IMAGES.startswith("/")
|
||||
assert not PREFIXES.UPLOADS.startswith("/")
|
||||
assert not PREFIXES.RESULTS.startswith("/")
|
||||
|
||||
def test_prefixes_have_no_trailing_slash(self) -> None:
|
||||
"""Prefixes should not end with a slash."""
|
||||
assert not PREFIXES.DOCUMENTS.endswith("/")
|
||||
assert not PREFIXES.IMAGES.endswith("/")
|
||||
assert not PREFIXES.UPLOADS.endswith("/")
|
||||
assert not PREFIXES.RESULTS.endswith("/")
|
||||
|
||||
def test_frozen_dataclass(self) -> None:
|
||||
"""StoragePrefixes should be immutable."""
|
||||
with pytest.raises(Exception): # FrozenInstanceError
|
||||
PREFIXES.DOCUMENTS = "new_value" # type: ignore
|
||||
|
||||
|
||||
class TestDocumentPath:
|
||||
"""Tests for document_path helper."""
|
||||
|
||||
def test_document_path_with_extension(self) -> None:
|
||||
"""Should generate correct document path with extension."""
|
||||
path = PREFIXES.document_path("abc123", ".pdf")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
def test_document_path_without_leading_dot(self) -> None:
|
||||
"""Should handle extension without leading dot."""
|
||||
path = PREFIXES.document_path("abc123", "pdf")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
def test_document_path_default_extension(self) -> None:
|
||||
"""Should use .pdf as default extension."""
|
||||
path = PREFIXES.document_path("abc123")
|
||||
assert path == "documents/abc123.pdf"
|
||||
|
||||
|
||||
class TestImagePath:
|
||||
"""Tests for image_path helper."""
|
||||
|
||||
def test_image_path_basic(self) -> None:
|
||||
"""Should generate correct image path."""
|
||||
path = PREFIXES.image_path("doc123", 1)
|
||||
assert path == "images/doc123/page_1.png"
|
||||
|
||||
def test_image_path_page_number(self) -> None:
|
||||
"""Should include page number in path."""
|
||||
path = PREFIXES.image_path("doc123", 5)
|
||||
assert path == "images/doc123/page_5.png"
|
||||
|
||||
def test_image_path_custom_extension(self) -> None:
|
||||
"""Should support custom extension."""
|
||||
path = PREFIXES.image_path("doc123", 1, ".jpg")
|
||||
assert path == "images/doc123/page_1.jpg"
|
||||
|
||||
|
||||
class TestUploadPath:
|
||||
"""Tests for upload_path helper."""
|
||||
|
||||
def test_upload_path_basic(self) -> None:
|
||||
"""Should generate correct upload path."""
|
||||
path = PREFIXES.upload_path("invoice.pdf")
|
||||
assert path == "uploads/invoice.pdf"
|
||||
|
||||
def test_upload_path_with_subfolder(self) -> None:
|
||||
"""Should include subfolder when provided."""
|
||||
path = PREFIXES.upload_path("invoice.pdf", "async")
|
||||
assert path == "uploads/async/invoice.pdf"
|
||||
|
||||
|
||||
class TestResultPath:
|
||||
"""Tests for result_path helper."""
|
||||
|
||||
def test_result_path_basic(self) -> None:
|
||||
"""Should generate correct result path."""
|
||||
path = PREFIXES.result_path("output.json")
|
||||
assert path == "results/output.json"
|
||||
|
||||
|
||||
class TestExportPath:
|
||||
"""Tests for export_path helper."""
|
||||
|
||||
def test_export_path_basic(self) -> None:
|
||||
"""Should generate correct export path."""
|
||||
path = PREFIXES.export_path("exp123", "dataset.zip")
|
||||
assert path == "exports/exp123/dataset.zip"
|
||||
|
||||
|
||||
class TestDatasetPath:
|
||||
"""Tests for dataset_path helper."""
|
||||
|
||||
def test_dataset_path_basic(self) -> None:
|
||||
"""Should generate correct dataset path."""
|
||||
path = PREFIXES.dataset_path("ds123", "data.yaml")
|
||||
assert path == "datasets/ds123/data.yaml"
|
||||
|
||||
|
||||
class TestModelPath:
|
||||
"""Tests for model_path helper."""
|
||||
|
||||
def test_model_path_basic(self) -> None:
|
||||
"""Should generate correct model path."""
|
||||
path = PREFIXES.model_path("v1.0.0", "best.pt")
|
||||
assert path == "models/v1.0.0/best.pt"
|
||||
|
||||
|
||||
class TestExportsFromInit:
|
||||
"""Tests for exports from storage __init__.py."""
|
||||
|
||||
def test_prefixes_exported(self) -> None:
|
||||
"""PREFIXES should be exported from storage module."""
|
||||
from shared.storage import PREFIXES as exported_prefixes
|
||||
|
||||
assert exported_prefixes is PREFIXES
|
||||
|
||||
def test_storage_prefixes_exported(self) -> None:
|
||||
"""StoragePrefixes should be exported from storage module."""
|
||||
from shared.storage import StoragePrefixes as exported_class
|
||||
|
||||
assert exported_class is StoragePrefixes
|
||||
264
tests/shared/storage/test_presigned_urls.py
Normal file
264
tests/shared/storage/test_presigned_urls.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Tests for pre-signed URL functionality across all storage backends.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir() -> Path:
|
||||
"""Create a temporary directory for storage tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_storage_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_storage_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
class TestStorageBackendInterfacePresignedUrl:
|
||||
"""Tests for get_presigned_url in StorageBackend interface."""
|
||||
|
||||
def test_subclass_must_implement_get_presigned_url(self) -> None:
|
||||
"""Test that subclass must implement get_presigned_url method."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class IncompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteBackend() # type: ignore
|
||||
|
||||
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
|
||||
"""Test that a complete subclass with get_presigned_url can be instantiated."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
class CompleteBackend(StorageBackend):
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return []
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def get_presigned_url(
|
||||
self, remote_path: str, expires_in_seconds: int = 3600
|
||||
) -> str:
|
||||
return f"https://example.com/{remote_path}?token=abc"
|
||||
|
||||
backend = CompleteBackend()
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestLocalStorageBackendPresignedUrl:
|
||||
"""Tests for LocalStorageBackend.get_presigned_url method."""
|
||||
|
||||
def test_get_presigned_url_returns_file_uri(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url returns file:// URI for existing file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
url = backend.get_presigned_url("sample.txt")
|
||||
|
||||
assert url.startswith("file://")
|
||||
assert "sample.txt" in url
|
||||
|
||||
def test_get_presigned_url_with_custom_expiry(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url accepts expires_in_seconds parameter."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "sample.txt")
|
||||
|
||||
# For local storage, expiry is ignored but should not raise error
|
||||
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||
|
||||
assert url.startswith("file://")
|
||||
|
||||
def test_get_presigned_url_nonexistent_file_raises(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
|
||||
def test_get_presigned_url_path_traversal_blocked(
|
||||
self, temp_storage_dir: Path
|
||||
) -> None:
|
||||
"""Test that path traversal in get_presigned_url is blocked."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||
|
||||
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||
backend.get_presigned_url("../escape.txt")
|
||||
|
||||
def test_get_presigned_url_nested_path(
|
||||
self, temp_storage_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test get_presigned_url works with nested paths."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_dir = temp_storage_dir / "storage"
|
||||
backend = LocalStorageBackend(base_path=storage_dir)
|
||||
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||
|
||||
url = backend.get_presigned_url("a/b/c/sample.txt")
|
||||
|
||||
assert url.startswith("file://")
|
||||
assert "sample.txt" in url
|
||||
|
||||
|
||||
class TestAzureBlobStorageBackendPresignedUrl:
|
||||
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_generates_sas_url(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url generates URL with SAS token."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
# Setup mocks
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.account_name = "testaccount"
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||
|
||||
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
|
||||
|
||||
assert "https://testaccount.blob.core.windows.net" in url
|
||||
assert "sv=2021-06-08" in url or "test.txt" in url
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_nonexistent_blob_raises(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises for nonexistent blob."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = False
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
|
||||
@patch("shared.storage.azure.BlobServiceClient")
|
||||
def test_get_presigned_url_uses_custom_expiry(
|
||||
self, mock_blob_service_class: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url uses custom expiry time."""
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.account_name = "testaccount"
|
||||
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.exists.return_value = True
|
||||
mock_blob_service.get_container_client.return_value = mock_container
|
||||
|
||||
mock_blob_client = MagicMock()
|
||||
mock_blob_client.exists.return_value = True
|
||||
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||
mock_container.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
backend = AzureBlobStorageBackend(
|
||||
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||
container_name="container",
|
||||
)
|
||||
|
||||
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||
|
||||
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
|
||||
|
||||
# Verify generate_blob_sas was called (expiry is part of the call)
|
||||
mock_generate_sas.assert_called_once()
|
||||
520
tests/shared/storage/test_s3.py
Normal file
520
tests/shared/storage/test_s3.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Tests for S3StorageBackend.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Path:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(temp_dir: Path) -> Path:
|
||||
"""Create a sample file for testing."""
|
||||
file_path = temp_dir / "sample.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_boto3_client():
|
||||
"""Create a mock boto3 S3 client."""
|
||||
with patch("boto3.client") as mock_client_func:
|
||||
mock_client = MagicMock()
|
||||
mock_client_func.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
class TestS3StorageBackendCreation:
|
||||
"""Tests for S3StorageBackend instantiation."""
|
||||
|
||||
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with bucket name."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.bucket_name == "test-bucket"
|
||||
|
||||
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with region."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("region_name") == "us-west-2"
|
||||
|
||||
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with explicit credentials."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
access_key_id="AKIATEST",
|
||||
secret_access_key="secret123",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
|
||||
assert call_kwargs.get("aws_secret_access_key") == "secret123"
|
||||
|
||||
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test creating backend with custom endpoint (for S3-compatible services)."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
with patch("boto3.client") as mock_client:
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
endpoint_url="http://localhost:9000",
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_kwargs = mock_client.call_args[1]
|
||||
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
|
||||
|
||||
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test that bucket is created when create_bucket=True."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_bucket.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadBucket"
|
||||
)
|
||||
|
||||
S3StorageBackend(
|
||||
bucket_name="test-bucket",
|
||||
create_bucket=True,
|
||||
)
|
||||
|
||||
mock_boto3_client.create_bucket.assert_called_once()
|
||||
|
||||
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
|
||||
"""Test that S3StorageBackend is a StorageBackend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
|
||||
class TestS3StorageBackendUpload:
|
||||
"""Tests for S3StorageBackend.upload method."""
|
||||
|
||||
def test_upload_file(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
|
||||
) -> None:
|
||||
"""Test uploading a file."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
# Object does not exist
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||
|
||||
assert result == "uploads/sample.txt"
|
||||
mock_boto3_client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_fails_if_exists_without_overwrite(
|
||||
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload fails if object exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||
|
||||
def test_upload_succeeds_with_overwrite(
|
||||
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||
) -> None:
|
||||
"""Test that upload succeeds with overwrite=True."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||
|
||||
assert result == "sample.txt"
|
||||
mock_boto3_client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_nonexistent_file_fails(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that uploading nonexistent file fails."""
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendDownload:
|
||||
"""Tests for S3StorageBackend.download method."""
|
||||
|
||||
def test_download_file(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test downloading a file."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
local_path = temp_dir / "downloaded.txt"
|
||||
|
||||
result = backend.download("sample.txt", local_path)
|
||||
|
||||
assert result == local_path
|
||||
mock_boto3_client.download_file.assert_called_once()
|
||||
|
||||
def test_download_creates_parent_directories(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that download creates parent directories."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
|
||||
|
||||
backend.download("sample.txt", local_path)
|
||||
|
||||
assert local_path.parent.exists()
|
||||
|
||||
def test_download_nonexistent_object_fails(
|
||||
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test that downloading nonexistent object fails."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download("nonexistent.txt", temp_dir / "file.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendExists:
|
||||
"""Tests for S3StorageBackend.exists method."""
|
||||
|
||||
def test_exists_returns_true_for_existing_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing object."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.exists("sample.txt") is True
|
||||
|
||||
def test_exists_returns_false_for_nonexistent_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test exists returns False for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
assert backend.exists("nonexistent.txt") is False
|
||||
|
||||
|
||||
class TestS3StorageBackendListFiles:
|
||||
"""Tests for S3StorageBackend.list_files method."""
|
||||
|
||||
def test_list_files_returns_objects(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing objects."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{"Key": "file1.txt"},
|
||||
{"Key": "file2.txt"},
|
||||
{"Key": "subdir/file3.txt"},
|
||||
]
|
||||
}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("")
|
||||
|
||||
assert len(files) == 3
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
def test_list_files_with_prefix(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing objects with prefix filter."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{"Key": "images/a.png"},
|
||||
{"Key": "images/b.png"},
|
||||
]
|
||||
}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("images/")
|
||||
|
||||
mock_boto3_client.list_objects_v2.assert_called_with(
|
||||
Bucket="test-bucket", Prefix="images/"
|
||||
)
|
||||
|
||||
def test_list_files_empty_bucket(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test listing files in empty bucket."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
files = backend.list_files("")
|
||||
|
||||
assert files == []
|
||||
|
||||
|
||||
class TestS3StorageBackendDelete:
|
||||
"""Tests for S3StorageBackend.delete method."""
|
||||
|
||||
def test_delete_existing_object(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting an existing object."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.delete("sample.txt")
|
||||
|
||||
assert result is True
|
||||
mock_boto3_client.delete_object.assert_called_once()
|
||||
|
||||
def test_delete_nonexistent_object_returns_false(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test deleting nonexistent object returns False."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
result = backend.delete("nonexistent.txt")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestS3StorageBackendGetUrl:
|
||||
"""Tests for S3StorageBackend.get_url method."""
|
||||
|
||||
def test_get_url_returns_s3_url(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url returns S3 URL."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = (
|
||||
"https://test-bucket.s3.amazonaws.com/sample.txt"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
url = backend.get_url("sample.txt")
|
||||
|
||||
assert "sample.txt" in url
|
||||
|
||||
def test_get_url_nonexistent_object_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_url raises for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_url("nonexistent.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendUploadBytes:
|
||||
"""Tests for S3StorageBackend.upload_bytes method."""
|
||||
|
||||
def test_upload_bytes(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test uploading bytes directly."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
data = b"Binary content here"
|
||||
|
||||
result = backend.upload_bytes(data, "binary.dat")
|
||||
|
||||
assert result == "binary.dat"
|
||||
mock_boto3_client.put_object.assert_called_once()
|
||||
|
||||
def test_upload_bytes_fails_if_exists_without_overwrite(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test upload_bytes fails if object exists and overwrite is False."""
|
||||
from shared.storage.base import StorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(StorageError, match="already exists"):
|
||||
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
|
||||
|
||||
|
||||
class TestS3StorageBackendDownloadBytes:
|
||||
"""Tests for S3StorageBackend.download_bytes method."""
|
||||
|
||||
def test_download_bytes(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading object as bytes."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"Hello, World!"
|
||||
mock_boto3_client.get_object.return_value = {"Body": mock_response}
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
data = backend.download_bytes("sample.txt")
|
||||
|
||||
assert data == b"Hello, World!"
|
||||
|
||||
def test_download_bytes_nonexistent_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test downloading nonexistent object as bytes."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.get_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.download_bytes("nonexistent.txt")
|
||||
|
||||
|
||||
class TestS3StorageBackendPresignedUrl:
|
||||
"""Tests for S3StorageBackend.get_presigned_url method."""
|
||||
|
||||
def test_get_presigned_url_generates_url(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url generates presigned URL."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = (
|
||||
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
url = backend.get_presigned_url("sample.txt")
|
||||
|
||||
assert "X-Amz-Algorithm" in url or "sample.txt" in url
|
||||
mock_boto3_client.generate_presigned_url.assert_called_once()
|
||||
|
||||
def test_get_presigned_url_with_custom_expiry(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url uses custom expiry."""
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
mock_boto3_client.generate_presigned_url.return_value = "https://..."
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||
|
||||
call_args = mock_boto3_client.generate_presigned_url.call_args
|
||||
assert call_args[1].get("ExpiresIn") == 7200
|
||||
|
||||
def test_get_presigned_url_nonexistent_raises(
|
||||
self, mock_boto3_client: MagicMock
|
||||
) -> None:
|
||||
"""Test get_presigned_url raises for nonexistent object."""
|
||||
from botocore.exceptions import ClientError
|
||||
from shared.storage.base import FileNotFoundStorageError
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "404"}}, "HeadObject"
|
||||
)
|
||||
|
||||
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||
|
||||
with pytest.raises(FileNotFoundStorageError):
|
||||
backend.get_presigned_url("nonexistent.txt")
|
||||
@@ -9,7 +9,8 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
|
||||
@@ -31,6 +31,7 @@ class MockAdminDocument:
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.category = kwargs.get('category', 'invoice')
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
@@ -67,12 +68,13 @@ class MockAdminDB:
|
||||
|
||||
def get_documents_by_token(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
upload_source=None,
|
||||
has_annotations=None,
|
||||
auto_label_status=None,
|
||||
batch_id=None,
|
||||
category=None,
|
||||
limit=20,
|
||||
offset=0
|
||||
):
|
||||
@@ -95,6 +97,8 @@ class MockAdminDB:
|
||||
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
||||
if batch_id:
|
||||
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
||||
if category:
|
||||
docs = [d for d in docs if d.category == category]
|
||||
|
||||
total = len(docs)
|
||||
return docs[offset:offset+limit], total
|
||||
|
||||
@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
|
||||
|
||||
def test_cleanup_orphan_files(self, async_service, mock_db):
|
||||
"""Test cleanup of orphan files."""
|
||||
# Create an orphan file
|
||||
# Create the async upload directory
|
||||
temp_dir = async_service._async_config.temp_upload_dir
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
orphan_file = temp_dir / "orphan-request.pdf"
|
||||
orphan_file.write_bytes(b"orphan content")
|
||||
|
||||
@@ -228,7 +230,13 @@ class TestAsyncProcessingService:
|
||||
# Mock database to say file doesn't exist
|
||||
mock_db.get_request.return_value = None
|
||||
|
||||
count = async_service._cleanup_orphan_files()
|
||||
# Mock the storage helper to return the same directory as the fixture
|
||||
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
|
||||
mock_helper = MagicMock()
|
||||
mock_helper.get_uploads_base_path.return_value = temp_dir
|
||||
mock_storage.return_value = mock_helper
|
||||
|
||||
count = async_service._cleanup_orphan_files()
|
||||
|
||||
assert count == 1
|
||||
assert not orphan_file.exists()
|
||||
|
||||
@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
||||
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_token() -> str:
|
||||
"""Provide admin token for testing."""
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get_document_by_token.return_value = None
|
||||
mock.get_dataset.return_value = None
|
||||
mock.get_augmented_datasets.return_value = ([], 0)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
"""Create test client with admin authentication."""
|
||||
app = FastAPI()
|
||||
|
||||
# Override dependencies
|
||||
def get_token_override():
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
|
||||
app.dependency_overrides[validate_admin_token] = get_token_override
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
|
||||
# Include router - the router already has /augmentation prefix
|
||||
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
"""Create test client WITHOUT admin authentication override."""
|
||||
app = FastAPI()
|
||||
|
||||
# Only override the database, NOT the token validation
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestAugmentationTypesEndpoint:
|
||||
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
|
||||
assert "stage" in aug_type
|
||||
|
||||
def test_list_augmentation_types_unauthorized(
|
||||
self, admin_client: TestClient
|
||||
self, unauthenticated_client: TestClient
|
||||
) -> None:
|
||||
"""Test that unauthorized request is rejected."""
|
||||
response = admin_client.get("/api/v1/admin/augmentation/types")
|
||||
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing augmentation on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {"std": 15},
|
||||
},
|
||||
)
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {"std": 15},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing full config on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||
"preserve_bboxes": True,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||
"preserve_bboxes": True,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating augmented dataset."""
|
||||
# Mock dataset exists
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.total_images = 100
|
||||
mock_admin_db.get_dataset.return_value = mock_dataset
|
||||
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
|
||||
@pytest.fixture
|
||||
def sample_document_id() -> str:
|
||||
"""Provide a sample document ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-document-id"
|
||||
return TEST_DOCUMENT_UUID
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset_id() -> str:
|
||||
"""Provide a sample dataset ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-dataset-id"
|
||||
return TEST_DATASET_UUID
|
||||
|
||||
@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
|
||||
name="test-dataset",
|
||||
description="Test dataset",
|
||||
status="ready",
|
||||
training_status=None,
|
||||
active_training_task_id=None,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||
# Mock the active training tasks lookup to return empty dict
|
||||
mock_db.get_active_training_tasks_for_datasets.return_value = {}
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
|
||||
363
tests/web/test_dataset_training_status.py
Normal file
363
tests/web/test_dataset_training_status.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Tests for dataset training status feature.
|
||||
|
||||
Tests cover:
|
||||
1. Database model fields (training_status, active_training_task_id)
|
||||
2. AdminDB update_dataset_training_status method
|
||||
3. API response includes training status fields
|
||||
4. Scheduler updates dataset status during training lifecycle
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Database Model
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTrainingDatasetModel:
|
||||
"""Tests for TrainingDataset model fields."""
|
||||
|
||||
def test_training_dataset_has_training_status_field(self):
|
||||
"""TrainingDataset model should have training_status field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(
|
||||
name="test-dataset",
|
||||
training_status="running",
|
||||
)
|
||||
assert dataset.training_status == "running"
|
||||
|
||||
def test_training_dataset_has_active_training_task_id_field(self):
|
||||
"""TrainingDataset model should have active_training_task_id field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
name="test-dataset",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
assert dataset.active_training_task_id == task_id
|
||||
|
||||
def test_training_dataset_defaults(self):
|
||||
"""TrainingDataset should have correct defaults for new fields."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test-dataset")
|
||||
assert dataset.training_status is None
|
||||
assert dataset.active_training_task_id is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test AdminDB Methods
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdminDBDatasetTrainingStatus:
|
||||
"""Tests for AdminDB.update_dataset_training_status method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
def test_update_dataset_training_status_sets_status(self, mock_session):
|
||||
"""update_dataset_training_status should set training_status."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
)
|
||||
|
||||
assert dataset.training_status == "running"
|
||||
mock_session.add.assert_called_once_with(dataset)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_dataset_training_status_sets_task_id(self, mock_session):
|
||||
"""update_dataset_training_status should set active_training_task_id."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(task_id),
|
||||
)
|
||||
|
||||
assert dataset.active_training_task_id == task_id
|
||||
|
||||
def test_update_dataset_training_status_updates_main_status_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should update main status to 'trained' when completed."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
assert dataset.status == "trained"
|
||||
assert dataset.training_status == "completed"
|
||||
|
||||
def test_update_dataset_training_status_clears_task_id_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should clear task_id when training completes."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
)
|
||||
|
||||
assert dataset.active_training_task_id is None
|
||||
|
||||
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
|
||||
"""update_dataset_training_status should handle missing dataset gracefully."""
|
||||
mock_session.get.return_value = None
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
# Should not raise
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(uuid4()),
|
||||
training_status="running",
|
||||
)
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test API Response
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatasetDetailResponseTrainingStatus:
|
||||
"""Tests for DatasetDetailResponse including training status fields."""
|
||||
|
||||
def test_dataset_detail_response_includes_training_status(self):
|
||||
"""DatasetDetailResponse schema should include training_status field."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
name="test-dataset",
|
||||
description=None,
|
||||
status="ready",
|
||||
training_status="running",
|
||||
active_training_task_id=str(uuid4()),
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=10,
|
||||
total_images=15,
|
||||
total_annotations=100,
|
||||
dataset_path="/path/to/dataset",
|
||||
error_message=None,
|
||||
documents=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert response.training_status == "running"
|
||||
assert response.active_training_task_id is not None
|
||||
|
||||
def test_dataset_detail_response_allows_null_training_status(self):
|
||||
"""DatasetDetailResponse should allow null training_status."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
name="test-dataset",
|
||||
description=None,
|
||||
status="ready",
|
||||
training_status=None,
|
||||
active_training_task_id=None,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=10,
|
||||
total_images=15,
|
||||
total_annotations=100,
|
||||
dataset_path=None,
|
||||
error_message=None,
|
||||
documents=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert response.training_status is None
|
||||
assert response.active_training_task_id is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Scheduler Training Status Updates
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSchedulerDatasetStatusUpdates:
|
||||
"""Tests for scheduler updating dataset status during training."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
mock = MagicMock()
|
||||
mock.get_dataset.return_value = MagicMock(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
dataset_path="/path/to/dataset",
|
||||
total_images=100,
|
||||
)
|
||||
mock.get_pending_training_tasks.return_value = []
|
||||
return mock
|
||||
|
||||
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
|
||||
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
|
||||
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
|
||||
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
scheduler._db = mock_db
|
||||
|
||||
task_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Execute task (will fail but we check the status update call)
|
||||
try:
|
||||
scheduler._execute_task(
|
||||
task_id=task_id,
|
||||
config={"model_name": "yolo11n.pt"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to fail in test environment
|
||||
|
||||
# Check that training status was updated to running
|
||||
mock_db.update_dataset_training_status.assert_called()
|
||||
first_call = mock_db.update_dataset_training_status.call_args_list[0]
|
||||
assert first_call.kwargs["training_status"] == "running"
|
||||
assert first_call.kwargs["active_training_task_id"] == task_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Dataset Status Values
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatasetStatusValues:
|
||||
"""Tests for valid dataset status values."""
|
||||
|
||||
def test_dataset_status_building(self):
|
||||
"""Dataset can have status 'building'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="building")
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_dataset_status_ready(self):
|
||||
"""Dataset can have status 'ready'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="ready")
|
||||
assert dataset.status == "ready"
|
||||
|
||||
def test_dataset_status_trained(self):
|
||||
"""Dataset can have status 'trained'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="trained")
|
||||
assert dataset.status == "trained"
|
||||
|
||||
def test_dataset_status_failed(self):
|
||||
"""Dataset can have status 'failed'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="failed")
|
||||
assert dataset.status == "failed"
|
||||
|
||||
def test_training_status_values(self):
|
||||
"""Training status can have various values."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
|
||||
for status in valid_statuses:
|
||||
dataset = TrainingDataset(name="test", training_status=status)
|
||||
assert dataset.training_status == status
|
||||
207
tests/web/test_document_category.py
Normal file
207
tests/web/test_document_category.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Tests for Document Category Feature.
|
||||
|
||||
TDD tests for adding category field to admin_documents table.
|
||||
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument
|
||||
|
||||
|
||||
# Test constants
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestAdminDocumentCategoryField:
|
||||
"""Tests for AdminDocument category field."""
|
||||
|
||||
def test_document_has_category_field(self):
|
||||
"""Test AdminDocument model has category field."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
)
|
||||
assert hasattr(doc, "category")
|
||||
|
||||
def test_document_category_defaults_to_invoice(self):
|
||||
"""Test category defaults to 'invoice' when not specified."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
)
|
||||
assert doc.category == "invoice"
|
||||
|
||||
def test_document_accepts_custom_category(self):
|
||||
"""Test document accepts custom category values."""
|
||||
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
|
||||
|
||||
for cat in categories:
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
category=cat,
|
||||
)
|
||||
assert doc.category == cat
|
||||
|
||||
def test_document_category_is_string_type(self):
|
||||
"""Test category field is a string type."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
category="letter",
|
||||
)
|
||||
assert isinstance(doc.category, str)
|
||||
|
||||
|
||||
class TestDocumentCategoryInReadModel:
|
||||
"""Tests for category in response models."""
|
||||
|
||||
def test_admin_document_read_has_category(self):
|
||||
"""Test AdminDocumentRead includes category field."""
|
||||
from inference.data.admin_models import AdminDocumentRead
|
||||
|
||||
# Check the model has category field in its schema
|
||||
assert "category" in AdminDocumentRead.model_fields
|
||||
|
||||
|
||||
class TestDocumentCategoryAPI:
|
||||
"""Tests for document category in API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
return db
|
||||
|
||||
def test_upload_document_with_category(self, mock_admin_db):
|
||||
"""Test uploading document with category parameter."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
# Verify response schema supports category
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
message="Upload successful",
|
||||
category="letter",
|
||||
)
|
||||
assert response.category == "letter"
|
||||
|
||||
def test_list_documents_returns_category(self, mock_admin_db):
|
||||
"""Test list documents endpoint returns category."""
|
||||
from inference.web.schemas.admin import DocumentItem
|
||||
|
||||
item = DocumentItem(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
annotation_count=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
category="invoice",
|
||||
)
|
||||
assert item.category == "invoice"
|
||||
|
||||
def test_document_detail_includes_category(self, mock_admin_db):
|
||||
"""Test document detail response includes category."""
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
|
||||
# Check schema has category
|
||||
assert "category" in DocumentDetailResponse.model_fields
|
||||
|
||||
|
||||
class TestDocumentCategoryFiltering:
|
||||
"""Tests for filtering documents by category."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB with category filtering support."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
|
||||
# Mock documents with different categories
|
||||
invoice_doc = MagicMock()
|
||||
invoice_doc.document_id = uuid4()
|
||||
invoice_doc.category = "invoice"
|
||||
|
||||
letter_doc = MagicMock()
|
||||
letter_doc.document_id = uuid4()
|
||||
letter_doc.category = "letter"
|
||||
|
||||
db.get_documents_by_category.return_value = [invoice_doc]
|
||||
return db
|
||||
|
||||
def test_filter_documents_by_category(self, mock_admin_db):
|
||||
"""Test filtering documents by category."""
|
||||
# This tests the DB method signature
|
||||
result = mock_admin_db.get_documents_by_category("invoice")
|
||||
assert len(result) == 1
|
||||
assert result[0].category == "invoice"
|
||||
|
||||
|
||||
class TestDocumentCategoryUpdate:
|
||||
"""Tests for updating document category."""
|
||||
|
||||
def test_update_document_category_schema(self):
|
||||
"""Test update document request supports category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="letter")
|
||||
assert request.category == "letter"
|
||||
|
||||
def test_update_document_category_optional(self):
|
||||
"""Test category is optional in update request."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
# Should not raise - category is optional
|
||||
request = DocumentUpdateRequest()
|
||||
assert request.category is None
|
||||
|
||||
|
||||
class TestDatasetWithCategory:
|
||||
"""Tests for dataset creation with category filtering."""
|
||||
|
||||
def test_dataset_create_with_category_filter(self):
|
||||
"""Test creating dataset can filter by document category."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Invoice Training Set",
|
||||
document_ids=[TEST_DOC_UUID],
|
||||
category="invoice", # Optional filter
|
||||
)
|
||||
assert request.category == "invoice"
|
||||
|
||||
def test_dataset_create_category_is_optional(self):
|
||||
"""Test category filter is optional when creating dataset."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Mixed Training Set",
|
||||
document_ids=[TEST_DOC_UUID],
|
||||
)
|
||||
# category should be optional
|
||||
assert not hasattr(request, "category") or request.category is None
|
||||
165
tests/web/test_document_category_api.py
Normal file
165
tests/web/test_document_category_api.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Tests for Document Category API Endpoints.
|
||||
|
||||
TDD tests for category filtering and management in document endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Test constants
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestGetCategoriesEndpoint:
|
||||
"""Tests for GET /admin/documents/categories endpoint."""
|
||||
|
||||
def test_categories_endpoint_returns_list(self):
|
||||
"""Test categories endpoint returns list of available categories."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
# Test schema exists and works
|
||||
response = DocumentCategoriesResponse(
|
||||
categories=["invoice", "letter", "receipt"],
|
||||
total=3,
|
||||
)
|
||||
assert response.categories == ["invoice", "letter", "receipt"]
|
||||
assert response.total == 3
|
||||
|
||||
def test_categories_response_schema(self):
|
||||
"""Test DocumentCategoriesResponse schema structure."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
assert "categories" in DocumentCategoriesResponse.model_fields
|
||||
assert "total" in DocumentCategoriesResponse.model_fields
|
||||
|
||||
|
||||
class TestDocumentListFilterByCategory:
|
||||
"""Tests for filtering documents by category."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
|
||||
# Mock documents with different categories
|
||||
invoice_doc = MagicMock()
|
||||
invoice_doc.document_id = uuid4()
|
||||
invoice_doc.category = "invoice"
|
||||
invoice_doc.filename = "invoice1.pdf"
|
||||
|
||||
letter_doc = MagicMock()
|
||||
letter_doc.document_id = uuid4()
|
||||
letter_doc.category = "letter"
|
||||
letter_doc.filename = "letter1.pdf"
|
||||
|
||||
db.get_documents.return_value = ([invoice_doc], 1)
|
||||
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
|
||||
return db
|
||||
|
||||
def test_list_documents_accepts_category_filter(self, mock_admin_db):
|
||||
"""Test list documents endpoint accepts category query parameter."""
|
||||
# The endpoint should accept ?category=invoice parameter
|
||||
# This test verifies the schema/query parameter exists
|
||||
from inference.web.schemas.admin import DocumentListResponse
|
||||
|
||||
# Schema should work with category filter applied
|
||||
assert DocumentListResponse is not None
|
||||
|
||||
def test_get_document_categories_from_db(self, mock_admin_db):
|
||||
"""Test fetching unique categories from database."""
|
||||
categories = mock_admin_db.get_document_categories()
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
assert len(categories) == 3
|
||||
|
||||
|
||||
class TestDocumentUploadWithCategory:
|
||||
"""Tests for uploading documents with category."""
|
||||
|
||||
def test_upload_request_accepts_category(self):
|
||||
"""Test upload request can include category field."""
|
||||
# When uploading via form data, category should be accepted
|
||||
# This is typically a form field, not a schema
|
||||
pass
|
||||
|
||||
def test_upload_response_includes_category(self):
|
||||
"""Test upload response includes the category that was set."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
category="letter", # Custom category
|
||||
message="Upload successful",
|
||||
)
|
||||
assert response.category == "letter"
|
||||
|
||||
def test_upload_defaults_to_invoice_category(self):
|
||||
"""Test upload defaults to 'invoice' if no category specified."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
message="Upload successful",
|
||||
# No category specified - should default to "invoice"
|
||||
)
|
||||
assert response.category == "invoice"
|
||||
|
||||
|
||||
class TestAdminDBCategoryMethods:
|
||||
"""Tests for AdminDB category-related methods."""
|
||||
|
||||
def test_get_document_categories_method_exists(self):
|
||||
"""Test AdminDB has get_document_categories method."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "get_document_categories")
|
||||
|
||||
def test_get_documents_accepts_category_filter(self):
|
||||
"""Test get_documents_by_token method accepts category parameter."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
import inspect
|
||||
|
||||
db = AdminDB()
|
||||
# Check the method exists and accepts category parameter
|
||||
method = getattr(db, "get_documents_by_token", None)
|
||||
assert callable(method)
|
||||
|
||||
# Check category is in the method signature
|
||||
sig = inspect.signature(method)
|
||||
assert "category" in sig.parameters
|
||||
|
||||
|
||||
class TestUpdateDocumentCategory:
|
||||
"""Tests for updating document category."""
|
||||
|
||||
def test_update_document_category_method_exists(self):
|
||||
"""Test AdminDB has method to update document category."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "update_document_category")
|
||||
|
||||
def test_update_request_schema(self):
|
||||
"""Test DocumentUpdateRequest can update category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="receipt")
|
||||
assert request.category == "receipt"
|
||||
@@ -32,10 +32,10 @@ def test_app(tmp_path):
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
),
|
||||
storage=StorageConfig(
|
||||
file=StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
|
||||
max_file_size_mb=50,
|
||||
),
|
||||
)
|
||||
@@ -252,20 +252,25 @@ class TestResultsEndpoint:
|
||||
response = client.get("/api/v1/results/nonexistent.png")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
||||
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
|
||||
"""Test that existing result file is returned."""
|
||||
# Get storage config from app
|
||||
storage_config = test_app.extra.get("storage_config")
|
||||
if not storage_config:
|
||||
pytest.skip("Storage config not available in test app")
|
||||
|
||||
# Create a test result file
|
||||
result_file = storage_config.result_dir / "test_result.png"
|
||||
# Create a test result file in temp directory
|
||||
result_dir = tmp_path / "results"
|
||||
result_dir.mkdir(exist_ok=True)
|
||||
result_file = result_dir / "test_result.png"
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img.save(result_file)
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
# Mock the storage helper to return our test file path
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_helper = Mock()
|
||||
mock_helper.get_result_local_path.return_value = result_file
|
||||
mock_storage.return_value = mock_helper
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.inference_service = None
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result.status == "active"
|
||||
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = None
|
||||
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.inference_service = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
|
||||
828
tests/web/test_storage_helpers.py
Normal file
828
tests/web/test_storage_helpers.py
Normal file
@@ -0,0 +1,828 @@
|
||||
"""Tests for storage helpers module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
|
||||
from shared.storage import PREFIXES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage() -> MagicMock:
|
||||
"""Create a mock storage backend."""
|
||||
storage = MagicMock()
|
||||
storage.upload_bytes = MagicMock()
|
||||
storage.download_bytes = MagicMock(return_value=b"test content")
|
||||
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
|
||||
storage.exists = MagicMock(return_value=True)
|
||||
storage.delete = MagicMock(return_value=True)
|
||||
storage.list_files = MagicMock(return_value=[])
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper(mock_storage: MagicMock) -> StorageHelper:
|
||||
"""Create a storage helper with mock backend."""
|
||||
return StorageHelper(storage=mock_storage)
|
||||
|
||||
|
||||
class TestStorageHelperInit:
|
||||
"""Tests for StorageHelper initialization."""
|
||||
|
||||
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
|
||||
"""Should use provided storage backend."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
assert helper.storage is mock_storage
|
||||
|
||||
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Storage property should return the backend."""
|
||||
assert helper.storage is mock_storage
|
||||
|
||||
|
||||
class TestDocumentOperations:
|
||||
"""Tests for document storage operations."""
|
||||
|
||||
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should upload document with correct path."""
|
||||
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
|
||||
|
||||
assert doc_id == "doc123"
|
||||
assert path == "documents/doc123.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"pdf content", "documents/doc123.pdf", overwrite=True
|
||||
)
|
||||
|
||||
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
|
||||
"""Should generate document ID if not provided."""
|
||||
doc_id, path = helper.upload_document(b"content", "file.pdf")
|
||||
|
||||
assert doc_id is not None
|
||||
assert len(doc_id) > 0
|
||||
assert path.startswith("documents/")
|
||||
|
||||
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should download document from correct path."""
|
||||
content = helper.download_document("doc123")
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for document."""
|
||||
url = helper.get_document_url("doc123", expires_in_seconds=7200)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"documents/doc123.pdf", 7200
|
||||
)
|
||||
|
||||
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check document existence."""
|
||||
exists = helper.document_exists("doc123")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete document."""
|
||||
result = helper.delete_document("doc123")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
|
||||
class TestImageOperations:
|
||||
"""Tests for image storage operations."""
|
||||
|
||||
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save page image with correct path."""
|
||||
path = helper.save_page_image("doc123", 1, b"image data")
|
||||
|
||||
assert path == "images/doc123/page_1.png"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"image data", "images/doc123/page_1.png", overwrite=True
|
||||
)
|
||||
|
||||
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get page image from correct path."""
|
||||
content = helper.get_page_image("doc123", 2)
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
|
||||
|
||||
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for page image."""
|
||||
url = helper.get_page_image_url("doc123", 3)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"images/doc123/page_3.png", 3600
|
||||
)
|
||||
|
||||
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete all images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"images/doc123/page_1.png",
|
||||
"images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
deleted = helper.delete_document_images("doc123")
|
||||
|
||||
assert deleted == 2
|
||||
mock_storage.list_files.assert_called_once_with("images/doc123/")
|
||||
|
||||
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should list all images for a document."""
|
||||
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
|
||||
|
||||
images = helper.list_document_images("doc123")
|
||||
|
||||
assert images == ["images/doc123/page_1.png"]
|
||||
|
||||
|
||||
class TestUploadOperations:
|
||||
"""Tests for upload staging operations."""
|
||||
|
||||
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save upload to correct path."""
|
||||
path = helper.save_upload(b"content", "file.pdf")
|
||||
|
||||
assert path == "uploads/file.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save upload with subfolder."""
|
||||
path = helper.save_upload(b"content", "file.pdf", "async")
|
||||
|
||||
assert path == "uploads/async/file.pdf"
|
||||
|
||||
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get upload from correct path."""
|
||||
content = helper.get_upload("file.pdf", "async")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
|
||||
|
||||
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete upload."""
|
||||
result = helper.delete_upload("file.pdf")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
|
||||
|
||||
|
||||
class TestResultOperations:
|
||||
"""Tests for result file operations."""
|
||||
|
||||
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save result to correct path."""
|
||||
path = helper.save_result(b"result data", "output.json")
|
||||
|
||||
assert path == "results/output.json"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get result from correct path."""
|
||||
content = helper.get_result("output.json")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("results/output.json")
|
||||
|
||||
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for result."""
|
||||
url = helper.get_result_url("output.json")
|
||||
|
||||
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
|
||||
|
||||
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check result existence."""
|
||||
exists = helper.result_exists("output.json")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("results/output.json")
|
||||
|
||||
|
||||
class TestExportOperations:
|
||||
"""Tests for export file operations."""
|
||||
|
||||
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save export to correct path."""
|
||||
path = helper.save_export(b"export data", "exp123", "dataset.zip")
|
||||
|
||||
assert path == "exports/exp123/dataset.zip"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for export."""
|
||||
url = helper.get_export_url("exp123", "dataset.zip")
|
||||
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"exports/exp123/dataset.zip", 3600
|
||||
)
|
||||
|
||||
|
||||
class TestRawPdfOperations:
|
||||
"""Tests for raw PDF operations (legacy compatibility)."""
|
||||
|
||||
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save raw PDF to correct path."""
|
||||
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
|
||||
|
||||
assert path == "raw_pdfs/invoice.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get raw PDF from correct path."""
|
||||
content = helper.get_raw_pdf("invoice.pdf")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||
|
||||
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check raw PDF existence."""
|
||||
exists = helper.raw_pdf_exists("invoice.pdf")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||
|
||||
|
||||
class TestAdminImageOperations:
|
||||
"""Tests for admin image storage operations."""
|
||||
|
||||
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save admin image with correct path."""
|
||||
path = helper.save_admin_image("doc123", 1, b"image data")
|
||||
|
||||
assert path == "admin_images/doc123/page_1.png"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"image data", "admin_images/doc123/page_1.png", overwrite=True
|
||||
)
|
||||
|
||||
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get admin image from correct path."""
|
||||
content = helper.get_admin_image("doc123", 2)
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
|
||||
|
||||
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for admin image."""
|
||||
url = helper.get_admin_image_url("doc123", 3)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"admin_images/doc123/page_3.png", 3600
|
||||
)
|
||||
|
||||
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check admin image existence."""
|
||||
exists = helper.admin_image_exists("doc123", 1)
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
|
||||
|
||||
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
|
||||
"""Should return correct admin image path."""
|
||||
path = helper.get_admin_image_path("doc123", 2)
|
||||
|
||||
assert path == "admin_images/doc123/page_2.png"
|
||||
|
||||
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should list all admin images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"admin_images/doc123/page_1.png",
|
||||
"admin_images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
images = helper.list_admin_images("doc123")
|
||||
|
||||
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
|
||||
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||
|
||||
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete all admin images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"admin_images/doc123/page_1.png",
|
||||
"admin_images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
deleted = helper.delete_admin_images("doc123")
|
||||
|
||||
assert deleted == 2
|
||||
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||
|
||||
|
||||
class TestGetLocalPath:
|
||||
"""Tests for get_local_path method."""
|
||||
|
||||
def test_get_admin_image_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test image
|
||||
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "page_1.png").write_bytes(b"test image")
|
||||
|
||||
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "page_1.png"
|
||||
|
||||
def test_get_admin_image_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when storage doesn't support local paths."""
|
||||
# Mock storage without get_local_path method (simulating cloud storage)
|
||||
mock_storage.get_local_path = MagicMock(return_value=None)
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
|
||||
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||
|
||||
assert local_path is None
|
||||
|
||||
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
local_path = helper.get_admin_image_local_path("nonexistent", 1)
|
||||
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestGetAdminImageDimensions:
|
||||
"""Tests for get_admin_image_dimensions method."""
|
||||
|
||||
def test_get_dimensions_with_local_storage(self) -> None:
|
||||
"""Should return image dimensions when using local storage."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test image with known dimensions
|
||||
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
img.save(test_path / "page_1.png")
|
||||
|
||||
dimensions = helper.get_admin_image_dimensions("doc123", 1)
|
||||
|
||||
assert dimensions == (800, 600)
|
||||
|
||||
def test_get_dimensions_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
|
||||
|
||||
assert dimensions is None
|
||||
|
||||
|
||||
class TestGetStorageHelper:
|
||||
"""Tests for get_storage_helper function."""
|
||||
|
||||
def test_returns_helper_instance(self) -> None:
|
||||
"""Should return a StorageHelper instance."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
# Reset the global helper
|
||||
import inference.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper = get_storage_helper()
|
||||
|
||||
assert isinstance(helper, StorageHelper)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
import inference.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper1 = get_storage_helper()
|
||||
helper2 = get_storage_helper()
|
||||
|
||||
assert helper1 is helper2
|
||||
|
||||
|
||||
class TestDeleteResult:
|
||||
"""Tests for delete_result method."""
|
||||
|
||||
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete result file."""
|
||||
result = helper.delete_result("output.json")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("results/output.json")
|
||||
|
||||
|
||||
class TestResultLocalPath:
|
||||
"""Tests for get_result_local_path method."""
|
||||
|
||||
def test_get_result_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test result file
|
||||
test_path = Path(temp_dir) / "results"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "output.json").write_bytes(b"test result")
|
||||
|
||||
local_path = helper.get_result_local_path("output.json")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "output.json"
|
||||
|
||||
def test_get_result_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when storage doesn't support local paths."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_result_local_path("output.json")
|
||||
assert local_path is None
|
||||
|
||||
def test_get_result_local_path_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
local_path = helper.get_result_local_path("nonexistent.json")
|
||||
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestResultsBasePath:
|
||||
"""Tests for get_results_base_path method."""
|
||||
|
||||
def test_get_results_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_results_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "results"
|
||||
|
||||
def test_get_results_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_results_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestUploadLocalPath:
|
||||
"""Tests for get_upload_local_path method."""
|
||||
|
||||
def test_get_upload_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test upload file
|
||||
test_path = Path(temp_dir) / "uploads"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||
|
||||
local_path = helper.get_upload_local_path("file.pdf")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "file.pdf"
|
||||
|
||||
def test_get_upload_local_path_with_subfolder(self) -> None:
|
||||
"""Should return local path with subfolder."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test upload file with subfolder
|
||||
test_path = Path(temp_dir) / "uploads" / "async"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||
|
||||
local_path = helper.get_upload_local_path("file.pdf", "async")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
|
||||
def test_get_upload_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_upload_local_path("file.pdf")
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestUploadsBasePath:
|
||||
"""Tests for get_uploads_base_path method."""
|
||||
|
||||
def test_get_uploads_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_uploads_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "uploads"
|
||||
|
||||
def test_get_uploads_base_path_with_subfolder(self) -> None:
|
||||
"""Should return base path with subfolder."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_uploads_base_path("async")
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "async"
|
||||
|
||||
def test_get_uploads_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_uploads_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestUploadExists:
|
||||
"""Tests for upload_exists method."""
|
||||
|
||||
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check upload existence."""
|
||||
exists = helper.upload_exists("file.pdf")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
|
||||
|
||||
def test_upload_exists_with_subfolder(
|
||||
self, helper: StorageHelper, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should check upload existence with subfolder."""
|
||||
helper.upload_exists("file.pdf", "async")
|
||||
|
||||
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
|
||||
|
||||
|
||||
class TestDatasetsBasePath:
|
||||
"""Tests for get_datasets_base_path method."""
|
||||
|
||||
def test_get_datasets_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_datasets_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "datasets"
|
||||
|
||||
def test_get_datasets_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_datasets_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestAdminImagesBasePath:
|
||||
"""Tests for get_admin_images_base_path method."""
|
||||
|
||||
def test_get_admin_images_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_admin_images_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "admin_images"
|
||||
|
||||
def test_get_admin_images_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_admin_images_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestRawPdfsBasePath:
|
||||
"""Tests for get_raw_pdfs_base_path method."""
|
||||
|
||||
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_raw_pdfs_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "raw_pdfs"
|
||||
|
||||
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_raw_pdfs_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestRawPdfLocalPath:
|
||||
"""Tests for get_raw_pdf_local_path method."""
|
||||
|
||||
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test raw PDF
|
||||
test_path = Path(temp_dir) / "raw_pdfs"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "invoice.pdf").write_bytes(b"test pdf")
|
||||
|
||||
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "invoice.pdf"
|
||||
|
||||
def test_get_raw_pdf_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestRawPdfPath:
|
||||
"""Tests for get_raw_pdf_path method."""
|
||||
|
||||
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
|
||||
"""Should return correct storage path."""
|
||||
path = helper.get_raw_pdf_path("invoice.pdf")
|
||||
assert path == "raw_pdfs/invoice.pdf"
|
||||
|
||||
|
||||
class TestAutolabelOutputPath:
|
||||
"""Tests for get_autolabel_output_path method."""
|
||||
|
||||
def test_get_autolabel_output_path_with_local_storage(self) -> None:
|
||||
"""Should return output path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
output_path = helper.get_autolabel_output_path()
|
||||
|
||||
assert output_path is not None
|
||||
assert output_path.exists()
|
||||
assert output_path.name == "autolabel_output"
|
||||
|
||||
def test_get_autolabel_output_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
output_path = helper.get_autolabel_output_path()
|
||||
assert output_path is None
|
||||
|
||||
|
||||
class TestTrainingDataPath:
|
||||
"""Tests for get_training_data_path method."""
|
||||
|
||||
def test_get_training_data_path_with_local_storage(self) -> None:
|
||||
"""Should return training path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
training_path = helper.get_training_data_path()
|
||||
|
||||
assert training_path is not None
|
||||
assert training_path.exists()
|
||||
assert training_path.name == "training"
|
||||
|
||||
def test_get_training_data_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
training_path = helper.get_training_data_path()
|
||||
assert training_path is None
|
||||
|
||||
|
||||
class TestExportsBasePath:
|
||||
"""Tests for get_exports_base_path method."""
|
||||
|
||||
def test_get_exports_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_exports_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "exports"
|
||||
|
||||
def test_get_exports_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_exports_base_path()
|
||||
assert base_path is None
|
||||
306
tests/web/test_storage_integration.py
Normal file
306
tests/web/test_storage_integration.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Tests for storage backend integration in web application.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageBackendInitialization:
|
||||
"""Tests for storage backend initialization in web config."""
|
||||
|
||||
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
|
||||
"""Test that get_storage_backend returns a StorageBackend instance."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = get_storage_backend()
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
def test_get_storage_backend_uses_config_file_if_exists(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that storage config file is used when present."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = get_storage_backend(config_path=config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||
"""Test fallback to environment variables when no config file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = get_storage_backend(config_path=None)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
|
||||
"""Test that AppConfig can be created with storage backend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import AppConfig, create_app_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = create_app_config()
|
||||
|
||||
assert hasattr(config, "storage_backend")
|
||||
assert isinstance(config.storage_backend, StorageBackend)
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentUpload:
|
||||
"""Tests for storage backend usage in document upload."""
|
||||
|
||||
def test_upload_document_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document upload uses storage backend."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a mock upload file
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
# Upload should use storage backend
|
||||
result = service.upload_document(
|
||||
content=pdf_content,
|
||||
filename="test.pdf",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
# Verify file was stored via storage backend
|
||||
assert backend.exists(f"documents/{result.id}.pdf")
|
||||
|
||||
def test_upload_document_stores_logical_path(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document stores logical path, not absolute path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
result = service.upload_document(
|
||||
content=pdf_content,
|
||||
filename="test.pdf",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
|
||||
# Path should be logical (relative), not absolute
|
||||
assert not result.file_path.startswith("/")
|
||||
assert not result.file_path.startswith("C:")
|
||||
assert result.file_path.startswith("documents/")
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentDownload:
|
||||
"""Tests for storage backend usage in document download/serving."""
|
||||
|
||||
def test_get_document_url_returns_presigned_url(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document URL uses presigned URL from storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test file
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
url = service.get_document_url(doc_path)
|
||||
|
||||
# Should return a URL (file:// for local, https:// for cloud)
|
||||
assert url is not None
|
||||
assert "test-doc.pdf" in url
|
||||
|
||||
def test_download_document_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document download uses storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test file
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
original_content = b"%PDF-1.4 test content"
|
||||
backend.upload_bytes(original_content, doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
content = service.download_document(doc_path)
|
||||
|
||||
assert content == original_content
|
||||
|
||||
|
||||
class TestStorageBackendInImageServing:
|
||||
"""Tests for storage backend usage in image serving."""
|
||||
|
||||
def test_get_page_image_url_returns_presigned_url(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that page image URL uses presigned URL."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test image
|
||||
image_path = "images/doc-123/page_1.png"
|
||||
backend.upload_bytes(b"fake png content", image_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
url = service.get_page_image_url("doc-123", 1)
|
||||
|
||||
assert url is not None
|
||||
assert "page_1.png" in url
|
||||
|
||||
def test_save_page_image_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that page image saving uses storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
image_content = b"fake png content"
|
||||
service.save_page_image("doc-123", 1, image_content)
|
||||
|
||||
# Verify image was stored
|
||||
assert backend.exists("images/doc-123/page_1.png")
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentDeletion:
|
||||
"""Tests for storage backend usage in document deletion."""
|
||||
|
||||
def test_delete_document_removes_from_storage(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document deletion removes file from storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create test files
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
service.delete_document_files(doc_path)
|
||||
|
||||
assert not backend.exists(doc_path)
|
||||
|
||||
def test_delete_document_removes_images(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document deletion removes associated images."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create test files
|
||||
doc_id = "test-doc-123"
|
||||
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
|
||||
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
service.delete_document_images(doc_id)
|
||||
|
||||
assert not backend.exists(f"images/{doc_id}/page_1.png")
|
||||
assert not backend.exists(f"images/{doc_id}/page_2.png")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
mock = MagicMock()
|
||||
mock.get_document.return_value = None
|
||||
mock.create_document.return_value = MagicMock(
|
||||
id="test-doc-id",
|
||||
file_path="documents/test-doc-id.pdf",
|
||||
)
|
||||
return mock
|
||||
@@ -103,6 +103,31 @@ class MockAnnotation:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockModelVersion:
|
||||
"""Mock ModelVersion for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.version_id = kwargs.get('version_id', uuid4())
|
||||
self.version = kwargs.get('version', '1.0.0')
|
||||
self.name = kwargs.get('name', 'Test Model')
|
||||
self.description = kwargs.get('description', None)
|
||||
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||
self.status = kwargs.get('status', 'inactive')
|
||||
self.is_active = kwargs.get('is_active', False)
|
||||
self.task_id = kwargs.get('task_id', None)
|
||||
self.dataset_id = kwargs.get('dataset_id', None)
|
||||
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||
self.document_count = kwargs.get('document_count', 100)
|
||||
self.training_config = kwargs.get('training_config', {})
|
||||
self.file_size = kwargs.get('file_size', 52428800)
|
||||
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
|
||||
self.activated_at = kwargs.get('activated_at', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
|
||||
@@ -111,6 +136,7 @@ class MockAdminDB:
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
self.model_versions = {}
|
||||
|
||||
def get_documents_for_training(
|
||||
self,
|
||||
@@ -174,6 +200,14 @@ class MockAdminDB:
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||
"""Get model versions with optional filtering."""
|
||||
models = list(self.model_versions.values())
|
||||
if status:
|
||||
models = [m for m in models if m.status == status]
|
||||
total = len(models)
|
||||
return models[offset:offset+limit], total
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
@@ -241,6 +275,30 @@ def app():
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Add model versions
|
||||
model1 = MockModelVersion(
|
||||
version="1.0.0",
|
||||
name="Model v1.0.0",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
)
|
||||
model2 = MockModelVersion(
|
||||
version="1.1.0",
|
||||
name="Model v1.1.0",
|
||||
status="active",
|
||||
is_active=True,
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
document_count=600,
|
||||
)
|
||||
mock_db.model_versions[str(model1.version_id)] = model1
|
||||
mock_db.model_versions[str(model2.version_id)] = model2
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
|
||||
|
||||
|
||||
class TestTrainingModels:
|
||||
"""Tests for GET /admin/training/models endpoint."""
|
||||
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
|
||||
|
||||
def test_get_training_models_success(self, client):
|
||||
"""Test getting trained models list."""
|
||||
"""Test getting model versions list."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -338,43 +396,44 @@ class TestTrainingModels:
|
||||
assert len(data["models"]) == 2
|
||||
|
||||
def test_get_training_models_includes_metrics(self, client):
|
||||
"""Test that models include metrics."""
|
||||
"""Test that model versions include metrics."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check first model has metrics
|
||||
# Check first model has metrics fields
|
||||
model = data["models"][0]
|
||||
assert "metrics" in model
|
||||
assert "mAP" in model["metrics"]
|
||||
assert model["metrics"]["mAP"] is not None
|
||||
assert "precision" in model["metrics"]
|
||||
assert "recall" in model["metrics"]
|
||||
assert "metrics_mAP" in model
|
||||
assert model["metrics_mAP"] is not None
|
||||
|
||||
def test_get_training_models_includes_download_url(self, client):
|
||||
"""Test that completed models have download URLs."""
|
||||
def test_get_training_models_includes_version_fields(self, client):
|
||||
"""Test that model versions include version fields."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check completed models have download URLs
|
||||
for model in data["models"]:
|
||||
if model["status"] == "completed":
|
||||
assert "download_url" in model
|
||||
assert model["download_url"] is not None
|
||||
# Check model has expected fields
|
||||
model = data["models"][0]
|
||||
assert "version_id" in model
|
||||
assert "version" in model
|
||||
assert "name" in model
|
||||
assert "status" in model
|
||||
assert "is_active" in model
|
||||
assert "document_count" in model
|
||||
|
||||
def test_get_training_models_filter_by_status(self, client):
|
||||
"""Test filtering models by status."""
|
||||
response = client.get("/admin/training/models?status=completed")
|
||||
"""Test filtering model versions by status."""
|
||||
response = client.get("/admin/training/models?status=active")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned models should be completed
|
||||
assert data["total"] == 1
|
||||
# All returned models should be active
|
||||
for model in data["models"]:
|
||||
assert model["status"] == "completed"
|
||||
assert model["status"] == "active"
|
||||
|
||||
def test_get_training_models_pagination(self, client):
|
||||
"""Test pagination for models."""
|
||||
"""Test pagination for model versions."""
|
||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user