This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -0,0 +1 @@
# Tests for storage module

View File

@@ -0,0 +1,718 @@
"""
Tests for AzureBlobStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
Uses mocking to avoid requiring actual Azure credentials.
"""
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@pytest.fixture
def mock_blob_service_client() -> MagicMock:
"""Create a mock BlobServiceClient."""
return MagicMock()
@pytest.fixture
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
"""Create a mock ContainerClient."""
container_client = MagicMock()
mock_blob_service_client.get_container_client.return_value = container_client
return container_client
@pytest.fixture
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
"""Create a mock BlobClient."""
blob_client = MagicMock()
mock_container_client.get_blob_client.return_value = blob_client
return blob_client
class TestAzureBlobStorageBackendCreation:
"""Tests for AzureBlobStorageBackend instantiation."""
@patch("shared.storage.azure.BlobServiceClient")
def test_create_with_connection_string(
self, mock_service_class: MagicMock
) -> None:
"""Test creating backend with connection string."""
from shared.storage.azure import AzureBlobStorageBackend
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
backend = AzureBlobStorageBackend(
connection_string=connection_string,
container_name="training-images",
)
mock_service_class.from_connection_string.assert_called_once_with(
connection_string
)
assert backend.container_name == "training-images"
@patch("shared.storage.azure.BlobServiceClient")
def test_create_creates_container_if_not_exists(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is created if it doesn't exist."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="new-container",
create_container=True,
)
mock_container.create_container.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_create_does_not_create_container_by_default(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is not created by default."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="existing-container",
)
mock_container.create_container.assert_not_called()
@patch("shared.storage.azure.BlobServiceClient")
def test_is_storage_backend_subclass(
self, mock_service_class: MagicMock
) -> None:
"""Test that AzureBlobStorageBackend is a StorageBackend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageBackend
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert isinstance(backend, StorageBackend)
class TestAzureBlobStorageBackendUpload:
"""Tests for AzureBlobStorageBackend.upload method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_file(self, mock_service_class: MagicMock) -> None:
"""Test uploading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"Hello, World!")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
mock_blob.upload_blob.assert_called_once()
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_fails_if_blob_exists_without_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload fails if blob exists and overwrite is False."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
with pytest.raises(StorageError, match="already exists"):
backend.upload(temp_path, "existing.txt", overwrite=False)
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_succeeds_with_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "existing.txt", overwrite=True)
assert result == "existing.txt"
mock_blob.upload_blob.assert_called_once()
# Check overwrite=True was passed
call_kwargs = mock_blob.upload_blob.call_args[1]
assert call_kwargs.get("overwrite") is True
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_nonexistent_file_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
class TestAzureBlobStorageBackendDownload:
"""Tests for AzureBlobStorageBackend.download method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_file(self, mock_service_class: MagicMock) -> None:
"""Test downloading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
# Mock download_blob to return stream
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "downloaded.txt"
result = backend.download("remote/sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_bytes() == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_creates_parent_directories(
self, mock_service_class: MagicMock
) -> None:
"""Test that download creates parent directories."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"content"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
@patch("shared.storage.azure.BlobServiceClient")
def test_download_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that downloading nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
class TestAzureBlobStorageBackendExists:
"""Tests for AzureBlobStorageBackend.exists method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_true_for_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns True for existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("existing.txt") is True
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_false_for_nonexistent_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns False for nonexistent blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("nonexistent.txt") is False
class TestAzureBlobStorageBackendListFiles:
"""Tests for AzureBlobStorageBackend.list_files method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_empty_container(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files in empty container."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.list_blobs.return_value = []
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.list_files("") == []
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_returns_all_blobs(
self, mock_service_class: MagicMock
) -> None:
"""Test listing all blobs."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Create mock blob items
mock_blob1 = MagicMock()
mock_blob1.name = "file1.txt"
mock_blob2 = MagicMock()
mock_blob2.name = "file2.txt"
mock_blob3 = MagicMock()
mock_blob3.name = "subdir/file3.txt"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_with_prefix(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("images/")
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
assert len(files) == 2
class TestAzureBlobStorageBackendDelete:
"""Tests for AzureBlobStorageBackend.delete method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting an existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("sample.txt")
assert result is True
mock_blob.delete_blob.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_nonexistent_blob_returns_false(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting nonexistent blob returns False."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("nonexistent.txt")
assert result is False
mock_blob.delete_blob.assert_not_called()
class TestAzureBlobStorageBackendGetUrl:
"""Tests for AzureBlobStorageBackend.get_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_returns_blob_url(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url returns blob URL."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
url = backend.get_url("sample.txt")
assert url == "https://account.blob.core.windows.net/container/sample.txt"
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url for nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestAzureBlobStorageBackendUploadBytes:
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
"""Test uploading bytes directly."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_blob.upload_blob.assert_called_once()
class TestAzureBlobStorageBackendDownloadBytes:
"""Tests for AzureBlobStorageBackend.download_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
"""Test downloading blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes_nonexistent(
self, mock_service_class: MagicMock
) -> None:
"""Test downloading nonexistent blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestAzureBlobStorageBackendBatchOperations:
"""Tests for batch operations in AzureBlobStorageBackend."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
"""Test uploading an entire directory."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
(temp_path / "file1.txt").write_text("content1")
(temp_path / "subdir").mkdir()
(temp_path / "subdir" / "file2.txt").write_text("content2")
results = backend.upload_directory(temp_path, "uploads/")
assert len(results) == 2
assert "uploads/file1.txt" in results
assert "uploads/subdir/file2.txt" in results
@patch("shared.storage.azure.BlobServiceClient")
def test_download_directory(self, mock_service_class: MagicMock) -> None:
"""Test downloading blobs matching a prefix."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Mock blob listing
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
# Mock blob clients
mock_blob_client = MagicMock()
mock_container.get_blob_client.return_value = mock_blob_client
mock_blob_client.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"image content"
mock_blob_client.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
results = backend.download_directory("images/", local_path)
assert len(results) == 2
# Files should be created relative to prefix
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()

View File

@@ -0,0 +1,301 @@
"""
Tests for storage base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from abc import ABC
from pathlib import Path
from typing import BinaryIO
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInterface:
"""Tests for StorageBackend abstract base class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that StorageBackend cannot be instantiated."""
from shared.storage.base import StorageBackend
with pytest.raises(TypeError):
StorageBackend() # type: ignore
def test_is_abstract_base_class(self) -> None:
"""Test that StorageBackend is an ABC."""
from shared.storage.base import StorageBackend
assert issubclass(StorageBackend, ABC)
def test_subclass_must_implement_upload(self) -> None:
"""Test that subclass must implement upload method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_download(self) -> None:
"""Test that subclass must implement download method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_exists(self) -> None:
"""Test that subclass must implement exists method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_list_files(self) -> None:
"""Test that subclass must implement list_files method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_delete(self) -> None:
"""Test that subclass must implement delete method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_get_url(self) -> None:
"""Test that subclass must implement get_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_can_be_instantiated(self) -> None:
"""Test that a complete subclass can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return ""
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestStorageError:
"""Tests for StorageError exception."""
def test_storage_error_is_exception(self) -> None:
"""Test that StorageError is an Exception."""
from shared.storage.base import StorageError
assert issubclass(StorageError, Exception)
def test_storage_error_with_message(self) -> None:
"""Test StorageError with message."""
from shared.storage.base import StorageError
error = StorageError("Upload failed")
assert str(error) == "Upload failed"
def test_storage_error_can_be_raised(self) -> None:
"""Test that StorageError can be raised and caught."""
from shared.storage.base import StorageError
with pytest.raises(StorageError, match="test error"):
raise StorageError("test error")
class TestFileNotFoundError:
"""Tests for FileNotFoundStorageError exception."""
def test_file_not_found_is_storage_error(self) -> None:
"""Test that FileNotFoundStorageError is a StorageError."""
from shared.storage.base import FileNotFoundStorageError, StorageError
assert issubclass(FileNotFoundStorageError, StorageError)
def test_file_not_found_with_path(self) -> None:
"""Test FileNotFoundStorageError with path."""
from shared.storage.base import FileNotFoundStorageError
error = FileNotFoundStorageError("images/test.png")
assert "images/test.png" in str(error)
class TestStorageConfig:
"""Tests for StorageConfig dataclass."""
def test_storage_config_creation(self) -> None:
"""Test creating StorageConfig."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
def test_storage_config_defaults(self) -> None:
"""Test StorageConfig with defaults."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
assert config.backend_type == "local"
assert config.connection_string is None
assert config.container_name is None
assert config.base_path is None
def test_storage_config_with_base_path(self) -> None:
"""Test StorageConfig with base_path for local backend."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="local",
base_path=Path("/data/images"),
)
assert config.backend_type == "local"
assert config.base_path == Path("/data/images")
def test_storage_config_immutable(self) -> None:
"""Test that StorageConfig is immutable (frozen)."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore

View File

@@ -0,0 +1,348 @@
"""
Tests for storage configuration file loader.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
class TestEnvVarSubstitution:
"""Tests for environment variable substitution in config values."""
def test_substitute_simple_env_var(self) -> None:
"""Test substituting a simple environment variable."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
result = substitute_env_vars("${MY_VAR}")
assert result == "my_value"
def test_substitute_env_var_with_default(self) -> None:
"""Test substituting env var with default when var is not set."""
from shared.storage.config_loader import substitute_env_vars
# Ensure var is not set
os.environ.pop("UNSET_VAR", None)
result = substitute_env_vars("${UNSET_VAR:-default_value}")
assert result == "default_value"
def test_substitute_env_var_ignores_default_when_set(self) -> None:
"""Test that default is ignored when env var is set."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
result = substitute_env_vars("${SET_VAR:-default_value}")
assert result == "actual_value"
def test_substitute_multiple_env_vars(self) -> None:
"""Test substituting multiple env vars in one string."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
assert result == "postgres://localhost:5432/db"
def test_substitute_preserves_non_env_text(self) -> None:
"""Test that non-env-var text is preserved."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"VAR": "value"}):
result = substitute_env_vars("prefix_${VAR}_suffix")
assert result == "prefix_value_suffix"
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
"""Test that empty string is returned when var not set and no default."""
from shared.storage.config_loader import substitute_env_vars
os.environ.pop("MISSING_VAR", None)
result = substitute_env_vars("${MISSING_VAR}")
assert result == ""
class TestLoadStorageConfigYaml:
"""Tests for loading storage configuration from YAML files."""
def test_load_local_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for local backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
presigned_url_expiry: 3600
local:
base_path: ./data/storage
""")
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.presigned_url_expiry == 3600
assert config.local is not None
assert config.local.base_path == Path("./data/storage")
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for Azure backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: azure_blob
presigned_url_expiry: 7200
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
create_container: true
""")
config = load_storage_config(config_path)
assert config.backend_type == "azure_blob"
assert config.presigned_url_expiry == 7200
assert config.azure is not None
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
assert config.azure.container_name == "documents"
assert config.azure.create_container is True
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for S3 backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: s3
presigned_url_expiry: 1800
s3:
bucket_name: my-bucket
region_name: us-west-2
endpoint_url: http://localhost:9000
create_bucket: false
""")
config = load_storage_config(config_path)
assert config.backend_type == "s3"
assert config.presigned_url_expiry == 1800
assert config.s3 is not None
assert config.s3.bucket_name == "my-bucket"
assert config.s3.region_name == "us-west-2"
assert config.s3.endpoint_url == "http://localhost:9000"
assert config.s3.create_bucket is False
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
"""Test that environment variables are substituted in config."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${STORAGE_PATH:-./default/path}
""")
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.local is not None
assert config.local.base_path == Path("/custom/path")
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
"""Test that FileNotFoundError is raised for missing config file."""
from shared.storage.config_loader import load_storage_config
with pytest.raises(FileNotFoundError):
load_storage_config(temp_dir / "nonexistent.yaml")
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised for invalid YAML."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("invalid: yaml: content: [")
with pytest.raises(ValueError, match="Invalid"):
load_storage_config(config_path)
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised when backend is missing."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
local:
base_path: ./data
""")
with pytest.raises(ValueError, match="backend"):
load_storage_config(config_path)
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
"""Test default presigned_url_expiry when not specified."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
local:
base_path: ./data
""")
config = load_storage_config(config_path)
assert config.presigned_url_expiry == 3600 # Default value
class TestStorageFileConfig:
"""Tests for StorageFileConfig dataclass."""
def test_storage_file_config_is_immutable(self) -> None:
"""Test that StorageFileConfig is frozen (immutable)."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore
def test_storage_file_config_defaults(self) -> None:
"""Test StorageFileConfig default values."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
assert config.backend_type == "local"
assert config.local is None
assert config.azure is None
assert config.s3 is None
assert config.presigned_url_expiry == 3600
class TestLocalConfig:
"""Tests for LocalConfig dataclass."""
def test_local_config_creation(self) -> None:
"""Test creating LocalConfig."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data/storage"))
assert config.base_path == Path("/data/storage")
def test_local_config_is_immutable(self) -> None:
"""Test that LocalConfig is frozen."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data"))
with pytest.raises(AttributeError):
config.base_path = Path("/other") # type: ignore
class TestAzureConfig:
"""Tests for AzureConfig dataclass."""
def test_azure_config_creation(self) -> None:
"""Test creating AzureConfig."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="test_connection",
container_name="test_container",
create_container=True,
)
assert config.connection_string == "test_connection"
assert config.container_name == "test_container"
assert config.create_container is True
def test_azure_config_defaults(self) -> None:
"""Test AzureConfig default values."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
assert config.create_container is False
def test_azure_config_is_immutable(self) -> None:
"""Test that AzureConfig is frozen."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
with pytest.raises(AttributeError):
config.container_name = "other" # type: ignore
class TestS3Config:
"""Tests for S3Config dataclass."""
def test_s3_config_creation(self) -> None:
"""Test creating S3Config."""
from shared.storage.config_loader import S3Config
config = S3Config(
bucket_name="my-bucket",
region_name="us-east-1",
access_key_id="AKIAIOSFODNN7EXAMPLE",
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
endpoint_url="http://localhost:9000",
create_bucket=True,
)
assert config.bucket_name == "my-bucket"
assert config.region_name == "us-east-1"
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
assert config.endpoint_url == "http://localhost:9000"
assert config.create_bucket is True
def test_s3_config_minimal(self) -> None:
"""Test S3Config with only required fields."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
assert config.bucket_name == "bucket"
assert config.region_name is None
assert config.access_key_id is None
assert config.secret_access_key is None
assert config.endpoint_url is None
assert config.create_bucket is False
def test_s3_config_is_immutable(self) -> None:
"""Test that S3Config is frozen."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
with pytest.raises(AttributeError):
config.bucket_name = "other" # type: ignore

