This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -0,0 +1,423 @@
"""
Tests for storage factory.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageFactory:
"""Tests for create_storage_backend factory function."""
def test_create_local_backend(self) -> None:
"""Test creating local storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
config = StorageConfig(
backend_type="local",
base_path=Path(temp_dir),
)
backend = create_storage_backend(config)
assert isinstance(backend, LocalStorageBackend)
assert backend.base_path == Path(temp_dir)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure blob storage backend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
backend = create_storage_backend(config)
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_unknown_backend_raises(self) -> None:
"""Test that unknown backend type raises ValueError."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="unknown_backend")
with pytest.raises(ValueError, match="Unknown storage backend"):
create_storage_backend(config)
def test_create_local_requires_base_path(self) -> None:
"""Test that local backend requires base_path."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="local")
with pytest.raises(ValueError, match="base_path"):
create_storage_backend(config)
def test_create_azure_requires_connection_string(self) -> None:
"""Test that Azure backend requires connection_string."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
container_name="container",
)
with pytest.raises(ValueError, match="connection_string"):
create_storage_backend(config)
def test_create_azure_requires_container_name(self) -> None:
"""Test that Azure backend requires container_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="connection_string",
)
with pytest.raises(ValueError, match="container_name"):
create_storage_backend(config)
class TestStorageFactoryFromEnv:
"""Tests for create_storage_backend_from_env factory function."""
def test_create_from_env_local(self) -> None:
"""Test creating local backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure backend from environment variables."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_from_env_defaults_to_local(self) -> None:
"""Test that factory defaults to local backend."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BASE_PATH": temp_dir,
}
# Remove STORAGE_BACKEND if present
with patch.dict(os.environ, env, clear=False):
if "STORAGE_BACKEND" in os.environ:
del os.environ["STORAGE_BACKEND"]
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
def test_create_from_env_missing_azure_vars(self) -> None:
"""Test error when Azure env vars are missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
# Missing AZURE_STORAGE_CONNECTION_STRING
}
with patch.dict(os.environ, env, clear=False):
# Remove the connection string if present
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
create_storage_backend_from_env()
class TestGetDefaultStorageConfig:
"""Tests for get_default_storage_config function."""
def test_get_default_config_local(self) -> None:
"""Test getting default local config."""
from shared.storage.factory import get_default_storage_config
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "local"
assert config.base_path == Path(temp_dir)
def test_get_default_config_azure(self) -> None:
"""Test getting default Azure config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
class TestStorageFactoryS3:
"""Tests for S3 backend support in factory."""
@patch("boto3.client")
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.s3 import S3StorageBackend
config = StorageConfig(
backend_type="s3",
bucket_name="test-bucket",
region_name="us-west-2",
)
backend = create_storage_backend(config)
assert isinstance(backend, S3StorageBackend)
def test_create_s3_requires_bucket_name(self) -> None:
"""Test that S3 backend requires bucket_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="s3",
region_name="us-west-2",
)
with pytest.raises(ValueError, match="bucket_name"):
create_storage_backend(config)
@patch("boto3.client")
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.s3 import S3StorageBackend
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-east-1",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, S3StorageBackend)
def test_create_from_env_s3_missing_bucket(self) -> None:
"""Test error when S3 bucket env var is missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "s3",
# Missing AWS_S3_BUCKET
}
with patch.dict(os.environ, env, clear=False):
if "AWS_S3_BUCKET" in os.environ:
del os.environ["AWS_S3_BUCKET"]
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
create_storage_backend_from_env()
def test_get_default_config_s3(self) -> None:
"""Test getting default S3 config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-west-2",
"AWS_ENDPOINT_URL": "http://localhost:9000",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "s3"
assert config.bucket_name == "test-bucket"
assert config.region_name == "us-west-2"
assert config.endpoint_url == "http://localhost:9000"
class TestStorageFactoryFromFile:
"""Tests for create_storage_backend_from_file factory function."""
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
"""Test creating local backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_yaml_file_azure(
self, mock_service_class: MagicMock, tmp_path: Path
) -> None:
"""Test creating Azure backend from YAML config file."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_file
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: azure_blob
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, AzureBlobStorageBackend)
@patch("boto3.client")
def test_create_from_yaml_file_s3(
self, mock_boto3_client: MagicMock, tmp_path: Path
) -> None:
"""Test creating S3 backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.s3 import S3StorageBackend
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: s3
s3:
bucket_name: my-bucket
region_name: us-east-1
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, S3StorageBackend)
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
"""Test that env vars are substituted in config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${CUSTOM_STORAGE_PATH}
""")
with patch.dict(
os.environ,
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
):
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
"""Test that FileNotFoundError is raised for missing file."""
from shared.storage.factory import create_storage_backend_from_file
with pytest.raises(FileNotFoundError):
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
class TestGetStorageBackend:
"""Tests for get_storage_backend convenience function."""
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
"""Test getting backend from explicit config file."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test that get_storage_backend falls back to env vars."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(storage_path),
}
with patch.dict(os.environ, env, clear=False):
# No config file provided, should use env vars
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)