This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

205
packages/shared/README.md Normal file
View File

@@ -0,0 +1,205 @@
# Shared Package
Shared utilities and abstractions for the Invoice Master system.
## Storage Abstraction Layer
A unified storage abstraction supporting multiple backends:
- **Local filesystem** - Development and testing
- **Azure Blob Storage** - Azure cloud deployments
- **AWS S3** - AWS cloud deployments
### Installation
```bash
# Basic installation (local storage only)
pip install -e packages/shared
# With Azure support
pip install -e "packages/shared[azure]"
# With S3 support
pip install -e "packages/shared[s3]"
# All cloud providers
pip install -e "packages/shared[all]"
```
### Quick Start
```python
from shared.storage import get_storage_backend
# Option 1: From configuration file
storage = get_storage_backend("storage.yaml")
# Option 2: From environment variables
from shared.storage import create_storage_backend_from_env
storage = create_storage_backend_from_env()
# Upload a file
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
# Download a file
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
# Get pre-signed URL for frontend access
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
```
### Configuration File Format
Create a `storage.yaml` file with environment variable substitution support:
```yaml
# Backend selection: local, azure_blob, or s3
backend: ${STORAGE_BACKEND:-local}
# Default pre-signed URL expiry (seconds)
presigned_url_expiry: 3600
# Local storage configuration
local:
base_path: ${STORAGE_BASE_PATH:-./data/storage}
# Azure Blob Storage configuration
azure:
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
create_container: false
# AWS S3 configuration
s3:
bucket_name: ${AWS_S3_BUCKET}
region_name: ${AWS_REGION:-us-east-1}
access_key_id: ${AWS_ACCESS_KEY_ID}
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
create_bucket: false
```
### Environment Variables
| Variable | Backend | Description |
|----------|---------|-------------|
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
| `STORAGE_BASE_PATH` | Local | Base directory path |
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
| `AWS_S3_BUCKET` | S3 | Bucket name |
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
### API Reference
#### StorageBackend Interface
```python
class StorageBackend(ABC):
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
"""Upload a file to storage."""
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage."""
def exists(self, remote_path: str) -> bool:
"""Check if a file exists."""
def list_files(self, prefix: str) -> list[str]:
"""List files with given prefix."""
def delete(self, remote_path: str) -> bool:
"""Delete a file."""
def get_url(self, remote_path: str) -> str:
"""Get URL for a file."""
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
"""Upload bytes directly."""
def download_bytes(self, remote_path: str) -> bytes:
"""Download file as bytes."""
```
#### Factory Functions
```python
# Create from configuration file
storage = create_storage_backend_from_file("storage.yaml")
# Create from environment variables
storage = create_storage_backend_from_env()
# Create from StorageConfig object
config = StorageConfig(backend_type="local", base_path=Path("./data"))
storage = create_storage_backend(config)
# Convenience function with fallback chain: config file -> env vars -> local default
storage = get_storage_backend("storage.yaml") # or None for env-only
```
### Pre-signed URLs
Pre-signed URLs provide temporary access to files without exposing credentials:
```python
# Generate URL valid for 1 hour (default)
url = storage.get_presigned_url("documents/invoice.pdf")
# Generate URL valid for 24 hours
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
# Maximum expiry: 7 days (604800 seconds)
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
```
**Note:** Local storage returns `file://` URLs that don't actually expire.
### Error Handling
```python
from shared.storage import (
StorageError,
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
)
try:
storage.download("nonexistent.pdf", Path("local.pdf"))
except FileNotFoundStorageError as e:
print(f"File not found: {e}")
except StorageError as e:
print(f"Storage error: {e}")
```
### Testing with MinIO (S3-compatible)
```bash
# Start MinIO locally
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
# Configure environment
export STORAGE_BACKEND=s3
export AWS_S3_BUCKET=test-bucket
export AWS_ENDPOINT_URL=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
```
### Module Structure
```
shared/storage/
├── __init__.py # Public exports
├── base.py # Abstract interface and exceptions
├── local.py # Local filesystem backend
├── azure.py # Azure Blob Storage backend
├── s3.py # AWS S3 backend
├── config_loader.py # YAML configuration loader
└── factory.py # Backend factory functions
```

View File

