424 lines
15 KiB
Python
424 lines
15 KiB
Python
"""
|
|
Tests for storage factory.
|
|
|
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestStorageFactory:
|
|
"""Tests for create_storage_backend factory function."""
|
|
|
|
def test_create_local_backend(self) -> None:
|
|
"""Test creating local storage backend."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
config = StorageConfig(
|
|
backend_type="local",
|
|
base_path=Path(temp_dir),
|
|
)
|
|
|
|
backend = create_storage_backend(config)
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
assert backend.base_path == Path(temp_dir)
|
|
|
|
@patch("shared.storage.azure.BlobServiceClient")
|
|
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
|
|
"""Test creating Azure blob storage backend."""
|
|
from shared.storage.azure import AzureBlobStorageBackend
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(
|
|
backend_type="azure_blob",
|
|
connection_string="DefaultEndpointsProtocol=https;...",
|
|
container_name="training-images",
|
|
)
|
|
|
|
backend = create_storage_backend(config)
|
|
|
|
assert isinstance(backend, AzureBlobStorageBackend)
|
|
|
|
def test_create_unknown_backend_raises(self) -> None:
|
|
"""Test that unknown backend type raises ValueError."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(backend_type="unknown_backend")
|
|
|
|
with pytest.raises(ValueError, match="Unknown storage backend"):
|
|
create_storage_backend(config)
|
|
|
|
def test_create_local_requires_base_path(self) -> None:
|
|
"""Test that local backend requires base_path."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(backend_type="local")
|
|
|
|
with pytest.raises(ValueError, match="base_path"):
|
|
create_storage_backend(config)
|
|
|
|
def test_create_azure_requires_connection_string(self) -> None:
|
|
"""Test that Azure backend requires connection_string."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(
|
|
backend_type="azure_blob",
|
|
container_name="container",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="connection_string"):
|
|
create_storage_backend(config)
|
|
|
|
def test_create_azure_requires_container_name(self) -> None:
|
|
"""Test that Azure backend requires container_name."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(
|
|
backend_type="azure_blob",
|
|
connection_string="connection_string",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="container_name"):
|
|
create_storage_backend(config)
|
|
|
|
|
|
class TestStorageFactoryFromEnv:
|
|
"""Tests for create_storage_backend_from_env factory function."""
|
|
|
|
def test_create_from_env_local(self) -> None:
|
|
"""Test creating local backend from environment variables."""
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
env = {
|
|
"STORAGE_BACKEND": "local",
|
|
"STORAGE_BASE_PATH": temp_dir,
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
backend = create_storage_backend_from_env()
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
|
|
@patch("shared.storage.azure.BlobServiceClient")
|
|
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
|
|
"""Test creating Azure backend from environment variables."""
|
|
from shared.storage.azure import AzureBlobStorageBackend
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "azure_blob",
|
|
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
|
"AZURE_STORAGE_CONTAINER": "training-images",
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
backend = create_storage_backend_from_env()
|
|
|
|
assert isinstance(backend, AzureBlobStorageBackend)
|
|
|
|
def test_create_from_env_defaults_to_local(self) -> None:
|
|
"""Test that factory defaults to local backend."""
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
env = {
|
|
"STORAGE_BASE_PATH": temp_dir,
|
|
}
|
|
|
|
# Remove STORAGE_BACKEND if present
|
|
with patch.dict(os.environ, env, clear=False):
|
|
if "STORAGE_BACKEND" in os.environ:
|
|
del os.environ["STORAGE_BACKEND"]
|
|
backend = create_storage_backend_from_env()
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
|
|
def test_create_from_env_missing_azure_vars(self) -> None:
|
|
"""Test error when Azure env vars are missing."""
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "azure_blob",
|
|
# Missing AZURE_STORAGE_CONNECTION_STRING
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
# Remove the connection string if present
|
|
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
|
|
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
|
|
|
|
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
|
|
create_storage_backend_from_env()
|
|
|
|
|
|
class TestGetDefaultStorageConfig:
|
|
"""Tests for get_default_storage_config function."""
|
|
|
|
def test_get_default_config_local(self) -> None:
|
|
"""Test getting default local config."""
|
|
from shared.storage.factory import get_default_storage_config
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
env = {
|
|
"STORAGE_BACKEND": "local",
|
|
"STORAGE_BASE_PATH": temp_dir,
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
config = get_default_storage_config()
|
|
|
|
assert config.backend_type == "local"
|
|
assert config.base_path == Path(temp_dir)
|
|
|
|
def test_get_default_config_azure(self) -> None:
|
|
"""Test getting default Azure config."""
|
|
from shared.storage.factory import get_default_storage_config
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "azure_blob",
|
|
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
|
"AZURE_STORAGE_CONTAINER": "training-images",
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
config = get_default_storage_config()
|
|
|
|
assert config.backend_type == "azure_blob"
|
|
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
|
assert config.container_name == "training-images"
|
|
|
|
|
|
class TestStorageFactoryS3:
|
|
"""Tests for S3 backend support in factory."""
|
|
|
|
@patch("boto3.client")
|
|
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
|
|
"""Test creating S3 storage backend."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
from shared.storage.s3 import S3StorageBackend
|
|
|
|
config = StorageConfig(
|
|
backend_type="s3",
|
|
bucket_name="test-bucket",
|
|
region_name="us-west-2",
|
|
)
|
|
|
|
backend = create_storage_backend(config)
|
|
|
|
assert isinstance(backend, S3StorageBackend)
|
|
|
|
def test_create_s3_requires_bucket_name(self) -> None:
|
|
"""Test that S3 backend requires bucket_name."""
|
|
from shared.storage.base import StorageConfig
|
|
from shared.storage.factory import create_storage_backend
|
|
|
|
config = StorageConfig(
|
|
backend_type="s3",
|
|
region_name="us-west-2",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="bucket_name"):
|
|
create_storage_backend(config)
|
|
|
|
@patch("boto3.client")
|
|
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
|
|
"""Test creating S3 backend from environment variables."""
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
from shared.storage.s3 import S3StorageBackend
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "s3",
|
|
"AWS_S3_BUCKET": "test-bucket",
|
|
"AWS_REGION": "us-east-1",
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
backend = create_storage_backend_from_env()
|
|
|
|
assert isinstance(backend, S3StorageBackend)
|
|
|
|
def test_create_from_env_s3_missing_bucket(self) -> None:
|
|
"""Test error when S3 bucket env var is missing."""
|
|
from shared.storage.factory import create_storage_backend_from_env
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "s3",
|
|
# Missing AWS_S3_BUCKET
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
if "AWS_S3_BUCKET" in os.environ:
|
|
del os.environ["AWS_S3_BUCKET"]
|
|
|
|
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
|
|
create_storage_backend_from_env()
|
|
|
|
def test_get_default_config_s3(self) -> None:
|
|
"""Test getting default S3 config."""
|
|
from shared.storage.factory import get_default_storage_config
|
|
|
|
env = {
|
|
"STORAGE_BACKEND": "s3",
|
|
"AWS_S3_BUCKET": "test-bucket",
|
|
"AWS_REGION": "us-west-2",
|
|
"AWS_ENDPOINT_URL": "http://localhost:9000",
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
config = get_default_storage_config()
|
|
|
|
assert config.backend_type == "s3"
|
|
assert config.bucket_name == "test-bucket"
|
|
assert config.region_name == "us-west-2"
|
|
assert config.endpoint_url == "http://localhost:9000"
|
|
|
|
|
|
class TestStorageFactoryFromFile:
|
|
"""Tests for create_storage_backend_from_file factory function."""
|
|
|
|
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
|
|
"""Test creating local backend from YAML config file."""
|
|
from shared.storage.factory import create_storage_backend_from_file
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
config_file = tmp_path / "storage.yaml"
|
|
storage_path = tmp_path / "storage"
|
|
config_file.write_text(f"""
|
|
backend: local
|
|
|
|
local:
|
|
base_path: {storage_path}
|
|
""")
|
|
|
|
backend = create_storage_backend_from_file(config_file)
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
|
|
@patch("shared.storage.azure.BlobServiceClient")
|
|
def test_create_from_yaml_file_azure(
|
|
self, mock_service_class: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
"""Test creating Azure backend from YAML config file."""
|
|
from shared.storage.azure import AzureBlobStorageBackend
|
|
from shared.storage.factory import create_storage_backend_from_file
|
|
|
|
config_file = tmp_path / "storage.yaml"
|
|
config_file.write_text("""
|
|
backend: azure_blob
|
|
|
|
azure:
|
|
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
|
container_name: documents
|
|
""")
|
|
|
|
backend = create_storage_backend_from_file(config_file)
|
|
|
|
assert isinstance(backend, AzureBlobStorageBackend)
|
|
|
|
@patch("boto3.client")
|
|
def test_create_from_yaml_file_s3(
|
|
self, mock_boto3_client: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
"""Test creating S3 backend from YAML config file."""
|
|
from shared.storage.factory import create_storage_backend_from_file
|
|
from shared.storage.s3 import S3StorageBackend
|
|
|
|
config_file = tmp_path / "storage.yaml"
|
|
config_file.write_text("""
|
|
backend: s3
|
|
|
|
s3:
|
|
bucket_name: my-bucket
|
|
region_name: us-east-1
|
|
""")
|
|
|
|
backend = create_storage_backend_from_file(config_file)
|
|
|
|
assert isinstance(backend, S3StorageBackend)
|
|
|
|
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
|
|
"""Test that env vars are substituted in config file."""
|
|
from shared.storage.factory import create_storage_backend_from_file
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
config_file = tmp_path / "storage.yaml"
|
|
storage_path = tmp_path / "storage"
|
|
config_file.write_text("""
|
|
backend: ${STORAGE_BACKEND:-local}
|
|
|
|
local:
|
|
base_path: ${CUSTOM_STORAGE_PATH}
|
|
""")
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
|
|
):
|
|
backend = create_storage_backend_from_file(config_file)
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
|
|
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
|
|
"""Test that FileNotFoundError is raised for missing file."""
|
|
from shared.storage.factory import create_storage_backend_from_file
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
|
|
|
|
|
|
class TestGetStorageBackend:
|
|
"""Tests for get_storage_backend convenience function."""
|
|
|
|
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
|
|
"""Test getting backend from explicit config file."""
|
|
from shared.storage.factory import get_storage_backend
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
config_file = tmp_path / "storage.yaml"
|
|
storage_path = tmp_path / "storage"
|
|
config_file.write_text(f"""
|
|
backend: local
|
|
|
|
local:
|
|
base_path: {storage_path}
|
|
""")
|
|
|
|
backend = get_storage_backend(config_path=config_file)
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|
|
|
|
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
|
"""Test that get_storage_backend falls back to env vars."""
|
|
from shared.storage.factory import get_storage_backend
|
|
from shared.storage.local import LocalStorageBackend
|
|
|
|
storage_path = tmp_path / "storage"
|
|
env = {
|
|
"STORAGE_BACKEND": "local",
|
|
"STORAGE_BASE_PATH": str(storage_path),
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
# No config file provided, should use env vars
|
|
backend = get_storage_backend(config_path=None)
|
|
|
|
assert isinstance(backend, LocalStorageBackend)
|