""" Tests for storage factory. TDD Phase 1: RED - Write tests first, then implement to pass. """ import os import tempfile from pathlib import Path from unittest.mock import MagicMock, patch import pytest class TestStorageFactory: """Tests for create_storage_backend factory function.""" def test_create_local_backend(self) -> None: """Test creating local storage backend.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend from shared.storage.local import LocalStorageBackend with tempfile.TemporaryDirectory() as temp_dir: config = StorageConfig( backend_type="local", base_path=Path(temp_dir), ) backend = create_storage_backend(config) assert isinstance(backend, LocalStorageBackend) assert backend.base_path == Path(temp_dir) @patch("shared.storage.azure.BlobServiceClient") def test_create_azure_backend(self, mock_service_class: MagicMock) -> None: """Test creating Azure blob storage backend.""" from shared.storage.azure import AzureBlobStorageBackend from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig( backend_type="azure_blob", connection_string="DefaultEndpointsProtocol=https;...", container_name="training-images", ) backend = create_storage_backend(config) assert isinstance(backend, AzureBlobStorageBackend) def test_create_unknown_backend_raises(self) -> None: """Test that unknown backend type raises ValueError.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig(backend_type="unknown_backend") with pytest.raises(ValueError, match="Unknown storage backend"): create_storage_backend(config) def test_create_local_requires_base_path(self) -> None: """Test that local backend requires base_path.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig(backend_type="local") with pytest.raises(ValueError, match="base_path"): create_storage_backend(config) def test_create_azure_requires_connection_string(self) -> None: """Test that Azure backend requires connection_string.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig( backend_type="azure_blob", container_name="container", ) with pytest.raises(ValueError, match="connection_string"): create_storage_backend(config) def test_create_azure_requires_container_name(self) -> None: """Test that Azure backend requires container_name.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig( backend_type="azure_blob", connection_string="connection_string", ) with pytest.raises(ValueError, match="container_name"): create_storage_backend(config) class TestStorageFactoryFromEnv: """Tests for create_storage_backend_from_env factory function.""" def test_create_from_env_local(self) -> None: """Test creating local backend from environment variables.""" from shared.storage.factory import create_storage_backend_from_env from shared.storage.local import LocalStorageBackend with tempfile.TemporaryDirectory() as temp_dir: env = { "STORAGE_BACKEND": "local", "STORAGE_BASE_PATH": temp_dir, } with patch.dict(os.environ, env, clear=False): backend = create_storage_backend_from_env() assert isinstance(backend, LocalStorageBackend) @patch("shared.storage.azure.BlobServiceClient") def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None: """Test creating Azure backend from environment variables.""" from shared.storage.azure import AzureBlobStorageBackend from shared.storage.factory import create_storage_backend_from_env env = { "STORAGE_BACKEND": "azure_blob", "AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...", "AZURE_STORAGE_CONTAINER": "training-images", } with patch.dict(os.environ, env, clear=False): backend = create_storage_backend_from_env() assert isinstance(backend, AzureBlobStorageBackend) def test_create_from_env_defaults_to_local(self) -> None: """Test that factory defaults to local backend.""" from shared.storage.factory import create_storage_backend_from_env from shared.storage.local import LocalStorageBackend with tempfile.TemporaryDirectory() as temp_dir: env = { "STORAGE_BASE_PATH": temp_dir, } # Remove STORAGE_BACKEND if present with patch.dict(os.environ, env, clear=False): if "STORAGE_BACKEND" in os.environ: del os.environ["STORAGE_BACKEND"] backend = create_storage_backend_from_env() assert isinstance(backend, LocalStorageBackend) def test_create_from_env_missing_azure_vars(self) -> None: """Test error when Azure env vars are missing.""" from shared.storage.factory import create_storage_backend_from_env env = { "STORAGE_BACKEND": "azure_blob", # Missing AZURE_STORAGE_CONNECTION_STRING } with patch.dict(os.environ, env, clear=False): # Remove the connection string if present if "AZURE_STORAGE_CONNECTION_STRING" in os.environ: del os.environ["AZURE_STORAGE_CONNECTION_STRING"] with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"): create_storage_backend_from_env() class TestGetDefaultStorageConfig: """Tests for get_default_storage_config function.""" def test_get_default_config_local(self) -> None: """Test getting default local config.""" from shared.storage.factory import get_default_storage_config with tempfile.TemporaryDirectory() as temp_dir: env = { "STORAGE_BACKEND": "local", "STORAGE_BASE_PATH": temp_dir, } with patch.dict(os.environ, env, clear=False): config = get_default_storage_config() assert config.backend_type == "local" assert config.base_path == Path(temp_dir) def test_get_default_config_azure(self) -> None: """Test getting default Azure config.""" from shared.storage.factory import get_default_storage_config env = { "STORAGE_BACKEND": "azure_blob", "AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...", "AZURE_STORAGE_CONTAINER": "training-images", } with patch.dict(os.environ, env, clear=False): config = get_default_storage_config() assert config.backend_type == "azure_blob" assert config.connection_string == "DefaultEndpointsProtocol=https;..." assert config.container_name == "training-images" class TestStorageFactoryS3: """Tests for S3 backend support in factory.""" @patch("boto3.client") def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None: """Test creating S3 storage backend.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend from shared.storage.s3 import S3StorageBackend config = StorageConfig( backend_type="s3", bucket_name="test-bucket", region_name="us-west-2", ) backend = create_storage_backend(config) assert isinstance(backend, S3StorageBackend) def test_create_s3_requires_bucket_name(self) -> None: """Test that S3 backend requires bucket_name.""" from shared.storage.base import StorageConfig from shared.storage.factory import create_storage_backend config = StorageConfig( backend_type="s3", region_name="us-west-2", ) with pytest.raises(ValueError, match="bucket_name"): create_storage_backend(config) @patch("boto3.client") def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None: """Test creating S3 backend from environment variables.""" from shared.storage.factory import create_storage_backend_from_env from shared.storage.s3 import S3StorageBackend env = { "STORAGE_BACKEND": "s3", "AWS_S3_BUCKET": "test-bucket", "AWS_REGION": "us-east-1", } with patch.dict(os.environ, env, clear=False): backend = create_storage_backend_from_env() assert isinstance(backend, S3StorageBackend) def test_create_from_env_s3_missing_bucket(self) -> None: """Test error when S3 bucket env var is missing.""" from shared.storage.factory import create_storage_backend_from_env env = { "STORAGE_BACKEND": "s3", # Missing AWS_S3_BUCKET } with patch.dict(os.environ, env, clear=False): if "AWS_S3_BUCKET" in os.environ: del os.environ["AWS_S3_BUCKET"] with pytest.raises(ValueError, match="AWS_S3_BUCKET"): create_storage_backend_from_env() def test_get_default_config_s3(self) -> None: """Test getting default S3 config.""" from shared.storage.factory import get_default_storage_config env = { "STORAGE_BACKEND": "s3", "AWS_S3_BUCKET": "test-bucket", "AWS_REGION": "us-west-2", "AWS_ENDPOINT_URL": "http://localhost:9000", } with patch.dict(os.environ, env, clear=False): config = get_default_storage_config() assert config.backend_type == "s3" assert config.bucket_name == "test-bucket" assert config.region_name == "us-west-2" assert config.endpoint_url == "http://localhost:9000" class TestStorageFactoryFromFile: """Tests for create_storage_backend_from_file factory function.""" def test_create_from_yaml_file_local(self, tmp_path: Path) -> None: """Test creating local backend from YAML config file.""" from shared.storage.factory import create_storage_backend_from_file from shared.storage.local import LocalStorageBackend config_file = tmp_path / "storage.yaml" storage_path = tmp_path / "storage" config_file.write_text(f""" backend: local local: base_path: {storage_path} """) backend = create_storage_backend_from_file(config_file) assert isinstance(backend, LocalStorageBackend) @patch("shared.storage.azure.BlobServiceClient") def test_create_from_yaml_file_azure( self, mock_service_class: MagicMock, tmp_path: Path ) -> None: """Test creating Azure backend from YAML config file.""" from shared.storage.azure import AzureBlobStorageBackend from shared.storage.factory import create_storage_backend_from_file config_file = tmp_path / "storage.yaml" config_file.write_text(""" backend: azure_blob azure: connection_string: DefaultEndpointsProtocol=https;AccountName=test container_name: documents """) backend = create_storage_backend_from_file(config_file) assert isinstance(backend, AzureBlobStorageBackend) @patch("boto3.client") def test_create_from_yaml_file_s3( self, mock_boto3_client: MagicMock, tmp_path: Path ) -> None: """Test creating S3 backend from YAML config file.""" from shared.storage.factory import create_storage_backend_from_file from shared.storage.s3 import S3StorageBackend config_file = tmp_path / "storage.yaml" config_file.write_text(""" backend: s3 s3: bucket_name: my-bucket region_name: us-east-1 """) backend = create_storage_backend_from_file(config_file) assert isinstance(backend, S3StorageBackend) def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None: """Test that env vars are substituted in config file.""" from shared.storage.factory import create_storage_backend_from_file from shared.storage.local import LocalStorageBackend config_file = tmp_path / "storage.yaml" storage_path = tmp_path / "storage" config_file.write_text(""" backend: ${STORAGE_BACKEND:-local} local: base_path: ${CUSTOM_STORAGE_PATH} """) with patch.dict( os.environ, {"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)}, ): backend = create_storage_backend_from_file(config_file) assert isinstance(backend, LocalStorageBackend) def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None: """Test that FileNotFoundError is raised for missing file.""" from shared.storage.factory import create_storage_backend_from_file with pytest.raises(FileNotFoundError): create_storage_backend_from_file(tmp_path / "nonexistent.yaml") class TestGetStorageBackend: """Tests for get_storage_backend convenience function.""" def test_get_storage_backend_from_file(self, tmp_path: Path) -> None: """Test getting backend from explicit config file.""" from shared.storage.factory import get_storage_backend from shared.storage.local import LocalStorageBackend config_file = tmp_path / "storage.yaml" storage_path = tmp_path / "storage" config_file.write_text(f""" backend: local local: base_path: {storage_path} """) backend = get_storage_backend(config_path=config_file) assert isinstance(backend, LocalStorageBackend) def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None: """Test that get_storage_backend falls back to env vars.""" from shared.storage.factory import get_storage_backend from shared.storage.local import LocalStorageBackend storage_path = tmp_path / "storage" env = { "STORAGE_BACKEND": "local", "STORAGE_BASE_PATH": str(storage_path), } with patch.dict(os.environ, env, clear=False): # No config file provided, should use env vars backend = get_storage_backend(config_path=None) assert isinstance(backend, LocalStorageBackend)