@@ -16,4 +16,18 @@ setup(
"pyyaml>=6.0",
"thefuzz>=0.20.0",
],
extras_require={
"azure": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
],
"s3": [
"boto3>=1.34.0",
],
"all": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
"boto3>=1.34.0",
],
},
)

View File

@@ -58,23 +58,16 @@ def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows
if _is_wsl():
# WSL: use native Linux filesystem for better I/O performance
PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
'reports_dir': 'reports', # Keep reports in project directory
}
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Paths Configuration - uses STORAGE_BASE_PATH for consistency
# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
PATHS = {
'csv_dir': f'{_storage_base}/structured_data',
'pdf_dir': f'{_storage_base}/raw_pdfs',
'output_dir': f'{_storage_base}/datasets',
'reports_dir': 'reports', # Keep reports in project directory
}
# Auto-labeling Configuration
AUTOLABEL = {

View File

@@ -0,0 +1,46 @@
"""
Shared Field Definitions - Single Source of Truth.
This module provides centralized field class definitions used throughout
the invoice extraction system. All field mappings are derived from
FIELD_DEFINITIONS to ensure consistency.
Usage:
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
Available exports:
- FieldDefinition: Dataclass for field definition
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
- NUM_CLASSES: Total number of field classes (10)
- CLASS_NAMES: List of class names in order [0..9]
- FIELD_CLASSES: dict[int, str] - class_id to class_name
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
"""
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
from .mappings import (
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
ACCOUNT_FIELD_MAPPING,
)
__all__ = [
"FieldDefinition",
"FIELD_DEFINITIONS",
"NUM_CLASSES",
"CLASS_NAMES",
"FIELD_CLASSES",
"FIELD_CLASS_IDS",
"CLASS_TO_FIELD",
"CSV_TO_CLASS_MAPPING",
"TRAINING_FIELD_CLASSES",
"ACCOUNT_FIELD_MAPPING",
]

View File

@@ -0,0 +1,58 @@
"""
Field Configuration - Single Source of Truth
This module defines all invoice field classes used throughout the system.
The class IDs are verified against the trained YOLO model (best.pt).
IMPORTANT: Do not modify class_id values without retraining the model!
"""
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class FieldDefinition:
"""Immutable field definition for invoice extraction.
Attributes:
class_id: YOLO class ID (0-9), must match trained model
class_name: YOLO class name (lowercase_underscore)
field_name: Business field name used in API responses
csv_name: CSV column name for data import/export
is_derived: True if field is derived from other fields (not in CSV)
"""
class_id: int
class_name: str
field_name: str
csv_name: str
is_derived: bool = False
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
#
# DO NOT CHANGE THE ORDER - it must match the trained model!
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
FieldDefinition(6, "amount", "Amount", "Amount"),
FieldDefinition(
7,
"supplier_org_number",
"supplier_organisation_number",
"supplier_organisation_number",
),
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
FieldDefinition(
9, "payment_line", "payment_line", "payment_line", is_derived=True
),
)
# Total number of field classes
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)

View File

@@ -0,0 +1,57 @@
"""
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
This ensures consistency across the entire codebase.
DO NOT hardcode field mappings elsewhere - always import from this module.
"""
from typing import Final
from .field_config import FIELD_DEFINITIONS
# List of class names in order (for YOLO classes.txt generation)
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
# class_id -> class_name mapping
# Example: {0: "invoice_number", 1: "invoice_date", ...}
FIELD_CLASSES: Final[dict[int, str]] = {
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
}
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
FIELD_CLASS_IDS: Final[dict[str, int]] = {
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# class_name -> field_name mapping (for API responses)
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
CLASS_TO_FIELD: Final[dict[str, str]] = {
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
}
# field_name -> class_id mapping (for CSV import)
# Excludes derived fields like payment_line
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
}
# field_name -> class_id mapping (for training, includes all fields)
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# Account field mapping for supplier_accounts special handling
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
"supplier_accounts": {
"BG": "Bankgiro",
"PG": "Plusgiro",
}
}

View File