View File

@@ -0,0 +1,423 @@
"""
Tests for storage factory.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageFactory:
"""Tests for create_storage_backend factory function."""
def test_create_local_backend(self) -> None:
"""Test creating local storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
config = StorageConfig(
backend_type="local",
base_path=Path(temp_dir),
)
backend = create_storage_backend(config)
assert isinstance(backend, LocalStorageBackend)
assert backend.base_path == Path(temp_dir)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure blob storage backend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
backend = create_storage_backend(config)
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_unknown_backend_raises(self) -> None:
"""Test that unknown backend type raises ValueError."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="unknown_backend")
with pytest.raises(ValueError, match="Unknown storage backend"):
create_storage_backend(config)
def test_create_local_requires_base_path(self) -> None:
"""Test that local backend requires base_path."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="local")
with pytest.raises(ValueError, match="base_path"):
create_storage_backend(config)
def test_create_azure_requires_connection_string(self) -> None:
"""Test that Azure backend requires connection_string."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
container_name="container",
)
with pytest.raises(ValueError, match="connection_string"):
create_storage_backend(config)
def test_create_azure_requires_container_name(self) -> None:
"""Test that Azure backend requires container_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="connection_string",
)
with pytest.raises(ValueError, match="container_name"):
create_storage_backend(config)
class TestStorageFactoryFromEnv:
"""Tests for create_storage_backend_from_env factory function."""
def test_create_from_env_local(self) -> None:
"""Test creating local backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure backend from environment variables."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_from_env_defaults_to_local(self) -> None:
"""Test that factory defaults to local backend."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BASE_PATH": temp_dir,
}
# Remove STORAGE_BACKEND if present
with patch.dict(os.environ, env, clear=False):
if "STORAGE_BACKEND" in os.environ:
del os.environ["STORAGE_BACKEND"]
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
def test_create_from_env_missing_azure_vars(self) -> None:
"""Test error when Azure env vars are missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
# Missing AZURE_STORAGE_CONNECTION_STRING
}
with patch.dict(os.environ, env, clear=False):
# Remove the connection string if present
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
create_storage_backend_from_env()
class TestGetDefaultStorageConfig:
"""Tests for get_default_storage_config function."""
def test_get_default_config_local(self) -> None:
"""Test getting default local config."""
from shared.storage.factory import get_default_storage_config
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "local"
assert config.base_path == Path(temp_dir)
def test_get_default_config_azure(self) -> None:
"""Test getting default Azure config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
class TestStorageFactoryS3:
"""Tests for S3 backend support in factory."""
@patch("boto3.client")
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.s3 import S3StorageBackend
config = StorageConfig(
backend_type="s3",
bucket_name="test-bucket",
region_name="us-west-2",
)
backend = create_storage_backend(config)
assert isinstance(backend, S3StorageBackend)
def test_create_s3_requires_bucket_name(self) -> None:
"""Test that S3 backend requires bucket_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="s3",
region_name="us-west-2",
)
with pytest.raises(ValueError, match="bucket_name"):
create_storage_backend(config)
@patch("boto3.client")
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.s3 import S3StorageBackend
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-east-1",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, S3StorageBackend)
def test_create_from_env_s3_missing_bucket(self) -> None:
"""Test error when S3 bucket env var is missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "s3",
# Missing AWS_S3_BUCKET
}
with patch.dict(os.environ, env, clear=False):
if "AWS_S3_BUCKET" in os.environ:
del os.environ["AWS_S3_BUCKET"]
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
create_storage_backend_from_env()
def test_get_default_config_s3(self) -> None:
"""Test getting default S3 config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-west-2",
"AWS_ENDPOINT_URL": "http://localhost:9000",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "s3"
assert config.bucket_name == "test-bucket"
assert config.region_name == "us-west-2"
assert config.endpoint_url == "http://localhost:9000"
class TestStorageFactoryFromFile:
"""Tests for create_storage_backend_from_file factory function."""
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
"""Test creating local backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_yaml_file_azure(
self, mock_service_class: MagicMock, tmp_path: Path
) -> None:
"""Test creating Azure backend from YAML config file."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_file
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: azure_blob
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, AzureBlobStorageBackend)
@patch("boto3.client")
def test_create_from_yaml_file_s3(
self, mock_boto3_client: MagicMock, tmp_path: Path
) -> None:
"""Test creating S3 backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.s3 import S3StorageBackend
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: s3
s3:
bucket_name: my-bucket
region_name: us-east-1
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, S3StorageBackend)
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
"""Test that env vars are substituted in config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${CUSTOM_STORAGE_PATH}
""")
with patch.dict(
os.environ,
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
):
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
"""Test that FileNotFoundError is raised for missing file."""
from shared.storage.factory import create_storage_backend_from_file
with pytest.raises(FileNotFoundError):
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
class TestGetStorageBackend:
"""Tests for get_storage_backend convenience function."""
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
"""Test getting backend from explicit config file."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test that get_storage_backend falls back to env vars."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(storage_path),
}
with patch.dict(os.environ, env, clear=False):
# No config file provided, should use env vars
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)

View File

@@ -0,0 +1,712 @@
"""
Tests for LocalStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def sample_image(temp_storage_dir: Path) -> Path:
"""Create a sample PNG file for testing."""
file_path = temp_storage_dir / "sample.png"
# Minimal valid PNG (1x1 transparent pixel)
png_data = bytes(
[
0x89,
0x50,
0x4E,
0x47,
0x0D,
0x0A,
0x1A,
0x0A, # PNG signature
0x00,
0x00,
0x00,
0x0D, # IHDR length
0x49,
0x48,
0x44,
0x52, # IHDR
0x00,
0x00,
0x00,
0x01, # width: 1
0x00,
0x00,
0x00,
0x01, # height: 1
0x08,
0x06,
0x00,
0x00,
0x00, # 8-bit RGBA
0x1F,
0x15,
0xC4,
0x89, # CRC
0x00,
0x00,
0x00,
0x0A, # IDAT length
0x49,
0x44,
0x41,
0x54, # IDAT
0x78,
0x9C,
0x63,
0x00,
0x01,
0x00,
0x00,
0x05,
0x00,
0x01, # compressed data
0x0D,
0x0A,
0x2D,
0xB4, # CRC
0x00,
0x00,
0x00,
0x00, # IEND length
0x49,
0x45,
0x4E,
0x44, # IEND
0xAE,
0x42,
0x60,
0x82, # CRC
]
)
file_path.write_bytes(png_data)
return file_path
class TestLocalStorageBackendCreation:
"""Tests for LocalStorageBackend instantiation."""
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with base path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.base_path == temp_storage_dir
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with string path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
assert backend.base_path == temp_storage_dir
def test_create_creates_directory_if_not_exists(
self, temp_storage_dir: Path
) -> None:
"""Test that base directory is created if it doesn't exist."""
from shared.storage.local import LocalStorageBackend
new_dir = temp_storage_dir / "new_storage"
assert not new_dir.exists()
backend = LocalStorageBackend(base_path=new_dir)
assert new_dir.exists()
assert backend.base_path == new_dir
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
"""Test that LocalStorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendUpload:
"""Tests for LocalStorageBackend.upload method."""
def test_upload_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
assert (storage_dir / "uploads" / "sample.txt").exists()
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
def test_upload_creates_subdirectories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload creates necessary subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
def test_upload_fails_if_file_exists_without_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload fails if file exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload succeeds
backend.upload(sample_file, "sample.txt")
# Second upload should fail
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Modify original file
sample_file.write_text("Modified content")
# Second upload with overwrite
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
assert (storage_dir / "sample.txt").read_text() == "Modified content"
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
def test_upload_binary_file(
self, temp_storage_dir: Path, sample_image: Path
) -> None:
"""Test uploading a binary file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_image, "images/sample.png")
assert result == "images/sample.png"
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
assert uploaded_content == sample_image.read_bytes()
class TestLocalStorageBackendDownload:
"""Tests for LocalStorageBackend.download method."""
def test_download_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
download_dir = temp_storage_dir / "downloads"
download_dir.mkdir()
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Then download
local_path = download_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_creates_parent_directories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that downloading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
def test_download_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file from nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
local_path = temp_storage_dir / "downloaded.txt"
result = backend.download("a/b/c/sample.txt", local_path)
assert local_path.read_text() == "Hello, World!"
class TestLocalStorageBackendExists:
"""Tests for LocalStorageBackend.exists method."""
def test_exists_returns_true_for_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists returns True for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_file(
self, temp_storage_dir: Path
) -> None:
"""Test exists returns False for nonexistent file."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.exists("nonexistent.txt") is False
def test_exists_with_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists with nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
assert backend.exists("a/b/sample.txt") is True
assert backend.exists("a/b/other.txt") is False
class TestLocalStorageBackendListFiles:
"""Tests for LocalStorageBackend.list_files method."""
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
"""Test listing files in empty storage."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.list_files("") == []
def test_list_files_returns_all_files(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing all files."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Upload multiple files
backend.upload(sample_file, "file1.txt")
backend.upload(sample_file, "file2.txt")
backend.upload(sample_file, "subdir/file3.txt")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "images/a.png")
backend.upload(sample_file, "images/b.png")
backend.upload(sample_file, "labels/a.txt")
files = backend.list_files("images/")
assert len(files) == 2
assert "images/a.png" in files
assert "images/b.png" in files
assert "labels/a.txt" not in files
def test_list_files_returns_sorted(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that list_files returns sorted list."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "c.txt")
backend.upload(sample_file, "a.txt")
backend.upload(sample_file, "b.txt")
files = backend.list_files("")
assert files == ["a.txt", "b.txt", "c.txt"]
class TestLocalStorageBackendDelete:
"""Tests for LocalStorageBackend.delete method."""
def test_delete_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting an existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
result = backend.delete("sample.txt")
assert result is True
assert not (storage_dir / "sample.txt").exists()
def test_delete_nonexistent_file_returns_false(
self, temp_storage_dir: Path
) -> None:
"""Test deleting nonexistent file returns False."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
result = backend.delete("nonexistent.txt")
assert result is False
def test_delete_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting a nested file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
result = backend.delete("a/b/sample.txt")
assert result is True
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
class TestLocalStorageBackendGetUrl:
"""Tests for LocalStorageBackend.get_url method."""
def test_get_url_returns_file_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_url returns file:// URL."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_url("sample.txt")
# Should return file:// URL or absolute path
assert "sample.txt" in url
# URL should be usable to locate the file
expected_path = storage_dir / "sample.txt"
assert str(expected_path) in url or expected_path.as_uri() == url
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
"""Test get_url for nonexistent file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestLocalStorageBackendUploadBytes:
"""Tests for LocalStorageBackend.upload_bytes method."""
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
"""Test uploading bytes directly."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
assert (storage_dir / "binary.dat").read_bytes() == data
def test_upload_bytes_creates_subdirectories(
self, temp_storage_dir: Path
) -> None:
"""Test that upload_bytes creates subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"content"
backend.upload_bytes(data, "a/b/c/file.dat")
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
class TestLocalStorageBackendDownloadBytes:
"""Tests for LocalStorageBackend.download_bytes method."""
def test_download_bytes(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading file as bytes."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
"""Test downloading nonexistent file as bytes."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestLocalStorageBackendSecurity:
"""Security tests for LocalStorageBackend - path traversal prevention."""
def test_path_traversal_with_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that path traversal using ../ is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "../escape.txt")
def test_path_traversal_with_nested_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "subdir/../../escape.txt")
def test_path_traversal_with_many_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that deeply nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
def test_absolute_path_unix_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Unix paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "/etc/passwd")
def test_absolute_path_windows_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Windows paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "C:\\Windows\\System32\\config")
def test_download_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download("../escape.txt", Path("/tmp/file.txt"))
def test_exists_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in exists is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.exists("../escape.txt")
def test_delete_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in delete is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.delete("../escape.txt")
def test_get_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_url("../escape.txt")
def test_upload_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in upload_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload_bytes(b"content", "../escape.txt")
def test_download_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download_bytes("../escape.txt")
def test_valid_nested_path_still_works(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that valid nested paths still work after security fix."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Valid nested paths should still work
result = backend.upload(sample_file, "a/b/c/d/file.txt")
assert result == "a/b/c/d/file.txt"
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()

View File

@@ -0,0 +1,158 @@
"""Tests for storage prefixes module."""
import pytest
from shared.storage.prefixes import PREFIXES, StoragePrefixes
class TestStoragePrefixes:
"""Tests for StoragePrefixes class."""
def test_prefixes_are_strings(self) -> None:
"""All prefix constants should be strings."""
assert isinstance(PREFIXES.DOCUMENTS, str)
assert isinstance(PREFIXES.IMAGES, str)
assert isinstance(PREFIXES.UPLOADS, str)
assert isinstance(PREFIXES.RESULTS, str)
assert isinstance(PREFIXES.EXPORTS, str)
assert isinstance(PREFIXES.DATASETS, str)
assert isinstance(PREFIXES.MODELS, str)
assert isinstance(PREFIXES.RAW_PDFS, str)
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
def test_prefixes_are_non_empty(self) -> None:
"""All prefix constants should be non-empty."""
assert PREFIXES.DOCUMENTS
assert PREFIXES.IMAGES
assert PREFIXES.UPLOADS
assert PREFIXES.RESULTS
assert PREFIXES.EXPORTS
assert PREFIXES.DATASETS
assert PREFIXES.MODELS
assert PREFIXES.RAW_PDFS
assert PREFIXES.STRUCTURED_DATA
assert PREFIXES.ADMIN_IMAGES
def test_prefixes_have_no_leading_slash(self) -> None:
"""Prefixes should not start with a slash for portability."""
assert not PREFIXES.DOCUMENTS.startswith("/")
assert not PREFIXES.IMAGES.startswith("/")
assert not PREFIXES.UPLOADS.startswith("/")
assert not PREFIXES.RESULTS.startswith("/")
def test_prefixes_have_no_trailing_slash(self) -> None:
"""Prefixes should not end with a slash."""
assert not PREFIXES.DOCUMENTS.endswith("/")
assert not PREFIXES.IMAGES.endswith("/")
assert not PREFIXES.UPLOADS.endswith("/")
assert not PREFIXES.RESULTS.endswith("/")
def test_frozen_dataclass(self) -> None:
"""StoragePrefixes should be immutable."""
with pytest.raises(Exception): # FrozenInstanceError
PREFIXES.DOCUMENTS = "new_value" # type: ignore
class TestDocumentPath:
"""Tests for document_path helper."""
def test_document_path_with_extension(self) -> None:
"""Should generate correct document path with extension."""
path = PREFIXES.document_path("abc123", ".pdf")
assert path == "documents/abc123.pdf"
def test_document_path_without_leading_dot(self) -> None:
"""Should handle extension without leading dot."""
path = PREFIXES.document_path("abc123", "pdf")
assert path == "documents/abc123.pdf"
def test_document_path_default_extension(self) -> None:
"""Should use .pdf as default extension."""
path = PREFIXES.document_path("abc123")
assert path == "documents/abc123.pdf"
class TestImagePath:
"""Tests for image_path helper."""
def test_image_path_basic(self) -> None:
"""Should generate correct image path."""
path = PREFIXES.image_path("doc123", 1)
assert path == "images/doc123/page_1.png"
def test_image_path_page_number(self) -> None:
"""Should include page number in path."""
path = PREFIXES.image_path("doc123", 5)
assert path == "images/doc123/page_5.png"
def test_image_path_custom_extension(self) -> None:
"""Should support custom extension."""
path = PREFIXES.image_path("doc123", 1, ".jpg")
assert path == "images/doc123/page_1.jpg"
class TestUploadPath:
"""Tests for upload_path helper."""
def test_upload_path_basic(self) -> None:
"""Should generate correct upload path."""
path = PREFIXES.upload_path("invoice.pdf")
assert path == "uploads/invoice.pdf"
def test_upload_path_with_subfolder(self) -> None:
"""Should include subfolder when provided."""
path = PREFIXES.upload_path("invoice.pdf", "async")
assert path == "uploads/async/invoice.pdf"
class TestResultPath:
"""Tests for result_path helper."""
def test_result_path_basic(self) -> None:
"""Should generate correct result path."""
path = PREFIXES.result_path("output.json")
assert path == "results/output.json"
class TestExportPath:
"""Tests for export_path helper."""
def test_export_path_basic(self) -> None:
"""Should generate correct export path."""
path = PREFIXES.export_path("exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
class TestDatasetPath:
"""Tests for dataset_path helper."""
def test_dataset_path_basic(self) -> None:
"""Should generate correct dataset path."""
path = PREFIXES.dataset_path("ds123", "data.yaml")
assert path == "datasets/ds123/data.yaml"
class TestModelPath:
"""Tests for model_path helper."""
def test_model_path_basic(self) -> None:
"""Should generate correct model path."""
path = PREFIXES.model_path("v1.0.0", "best.pt")
assert path == "models/v1.0.0/best.pt"
class TestExportsFromInit:
"""Tests for exports from storage __init__.py."""
def test_prefixes_exported(self) -> None:
"""PREFIXES should be exported from storage module."""
from shared.storage import PREFIXES as exported_prefixes
assert exported_prefixes is PREFIXES
def test_storage_prefixes_exported(self) -> None:
"""StoragePrefixes should be exported from storage module."""
from shared.storage import StoragePrefixes as exported_class
assert exported_class is StoragePrefixes

View File

@@ -0,0 +1,264 @@
"""
Tests for pre-signed URL functionality across all storage backends.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
class TestStorageBackendInterfacePresignedUrl:
"""Tests for get_presigned_url in StorageBackend interface."""
def test_subclass_must_implement_get_presigned_url(self) -> None:
"""Test that subclass must implement get_presigned_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
"""Test that a complete subclass with get_presigned_url can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return f"https://example.com/{remote_path}?token=abc"
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendPresignedUrl:
"""Tests for LocalStorageBackend.get_presigned_url method."""
def test_get_presigned_url_returns_file_uri(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url returns file:// URI for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_presigned_url("sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
def test_get_presigned_url_with_custom_expiry(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url accepts expires_in_seconds parameter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
# For local storage, expiry is ignored but should not raise error
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
assert url.startswith("file://")
def test_get_presigned_url_nonexistent_file_raises(
self, temp_storage_dir: Path
) -> None:
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
def test_get_presigned_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_presigned_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_presigned_url("../escape.txt")
def test_get_presigned_url_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url works with nested paths."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
url = backend.get_presigned_url("a/b/c/sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
class TestAzureBlobStorageBackendPresignedUrl:
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_generates_sas_url(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url generates URL with SAS token."""
from shared.storage.azure import AzureBlobStorageBackend
# Setup mocks
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
assert "https://testaccount.blob.core.windows.net" in url
assert "sv=2021-06-08" in url or "test.txt" in url
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_nonexistent_blob_raises(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent blob."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = False
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_uses_custom_expiry(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry time."""
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
# Verify generate_blob_sas was called (expiry is part of the call)
mock_generate_sas.assert_called_once()

View File

@@ -0,0 +1,520 @@
"""
Tests for S3StorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch, call
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def mock_boto3_client():
"""Create a mock boto3 S3 client."""
with patch("boto3.client") as mock_client_func:
mock_client = MagicMock()
mock_client_func.return_value = mock_client
yield mock_client
class TestS3StorageBackendCreation:
"""Tests for S3StorageBackend instantiation."""
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with bucket name."""
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.bucket_name == "test-bucket"
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with region."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
region_name="us-west-2",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("region_name") == "us-west-2"
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with explicit credentials."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
access_key_id="AKIATEST",
secret_access_key="secret123",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
assert call_kwargs.get("aws_secret_access_key") == "secret123"
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with custom endpoint (for S3-compatible services)."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
endpoint_url="http://localhost:9000",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
"""Test that bucket is created when create_bucket=True."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_bucket.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadBucket"
)
S3StorageBackend(
bucket_name="test-bucket",
create_bucket=True,
)
mock_boto3_client.create_bucket.assert_called_once()
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
"""Test that S3StorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert isinstance(backend, StorageBackend)
class TestS3StorageBackendUpload:
"""Tests for S3StorageBackend.upload method."""
def test_upload_file(
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
# Object does not exist
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_nonexistent_file_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
class TestS3StorageBackendDownload:
"""Tests for S3StorageBackend.download method."""
def test_download_file(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test downloading a file."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
mock_boto3_client.download_file.assert_called_once()
def test_download_creates_parent_directories(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
backend.download("sample.txt", local_path)
assert local_path.parent.exists()
def test_download_nonexistent_object_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that downloading nonexistent object fails."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download("nonexistent.txt", temp_dir / "file.txt")
class TestS3StorageBackendExists:
"""Tests for S3StorageBackend.exists method."""
def test_exists_returns_true_for_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns True for existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns False for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("nonexistent.txt") is False
class TestS3StorageBackendListFiles:
"""Tests for S3StorageBackend.list_files method."""
def test_list_files_returns_objects(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "file1.txt"},
{"Key": "file2.txt"},
{"Key": "subdir/file3.txt"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects with prefix filter."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "images/a.png"},
{"Key": "images/b.png"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("images/")
mock_boto3_client.list_objects_v2.assert_called_with(
Bucket="test-bucket", Prefix="images/"
)
def test_list_files_empty_bucket(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing files in empty bucket."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert files == []
class TestS3StorageBackendDelete:
"""Tests for S3StorageBackend.delete method."""
def test_delete_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting an existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("sample.txt")
assert result is True
mock_boto3_client.delete_object.assert_called_once()
def test_delete_nonexistent_object_returns_false(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting nonexistent object returns False."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("nonexistent.txt")
assert result is False
class TestS3StorageBackendGetUrl:
"""Tests for S3StorageBackend.get_url method."""
def test_get_url_returns_s3_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url returns S3 URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt"
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_url("sample.txt")
assert "sample.txt" in url
def test_get_url_nonexistent_object_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestS3StorageBackendUploadBytes:
"""Tests for S3StorageBackend.upload_bytes method."""
def test_upload_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test uploading bytes directly."""
from shared.storage.s3 import S3StorageBackend
from botocore.exceptions import ClientError
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_boto3_client.put_object.assert_called_once()
def test_upload_bytes_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock
) -> None:
"""Test upload_bytes fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
class TestS3StorageBackendDownloadBytes:
"""Tests for S3StorageBackend.download_bytes method."""
def test_download_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading object as bytes."""
from shared.storage.s3 import S3StorageBackend
mock_response = MagicMock()
mock_response.read.return_value = b"Hello, World!"
mock_boto3_client.get_object.return_value = {"Body": mock_response}
backend = S3StorageBackend(bucket_name="test-bucket")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading nonexistent object as bytes."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.get_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestS3StorageBackendPresignedUrl:
"""Tests for S3StorageBackend.get_presigned_url method."""
def test_get_presigned_url_generates_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url generates presigned URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_presigned_url("sample.txt")
assert "X-Amz-Algorithm" in url or "sample.txt" in url
mock_boto3_client.generate_presigned_url.assert_called_once()
def test_get_presigned_url_with_custom_expiry(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = "https://..."
backend = S3StorageBackend(bucket_name="test-bucket")
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
call_args = mock_boto3_client.generate_presigned_url.call_args
assert call_args[1].get("ExpiresIn") == 7200
def test_get_presigned_url_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")