This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -17,9 +17,8 @@ from inference.data.admin_models import (
AdminDocument,
AdminAnnotation,
TrainingTask,
FIELD_CLASSES,
CSV_TO_CLASS_MAPPING,
)
from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING
class TestBatchUpload:
@@ -507,7 +506,10 @@ class TestCSVToClassMapping:
assert len(CSV_TO_CLASS_MAPPING) > 0
def test_csv_mapping_values(self):
"""Test specific CSV column mappings."""
"""Test specific CSV column mappings.
Note: customer_number is class 8 (verified from trained model best.pt).
"""
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
@@ -516,7 +518,7 @@ class TestCSVToClassMapping:
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9
assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8
def test_csv_mapping_matches_field_classes(self):
"""Test that CSV mapping is consistent with FIELD_CLASSES."""

View File

@@ -0,0 +1 @@
"""Tests for shared.fields module."""

View File

@@ -0,0 +1,200 @@
"""
Tests for field configuration - Single Source of Truth.
These tests ensure consistency across all field definitions and prevent
accidental changes that could break model inference.
CRITICAL: These tests verify that field definitions match the trained YOLO model.
If these tests fail, it likely means someone modified field IDs incorrectly.
"""
import pytest
from shared.fields import (
FIELD_DEFINITIONS,
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
NUM_CLASSES,
FieldDefinition,
)
class TestFieldDefinitionsIntegrity:
"""Tests to ensure field definitions are complete and consistent."""
def test_exactly_10_field_definitions(self):
"""Verify we have exactly 10 field classes (matching trained model)."""
assert len(FIELD_DEFINITIONS) == 10
assert NUM_CLASSES == 10
def test_class_ids_are_sequential(self):
"""Verify class IDs are 0-9 without gaps."""
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
assert class_ids == set(range(10))
def test_class_ids_are_unique(self):
"""Verify no duplicate class IDs."""
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
assert len(class_ids) == len(set(class_ids))
def test_class_names_are_unique(self):
"""Verify no duplicate class names."""
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
assert len(class_names) == len(set(class_names))
def test_field_definition_is_immutable(self):
"""Verify FieldDefinition is frozen (immutable)."""
fd = FIELD_DEFINITIONS[0]
with pytest.raises(AttributeError):
fd.class_id = 99 # type: ignore
class TestModelCompatibility:
"""Tests to verify field definitions match the trained YOLO model.
These exact values are read from runs/train/invoice_fields/weights/best.pt
and MUST NOT be changed without retraining the model.
"""
# Expected model.names from best.pt - DO NOT CHANGE
EXPECTED_MODEL_NAMES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_org_number",
8: "customer_number",
9: "payment_line",
}
def test_field_classes_match_model(self):
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
def test_class_names_order_matches_model(self):
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
expected_order = [
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
]
assert CLASS_NAMES == expected_order
def test_customer_number_is_class_8(self):
"""CRITICAL: customer_number must be class 8 (not 9)."""
assert FIELD_CLASS_IDS["customer_number"] == 8
assert FIELD_CLASSES[8] == "customer_number"
def test_payment_line_is_class_9(self):
"""CRITICAL: payment_line must be class 9 (not 8)."""
assert FIELD_CLASS_IDS["payment_line"] == 9
assert FIELD_CLASSES[9] == "payment_line"
class TestMappingConsistency:
"""Tests to verify all mappings are consistent with each other."""
def test_field_classes_and_field_class_ids_are_inverses(self):
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
for class_id, class_name in FIELD_CLASSES.items():
assert FIELD_CLASS_IDS[class_name] == class_id
for class_name, class_id in FIELD_CLASS_IDS.items():
assert FIELD_CLASSES[class_id] == class_name
def test_class_names_matches_field_classes_values(self):
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
for i, class_name in enumerate(CLASS_NAMES):
assert FIELD_CLASSES[i] == class_name
def test_class_to_field_has_all_classes(self):
"""Verify CLASS_TO_FIELD has mapping for all class names."""
for class_name in CLASS_NAMES:
assert class_name in CLASS_TO_FIELD
def test_csv_mapping_excludes_derived_fields(self):
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
# payment_line is derived, should not be in CSV mapping
assert "payment_line" not in CSV_TO_CLASS_MAPPING
# All non-derived fields should be in CSV mapping
for fd in FIELD_DEFINITIONS:
if not fd.is_derived:
assert fd.field_name in CSV_TO_CLASS_MAPPING
def test_training_field_classes_includes_all(self):
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in TRAINING_FIELD_CLASSES
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
class TestSpecificFieldDefinitions:
"""Tests for specific field definitions to catch common mistakes."""
@pytest.mark.parametrize(
"class_id,expected_class_name",
[
(0, "invoice_number"),
(1, "invoice_date"),
(2, "invoice_due_date"),
(3, "ocr_number"),
(4, "bankgiro"),
(5, "plusgiro"),
(6, "amount"),
(7, "supplier_org_number"),
(8, "customer_number"),
(9, "payment_line"),
],
)
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
"""Verify each class ID maps to the correct class name."""
assert FIELD_CLASSES[class_id] == expected_class_name
def test_payment_line_is_derived(self):
"""Verify payment_line is marked as derived."""
payment_line_def = next(
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
)
assert payment_line_def.is_derived is True
def test_other_fields_are_not_derived(self):
"""Verify all fields except payment_line are not derived."""
for fd in FIELD_DEFINITIONS:
if fd.class_name != "payment_line":
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
class TestBackwardCompatibility:
"""Tests to ensure backward compatibility with existing code."""
def test_csv_to_class_mapping_field_names(self):
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
# These are the field names used in CSV files
expected_fields = {
"InvoiceNumber": 0,
"InvoiceDate": 1,
"InvoiceDueDate": 2,
"OCR": 3,
"Bankgiro": 4,
"Plusgiro": 5,
"Amount": 6,
"supplier_organisation_number": 7,
"customer_number": 8,
# payment_line (9) is derived, not in CSV
}
assert CSV_TO_CLASS_MAPPING == expected_fields
def test_class_to_field_returns_field_names(self):
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
# Sample checks for key fields
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
assert CLASS_TO_FIELD["payment_line"] == "payment_line"

View File

@@ -0,0 +1 @@
# Tests for storage module

View File