@@ -0,0 +1,59 @@
"""
Storage abstraction layer for training data.
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
"""
from shared.storage.base import (
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
StorageBackend,
StorageConfig,
StorageError,
)
from shared.storage.factory import (
create_storage_backend,
create_storage_backend_from_env,
create_storage_backend_from_file,
get_default_storage_config,
get_storage_backend,
)
from shared.storage.local import LocalStorageBackend
from shared.storage.prefixes import PREFIXES, StoragePrefixes
__all__ = [
# Base classes and exceptions
"StorageBackend",
"StorageConfig",
"StorageError",
"FileNotFoundStorageError",
"PresignedUrlNotSupportedError",
# Backends
"LocalStorageBackend",
# Factory functions
"create_storage_backend",
"create_storage_backend_from_env",
"create_storage_backend_from_file",
"get_default_storage_config",
"get_storage_backend",
# Path prefixes
"PREFIXES",
"StoragePrefixes",
]
# Lazy imports to avoid dependencies when not using specific backends
def __getattr__(name: str):
if name == "AzureBlobStorageBackend":
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend
if name == "S3StorageBackend":
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend
if name == "load_storage_config":
from shared.storage.config_loader import load_storage_config
return load_storage_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,335 @@
"""
Azure Blob Storage backend.
Provides storage operations using Azure Blob Storage.
"""
from pathlib import Path
from azure.storage.blob import (
BlobSasPermissions,
BlobServiceClient,
ContainerClient,
generate_blob_sas,
)
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class AzureBlobStorageBackend(StorageBackend):
"""Storage backend using Azure Blob Storage.
Files are stored as blobs in an Azure Blob Storage container.
"""
def __init__(
self,
connection_string: str,
container_name: str,
create_container: bool = False,
) -> None:
"""Initialize Azure Blob Storage backend.
Args:
connection_string: Azure Storage connection string.
container_name: Name of the blob container.
create_container: If True, create the container if it doesn't exist.
"""
self._connection_string = connection_string
self._container_name = container_name
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
self._container = self._blob_service.get_container_client(container_name)
# Extract account key from connection string for SAS token generation
self._account_key = self._extract_account_key(connection_string)
if create_container and not self._container.exists():
self._container.create_container()
@staticmethod
def _extract_account_key(connection_string: str) -> str | None:
"""Extract account key from connection string.
Args:
connection_string: Azure Storage connection string.
Returns:
Account key if found, None otherwise.
"""
for part in connection_string.split(";"):
if part.startswith("AccountKey="):
return part[len("AccountKey=") :]
return None
@property
def container_name(self) -> str:
"""Get the container name for this storage backend."""
return self._container_name
@property
def container_client(self) -> ContainerClient:
"""Get the Azure container client."""
return self._container
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to Azure Blob Storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If blob exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
with open(local_path, "rb") as f:
blob_client.upload_blob(f, overwrite=overwrite)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a blob from Azure Blob Storage.
Args:
remote_path: Blob path in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
stream = blob_client.download_blob()
local_path.write_bytes(stream.readall())
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a blob exists in storage.
Args:
remote_path: Blob path to check.
Returns:
True if the blob exists, False otherwise.
"""
blob_client = self._container.get_blob_client(remote_path)
return blob_client.exists()
def list_files(self, prefix: str) -> list[str]:
"""List blobs in storage with given prefix.
Args:
prefix: Blob path prefix to filter.
Returns:
List of blob paths matching the prefix.
"""
if prefix:
blobs = self._container.list_blobs(name_starts_with=prefix)
else:
blobs = self._container.list_blobs()
return [blob.name for blob in blobs]
def delete(self, remote_path: str) -> bool:
"""Delete a blob from storage.
Args:
remote_path: Blob path to delete.
Returns:
True if blob was deleted, False if it didn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
return False
blob_client.delete_blob()
return True
def get_url(self, remote_path: str) -> str:
"""Get the URL for a blob.
Args:
remote_path: Blob path in storage.
Returns:
URL to access the blob.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
return blob_client.url
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to Azure Blob Storage.
Args:
data: Bytes to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the data was stored.
"""
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
blob_client.upload_blob(data, overwrite=overwrite)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a blob as bytes.
Args:
remote_path: Blob path in storage.
Returns:
The blob contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
stream = blob_client.download_blob()
return stream.readall()
def upload_directory(
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
) -> list[str]:
"""Upload all files in a directory to Azure Blob Storage.
Args:
local_dir: Local directory to upload.
remote_prefix: Prefix for remote blob paths.
overwrite: If True, overwrite existing blobs.
Returns:
List of remote paths where files were stored.
"""
uploaded: list[str] = []
for file_path in local_dir.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(local_dir)
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
self.upload(file_path, remote_path, overwrite=overwrite)
uploaded.append(remote_path)
return uploaded
def download_directory(
self, remote_prefix: str, local_dir: Path
) -> list[Path]:
"""Download all blobs with a prefix to a local directory.
Args:
remote_prefix: Blob path prefix to download.
local_dir: Local directory to download to.
Returns:
List of local paths where files were downloaded.
"""
downloaded: list[Path] = []
blobs = self.list_files(remote_prefix)
for blob_path in blobs:
# Remove prefix to get relative path
if remote_prefix:
relative_path = blob_path[len(remote_prefix):]
if relative_path.startswith("/"):
relative_path = relative_path[1:]
else:
relative_path = blob_path
local_path = local_dir / relative_path
self.download(blob_path, local_path)
downloaded.append(local_path)
return downloaded
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a SAS URL for temporary blob access.
Args:
remote_path: Blob path in storage.
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
Returns:
Blob URL with SAS token.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
from datetime import datetime, timedelta, timezone
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
# Generate SAS token
sas_token = generate_blob_sas(
account_name=self._blob_service.account_name,
container_name=self._container_name,
blob_name=remote_path,
account_key=self._account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
)
return f"{blob_client.url}?{sas_token}"

View File

@@ -0,0 +1,229 @@
"""
Base classes and interfaces for storage backends.
Defines the abstract StorageBackend interface and common exceptions.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
class StorageError(Exception):
"""Base exception for storage operations."""
pass
class FileNotFoundStorageError(StorageError):
"""Raised when a file is not found in storage."""
def __init__(self, path: str) -> None:
self.path = path
super().__init__(f"File not found in storage: {path}")
class PresignedUrlNotSupportedError(StorageError):
"""Raised when pre-signed URLs are not supported by a backend."""
def __init__(self, backend_type: str) -> None:
self.backend_type = backend_type
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
@dataclass(frozen=True)
class StorageConfig:
"""Configuration for storage backend.
Attributes:
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
connection_string: Azure Blob Storage connection string (for azure_blob).
container_name: Azure Blob Storage container name (for azure_blob).
base_path: Base path for local storage (for local).
bucket_name: S3 bucket name (for s3).
region_name: AWS region name (for s3).
access_key_id: AWS access key ID (for s3).
secret_access_key: AWS secret access key (for s3).
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
connection_string: str | None = None
container_name: str | None = None
base_path: Path | None = None
bucket_name: str | None = None
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
presigned_url_expiry: int = 3600
class StorageBackend(ABC):
"""Abstract base class for storage backends.
Provides a unified interface for storing and retrieving files
from different storage systems (local filesystem, Azure Blob, etc.).
"""
@abstractmethod
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
pass
@abstractmethod
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
pass
@abstractmethod
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
List of file paths matching the prefix.
"""
pass
@abstractmethod
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
pass
@abstractmethod
def get_url(self, remote_path: str) -> str:
"""Get a URL or path to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
URL or path to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary access.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: URL validity duration (default 1 hour).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
"""
pass
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Default implementation writes to temp file then uploads.
Subclasses may override for more efficient implementation.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
try:
return self.upload(temp_path, remote_path, overwrite=overwrite)
finally:
temp_path.unlink(missing_ok=True)
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Default implementation downloads to temp file then reads.
Subclasses may override for more efficient implementation.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_path = Path(f.name)
try:
self.download(remote_path, temp_path)
return temp_path.read_bytes()
finally:
temp_path.unlink(missing_ok=True)

View File

@@ -0,0 +1,242 @@
"""
Configuration file loader for storage backends.
Supports YAML configuration files with environment variable substitution.
"""
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LocalConfig:
"""Local storage backend configuration."""
base_path: Path
@dataclass(frozen=True)
class AzureConfig:
"""Azure Blob Storage configuration."""
connection_string: str
container_name: str
create_container: bool = False
@dataclass(frozen=True)
class S3Config:
"""AWS S3 configuration."""
bucket_name: str
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
create_bucket: bool = False
@dataclass(frozen=True)
class StorageFileConfig:
"""Extended storage configuration from file.
Attributes:
backend_type: Type of storage backend.
local: Local backend configuration.
azure: Azure Blob configuration.
s3: S3 configuration.
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
local: LocalConfig | None = None
azure: AzureConfig | None = None
s3: S3Config | None = None
presigned_url_expiry: int = 3600
def substitute_env_vars(value: str) -> str:
"""Substitute environment variables in a string.
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
Args:
value: String potentially containing env var references.
Returns:
String with env vars substituted.
"""
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
def replace(match: re.Match[str]) -> str:
var_name = match.group(1)
default = match.group(2)
return os.environ.get(var_name, default or "")
return re.sub(pattern, replace, value)
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
"""Recursively substitute env vars in a dictionary.
Args:
data: Dictionary to process.
Returns:
New dictionary with substitutions applied.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, str):
result[key] = substitute_env_vars(value)
elif isinstance(value, dict):
result[key] = _substitute_in_dict(value)
elif isinstance(value, list):
result[key] = [
substitute_env_vars(item) if isinstance(item, str) else item
for item in value
]
else:
result[key] = value
return result
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
"""Parse local configuration section.
Args:
data: Dictionary containing local config.
Returns:
LocalConfig instance.
Raises:
ValueError: If required fields are missing.
"""
base_path = data.get("base_path")
if not base_path:
raise ValueError("local.base_path is required")
return LocalConfig(base_path=Path(base_path))
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
"""Parse Azure configuration section.
Args:
data: Dictionary containing Azure config.
Returns:
AzureConfig instance.
Raises:
ValueError: If required fields are missing.
"""
connection_string = data.get("connection_string")
container_name = data.get("container_name")
if not connection_string:
raise ValueError("azure.connection_string is required")
if not container_name:
raise ValueError("azure.container_name is required")
return AzureConfig(
connection_string=connection_string,
container_name=container_name,
create_container=data.get("create_container", False),
)
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
"""Parse S3 configuration section.
Args:
data: Dictionary containing S3 config.
Returns:
S3Config instance.
Raises:
ValueError: If required fields are missing.
"""
bucket_name = data.get("bucket_name")
if not bucket_name:
raise ValueError("s3.bucket_name is required")
return S3Config(
bucket_name=bucket_name,
region_name=data.get("region_name"),
access_key_id=data.get("access_key_id"),
secret_access_key=data.get("secret_access_key"),
endpoint_url=data.get("endpoint_url"),
create_bucket=data.get("create_bucket", False),
)
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
"""Load storage configuration from YAML file.
Supports environment variable substitution using ${VAR_NAME} or
${VAR_NAME:-default} syntax.
Args:
config_path: Path to configuration file.
Returns:
Parsed StorageFileConfig.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If config is invalid.
"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
raw_content = config_path.read_text(encoding="utf-8")
data = yaml.safe_load(raw_content)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in config file: {e}") from e
if not isinstance(data, dict):
raise ValueError("Config file must contain a YAML dictionary")
# Substitute environment variables
data = _substitute_in_dict(data)
# Extract backend type
backend_type = data.get("backend")
if not backend_type:
raise ValueError("'backend' field is required in config file")
# Parse presigned URL expiry
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
# Parse backend-specific configurations
local_config = None
azure_config = None
s3_config = None
if "local" in data:
local_config = _parse_local_config(data["local"])
if "azure" in data:
azure_config = _parse_azure_config(data["azure"])
if "s3" in data:
s3_config = _parse_s3_config(data["s3"])
return StorageFileConfig(
backend_type=backend_type,
local=local_config,
azure=azure_config,
s3=s3_config,
presigned_url_expiry=presigned_url_expiry,
)