@@ -0,0 +1,718 @@
"""
Tests for AzureBlobStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
Uses mocking to avoid requiring actual Azure credentials.
"""
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@pytest.fixture
def mock_blob_service_client() -> MagicMock:
"""Create a mock BlobServiceClient."""
return MagicMock()
@pytest.fixture
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
"""Create a mock ContainerClient."""
container_client = MagicMock()
mock_blob_service_client.get_container_client.return_value = container_client
return container_client
@pytest.fixture
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
"""Create a mock BlobClient."""
blob_client = MagicMock()
mock_container_client.get_blob_client.return_value = blob_client
return blob_client
class TestAzureBlobStorageBackendCreation:
"""Tests for AzureBlobStorageBackend instantiation."""
@patch("shared.storage.azure.BlobServiceClient")
def test_create_with_connection_string(
self, mock_service_class: MagicMock
) -> None:
"""Test creating backend with connection string."""
from shared.storage.azure import AzureBlobStorageBackend
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
backend = AzureBlobStorageBackend(
connection_string=connection_string,
container_name="training-images",
)
mock_service_class.from_connection_string.assert_called_once_with(
connection_string
)
assert backend.container_name == "training-images"
@patch("shared.storage.azure.BlobServiceClient")
def test_create_creates_container_if_not_exists(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is created if it doesn't exist."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="new-container",
create_container=True,
)
mock_container.create_container.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_create_does_not_create_container_by_default(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is not created by default."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="existing-container",
)
mock_container.create_container.assert_not_called()
@patch("shared.storage.azure.BlobServiceClient")
def test_is_storage_backend_subclass(
self, mock_service_class: MagicMock
) -> None:
"""Test that AzureBlobStorageBackend is a StorageBackend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageBackend
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert isinstance(backend, StorageBackend)
class TestAzureBlobStorageBackendUpload:
"""Tests for AzureBlobStorageBackend.upload method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_file(self, mock_service_class: MagicMock) -> None:
"""Test uploading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"Hello, World!")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
mock_blob.upload_blob.assert_called_once()
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_fails_if_blob_exists_without_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload fails if blob exists and overwrite is False."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
with pytest.raises(StorageError, match="already exists"):
backend.upload(temp_path, "existing.txt", overwrite=False)
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_succeeds_with_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "existing.txt", overwrite=True)
assert result == "existing.txt"
mock_blob.upload_blob.assert_called_once()
# Check overwrite=True was passed
call_kwargs = mock_blob.upload_blob.call_args[1]
assert call_kwargs.get("overwrite") is True
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_nonexistent_file_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
class TestAzureBlobStorageBackendDownload:
"""Tests for AzureBlobStorageBackend.download method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_file(self, mock_service_class: MagicMock) -> None:
"""Test downloading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
# Mock download_blob to return stream
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "downloaded.txt"
result = backend.download("remote/sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_bytes() == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_creates_parent_directories(
self, mock_service_class: MagicMock
) -> None:
"""Test that download creates parent directories."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"content"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
@patch("shared.storage.azure.BlobServiceClient")
def test_download_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that downloading nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
class TestAzureBlobStorageBackendExists:
"""Tests for AzureBlobStorageBackend.exists method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_true_for_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns True for existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("existing.txt") is True
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_false_for_nonexistent_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns False for nonexistent blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("nonexistent.txt") is False
class TestAzureBlobStorageBackendListFiles:
"""Tests for AzureBlobStorageBackend.list_files method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_empty_container(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files in empty container."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.list_blobs.return_value = []
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.list_files("") == []
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_returns_all_blobs(
self, mock_service_class: MagicMock
) -> None:
"""Test listing all blobs."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Create mock blob items
mock_blob1 = MagicMock()
mock_blob1.name = "file1.txt"
mock_blob2 = MagicMock()
mock_blob2.name = "file2.txt"
mock_blob3 = MagicMock()
mock_blob3.name = "subdir/file3.txt"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_with_prefix(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("images/")
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
assert len(files) == 2
class TestAzureBlobStorageBackendDelete:
"""Tests for AzureBlobStorageBackend.delete method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting an existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("sample.txt")
assert result is True
mock_blob.delete_blob.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_nonexistent_blob_returns_false(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting nonexistent blob returns False."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("nonexistent.txt")
assert result is False
mock_blob.delete_blob.assert_not_called()
class TestAzureBlobStorageBackendGetUrl:
"""Tests for AzureBlobStorageBackend.get_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_returns_blob_url(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url returns blob URL."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
url = backend.get_url("sample.txt")
assert url == "https://account.blob.core.windows.net/container/sample.txt"
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url for nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestAzureBlobStorageBackendUploadBytes:
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
"""Test uploading bytes directly."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_blob.upload_blob.assert_called_once()
class TestAzureBlobStorageBackendDownloadBytes:
"""Tests for AzureBlobStorageBackend.download_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
"""Test downloading blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes_nonexistent(
self, mock_service_class: MagicMock
) -> None:
"""Test downloading nonexistent blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestAzureBlobStorageBackendBatchOperations:
"""Tests for batch operations in AzureBlobStorageBackend."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
"""Test uploading an entire directory."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
(temp_path / "file1.txt").write_text("content1")
(temp_path / "subdir").mkdir()
(temp_path / "subdir" / "file2.txt").write_text("content2")
results = backend.upload_directory(temp_path, "uploads/")
assert len(results) == 2
assert "uploads/file1.txt" in results
assert "uploads/subdir/file2.txt" in results
@patch("shared.storage.azure.BlobServiceClient")
def test_download_directory(self, mock_service_class: MagicMock) -> None:
"""Test downloading blobs matching a prefix."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Mock blob listing
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
# Mock blob clients
mock_blob_client = MagicMock()
mock_container.get_blob_client.return_value = mock_blob_client
mock_blob_client.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"image content"
mock_blob_client.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
results = backend.download_directory("images/", local_path)
assert len(results) == 2
# Files should be created relative to prefix
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()

View File

@@ -0,0 +1,301 @@
"""
Tests for storage base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from abc import ABC
from pathlib import Path
from typing import BinaryIO
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInterface:
"""Tests for StorageBackend abstract base class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that StorageBackend cannot be instantiated."""
from shared.storage.base import StorageBackend
with pytest.raises(TypeError):
StorageBackend() # type: ignore
def test_is_abstract_base_class(self) -> None:
"""Test that StorageBackend is an ABC."""
from shared.storage.base import StorageBackend
assert issubclass(StorageBackend, ABC)
def test_subclass_must_implement_upload(self) -> None:
"""Test that subclass must implement upload method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_download(self) -> None:
"""Test that subclass must implement download method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_exists(self) -> None:
"""Test that subclass must implement exists method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_list_files(self) -> None:
"""Test that subclass must implement list_files method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_delete(self) -> None:
"""Test that subclass must implement delete method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_get_url(self) -> None:
"""Test that subclass must implement get_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_can_be_instantiated(self) -> None:
"""Test that a complete subclass can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return ""
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestStorageError:
"""Tests for StorageError exception."""
def test_storage_error_is_exception(self) -> None:
"""Test that StorageError is an Exception."""
from shared.storage.base import StorageError
assert issubclass(StorageError, Exception)
def test_storage_error_with_message(self) -> None:
"""Test StorageError with message."""
from shared.storage.base import StorageError
error = StorageError("Upload failed")
assert str(error) == "Upload failed"
def test_storage_error_can_be_raised(self) -> None:
"""Test that StorageError can be raised and caught."""
from shared.storage.base import StorageError
with pytest.raises(StorageError, match="test error"):
raise StorageError("test error")
class TestFileNotFoundError:
"""Tests for FileNotFoundStorageError exception."""
def test_file_not_found_is_storage_error(self) -> None:
"""Test that FileNotFoundStorageError is a StorageError."""
from shared.storage.base import FileNotFoundStorageError, StorageError
assert issubclass(FileNotFoundStorageError, StorageError)
def test_file_not_found_with_path(self) -> None:
"""Test FileNotFoundStorageError with path."""
from shared.storage.base import FileNotFoundStorageError
error = FileNotFoundStorageError("images/test.png")
assert "images/test.png" in str(error)
class TestStorageConfig:
"""Tests for StorageConfig dataclass."""
def test_storage_config_creation(self) -> None:
"""Test creating StorageConfig."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
def test_storage_config_defaults(self) -> None:
"""Test StorageConfig with defaults."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
assert config.backend_type == "local"
assert config.connection_string is None
assert config.container_name is None
assert config.base_path is None
def test_storage_config_with_base_path(self) -> None:
"""Test StorageConfig with base_path for local backend."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="local",
base_path=Path("/data/images"),
)
assert config.backend_type == "local"
assert config.base_path == Path("/data/images")
def test_storage_config_immutable(self) -> None:
"""Test that StorageConfig is immutable (frozen)."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore

View File

@@ -0,0 +1,348 @@
"""
Tests for storage configuration file loader.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
class TestEnvVarSubstitution:
"""Tests for environment variable substitution in config values."""
def test_substitute_simple_env_var(self) -> None:
"""Test substituting a simple environment variable."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
result = substitute_env_vars("${MY_VAR}")
assert result == "my_value"
def test_substitute_env_var_with_default(self) -> None:
"""Test substituting env var with default when var is not set."""
from shared.storage.config_loader import substitute_env_vars
# Ensure var is not set
os.environ.pop("UNSET_VAR", None)
result = substitute_env_vars("${UNSET_VAR:-default_value}")
assert result == "default_value"
def test_substitute_env_var_ignores_default_when_set(self) -> None:
"""Test that default is ignored when env var is set."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
result = substitute_env_vars("${SET_VAR:-default_value}")
assert result == "actual_value"
def test_substitute_multiple_env_vars(self) -> None:
"""Test substituting multiple env vars in one string."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
assert result == "postgres://localhost:5432/db"
def test_substitute_preserves_non_env_text(self) -> None:
"""Test that non-env-var text is preserved."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"VAR": "value"}):
result = substitute_env_vars("prefix_${VAR}_suffix")
assert result == "prefix_value_suffix"
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
"""Test that empty string is returned when var not set and no default."""
from shared.storage.config_loader import substitute_env_vars
os.environ.pop("MISSING_VAR", None)
result = substitute_env_vars("${MISSING_VAR}")
assert result == ""
class TestLoadStorageConfigYaml:
"""Tests for loading storage configuration from YAML files."""
def test_load_local_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for local backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
presigned_url_expiry: 3600
local:
base_path: ./data/storage
""")
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.presigned_url_expiry == 3600
assert config.local is not None
assert config.local.base_path == Path("./data/storage")
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for Azure backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: azure_blob
presigned_url_expiry: 7200
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
create_container: true
""")
config = load_storage_config(config_path)
assert config.backend_type == "azure_blob"
assert config.presigned_url_expiry == 7200
assert config.azure is not None
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
assert config.azure.container_name == "documents"
assert config.azure.create_container is True
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for S3 backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: s3
presigned_url_expiry: 1800
s3:
bucket_name: my-bucket
region_name: us-west-2
endpoint_url: http://localhost:9000
create_bucket: false
""")
config = load_storage_config(config_path)
assert config.backend_type == "s3"
assert config.presigned_url_expiry == 1800
assert config.s3 is not None
assert config.s3.bucket_name == "my-bucket"
assert config.s3.region_name == "us-west-2"
assert config.s3.endpoint_url == "http://localhost:9000"
assert config.s3.create_bucket is False
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
"""Test that environment variables are substituted in config."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${STORAGE_PATH:-./default/path}
""")
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.local is not None
assert config.local.base_path == Path("/custom/path")
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
"""Test that FileNotFoundError is raised for missing config file."""
from shared.storage.config_loader import load_storage_config
with pytest.raises(FileNotFoundError):
load_storage_config(temp_dir / "nonexistent.yaml")
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised for invalid YAML."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("invalid: yaml: content: [")
with pytest.raises(ValueError, match="Invalid"):
load_storage_config(config_path)
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised when backend is missing."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
local:
base_path: ./data
""")
with pytest.raises(ValueError, match="backend"):
load_storage_config(config_path)
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
"""Test default presigned_url_expiry when not specified."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
local:
base_path: ./data
""")
config = load_storage_config(config_path)
assert config.presigned_url_expiry == 3600 # Default value
class TestStorageFileConfig:
"""Tests for StorageFileConfig dataclass."""
def test_storage_file_config_is_immutable(self) -> None:
"""Test that StorageFileConfig is frozen (immutable)."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore
def test_storage_file_config_defaults(self) -> None:
"""Test StorageFileConfig default values."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
assert config.backend_type == "local"
assert config.local is None
assert config.azure is None
assert config.s3 is None
assert config.presigned_url_expiry == 3600
class TestLocalConfig:
"""Tests for LocalConfig dataclass."""
def test_local_config_creation(self) -> None:
"""Test creating LocalConfig."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data/storage"))
assert config.base_path == Path("/data/storage")
def test_local_config_is_immutable(self) -> None:
"""Test that LocalConfig is frozen."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data"))
with pytest.raises(AttributeError):
config.base_path = Path("/other") # type: ignore
class TestAzureConfig:
"""Tests for AzureConfig dataclass."""
def test_azure_config_creation(self) -> None:
"""Test creating AzureConfig."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="test_connection",
container_name="test_container",
create_container=True,
)
assert config.connection_string == "test_connection"
assert config.container_name == "test_container"
assert config.create_container is True
def test_azure_config_defaults(self) -> None:
"""Test AzureConfig default values."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
assert config.create_container is False
def test_azure_config_is_immutable(self) -> None:
"""Test that AzureConfig is frozen."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
with pytest.raises(AttributeError):
config.container_name = "other" # type: ignore
class TestS3Config:
"""Tests for S3Config dataclass."""
def test_s3_config_creation(self) -> None:
"""Test creating S3Config."""
from shared.storage.config_loader import S3Config
config = S3Config(
bucket_name="my-bucket",
region_name="us-east-1",
access_key_id="AKIAIOSFODNN7EXAMPLE",
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
endpoint_url="http://localhost:9000",
create_bucket=True,
)
assert config.bucket_name == "my-bucket"
assert config.region_name == "us-east-1"
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
assert config.endpoint_url == "http://localhost:9000"
assert config.create_bucket is True
def test_s3_config_minimal(self) -> None:
"""Test S3Config with only required fields."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
assert config.bucket_name == "bucket"
assert config.region_name is None
assert config.access_key_id is None
assert config.secret_access_key is None
assert config.endpoint_url is None
assert config.create_bucket is False
def test_s3_config_is_immutable(self) -> None:
"""Test that S3Config is frozen."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
with pytest.raises(AttributeError):
config.bucket_name = "other" # type: ignore