View File

@@ -0,0 +1,296 @@
"""
Factory functions for creating storage backends.
Provides convenient functions for creating storage backends from
configuration or environment variables.
"""
import os
from pathlib import Path
from shared.storage.base import StorageBackend, StorageConfig
def create_storage_backend(config: StorageConfig) -> StorageBackend:
"""Create a storage backend from configuration.
Args:
config: Storage configuration.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
"""
if config.backend_type == "local":
if config.base_path is None:
raise ValueError("base_path is required for local storage backend")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=config.base_path)
elif config.backend_type == "azure_blob":
if config.connection_string is None:
raise ValueError(
"connection_string is required for Azure blob storage backend"
)
if config.container_name is None:
raise ValueError(
"container_name is required for Azure blob storage backend"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=config.connection_string,
container_name=config.container_name,
)
elif config.backend_type == "s3":
if config.bucket_name is None:
raise ValueError("bucket_name is required for S3 storage backend")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=config.bucket_name,
region_name=config.region_name,
access_key_id=config.access_key_id,
secret_access_key=config.secret_access_key,
endpoint_url=config.endpoint_url,
)
else:
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
def get_default_storage_config() -> StorageConfig:
"""Get storage configuration from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
StorageConfig from environment.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local")
if backend_type == "local":
base_path_str = os.environ.get("STORAGE_BASE_PATH")
# Expand ~ to home directory
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
return StorageConfig(
backend_type="local",
base_path=base_path,
)
elif backend_type == "azure_blob":
return StorageConfig(
backend_type="azure_blob",
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
)
elif backend_type == "s3":
return StorageConfig(
backend_type="s3",
bucket_name=os.environ.get("AWS_S3_BUCKET"),
region_name=os.environ.get("AWS_REGION"),
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
)
else:
return StorageConfig(backend_type=backend_type)
def create_storage_backend_from_env() -> StorageBackend:
"""Create a storage backend from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
A configured storage backend.
Raises:
ValueError: If required environment variables are missing or empty.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
if backend_type == "local":
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
if not base_path:
raise ValueError(
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
)
# Expand ~ to home directory
base_path_expanded = os.path.expanduser(base_path)
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=Path(base_path_expanded))
elif backend_type == "azure_blob":
connection_string = os.environ.get(
"AZURE_STORAGE_CONNECTION_STRING", ""
).strip()
if not connection_string:
raise ValueError(
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
"and cannot be empty"
)
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
if not container_name:
raise ValueError(
"AZURE_STORAGE_CONTAINER environment variable is required "
"and cannot be empty"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=connection_string,
container_name=container_name,
)
elif backend_type == "s3":
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
if not bucket_name:
raise ValueError(
"AWS_S3_BUCKET environment variable is required and cannot be empty"
)
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=bucket_name,
region_name=os.environ.get("AWS_REGION", "").strip() or None,
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
or None,
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
)
else:
raise ValueError(f"Unknown storage backend type: {backend_type}")
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
"""Create a storage backend from a configuration file.
Args:
config_path: Path to YAML configuration file.
Returns:
A configured storage backend.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If configuration is invalid.
"""
from shared.storage.config_loader import load_storage_config
file_config = load_storage_config(config_path)
if file_config.backend_type == "local":
if file_config.local is None:
raise ValueError("local configuration section is required")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=file_config.local.base_path)
elif file_config.backend_type == "azure_blob":
if file_config.azure is None:
raise ValueError("azure configuration section is required")
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=file_config.azure.connection_string,
container_name=file_config.azure.container_name,
create_container=file_config.azure.create_container,
)
elif file_config.backend_type == "s3":
if file_config.s3 is None:
raise ValueError("s3 configuration section is required")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=file_config.s3.bucket_name,
region_name=file_config.s3.region_name,
access_key_id=file_config.s3.access_key_id,
secret_access_key=file_config.s3.secret_access_key,
endpoint_url=file_config.s3.endpoint_url,
create_bucket=file_config.s3.create_bucket,
)
else:
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
"""Get storage backend with fallback chain.
Priority:
1. Config file (if provided)
2. Environment variables
Args:
config_path: Optional path to config file.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
FileNotFoundError: If specified config file doesn't exist.
"""
if config_path:
return create_storage_backend_from_file(config_path)
# Fall back to environment variables
return create_storage_backend_from_env()