View File

@@ -0,0 +1,423 @@
"""
Tests for storage factory.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageFactory:
"""Tests for create_storage_backend factory function."""
def test_create_local_backend(self) -> None:
"""Test creating local storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
config = StorageConfig(
backend_type="local",
base_path=Path(temp_dir),
)
backend = create_storage_backend(config)
assert isinstance(backend, LocalStorageBackend)
assert backend.base_path == Path(temp_dir)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure blob storage backend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
backend = create_storage_backend(config)
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_unknown_backend_raises(self) -> None:
"""Test that unknown backend type raises ValueError."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="unknown_backend")
with pytest.raises(ValueError, match="Unknown storage backend"):
create_storage_backend(config)
def test_create_local_requires_base_path(self) -> None:
"""Test that local backend requires base_path."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="local")
with pytest.raises(ValueError, match="base_path"):
create_storage_backend(config)
def test_create_azure_requires_connection_string(self) -> None:
"""Test that Azure backend requires connection_string."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
container_name="container",
)
with pytest.raises(ValueError, match="connection_string"):
create_storage_backend(config)
def test_create_azure_requires_container_name(self) -> None:
"""Test that Azure backend requires container_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="connection_string",
)
with pytest.raises(ValueError, match="container_name"):
create_storage_backend(config)
class TestStorageFactoryFromEnv:
"""Tests for create_storage_backend_from_env factory function."""
def test_create_from_env_local(self) -> None:
"""Test creating local backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure backend from environment variables."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_from_env_defaults_to_local(self) -> None:
"""Test that factory defaults to local backend."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BASE_PATH": temp_dir,
}
# Remove STORAGE_BACKEND if present
with patch.dict(os.environ, env, clear=False):
if "STORAGE_BACKEND" in os.environ:
del os.environ["STORAGE_BACKEND"]
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
def test_create_from_env_missing_azure_vars(self) -> None:
"""Test error when Azure env vars are missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
# Missing AZURE_STORAGE_CONNECTION_STRING
}
with patch.dict(os.environ, env, clear=False):
# Remove the connection string if present
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
create_storage_backend_from_env()
class TestGetDefaultStorageConfig:
"""Tests for get_default_storage_config function."""
def test_get_default_config_local(self) -> None:
"""Test getting default local config."""
from shared.storage.factory import get_default_storage_config
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "local"
assert config.base_path == Path(temp_dir)
def test_get_default_config_azure(self) -> None:
"""Test getting default Azure config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
class TestStorageFactoryS3:
"""Tests for S3 backend support in factory."""
@patch("boto3.client")
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.s3 import S3StorageBackend
config = StorageConfig(
backend_type="s3",
bucket_name="test-bucket",
region_name="us-west-2",
)
backend = create_storage_backend(config)
assert isinstance(backend, S3StorageBackend)
def test_create_s3_requires_bucket_name(self) -> None:
"""Test that S3 backend requires bucket_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="s3",
region_name="us-west-2",
)
with pytest.raises(ValueError, match="bucket_name"):
create_storage_backend(config)
@patch("boto3.client")
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.s3 import S3StorageBackend
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-east-1",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, S3StorageBackend)
def test_create_from_env_s3_missing_bucket(self) -> None:
"""Test error when S3 bucket env var is missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "s3",
# Missing AWS_S3_BUCKET
}
with patch.dict(os.environ, env, clear=False):
if "AWS_S3_BUCKET" in os.environ:
del os.environ["AWS_S3_BUCKET"]
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
create_storage_backend_from_env()
def test_get_default_config_s3(self) -> None:
"""Test getting default S3 config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-west-2",
"AWS_ENDPOINT_URL": "http://localhost:9000",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "s3"
assert config.bucket_name == "test-bucket"
assert config.region_name == "us-west-2"
assert config.endpoint_url == "http://localhost:9000"
class TestStorageFactoryFromFile:
"""Tests for create_storage_backend_from_file factory function."""
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
"""Test creating local backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_yaml_file_azure(
self, mock_service_class: MagicMock, tmp_path: Path
) -> None:
"""Test creating Azure backend from YAML config file."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_file
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: azure_blob
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, AzureBlobStorageBackend)
@patch("boto3.client")
def test_create_from_yaml_file_s3(
self, mock_boto3_client: MagicMock, tmp_path: Path
) -> None:
"""Test creating S3 backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.s3 import S3StorageBackend
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: s3
s3:
bucket_name: my-bucket
region_name: us-east-1
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, S3StorageBackend)
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
"""Test that env vars are substituted in config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${CUSTOM_STORAGE_PATH}
""")
with patch.dict(
os.environ,
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
):
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
"""Test that FileNotFoundError is raised for missing file."""
from shared.storage.factory import create_storage_backend_from_file
with pytest.raises(FileNotFoundError):
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
class TestGetStorageBackend:
"""Tests for get_storage_backend convenience function."""
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
"""Test getting backend from explicit config file."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test that get_storage_backend falls back to env vars."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(storage_path),
}
with patch.dict(os.environ, env, clear=False):
# No config file provided, should use env vars
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)

View File

@@ -0,0 +1,712 @@
"""
Tests for LocalStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def sample_image(temp_storage_dir: Path) -> Path:
"""Create a sample PNG file for testing."""
file_path = temp_storage_dir / "sample.png"
# Minimal valid PNG (1x1 transparent pixel)
png_data = bytes(
[
0x89,
0x50,
0x4E,
0x47,
0x0D,
0x0A,
0x1A,
0x0A, # PNG signature
0x00,
0x00,
0x00,
0x0D, # IHDR length
0x49,
0x48,
0x44,
0x52, # IHDR
0x00,
0x00,
0x00,
0x01, # width: 1
0x00,
0x00,
0x00,
0x01, # height: 1
0x08,
0x06,
0x00,
0x00,
0x00, # 8-bit RGBA
0x1F,
0x15,
0xC4,
0x89, # CRC
0x00,
0x00,
0x00,
0x0A, # IDAT length
0x49,
0x44,
0x41,
0x54, # IDAT
0x78,
0x9C,
0x63,
0x00,
0x01,
0x00,
0x00,
0x05,
0x00,
0x01, # compressed data
0x0D,
0x0A,
0x2D,
0xB4, # CRC
0x00,
0x00,
0x00,
0x00, # IEND length
0x49,
0x45,
0x4E,
0x44, # IEND
0xAE,
0x42,
0x60,
0x82, # CRC
]
)
file_path.write_bytes(png_data)
return file_path
class TestLocalStorageBackendCreation:
"""Tests for LocalStorageBackend instantiation."""
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with base path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.base_path == temp_storage_dir
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with string path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
assert backend.base_path == temp_storage_dir
def test_create_creates_directory_if_not_exists(
self, temp_storage_dir: Path
) -> None:
"""Test that base directory is created if it doesn't exist."""
from shared.storage.local import LocalStorageBackend
new_dir = temp_storage_dir / "new_storage"
assert not new_dir.exists()
backend = LocalStorageBackend(base_path=new_dir)
assert new_dir.exists()
assert backend.base_path == new_dir
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
"""Test that LocalStorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendUpload:
"""Tests for LocalStorageBackend.upload method."""
def test_upload_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
assert (storage_dir / "uploads" / "sample.txt").exists()
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
def test_upload_creates_subdirectories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload creates necessary subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
def test_upload_fails_if_file_exists_without_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload fails if file exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload succeeds
backend.upload(sample_file, "sample.txt")
# Second upload should fail
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Modify original file
sample_file.write_text("Modified content")
# Second upload with overwrite
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
assert (storage_dir / "sample.txt").read_text() == "Modified content"
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
def test_upload_binary_file(
self, temp_storage_dir: Path, sample_image: Path
) -> None:
"""Test uploading a binary file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_image, "images/sample.png")
assert result == "images/sample.png"
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
assert uploaded_content == sample_image.read_bytes()
class TestLocalStorageBackendDownload:
"""Tests for LocalStorageBackend.download method."""
def test_download_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
download_dir = temp_storage_dir / "downloads"
download_dir.mkdir()
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Then download
local_path = download_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_creates_parent_directories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that downloading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
def test_download_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file from nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
local_path = temp_storage_dir / "downloaded.txt"
result = backend.download("a/b/c/sample.txt", local_path)
assert local_path.read_text() == "Hello, World!"
class TestLocalStorageBackendExists:
"""Tests for LocalStorageBackend.exists method."""
def test_exists_returns_true_for_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists returns True for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_file(
self, temp_storage_dir: Path
) -> None:
"""Test exists returns False for nonexistent file."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.exists("nonexistent.txt") is False
def test_exists_with_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists with nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
assert backend.exists("a/b/sample.txt") is True
assert backend.exists("a/b/other.txt") is False
class TestLocalStorageBackendListFiles:
"""Tests for LocalStorageBackend.list_files method."""
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
"""Test listing files in empty storage."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.list_files("") == []
def test_list_files_returns_all_files(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing all files."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Upload multiple files
backend.upload(sample_file, "file1.txt")
backend.upload(sample_file, "file2.txt")
backend.upload(sample_file, "subdir/file3.txt")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "images/a.png")
backend.upload(sample_file, "images/b.png")
backend.upload(sample_file, "labels/a.txt")
files = backend.list_files("images/")
assert len(files) == 2
assert "images/a.png" in files
assert "images/b.png" in files
assert "labels/a.txt" not in files
def test_list_files_returns_sorted(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that list_files returns sorted list."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "c.txt")
backend.upload(sample_file, "a.txt")
backend.upload(sample_file, "b.txt")
files = backend.list_files("")
assert files == ["a.txt", "b.txt", "c.txt"]
class TestLocalStorageBackendDelete:
"""Tests for LocalStorageBackend.delete method."""
def test_delete_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting an existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
result = backend.delete("sample.txt")
assert result is True
assert not (storage_dir / "sample.txt").exists()
def test_delete_nonexistent_file_returns_false(
self, temp_storage_dir: Path
) -> None:
"""Test deleting nonexistent file returns False."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
result = backend.delete("nonexistent.txt")
assert result is False
def test_delete_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting a nested file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
result = backend.delete("a/b/sample.txt")
assert result is True
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
class TestLocalStorageBackendGetUrl:
"""Tests for LocalStorageBackend.get_url method."""
def test_get_url_returns_file_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_url returns file:// URL."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_url("sample.txt")
# Should return file:// URL or absolute path
assert "sample.txt" in url
# URL should be usable to locate the file
expected_path = storage_dir / "sample.txt"
assert str(expected_path) in url or expected_path.as_uri() == url
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
"""Test get_url for nonexistent file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestLocalStorageBackendUploadBytes:
"""Tests for LocalStorageBackend.upload_bytes method."""
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
"""Test uploading bytes directly."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
assert (storage_dir / "binary.dat").read_bytes() == data
def test_upload_bytes_creates_subdirectories(
self, temp_storage_dir: Path
) -> None:
"""Test that upload_bytes creates subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"content"
backend.upload_bytes(data, "a/b/c/file.dat")
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
class TestLocalStorageBackendDownloadBytes:
"""Tests for LocalStorageBackend.download_bytes method."""
def test_download_bytes(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading file as bytes."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
"""Test downloading nonexistent file as bytes."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestLocalStorageBackendSecurity:
"""Security tests for LocalStorageBackend - path traversal prevention."""
def test_path_traversal_with_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that path traversal using ../ is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "../escape.txt")
def test_path_traversal_with_nested_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "subdir/../../escape.txt")
def test_path_traversal_with_many_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that deeply nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
def test_absolute_path_unix_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Unix paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "/etc/passwd")
def test_absolute_path_windows_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Windows paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "C:\\Windows\\System32\\config")
def test_download_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download("../escape.txt", Path("/tmp/file.txt"))
def test_exists_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in exists is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.exists("../escape.txt")
def test_delete_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in delete is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.delete("../escape.txt")
def test_get_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_url("../escape.txt")
def test_upload_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in upload_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload_bytes(b"content", "../escape.txt")
def test_download_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download_bytes("../escape.txt")
def test_valid_nested_path_still_works(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that valid nested paths still work after security fix."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Valid nested paths should still work
result = backend.upload(sample_file, "a/b/c/d/file.txt")
assert result == "a/b/c/d/file.txt"
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()

View File

@@ -0,0 +1,158 @@
"""Tests for storage prefixes module."""
import pytest
from shared.storage.prefixes import PREFIXES, StoragePrefixes
class TestStoragePrefixes:
"""Tests for StoragePrefixes class."""
def test_prefixes_are_strings(self) -> None:
"""All prefix constants should be strings."""
assert isinstance(PREFIXES.DOCUMENTS, str)
assert isinstance(PREFIXES.IMAGES, str)
assert isinstance(PREFIXES.UPLOADS, str)
assert isinstance(PREFIXES.RESULTS, str)
assert isinstance(PREFIXES.EXPORTS, str)
assert isinstance(PREFIXES.DATASETS, str)
assert isinstance(PREFIXES.MODELS, str)
assert isinstance(PREFIXES.RAW_PDFS, str)
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
def test_prefixes_are_non_empty(self) -> None:
"""All prefix constants should be non-empty."""
assert PREFIXES.DOCUMENTS
assert PREFIXES.IMAGES
assert PREFIXES.UPLOADS
assert PREFIXES.RESULTS
assert PREFIXES.EXPORTS
assert PREFIXES.DATASETS
assert PREFIXES.MODELS
assert PREFIXES.RAW_PDFS
assert PREFIXES.STRUCTURED_DATA
assert PREFIXES.ADMIN_IMAGES
def test_prefixes_have_no_leading_slash(self) -> None:
"""Prefixes should not start with a slash for portability."""
assert not PREFIXES.DOCUMENTS.startswith("/")
assert not PREFIXES.IMAGES.startswith("/")
assert not PREFIXES.UPLOADS.startswith("/")
assert not PREFIXES.RESULTS.startswith("/")
def test_prefixes_have_no_trailing_slash(self) -> None:
"""Prefixes should not end with a slash."""
assert not PREFIXES.DOCUMENTS.endswith("/")
assert not PREFIXES.IMAGES.endswith("/")
assert not PREFIXES.UPLOADS.endswith("/")
assert not PREFIXES.RESULTS.endswith("/")
def test_frozen_dataclass(self) -> None:
"""StoragePrefixes should be immutable."""
with pytest.raises(Exception): # FrozenInstanceError
PREFIXES.DOCUMENTS = "new_value" # type: ignore
class TestDocumentPath:
"""Tests for document_path helper."""
def test_document_path_with_extension(self) -> None:
"""Should generate correct document path with extension."""
path = PREFIXES.document_path("abc123", ".pdf")
assert path == "documents/abc123.pdf"
def test_document_path_without_leading_dot(self) -> None:
"""Should handle extension without leading dot."""
path = PREFIXES.document_path("abc123", "pdf")
assert path == "documents/abc123.pdf"
def test_document_path_default_extension(self) -> None:
"""Should use .pdf as default extension."""
path = PREFIXES.document_path("abc123")
assert path == "documents/abc123.pdf"
class TestImagePath:
"""Tests for image_path helper."""
def test_image_path_basic(self) -> None:
"""Should generate correct image path."""
path = PREFIXES.image_path("doc123", 1)
assert path == "images/doc123/page_1.png"
def test_image_path_page_number(self) -> None:
"""Should include page number in path."""
path = PREFIXES.image_path("doc123", 5)
assert path == "images/doc123/page_5.png"
def test_image_path_custom_extension(self) -> None:
"""Should support custom extension."""
path = PREFIXES.image_path("doc123", 1, ".jpg")
assert path == "images/doc123/page_1.jpg"
class TestUploadPath:
"""Tests for upload_path helper."""
def test_upload_path_basic(self) -> None:
"""Should generate correct upload path."""
path = PREFIXES.upload_path("invoice.pdf")
assert path == "uploads/invoice.pdf"
def test_upload_path_with_subfolder(self) -> None:
"""Should include subfolder when provided."""
path = PREFIXES.upload_path("invoice.pdf", "async")
assert path == "uploads/async/invoice.pdf"
class TestResultPath:
"""Tests for result_path helper."""
def test_result_path_basic(self) -> None:
"""Should generate correct result path."""
path = PREFIXES.result_path("output.json")
assert path == "results/output.json"
class TestExportPath:
"""Tests for export_path helper."""
def test_export_path_basic(self) -> None:
"""Should generate correct export path."""
path = PREFIXES.export_path("exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
class TestDatasetPath:
"""Tests for dataset_path helper."""
def test_dataset_path_basic(self) -> None:
"""Should generate correct dataset path."""
path = PREFIXES.dataset_path("ds123", "data.yaml")
assert path == "datasets/ds123/data.yaml"
class TestModelPath:
"""Tests for model_path helper."""
def test_model_path_basic(self) -> None:
"""Should generate correct model path."""
path = PREFIXES.model_path("v1.0.0", "best.pt")
assert path == "models/v1.0.0/best.pt"
class TestExportsFromInit:
"""Tests for exports from storage __init__.py."""
def test_prefixes_exported(self) -> None:
"""PREFIXES should be exported from storage module."""
from shared.storage import PREFIXES as exported_prefixes
assert exported_prefixes is PREFIXES
def test_storage_prefixes_exported(self) -> None:
"""StoragePrefixes should be exported from storage module."""
from shared.storage import StoragePrefixes as exported_class
assert exported_class is StoragePrefixes

View File

@@ -0,0 +1,264 @@
"""
Tests for pre-signed URL functionality across all storage backends.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
class TestStorageBackendInterfacePresignedUrl:
"""Tests for get_presigned_url in StorageBackend interface."""
def test_subclass_must_implement_get_presigned_url(self) -> None:
"""Test that subclass must implement get_presigned_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
"""Test that a complete subclass with get_presigned_url can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return f"https://example.com/{remote_path}?token=abc"
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendPresignedUrl:
"""Tests for LocalStorageBackend.get_presigned_url method."""
def test_get_presigned_url_returns_file_uri(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url returns file:// URI for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_presigned_url("sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
def test_get_presigned_url_with_custom_expiry(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url accepts expires_in_seconds parameter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
# For local storage, expiry is ignored but should not raise error
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
assert url.startswith("file://")
def test_get_presigned_url_nonexistent_file_raises(
self, temp_storage_dir: Path
) -> None:
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
def test_get_presigned_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_presigned_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_presigned_url("../escape.txt")
def test_get_presigned_url_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url works with nested paths."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
url = backend.get_presigned_url("a/b/c/sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
class TestAzureBlobStorageBackendPresignedUrl:
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_generates_sas_url(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url generates URL with SAS token."""
from shared.storage.azure import AzureBlobStorageBackend
# Setup mocks
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
assert "https://testaccount.blob.core.windows.net" in url
assert "sv=2021-06-08" in url or "test.txt" in url
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_nonexistent_blob_raises(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent blob."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = False
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_uses_custom_expiry(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry time."""
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
# Verify generate_blob_sas was called (expiry is part of the call)
mock_generate_sas.assert_called_once()

View File

@@ -0,0 +1,520 @@
"""
Tests for S3StorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch, call
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def mock_boto3_client():
"""Create a mock boto3 S3 client."""
with patch("boto3.client") as mock_client_func:
mock_client = MagicMock()
mock_client_func.return_value = mock_client
yield mock_client
class TestS3StorageBackendCreation:
"""Tests for S3StorageBackend instantiation."""
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with bucket name."""
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.bucket_name == "test-bucket"
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with region."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
region_name="us-west-2",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("region_name") == "us-west-2"
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with explicit credentials."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
access_key_id="AKIATEST",
secret_access_key="secret123",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
assert call_kwargs.get("aws_secret_access_key") == "secret123"
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with custom endpoint (for S3-compatible services)."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
endpoint_url="http://localhost:9000",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
"""Test that bucket is created when create_bucket=True."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_bucket.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadBucket"
)
S3StorageBackend(
bucket_name="test-bucket",
create_bucket=True,
)
mock_boto3_client.create_bucket.assert_called_once()
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
"""Test that S3StorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert isinstance(backend, StorageBackend)
class TestS3StorageBackendUpload:
"""Tests for S3StorageBackend.upload method."""
def test_upload_file(
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
# Object does not exist
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_nonexistent_file_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
class TestS3StorageBackendDownload:
"""Tests for S3StorageBackend.download method."""
def test_download_file(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test downloading a file."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
mock_boto3_client.download_file.assert_called_once()
def test_download_creates_parent_directories(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
backend.download("sample.txt", local_path)
assert local_path.parent.exists()
def test_download_nonexistent_object_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that downloading nonexistent object fails."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download("nonexistent.txt", temp_dir / "file.txt")
class TestS3StorageBackendExists:
"""Tests for S3StorageBackend.exists method."""
def test_exists_returns_true_for_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns True for existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns False for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("nonexistent.txt") is False
class TestS3StorageBackendListFiles:
"""Tests for S3StorageBackend.list_files method."""
def test_list_files_returns_objects(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "file1.txt"},
{"Key": "file2.txt"},
{"Key": "subdir/file3.txt"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects with prefix filter."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "images/a.png"},
{"Key": "images/b.png"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("images/")
mock_boto3_client.list_objects_v2.assert_called_with(
Bucket="test-bucket", Prefix="images/"
)
def test_list_files_empty_bucket(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing files in empty bucket."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert files == []
class TestS3StorageBackendDelete:
"""Tests for S3StorageBackend.delete method."""
def test_delete_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting an existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("sample.txt")
assert result is True
mock_boto3_client.delete_object.assert_called_once()
def test_delete_nonexistent_object_returns_false(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting nonexistent object returns False."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("nonexistent.txt")
assert result is False
class TestS3StorageBackendGetUrl:
"""Tests for S3StorageBackend.get_url method."""
def test_get_url_returns_s3_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url returns S3 URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt"
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_url("sample.txt")
assert "sample.txt" in url
def test_get_url_nonexistent_object_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestS3StorageBackendUploadBytes:
"""Tests for S3StorageBackend.upload_bytes method."""
def test_upload_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test uploading bytes directly."""
from shared.storage.s3 import S3StorageBackend
from botocore.exceptions import ClientError
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_boto3_client.put_object.assert_called_once()
def test_upload_bytes_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock
) -> None:
"""Test upload_bytes fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
class TestS3StorageBackendDownloadBytes:
"""Tests for S3StorageBackend.download_bytes method."""
def test_download_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading object as bytes."""
from shared.storage.s3 import S3StorageBackend
mock_response = MagicMock()
mock_response.read.return_value = b"Hello, World!"
mock_boto3_client.get_object.return_value = {"Body": mock_response}
backend = S3StorageBackend(bucket_name="test-bucket")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading nonexistent object as bytes."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.get_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestS3StorageBackendPresignedUrl:
"""Tests for S3StorageBackend.get_presigned_url method."""
def test_get_presigned_url_generates_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url generates presigned URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_presigned_url("sample.txt")
assert "X-Amz-Algorithm" in url or "sample.txt" in url
mock_boto3_client.generate_presigned_url.assert_called_once()
def test_get_presigned_url_with_custom_expiry(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = "https://..."
backend = S3StorageBackend(bucket_name="test-bucket")
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
call_args = mock_boto3_client.generate_presigned_url.call_args
assert call_args[1].get("ExpiresIn") == 7200
def test_get_presigned_url_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")

View File

@@ -9,7 +9,8 @@ from uuid import UUID
from fastapi import HTTPException
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from inference.data.admin_models import AdminAnnotation, AdminDocument
from shared.fields import FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import (
AnnotationCreate,

View File

@@ -31,6 +31,7 @@ class MockAdminDocument:
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.category = kwargs.get('category', 'invoice')
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
@@ -67,12 +68,13 @@ class MockAdminDB:
def get_documents_by_token(
self,
admin_token,
admin_token=None,
status=None,
upload_source=None,
has_annotations=None,
auto_label_status=None,
batch_id=None,
category=None,
limit=20,
offset=0
):
@@ -95,6 +97,8 @@ class MockAdminDB:
docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
if category:
docs = [d for d in docs if d.category == category]
total = len(docs)
return docs[offset:offset+limit], total

View File

@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
def test_cleanup_orphan_files(self, async_service, mock_db):
"""Test cleanup of orphan files."""
# Create an orphan file
# Create the async upload directory
temp_dir = async_service._async_config.temp_upload_dir
temp_dir.mkdir(parents=True, exist_ok=True)
orphan_file = temp_dir / "orphan-request.pdf"
orphan_file.write_bytes(b"orphan content")
@@ -228,7 +230,13 @@ class TestAsyncProcessingService:
# Mock database to say file doesn't exist
mock_db.get_request.return_value = None
count = async_service._cleanup_orphan_files()
# Mock the storage helper to return the same directory as the fixture
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
mock_helper = MagicMock()
mock_helper.get_uploads_base_path.return_value = temp_dir
mock_storage.return_value = mock_helper
count = async_service._cleanup_orphan_files()
assert count == 1
assert not orphan_file.exists()

View File

@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
"""
import pytest
from unittest.mock import MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
TEST_ADMIN_TOKEN = "test-admin-token-12345"
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
@pytest.fixture
def admin_token() -> str:
"""Provide admin token for testing."""
return TEST_ADMIN_TOKEN
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
# Default return values
mock.get_document_by_token.return_value = None
mock.get_dataset.return_value = None
mock.get_augmented_datasets.return_value = ([], 0)
return mock
@pytest.fixture
def admin_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client with admin authentication."""
app = FastAPI()
# Override dependencies
def get_token_override():
return TEST_ADMIN_TOKEN
def get_db_override():
return mock_admin_db
app.dependency_overrides[validate_admin_token] = get_token_override
app.dependency_overrides[get_admin_db] = get_db_override
# Include router - the router already has /augmentation prefix
# so we add /api/v1/admin to get /api/v1/admin/augmentation
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
@pytest.fixture
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client WITHOUT admin authentication override."""
app = FastAPI()
# Only override the database, NOT the token validation
def get_db_override():
return mock_admin_db
app.dependency_overrides[get_admin_db] = get_db_override
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
class TestAugmentationTypesEndpoint:
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
assert "stage" in aug_type
def test_list_augmentation_types_unauthorized(
self, admin_client: TestClient
self, unauthenticated_client: TestClient
) -> None:
"""Test that unauthorized request is rejected."""
response = admin_client.get("/api/v1/admin/augmentation/types")
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401
@@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test previewing augmentation on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
assert response.status_code == 200
data = response.json()
@@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test previewing full config on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
assert response.status_code == 200
data = response.json()
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test creating augmented dataset."""
# Mock dataset exists
mock_dataset = MagicMock()
mock_dataset.total_images = 100
mock_admin_db.get_dataset.return_value = mock_dataset
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
@pytest.fixture
def sample_document_id() -> str:
"""Provide a sample document ID for testing."""
# This would need to be created in test setup
return "test-document-id"
return TEST_DOCUMENT_UUID
@pytest.fixture
def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing."""
# This would need to be created in test setup
return "test-dataset-id"
return TEST_DATASET_UUID

View File

@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
name="test-dataset",
description="Test dataset",
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
# Mock the active training tasks lookup to return empty dict
mock_db.get_active_training_tasks_for_datasets.return_value = {}
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))

View File

@@ -0,0 +1,363 @@
"""
Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# =============================================================================
# Test Database Model
# =============================================================================
class TestTrainingDatasetModel:
"""Tests for TrainingDataset model fields."""
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
training_status="running",
)
assert dataset.training_status == "running"
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
name="test-dataset",
active_training_task_id=task_id,
)
assert dataset.active_training_task_id == task_id
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
assert dataset.active_training_task_id is None
# =============================================================================
# Test AdminDB Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
assert dataset.training_status == "running"
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
)
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
)
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
training_status="running",
active_training_task_id=task_id,
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
)
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
# Should not raise
db.update_dataset_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
# =============================================================================
# Test API Response
# =============================================================================
class TestDatasetDetailResponseTrainingStatus:
"""Tests for DatasetDetailResponse including training status fields."""
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status="running",
active_training_task_id=str(uuid4()),
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path="/path/to/dataset",
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status == "running"
assert response.active_training_task_id is not None
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path=None,
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status is None
assert response.active_training_task_id is None
# =============================================================================
# Test Scheduler Training Status Updates
# =============================================================================
class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
task_id = str(uuid4())
dataset_id = str(uuid4())
# Execute task (will fail but we check the status update call)
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
dataset_id=dataset_id,
)
except Exception:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id
# =============================================================================
# Test Dataset Status Values
# =============================================================================
class TestDatasetStatusValues:
"""Tests for valid dataset status values."""
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:
dataset = TrainingDataset(name="test", training_status=status)
assert dataset.training_status == status

View File

@@ -0,0 +1,207 @@
"""
Tests for Document Category Feature.
TDD tests for adding category field to admin_documents table.
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock
from uuid import UUID, uuid4
from inference.data.admin_models import AdminDocument
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestAdminDocumentCategoryField:
"""Tests for AdminDocument category field."""
def test_document_has_category_field(self):
"""Test AdminDocument model has category field."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert hasattr(doc, "category")
def test_document_category_defaults_to_invoice(self):
"""Test category defaults to 'invoice' when not specified."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert doc.category == "invoice"
def test_document_accepts_custom_category(self):
"""Test document accepts custom category values."""
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
for cat in categories:
doc = AdminDocument(
document_id=uuid4(),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category=cat,
)
assert doc.category == cat
def test_document_category_is_string_type(self):
"""Test category field is a string type."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category="letter",
)
assert isinstance(doc.category, str)
class TestDocumentCategoryInReadModel:
"""Tests for category in response models."""
def test_admin_document_read_has_category(self):
"""Test AdminDocumentRead includes category field."""
from inference.data.admin_models import AdminDocumentRead
# Check the model has category field in its schema
assert "category" in AdminDocumentRead.model_fields
class TestDocumentCategoryAPI:
"""Tests for document category in API endpoints."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
return db
def test_upload_document_with_category(self, mock_admin_db):
"""Test uploading document with category parameter."""
from inference.web.schemas.admin import DocumentUploadResponse
# Verify response schema supports category
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
category="letter",
)
assert response.category == "letter"
def test_list_documents_returns_category(self, mock_admin_db):
"""Test list documents endpoint returns category."""
from inference.web.schemas.admin import DocumentItem
item = DocumentItem(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
annotation_count=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
category="invoice",
)
assert item.category == "invoice"
def test_document_detail_includes_category(self, mock_admin_db):
"""Test document detail response includes category."""
from inference.web.schemas.admin import DocumentDetailResponse
# Check schema has category
assert "category" in DocumentDetailResponse.model_fields
class TestDocumentCategoryFiltering:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB with category filtering support."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
db.get_documents_by_category.return_value = [invoice_doc]
return db
def test_filter_documents_by_category(self, mock_admin_db):
"""Test filtering documents by category."""
# This tests the DB method signature
result = mock_admin_db.get_documents_by_category("invoice")
assert len(result) == 1
assert result[0].category == "invoice"
class TestDocumentCategoryUpdate:
"""Tests for updating document category."""
def test_update_document_category_schema(self):
"""Test update document request supports category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="letter")
assert request.category == "letter"
def test_update_document_category_optional(self):
"""Test category is optional in update request."""
from inference.web.schemas.admin import DocumentUpdateRequest
# Should not raise - category is optional
request = DocumentUpdateRequest()
assert request.category is None
class TestDatasetWithCategory:
"""Tests for dataset creation with category filtering."""
def test_dataset_create_with_category_filter(self):
"""Test creating dataset can filter by document category."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Invoice Training Set",
document_ids=[TEST_DOC_UUID],
category="invoice", # Optional filter
)
assert request.category == "invoice"
def test_dataset_create_category_is_optional(self):
"""Test category filter is optional when creating dataset."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Mixed Training Set",
document_ids=[TEST_DOC_UUID],
)
# category should be optional
assert not hasattr(request, "category") or request.category is None

View File

@@ -0,0 +1,165 @@
"""
Tests for Document Category API Endpoints.
TDD tests for category filtering and management in document endpoints.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestGetCategoriesEndpoint:
"""Tests for GET /admin/documents/categories endpoint."""
def test_categories_endpoint_returns_list(self):
"""Test categories endpoint returns list of available categories."""
from inference.web.schemas.admin import DocumentCategoriesResponse
# Test schema exists and works
response = DocumentCategoriesResponse(
categories=["invoice", "letter", "receipt"],
total=3,
)
assert response.categories == ["invoice", "letter", "receipt"]
assert response.total == 3
def test_categories_response_schema(self):
"""Test DocumentCategoriesResponse schema structure."""
from inference.web.schemas.admin import DocumentCategoriesResponse
assert "categories" in DocumentCategoriesResponse.model_fields
assert "total" in DocumentCategoriesResponse.model_fields
class TestDocumentListFilterByCategory:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
invoice_doc.filename = "invoice1.pdf"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
letter_doc.filename = "letter1.pdf"
db.get_documents.return_value = ([invoice_doc], 1)
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
return db
def test_list_documents_accepts_category_filter(self, mock_admin_db):
"""Test list documents endpoint accepts category query parameter."""
# The endpoint should accept ?category=invoice parameter
# This test verifies the schema/query parameter exists
from inference.web.schemas.admin import DocumentListResponse
# Schema should work with category filter applied
assert DocumentListResponse is not None
def test_get_document_categories_from_db(self, mock_admin_db):
"""Test fetching unique categories from database."""
categories = mock_admin_db.get_document_categories()
assert "invoice" in categories
assert "letter" in categories
assert len(categories) == 3
class TestDocumentUploadWithCategory:
"""Tests for uploading documents with category."""
def test_upload_request_accepts_category(self):
"""Test upload request can include category field."""
# When uploading via form data, category should be accepted
# This is typically a form field, not a schema
pass
def test_upload_response_includes_category(self):
"""Test upload response includes the category that was set."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
category="letter", # Custom category
message="Upload successful",
)
assert response.category == "letter"
def test_upload_defaults_to_invoice_category(self):
"""Test upload defaults to 'invoice' if no category specified."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
# No category specified - should default to "invoice"
)
assert response.category == "invoice"
class TestAdminDBCategoryMethods:
"""Tests for AdminDB category-related methods."""
def test_get_document_categories_method_exists(self):
"""Test AdminDB has get_document_categories method."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "get_document_categories")
def test_get_documents_accepts_category_filter(self):
"""Test get_documents_by_token method accepts category parameter."""
from inference.data.admin_db import AdminDB
import inspect
db = AdminDB()
# Check the method exists and accepts category parameter
method = getattr(db, "get_documents_by_token", None)
assert callable(method)
# Check category is in the method signature
sig = inspect.signature(method)
assert "category" in sig.parameters
class TestUpdateDocumentCategory:
"""Tests for updating document category."""
def test_update_document_category_method_exists(self):
"""Test AdminDB has method to update document category."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "update_document_category")
def test_update_request_schema(self):
"""Test DocumentUpdateRequest can update category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="receipt")
assert request.category == "receipt"

View File

@@ -32,10 +32,10 @@ def test_app(tmp_path):
use_gpu=False,
dpi=150,
),
storage=StorageConfig(
file=StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
max_file_size_mb=50,
),
)
@@ -252,20 +252,25 @@ class TestResultsEndpoint:
response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
"""Test that existing result file is returned."""
# Get storage config from app
storage_config = test_app.extra.get("storage_config")
if not storage_config:
pytest.skip("Storage config not available in test app")
# Create a test result file
result_file = storage_config.result_dir / "test_result.png"
# Create a test result file in temp directory
result_dir = tmp_path / "results"
result_dir.mkdir(exist_ok=True)
result_file = result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red')
img.save(result_file)
# Request the file
response = client.get("/api/v1/results/test_result.png")
# Mock the storage helper to return our test file path
with patch(
"inference.web.api.v1.public.inference.get_storage_helper"
) as mock_storage:
mock_helper = Mock()
mock_helper.get_result_local_path.return_value = result_file
mock_storage.return_value = mock_helper
# Request the file
response = client.get("/api/v1/results/test_result.png")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"

View File

@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404

View File

@@ -0,0 +1,828 @@
"""Tests for storage helpers module."""
import pytest
from unittest.mock import MagicMock, patch
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
from shared.storage import PREFIXES
@pytest.fixture
def mock_storage() -> MagicMock:
"""Create a mock storage backend."""
storage = MagicMock()
storage.upload_bytes = MagicMock()
storage.download_bytes = MagicMock(return_value=b"test content")
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
storage.exists = MagicMock(return_value=True)
storage.delete = MagicMock(return_value=True)
storage.list_files = MagicMock(return_value=[])
return storage
@pytest.fixture
def helper(mock_storage: MagicMock) -> StorageHelper:
"""Create a storage helper with mock backend."""
return StorageHelper(storage=mock_storage)
class TestStorageHelperInit:
"""Tests for StorageHelper initialization."""
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
"""Should use provided storage backend."""
helper = StorageHelper(storage=mock_storage)
assert helper.storage is mock_storage
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Storage property should return the backend."""
assert helper.storage is mock_storage
class TestDocumentOperations:
"""Tests for document storage operations."""
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should upload document with correct path."""
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
assert doc_id == "doc123"
assert path == "documents/doc123.pdf"
mock_storage.upload_bytes.assert_called_once_with(
b"pdf content", "documents/doc123.pdf", overwrite=True
)
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
"""Should generate document ID if not provided."""
doc_id, path = helper.upload_document(b"content", "file.pdf")
assert doc_id is not None
assert len(doc_id) > 0
assert path.startswith("documents/")
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should download document from correct path."""
content = helper.download_document("doc123")
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for document."""
url = helper.get_document_url("doc123", expires_in_seconds=7200)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"documents/doc123.pdf", 7200
)
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check document existence."""
exists = helper.document_exists("doc123")
assert exists is True
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete document."""
result = helper.delete_document("doc123")
assert result is True
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
class TestImageOperations:
"""Tests for image storage operations."""
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save page image with correct path."""
path = helper.save_page_image("doc123", 1, b"image data")
assert path == "images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "images/doc123/page_1.png", overwrite=True
)
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get page image from correct path."""
content = helper.get_page_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for page image."""
url = helper.get_page_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"images/doc123/page_3.png", 3600
)
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all images for a document."""
mock_storage.list_files.return_value = [
"images/doc123/page_1.png",
"images/doc123/page_2.png",
]
deleted = helper.delete_document_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("images/doc123/")
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all images for a document."""
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
images = helper.list_document_images("doc123")
assert images == ["images/doc123/page_1.png"]
class TestUploadOperations:
"""Tests for upload staging operations."""
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload to correct path."""
path = helper.save_upload(b"content", "file.pdf")
assert path == "uploads/file.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload with subfolder."""
path = helper.save_upload(b"content", "file.pdf", "async")
assert path == "uploads/async/file.pdf"
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get upload from correct path."""
content = helper.get_upload("file.pdf", "async")
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete upload."""
result = helper.delete_upload("file.pdf")
assert result is True
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
class TestResultOperations:
"""Tests for result file operations."""
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save result to correct path."""
path = helper.save_result(b"result data", "output.json")
assert path == "results/output.json"
mock_storage.upload_bytes.assert_called_once()
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get result from correct path."""
content = helper.get_result("output.json")
mock_storage.download_bytes.assert_called_once_with("results/output.json")
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for result."""
url = helper.get_result_url("output.json")
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check result existence."""
exists = helper.result_exists("output.json")
assert exists is True
mock_storage.exists.assert_called_once_with("results/output.json")
class TestExportOperations:
"""Tests for export file operations."""
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save export to correct path."""
path = helper.save_export(b"export data", "exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
mock_storage.upload_bytes.assert_called_once()
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for export."""
url = helper.get_export_url("exp123", "dataset.zip")
mock_storage.get_presigned_url.assert_called_once_with(
"exports/exp123/dataset.zip", 3600
)
class TestRawPdfOperations:
"""Tests for raw PDF operations (legacy compatibility)."""
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save raw PDF to correct path."""
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get raw PDF from correct path."""
content = helper.get_raw_pdf("invoice.pdf")
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check raw PDF existence."""
exists = helper.raw_pdf_exists("invoice.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
class TestAdminImageOperations:
"""Tests for admin image storage operations."""
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save admin image with correct path."""
path = helper.save_admin_image("doc123", 1, b"image data")
assert path == "admin_images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "admin_images/doc123/page_1.png", overwrite=True
)
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get admin image from correct path."""
content = helper.get_admin_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for admin image."""
url = helper.get_admin_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"admin_images/doc123/page_3.png", 3600
)
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check admin image existence."""
exists = helper.admin_image_exists("doc123", 1)
assert exists is True
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
"""Should return correct admin image path."""
path = helper.get_admin_image_path("doc123", 2)
assert path == "admin_images/doc123/page_2.png"
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
images = helper.list_admin_images("doc123")
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
deleted = helper.delete_admin_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
class TestGetLocalPath:
"""Tests for get_local_path method."""
def test_get_admin_image_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "page_1.png").write_bytes(b"test image")
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is not None
assert local_path.exists()
assert local_path.name == "page_1.png"
def test_get_admin_image_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
# Mock storage without get_local_path method (simulating cloud storage)
mock_storage.get_local_path = MagicMock(return_value=None)
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is None
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_admin_image_local_path("nonexistent", 1)
assert local_path is None
class TestGetAdminImageDimensions:
"""Tests for get_admin_image_dimensions method."""
def test_get_dimensions_with_local_storage(self) -> None:
"""Should return image dimensions when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
from PIL import Image
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image with known dimensions
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
img = Image.new("RGB", (800, 600), color="white")
img.save(test_path / "page_1.png")
dimensions = helper.get_admin_image_dimensions("doc123", 1)
assert dimensions == (800, 600)
def test_get_dimensions_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
assert dimensions is None
class TestGetStorageHelper:
"""Tests for get_storage_helper function."""
def test_returns_helper_instance(self) -> None:
"""Should return a StorageHelper instance."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
# Reset the global helper
import inference.web.services.storage_helpers as module
module._default_helper = None
helper = get_storage_helper()
assert isinstance(helper, StorageHelper)
def test_returns_same_instance(self) -> None:
"""Should return the same instance on subsequent calls."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
import inference.web.services.storage_helpers as module
module._default_helper = None
helper1 = get_storage_helper()
helper2 = get_storage_helper()
assert helper1 is helper2
class TestDeleteResult:
"""Tests for delete_result method."""
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete result file."""
result = helper.delete_result("output.json")
assert result is True
mock_storage.delete.assert_called_once_with("results/output.json")
class TestResultLocalPath:
"""Tests for get_result_local_path method."""
def test_get_result_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test result file
test_path = Path(temp_dir) / "results"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "output.json").write_bytes(b"test result")
local_path = helper.get_result_local_path("output.json")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "output.json"
def test_get_result_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_result_local_path("output.json")
assert local_path is None
def test_get_result_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_result_local_path("nonexistent.json")
assert local_path is None
class TestResultsBasePath:
"""Tests for get_results_base_path method."""
def test_get_results_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_results_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "results"
def test_get_results_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_results_base_path()
assert base_path is None
class TestUploadLocalPath:
"""Tests for get_upload_local_path method."""
def test_get_upload_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file
test_path = Path(temp_dir) / "uploads"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "file.pdf"
def test_get_upload_local_path_with_subfolder(self) -> None:
"""Should return local path with subfolder."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file with subfolder
test_path = Path(temp_dir) / "uploads" / "async"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf", "async")
assert local_path is not None
assert local_path.exists()
def test_get_upload_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is None
class TestUploadsBasePath:
"""Tests for get_uploads_base_path method."""
def test_get_uploads_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "uploads"
def test_get_uploads_base_path_with_subfolder(self) -> None:
"""Should return base path with subfolder."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path("async")
assert base_path is not None
assert base_path.exists()
assert base_path.name == "async"
def test_get_uploads_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_uploads_base_path()
assert base_path is None
class TestUploadExists:
"""Tests for upload_exists method."""
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check upload existence."""
exists = helper.upload_exists("file.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
def test_upload_exists_with_subfolder(
self, helper: StorageHelper, mock_storage: MagicMock
) -> None:
"""Should check upload existence with subfolder."""
helper.upload_exists("file.pdf", "async")
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
class TestDatasetsBasePath:
"""Tests for get_datasets_base_path method."""
def test_get_datasets_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_datasets_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "datasets"
def test_get_datasets_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_datasets_base_path()
assert base_path is None
class TestAdminImagesBasePath:
"""Tests for get_admin_images_base_path method."""
def test_get_admin_images_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_admin_images_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "admin_images"
def test_get_admin_images_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_admin_images_base_path()
assert base_path is None
class TestRawPdfsBasePath:
"""Tests for get_raw_pdfs_base_path method."""
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "raw_pdfs"
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is None
class TestRawPdfLocalPath:
"""Tests for get_raw_pdf_local_path method."""
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test raw PDF
test_path = Path(temp_dir) / "raw_pdfs"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "invoice.pdf").write_bytes(b"test pdf")
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "invoice.pdf"
def test_get_raw_pdf_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is None
class TestRawPdfPath:
"""Tests for get_raw_pdf_path method."""
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
"""Should return correct storage path."""
path = helper.get_raw_pdf_path("invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
class TestAutolabelOutputPath:
"""Tests for get_autolabel_output_path method."""
def test_get_autolabel_output_path_with_local_storage(self) -> None:
"""Should return output path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
output_path = helper.get_autolabel_output_path()
assert output_path is not None
assert output_path.exists()
assert output_path.name == "autolabel_output"
def test_get_autolabel_output_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
output_path = helper.get_autolabel_output_path()
assert output_path is None
class TestTrainingDataPath:
"""Tests for get_training_data_path method."""
def test_get_training_data_path_with_local_storage(self) -> None:
"""Should return training path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
training_path = helper.get_training_data_path()
assert training_path is not None
assert training_path.exists()
assert training_path.name == "training"
def test_get_training_data_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
training_path = helper.get_training_data_path()
assert training_path is None
class TestExportsBasePath:
"""Tests for get_exports_base_path method."""
def test_get_exports_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_exports_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "exports"
def test_get_exports_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_exports_base_path()
assert base_path is None

View File

@@ -0,0 +1,306 @@
"""
Tests for storage backend integration in web application.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInitialization:
"""Tests for storage backend initialization in web config."""
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
"""Test that get_storage_backend returns a StorageBackend instance."""
from shared.storage.base import StorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend()
assert isinstance(backend, StorageBackend)
def test_get_storage_backend_uses_config_file_if_exists(
self, tmp_path: Path
) -> None:
"""Test that storage config file is used when present."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test fallback to environment variables when no config file."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
"""Test that AppConfig can be created with storage backend."""
from shared.storage.base import StorageBackend
from inference.web.config import AppConfig, create_app_config
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
config = create_app_config()
assert hasattr(config, "storage_backend")
assert isinstance(config.storage_backend, StorageBackend)
class TestStorageBackendInDocumentUpload:
"""Tests for storage backend usage in document upload."""
def test_upload_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document upload uses storage backend."""
from unittest.mock import AsyncMock
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a mock upload file
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
# Upload should use storage backend
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
assert result is not None
# Verify file was stored via storage backend
assert backend.exists(f"documents/{result.id}.pdf")
def test_upload_document_stores_logical_path(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document stores logical path, not absolute path."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
# Path should be logical (relative), not absolute
assert not result.file_path.startswith("/")
assert not result.file_path.startswith("C:")
assert result.file_path.startswith("documents/")
class TestStorageBackendInDocumentDownload:
"""Tests for storage backend usage in document download/serving."""
def test_get_document_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document URL uses presigned URL from storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_document_url(doc_path)
# Should return a URL (file:// for local, https:// for cloud)
assert url is not None
assert "test-doc.pdf" in url
def test_download_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document download uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
original_content = b"%PDF-1.4 test content"
backend.upload_bytes(original_content, doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
content = service.download_document(doc_path)
assert content == original_content
class TestStorageBackendInImageServing:
"""Tests for storage backend usage in image serving."""
def test_get_page_image_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image URL uses presigned URL."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test image
image_path = "images/doc-123/page_1.png"
backend.upload_bytes(b"fake png content", image_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_page_image_url("doc-123", 1)
assert url is not None
assert "page_1.png" in url
def test_save_page_image_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image saving uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
image_content = b"fake png content"
service.save_page_image("doc-123", 1, image_content)
# Verify image was stored
assert backend.exists("images/doc-123/page_1.png")
class TestStorageBackendInDocumentDeletion:
"""Tests for storage backend usage in document deletion."""
def test_delete_document_removes_from_storage(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes file from storage."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_files(doc_path)
assert not backend.exists(doc_path)
def test_delete_document_removes_images(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes associated images."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_id = "test-doc-123"
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_images(doc_id)
assert not backend.exists(f"images/{doc_id}/page_1.png")
assert not backend.exists(f"images/{doc_id}/page_2.png")
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
mock.get_document.return_value = None
mock.create_document.return_value = MagicMock(
id="test-doc-id",
file_path="documents/test-doc-id.pdf",
)
return mock

View File

@@ -103,6 +103,31 @@ class MockAnnotation:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockModelVersion:
"""Mock ModelVersion for testing."""
def __init__(self, **kwargs):
self.version_id = kwargs.get('version_id', uuid4())
self.version = kwargs.get('version', '1.0.0')
self.name = kwargs.get('name', 'Test Model')
self.description = kwargs.get('description', None)
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.status = kwargs.get('status', 'inactive')
self.is_active = kwargs.get('is_active', False)
self.task_id = kwargs.get('task_id', None)
self.dataset_id = kwargs.get('dataset_id', None)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.document_count = kwargs.get('document_count', 100)
self.training_config = kwargs.get('training_config', {})
self.file_size = kwargs.get('file_size', 52428800)
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
self.activated_at = kwargs.get('activated_at', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
@@ -111,6 +136,7 @@ class MockAdminDB:
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
self.model_versions = {}
def get_documents_for_training(
self,
@@ -174,6 +200,14 @@ class MockAdminDB:
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
models = [m for m in models if m.status == status]
total = len(models)
return models[offset:offset+limit], total
@pytest.fixture
def app():
@@ -241,6 +275,30 @@ def app():
)
mock_db.training_links[str(doc1.document_id)] = [link1]
# Add model versions
model1 = MockModelVersion(
version="1.0.0",
name="Model v1.0.0",
status="inactive",
is_active=False,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
model2 = MockModelVersion(
version="1.1.0",
name="Model v1.1.0",
status="active",
is_active=True,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint."""
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
def test_get_training_models_success(self, client):
"""Test getting trained models list."""
"""Test getting model versions list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
@@ -338,43 +396,44 @@ class TestTrainingModels:
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics."""
"""Test that model versions include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics
# Check first model has metrics fields
model = data["models"][0]
assert "metrics" in model
assert "mAP" in model["metrics"]
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
assert "metrics_mAP" in model
assert model["metrics_mAP"] is not None
def test_get_training_models_includes_download_url(self, client):
"""Test that completed models have download URLs."""
def test_get_training_models_includes_version_fields(self, client):
"""Test that model versions include version fields."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check completed models have download URLs
for model in data["models"]:
if model["status"] == "completed":
assert "download_url" in model
assert model["download_url"] is not None
# Check model has expected fields
model = data["models"][0]
assert "version_id" in model
assert "version" in model
assert "name" in model
assert "status" in model
assert "is_active" in model
assert "document_count" in model
def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status."""
response = client.get("/admin/training/models?status=completed")
"""Test filtering model versions by status."""
response = client.get("/admin/training/models?status=active")
assert response.status_code == 200
data = response.json()
# All returned models should be completed
assert data["total"] == 1
# All returned models should be active
for model in data["models"]:
assert model["status"] == "completed"
assert model["status"] == "active"
def test_get_training_models_pagination(self, client):
"""Test pagination for models."""
"""Test pagination for model versions."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200