View File

@@ -0,0 +1,262 @@
"""
Local filesystem storage backend.
Provides storage operations using the local filesystem.
"""
import shutil
from pathlib import Path
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class LocalStorageBackend(StorageBackend):
"""Storage backend using local filesystem.
Files are stored relative to a base path on the local filesystem.
"""
def __init__(self, base_path: str | Path) -> None:
"""Initialize local storage backend.
Args:
base_path: Base directory for all storage operations.
Will be created if it doesn't exist.
"""
self._base_path = Path(base_path)
self._base_path.mkdir(parents=True, exist_ok=True)
@property
def base_path(self) -> Path:
"""Get the base path for this storage backend."""
return self._base_path
def _get_full_path(self, remote_path: str) -> Path:
"""Convert a remote path to a full local path with security validation.
Args:
remote_path: The remote path to resolve.
Returns:
The full local path.
Raises:
StorageError: If the path attempts to escape the base directory.
"""
# Reject absolute paths
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
raise StorageError(f"Absolute paths not allowed: {remote_path}")
# Resolve to prevent path traversal attacks
full_path = (self._base_path / remote_path).resolve()
base_resolved = self._base_path.resolve()
# Verify the resolved path is within base_path
try:
full_path.relative_to(base_resolved)
except ValueError:
raise StorageError(f"Path traversal not allowed: {remote_path}")
return full_path
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to local storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(local_path, dest_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from local storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
source_path = self._get_full_path(remote_path)
if not source_path.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_path, local_path)
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
return self._get_full_path(remote_path).exists()
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
Sorted list of file paths matching the prefix.
"""
if prefix:
search_path = self._get_full_path(prefix)
if not search_path.exists():
return []
base_for_relative = self._base_path
else:
search_path = self._base_path
base_for_relative = self._base_path
files: list[str] = []
if search_path.is_file():
files.append(str(search_path.relative_to(self._base_path)))
elif search_path.is_dir():
for file_path in search_path.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(self._base_path)
files.append(str(relative_path).replace("\\", "/"))
return sorted(files)
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
return False
file_path.unlink()
return True
def get_url(self, remote_path: str) -> str:
"""Get a file:// URL to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
dest_path.write_bytes(data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.read_bytes()
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a file:// URL for local file access.
For local storage, this returns a file:// URI.
Note: Local file:// URLs don't actually expire.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: Ignored for local storage (URLs don't expire).
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()

View File

@@ -0,0 +1,158 @@
"""
Storage path prefixes for unified file organization.
Provides standardized path prefixes for organizing files within
the storage backend, ensuring consistent structure across
local, Azure Blob, and S3 storage.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class StoragePrefixes:
"""Standardized storage path prefixes.
All paths are relative to the storage backend root.
These prefixes ensure consistent file organization across
all storage backends (local, Azure, S3).
Usage:
from shared.storage.prefixes import PREFIXES
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
storage.upload_bytes(content, path)
"""
# Document storage
DOCUMENTS: str = "documents"
"""Original document files (PDFs, etc.)"""
IMAGES: str = "images"
"""Page images extracted from documents"""
# Processing directories
UPLOADS: str = "uploads"
"""Temporary upload staging area"""
RESULTS: str = "results"
"""Inference results and visualizations"""
EXPORTS: str = "exports"
"""Exported datasets and annotations"""
# Training data
DATASETS: str = "datasets"
"""Training dataset files"""
MODELS: str = "models"
"""Trained model weights and checkpoints"""
# Data pipeline directories (legacy compatibility)
RAW_PDFS: str = "raw_pdfs"
"""Raw PDF files for auto-labeling pipeline"""
STRUCTURED_DATA: str = "structured_data"
"""CSV/structured data for matching"""
ADMIN_IMAGES: str = "admin_images"
"""Admin UI page images"""
@staticmethod
def document_path(document_id: str, extension: str = ".pdf") -> str:
"""Get path for a document file.
Args:
document_id: Unique document identifier.
extension: File extension (include leading dot).
Returns:
Storage path like "documents/abc123.pdf"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
@staticmethod
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
"""Get path for a page image file.
Args:
document_id: Unique document identifier.
page_num: Page number (1-indexed).
extension: File extension (include leading dot).
Returns:
Storage path like "images/abc123/page_1.png"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
@staticmethod
def upload_path(filename: str, subfolder: str | None = None) -> str:
"""Get path for a temporary upload file.
Args:
filename: Original filename.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
"""
if subfolder:
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
return f"{PREFIXES.UPLOADS}/{filename}"
@staticmethod
def result_path(filename: str) -> str:
"""Get path for a result file.
Args:
filename: Result filename.
Returns:
Storage path like "results/filename.json"
"""
return f"{PREFIXES.RESULTS}/{filename}"
@staticmethod
def export_path(export_id: str, filename: str) -> str:
"""Get path for an export file.
Args:
export_id: Unique export identifier.
filename: Export filename.
Returns:
Storage path like "exports/abc123/filename.zip"
"""
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
@staticmethod
def dataset_path(dataset_id: str, filename: str) -> str:
"""Get path for a dataset file.
Args:
dataset_id: Unique dataset identifier.
filename: Dataset filename.
Returns:
Storage path like "datasets/abc123/filename.yaml"
"""
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
@staticmethod
def model_path(version: str, filename: str) -> str:
"""Get path for a model file.
Args:
version: Model version string.
filename: Model filename.
Returns:
Storage path like "models/v1.0.0/best.pt"
"""
return f"{PREFIXES.MODELS}/{version}/{filename}"
# Default instance for convenient access
PREFIXES = StoragePrefixes()

View File

@@ -0,0 +1,309 @@
"""
AWS S3 Storage backend.
Provides storage operations using AWS S3.
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class S3StorageBackend(StorageBackend):
"""Storage backend using AWS S3.
Files are stored as objects in an S3 bucket.
"""
def __init__(
self,
bucket_name: str,
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
endpoint_url: str | None = None,
create_bucket: bool = False,
) -> None:
"""Initialize S3 storage backend.
Args:
bucket_name: Name of the S3 bucket.
region_name: AWS region name (optional, uses default if not set).
access_key_id: AWS access key ID (optional, uses credentials chain).
secret_access_key: AWS secret access key (optional).
endpoint_url: Custom endpoint URL (for S3-compatible services).
create_bucket: If True, create the bucket if it doesn't exist.
"""
import boto3
self._bucket_name = bucket_name
self._region_name = region_name
# Build client kwargs
client_kwargs: dict[str, Any] = {}
if region_name:
client_kwargs["region_name"] = region_name
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
if access_key_id and secret_access_key:
client_kwargs["aws_access_key_id"] = access_key_id
client_kwargs["aws_secret_access_key"] = secret_access_key
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
if create_bucket:
self._ensure_bucket_exists()
def _ensure_bucket_exists(self) -> None:
"""Create the bucket if it doesn't exist."""
from botocore.exceptions import ClientError
try:
self._s3.head_bucket(Bucket=self._bucket_name)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchBucket"):
# Bucket doesn't exist, create it
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if self._region_name and self._region_name != "us-east-1":
create_kwargs["CreateBucketConfiguration"] = {
"LocationConstraint": self._region_name
}
self._s3.create_bucket(**create_kwargs)
else:
# Re-raise permission errors, network issues, etc.
raise
def _object_exists(self, key: str) -> bool:
"""Check if an object exists in S3.
Args:
key: Object key to check.
Returns:
True if object exists, False otherwise.
"""
from botocore.exceptions import ClientError
try:
self._s3.head_object(Bucket=self._bucket_name, Key=key)
return True
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
return False
raise
@property
def bucket_name(self) -> str:
"""Get the bucket name for this storage backend."""
return self._bucket_name
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to S3.
Args:
local_path: Path to the local file to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If object exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download an object from S3.
Args:
remote_path: Object key in S3.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if an object exists in S3.
Args:
remote_path: Object key to check.
Returns:
True if the object exists, False otherwise.
"""
return self._object_exists(remote_path)
def list_files(self, prefix: str) -> list[str]:
"""List objects in S3 with given prefix.
Handles pagination to return all matching objects (S3 returns max 1000 per request).
Args:
prefix: Object key prefix to filter.
Returns:
List of object keys matching the prefix.
"""
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if prefix:
kwargs["Prefix"] = prefix
all_keys: list[str] = []
while True:
response = self._s3.list_objects_v2(**kwargs)
contents = response.get("Contents", [])
all_keys.extend(obj["Key"] for obj in contents)
if not response.get("IsTruncated"):
break
kwargs["ContinuationToken"] = response["NextContinuationToken"]
return all_keys
def delete(self, remote_path: str) -> bool:
"""Delete an object from S3.
Args:
remote_path: Object key to delete.
Returns:
True if object was deleted, False if it didn't exist.
"""
if not self._object_exists(remote_path):
return False
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
return True
def get_url(self, remote_path: str) -> str:
"""Get a URL for an object.
Args:
remote_path: Object key in S3.
Returns:
URL to access the object.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=3600,
)
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary object access.
Args:
remote_path: Object key in S3.
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=expires_in_seconds,
)
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to S3.
Args:
data: Bytes to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the data was stored.
Raises:
StorageError: If object exists and overwrite is False.
"""
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download an object as bytes.
Args:
remote_path: Object key in S3.
Returns:
The object contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
from botocore.exceptions import ClientError
try:
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
raise FileNotFoundStorageError(remote_path) from e
raise