WIP
This commit is contained in:
205
packages/shared/README.md
Normal file
205
packages/shared/README.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# Shared Package
|
||||
|
||||
Shared utilities and abstractions for the Invoice Master system.
|
||||
|
||||
## Storage Abstraction Layer
|
||||
|
||||
A unified storage abstraction supporting multiple backends:
|
||||
- **Local filesystem** - Development and testing
|
||||
- **Azure Blob Storage** - Azure cloud deployments
|
||||
- **AWS S3** - AWS cloud deployments
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation (local storage only)
|
||||
pip install -e packages/shared
|
||||
|
||||
# With Azure support
|
||||
pip install -e "packages/shared[azure]"
|
||||
|
||||
# With S3 support
|
||||
pip install -e "packages/shared[s3]"
|
||||
|
||||
# All cloud providers
|
||||
pip install -e "packages/shared[all]"
|
||||
```
|
||||
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
from shared.storage import get_storage_backend
|
||||
|
||||
# Option 1: From configuration file
|
||||
storage = get_storage_backend("storage.yaml")
|
||||
|
||||
# Option 2: From environment variables
|
||||
from shared.storage import create_storage_backend_from_env
|
||||
storage = create_storage_backend_from_env()
|
||||
|
||||
# Upload a file
|
||||
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
|
||||
|
||||
# Download a file
|
||||
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
|
||||
|
||||
# Get pre-signed URL for frontend access
|
||||
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
|
||||
```
|
||||
|
||||
### Configuration File Format
|
||||
|
||||
Create a `storage.yaml` file with environment variable substitution support:
|
||||
|
||||
```yaml
|
||||
# Backend selection: local, azure_blob, or s3
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
# Default pre-signed URL expiry (seconds)
|
||||
presigned_url_expiry: 3600
|
||||
|
||||
# Local storage configuration
|
||||
local:
|
||||
base_path: ${STORAGE_BASE_PATH:-./data/storage}
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
azure:
|
||||
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
|
||||
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
|
||||
create_container: false
|
||||
|
||||
# AWS S3 configuration
|
||||
s3:
|
||||
bucket_name: ${AWS_S3_BUCKET}
|
||||
region_name: ${AWS_REGION:-us-east-1}
|
||||
access_key_id: ${AWS_ACCESS_KEY_ID}
|
||||
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
|
||||
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
|
||||
create_bucket: false
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Backend | Description |
|
||||
|----------|---------|-------------|
|
||||
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
|
||||
| `STORAGE_BASE_PATH` | Local | Base directory path |
|
||||
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
|
||||
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
|
||||
| `AWS_S3_BUCKET` | S3 | Bucket name |
|
||||
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
|
||||
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
|
||||
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
|
||||
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
|
||||
|
||||
### API Reference
|
||||
|
||||
#### StorageBackend Interface
|
||||
|
||||
```python
|
||||
class StorageBackend(ABC):
|
||||
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||
"""Upload a file to storage."""
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from storage."""
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists."""
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files with given prefix."""
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file."""
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get URL for a file."""
|
||||
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
|
||||
|
||||
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
|
||||
"""Upload bytes directly."""
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download file as bytes."""
|
||||
```
|
||||
|
||||
#### Factory Functions
|
||||
|
||||
```python
|
||||
# Create from configuration file
|
||||
storage = create_storage_backend_from_file("storage.yaml")
|
||||
|
||||
# Create from environment variables
|
||||
storage = create_storage_backend_from_env()
|
||||
|
||||
# Create from StorageConfig object
|
||||
config = StorageConfig(backend_type="local", base_path=Path("./data"))
|
||||
storage = create_storage_backend(config)
|
||||
|
||||
# Convenience function with fallback chain: config file -> env vars -> local default
|
||||
storage = get_storage_backend("storage.yaml") # or None for env-only
|
||||
```
|
||||
|
||||
### Pre-signed URLs
|
||||
|
||||
Pre-signed URLs provide temporary access to files without exposing credentials:
|
||||
|
||||
```python
|
||||
# Generate URL valid for 1 hour (default)
|
||||
url = storage.get_presigned_url("documents/invoice.pdf")
|
||||
|
||||
# Generate URL valid for 24 hours
|
||||
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
|
||||
|
||||
# Maximum expiry: 7 days (604800 seconds)
|
||||
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
|
||||
```
|
||||
|
||||
**Note:** Local storage returns `file://` URLs that don't actually expire.
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
from shared.storage import (
|
||||
StorageError,
|
||||
FileNotFoundStorageError,
|
||||
PresignedUrlNotSupportedError,
|
||||
)
|
||||
|
||||
try:
|
||||
storage.download("nonexistent.pdf", Path("local.pdf"))
|
||||
except FileNotFoundStorageError as e:
|
||||
print(f"File not found: {e}")
|
||||
except StorageError as e:
|
||||
print(f"Storage error: {e}")
|
||||
```
|
||||
|
||||
### Testing with MinIO (S3-compatible)
|
||||
|
||||
```bash
|
||||
# Start MinIO locally
|
||||
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
|
||||
|
||||
# Configure environment
|
||||
export STORAGE_BACKEND=s3
|
||||
export AWS_S3_BUCKET=test-bucket
|
||||
export AWS_ENDPOINT_URL=http://localhost:9000
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
```
|
||||
|
||||
### Module Structure
|
||||
|
||||
```
|
||||
shared/storage/
|
||||
├── __init__.py # Public exports
|
||||
├── base.py # Abstract interface and exceptions
|
||||
├── local.py # Local filesystem backend
|
||||
├── azure.py # Azure Blob Storage backend
|
||||
├── s3.py # AWS S3 backend
|
||||
├── config_loader.py # YAML configuration loader
|
||||
└── factory.py # Backend factory functions
|
||||
```
|
||||
@@ -16,4 +16,18 @@ setup(
|
||||
"pyyaml>=6.0",
|
||||
"thefuzz>=0.20.0",
|
||||
],
|
||||
extras_require={
|
||||
"azure": [
|
||||
"azure-storage-blob>=12.19.0",
|
||||
"azure-identity>=1.15.0",
|
||||
],
|
||||
"s3": [
|
||||
"boto3>=1.34.0",
|
||||
],
|
||||
"all": [
|
||||
"azure-storage-blob>=12.19.0",
|
||||
"azure-identity>=1.15.0",
|
||||
"boto3>=1.34.0",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -58,23 +58,16 @@ def get_db_connection_string():
|
||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||
|
||||
|
||||
# Paths Configuration - auto-detect WSL vs Windows
|
||||
if _is_wsl():
|
||||
# WSL: use native Linux filesystem for better I/O performance
|
||||
PATHS = {
|
||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
else:
|
||||
# Windows or native Linux: use relative paths
|
||||
PATHS = {
|
||||
'csv_dir': 'data/structured_data',
|
||||
'pdf_dir': 'data/raw_pdfs',
|
||||
'output_dir': 'data/dataset',
|
||||
'reports_dir': 'reports',
|
||||
}
|
||||
# Paths Configuration - uses STORAGE_BASE_PATH for consistency
|
||||
# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
|
||||
_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
|
||||
|
||||
PATHS = {
|
||||
'csv_dir': f'{_storage_base}/structured_data',
|
||||
'pdf_dir': f'{_storage_base}/raw_pdfs',
|
||||
'output_dir': f'{_storage_base}/datasets',
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
|
||||
46
packages/shared/shared/fields/__init__.py
Normal file
46
packages/shared/shared/fields/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Shared Field Definitions - Single Source of Truth.
|
||||
|
||||
This module provides centralized field class definitions used throughout
|
||||
the invoice extraction system. All field mappings are derived from
|
||||
FIELD_DEFINITIONS to ensure consistency.
|
||||
|
||||
Usage:
|
||||
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
Available exports:
|
||||
- FieldDefinition: Dataclass for field definition
|
||||
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
|
||||
- NUM_CLASSES: Total number of field classes (10)
|
||||
- CLASS_NAMES: List of class names in order [0..9]
|
||||
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||
"""
|
||||
|
||||
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
|
||||
from .mappings import (
|
||||
CLASS_NAMES,
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FieldDefinition",
|
||||
"FIELD_DEFINITIONS",
|
||||
"NUM_CLASSES",
|
||||
"CLASS_NAMES",
|
||||
"FIELD_CLASSES",
|
||||
"FIELD_CLASS_IDS",
|
||||
"CLASS_TO_FIELD",
|
||||
"CSV_TO_CLASS_MAPPING",
|
||||
"TRAINING_FIELD_CLASSES",
|
||||
"ACCOUNT_FIELD_MAPPING",
|
||||
]
|
||||
58
packages/shared/shared/fields/field_config.py
Normal file
58
packages/shared/shared/fields/field_config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Field Configuration - Single Source of Truth
|
||||
|
||||
This module defines all invoice field classes used throughout the system.
|
||||
The class IDs are verified against the trained YOLO model (best.pt).
|
||||
|
||||
IMPORTANT: Do not modify class_id values without retraining the model!
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FieldDefinition:
|
||||
"""Immutable field definition for invoice extraction.
|
||||
|
||||
Attributes:
|
||||
class_id: YOLO class ID (0-9), must match trained model
|
||||
class_name: YOLO class name (lowercase_underscore)
|
||||
field_name: Business field name used in API responses
|
||||
csv_name: CSV column name for data import/export
|
||||
is_derived: True if field is derived from other fields (not in CSV)
|
||||
"""
|
||||
|
||||
class_id: int
|
||||
class_name: str
|
||||
field_name: str
|
||||
csv_name: str
|
||||
is_derived: bool = False
|
||||
|
||||
|
||||
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
|
||||
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
|
||||
#
|
||||
# DO NOT CHANGE THE ORDER - it must match the trained model!
|
||||
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
|
||||
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
|
||||
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
|
||||
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
|
||||
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
|
||||
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
|
||||
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
|
||||
FieldDefinition(6, "amount", "Amount", "Amount"),
|
||||
FieldDefinition(
|
||||
7,
|
||||
"supplier_org_number",
|
||||
"supplier_organisation_number",
|
||||
"supplier_organisation_number",
|
||||
),
|
||||
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
|
||||
FieldDefinition(
|
||||
9, "payment_line", "payment_line", "payment_line", is_derived=True
|
||||
),
|
||||
)
|
||||
|
||||
# Total number of field classes
|
||||
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)
|
||||
57
packages/shared/shared/fields/mappings.py
Normal file
57
packages/shared/shared/fields/mappings.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
|
||||
|
||||
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
|
||||
This ensures consistency across the entire codebase.
|
||||
|
||||
DO NOT hardcode field mappings elsewhere - always import from this module.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
|
||||
from .field_config import FIELD_DEFINITIONS
|
||||
|
||||
|
||||
# List of class names in order (for YOLO classes.txt generation)
|
||||
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
|
||||
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||
|
||||
# class_id -> class_name mapping
|
||||
# Example: {0: "invoice_number", 1: "invoice_date", ...}
|
||||
FIELD_CLASSES: Final[dict[int, str]] = {
|
||||
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
|
||||
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
|
||||
FIELD_CLASS_IDS: Final[dict[str, int]] = {
|
||||
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# class_name -> field_name mapping (for API responses)
|
||||
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
|
||||
CLASS_TO_FIELD: Final[dict[str, str]] = {
|
||||
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# field_name -> class_id mapping (for CSV import)
|
||||
# Excludes derived fields like payment_line
|
||||
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
|
||||
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
|
||||
}
|
||||
|
||||
# field_name -> class_id mapping (for training, includes all fields)
|
||||
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
|
||||
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# Account field mapping for supplier_accounts special handling
|
||||
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||
"supplier_accounts": {
|
||||
"BG": "Bankgiro",
|
||||
"PG": "Plusgiro",
|
||||
}
|
||||
}
|
||||
59
packages/shared/shared/storage/__init__.py
Normal file
59
packages/shared/shared/storage/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Storage abstraction layer for training data.
|
||||
|
||||
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
|
||||
"""
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
PresignedUrlNotSupportedError,
|
||||
StorageBackend,
|
||||
StorageConfig,
|
||||
StorageError,
|
||||
)
|
||||
from shared.storage.factory import (
|
||||
create_storage_backend,
|
||||
create_storage_backend_from_env,
|
||||
create_storage_backend_from_file,
|
||||
get_default_storage_config,
|
||||
get_storage_backend,
|
||||
)
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||
|
||||
__all__ = [
|
||||
# Base classes and exceptions
|
||||
"StorageBackend",
|
||||
"StorageConfig",
|
||||
"StorageError",
|
||||
"FileNotFoundStorageError",
|
||||
"PresignedUrlNotSupportedError",
|
||||
# Backends
|
||||
"LocalStorageBackend",
|
||||
# Factory functions
|
||||
"create_storage_backend",
|
||||
"create_storage_backend_from_env",
|
||||
"create_storage_backend_from_file",
|
||||
"get_default_storage_config",
|
||||
"get_storage_backend",
|
||||
# Path prefixes
|
||||
"PREFIXES",
|
||||
"StoragePrefixes",
|
||||
]
|
||||
|
||||
|
||||
# Lazy imports to avoid dependencies when not using specific backends
|
||||
def __getattr__(name: str):
|
||||
if name == "AzureBlobStorageBackend":
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend
|
||||
if name == "S3StorageBackend":
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend
|
||||
if name == "load_storage_config":
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
return load_storage_config
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
335
packages/shared/shared/storage/azure.py
Normal file
335
packages/shared/shared/storage/azure.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Azure Blob Storage backend.
|
||||
|
||||
Provides storage operations using Azure Blob Storage.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from azure.storage.blob import (
|
||||
BlobSasPermissions,
|
||||
BlobServiceClient,
|
||||
ContainerClient,
|
||||
generate_blob_sas,
|
||||
)
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class AzureBlobStorageBackend(StorageBackend):
|
||||
"""Storage backend using Azure Blob Storage.
|
||||
|
||||
Files are stored as blobs in an Azure Blob Storage container.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
container_name: str,
|
||||
create_container: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Azure Blob Storage backend.
|
||||
|
||||
Args:
|
||||
connection_string: Azure Storage connection string.
|
||||
container_name: Name of the blob container.
|
||||
create_container: If True, create the container if it doesn't exist.
|
||||
"""
|
||||
self._connection_string = connection_string
|
||||
self._container_name = container_name
|
||||
|
||||
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
|
||||
self._container = self._blob_service.get_container_client(container_name)
|
||||
|
||||
# Extract account key from connection string for SAS token generation
|
||||
self._account_key = self._extract_account_key(connection_string)
|
||||
|
||||
if create_container and not self._container.exists():
|
||||
self._container.create_container()
|
||||
|
||||
@staticmethod
|
||||
def _extract_account_key(connection_string: str) -> str | None:
|
||||
"""Extract account key from connection string.
|
||||
|
||||
Args:
|
||||
connection_string: Azure Storage connection string.
|
||||
|
||||
Returns:
|
||||
Account key if found, None otherwise.
|
||||
"""
|
||||
for part in connection_string.split(";"):
|
||||
if part.startswith("AccountKey="):
|
||||
return part[len("AccountKey=") :]
|
||||
return None
|
||||
|
||||
@property
|
||||
def container_name(self) -> str:
|
||||
"""Get the container name for this storage backend."""
|
||||
return self._container_name
|
||||
|
||||
@property
|
||||
def container_client(self) -> ContainerClient:
|
||||
"""Get the Azure container client."""
|
||||
return self._container
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination blob path.
|
||||
overwrite: If True, overwrite existing blob.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If blob exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if blob_client.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
blob_client.upload_blob(f, overwrite=overwrite)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a blob from Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stream = blob_client.download_blob()
|
||||
local_path.write_bytes(stream.readall())
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a blob exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path to check.
|
||||
|
||||
Returns:
|
||||
True if the blob exists, False otherwise.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
return blob_client.exists()
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List blobs in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Blob path prefix to filter.
|
||||
|
||||
Returns:
|
||||
List of blob paths matching the prefix.
|
||||
"""
|
||||
if prefix:
|
||||
blobs = self._container.list_blobs(name_starts_with=prefix)
|
||||
else:
|
||||
blobs = self._container.list_blobs()
|
||||
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a blob from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path to delete.
|
||||
|
||||
Returns:
|
||||
True if blob was deleted, False if it didn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
return False
|
||||
|
||||
blob_client.delete_blob()
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get the URL for a blob.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
|
||||
Returns:
|
||||
URL to access the blob.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return blob_client.url
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination blob path.
|
||||
overwrite: If True, overwrite existing blob.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if blob_client.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
blob_client.upload_blob(data, overwrite=overwrite)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a blob as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
|
||||
Returns:
|
||||
The blob contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
stream = blob_client.download_blob()
|
||||
return stream.readall()
|
||||
|
||||
def upload_directory(
|
||||
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
|
||||
) -> list[str]:
|
||||
"""Upload all files in a directory to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory to upload.
|
||||
remote_prefix: Prefix for remote blob paths.
|
||||
overwrite: If True, overwrite existing blobs.
|
||||
|
||||
Returns:
|
||||
List of remote paths where files were stored.
|
||||
"""
|
||||
uploaded: list[str] = []
|
||||
|
||||
for file_path in local_dir.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(local_dir)
|
||||
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
|
||||
self.upload(file_path, remote_path, overwrite=overwrite)
|
||||
uploaded.append(remote_path)
|
||||
|
||||
return uploaded
|
||||
|
||||
def download_directory(
|
||||
self, remote_prefix: str, local_dir: Path
|
||||
) -> list[Path]:
|
||||
"""Download all blobs with a prefix to a local directory.
|
||||
|
||||
Args:
|
||||
remote_prefix: Blob path prefix to download.
|
||||
local_dir: Local directory to download to.
|
||||
|
||||
Returns:
|
||||
List of local paths where files were downloaded.
|
||||
"""
|
||||
downloaded: list[Path] = []
|
||||
|
||||
blobs = self.list_files(remote_prefix)
|
||||
|
||||
for blob_path in blobs:
|
||||
# Remove prefix to get relative path
|
||||
if remote_prefix:
|
||||
relative_path = blob_path[len(remote_prefix):]
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
else:
|
||||
relative_path = blob_path
|
||||
|
||||
local_path = local_dir / relative_path
|
||||
self.download(blob_path, local_path)
|
||||
downloaded.append(local_path)
|
||||
|
||||
return downloaded
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a SAS URL for temporary blob access.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
|
||||
|
||||
Returns:
|
||||
Blob URL with SAS token.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
ValueError: If expires_in_seconds is out of valid range.
|
||||
"""
|
||||
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||
raise ValueError(
|
||||
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||
)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
# Generate SAS token
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=self._blob_service.account_name,
|
||||
container_name=self._container_name,
|
||||
blob_name=remote_path,
|
||||
account_key=self._account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
|
||||
)
|
||||
|
||||
return f"{blob_client.url}?{sas_token}"
|
||||
229
packages/shared/shared/storage/base.py
Normal file
229
packages/shared/shared/storage/base.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Base classes and interfaces for storage backends.
|
||||
|
||||
Defines the abstract StorageBackend interface and common exceptions.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StorageError(Exception):
|
||||
"""Base exception for storage operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileNotFoundStorageError(StorageError):
|
||||
"""Raised when a file is not found in storage."""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
super().__init__(f"File not found in storage: {path}")
|
||||
|
||||
|
||||
class PresignedUrlNotSupportedError(StorageError):
|
||||
"""Raised when pre-signed URLs are not supported by a backend."""
|
||||
|
||||
def __init__(self, backend_type: str) -> None:
|
||||
self.backend_type = backend_type
|
||||
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""Configuration for storage backend.
|
||||
|
||||
Attributes:
|
||||
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
|
||||
connection_string: Azure Blob Storage connection string (for azure_blob).
|
||||
container_name: Azure Blob Storage container name (for azure_blob).
|
||||
base_path: Base path for local storage (for local).
|
||||
bucket_name: S3 bucket name (for s3).
|
||||
region_name: AWS region name (for s3).
|
||||
access_key_id: AWS access key ID (for s3).
|
||||
secret_access_key: AWS secret access key (for s3).
|
||||
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
|
||||
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||
"""
|
||||
|
||||
backend_type: str
|
||||
connection_string: str | None = None
|
||||
container_name: str | None = None
|
||||
base_path: Path | None = None
|
||||
bucket_name: str | None = None
|
||||
region_name: str | None = None
|
||||
access_key_id: str | None = None
|
||||
secret_access_key: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
presigned_url_expiry: int = 3600
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Abstract base class for storage backends.
|
||||
|
||||
Provides a unified interface for storing and retrieving files
|
||||
from different storage systems (local filesystem, Azure Blob, etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If file exists and overwrite is False.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to check in storage.
|
||||
|
||||
Returns:
|
||||
True if the file exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Path prefix to filter files.
|
||||
|
||||
Returns:
|
||||
List of file paths matching the prefix.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if it didn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a URL or path to access a file.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
URL or path to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a pre-signed URL for temporary access.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
expires_in_seconds: URL validity duration (default 1 hour).
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
|
||||
"""
|
||||
pass
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to storage.
|
||||
|
||||
Default implementation writes to temp file then uploads.
|
||||
Subclasses may override for more efficient implementation.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(data)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
return self.upload(temp_path, remote_path, overwrite=overwrite)
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a file as bytes.
|
||||
|
||||
Default implementation downloads to temp file then reads.
|
||||
Subclasses may override for more efficient implementation.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
The file contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
self.download(remote_path, temp_path)
|
||||
return temp_path.read_bytes()
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
242
packages/shared/shared/storage/config_loader.py
Normal file
242
packages/shared/shared/storage/config_loader.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Configuration file loader for storage backends.
|
||||
|
||||
Supports YAML configuration files with environment variable substitution.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocalConfig:
|
||||
"""Local storage backend configuration."""
|
||||
|
||||
base_path: Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureConfig:
|
||||
"""Azure Blob Storage configuration."""
|
||||
|
||||
connection_string: str
|
||||
container_name: str
|
||||
create_container: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class S3Config:
|
||||
"""AWS S3 configuration."""
|
||||
|
||||
bucket_name: str
|
||||
region_name: str | None = None
|
||||
access_key_id: str | None = None
|
||||
secret_access_key: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
create_bucket: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageFileConfig:
|
||||
"""Extended storage configuration from file.
|
||||
|
||||
Attributes:
|
||||
backend_type: Type of storage backend.
|
||||
local: Local backend configuration.
|
||||
azure: Azure Blob configuration.
|
||||
s3: S3 configuration.
|
||||
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||
"""
|
||||
|
||||
backend_type: str
|
||||
local: LocalConfig | None = None
|
||||
azure: AzureConfig | None = None
|
||||
s3: S3Config | None = None
|
||||
presigned_url_expiry: int = 3600
|
||||
|
||||
|
||||
def substitute_env_vars(value: str) -> str:
|
||||
"""Substitute environment variables in a string.
|
||||
|
||||
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
|
||||
|
||||
Args:
|
||||
value: String potentially containing env var references.
|
||||
|
||||
Returns:
|
||||
String with env vars substituted.
|
||||
"""
|
||||
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
var_name = match.group(1)
|
||||
default = match.group(2)
|
||||
return os.environ.get(var_name, default or "")
|
||||
|
||||
return re.sub(pattern, replace, value)
|
||||
|
||||
|
||||
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively substitute env vars in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to process.
|
||||
|
||||
Returns:
|
||||
New dictionary with substitutions applied.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
result[key] = substitute_env_vars(value)
|
||||
elif isinstance(value, dict):
|
||||
result[key] = _substitute_in_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
substitute_env_vars(item) if isinstance(item, str) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
|
||||
"""Parse local configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing local config.
|
||||
|
||||
Returns:
|
||||
LocalConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
base_path = data.get("base_path")
|
||||
if not base_path:
|
||||
raise ValueError("local.base_path is required")
|
||||
return LocalConfig(base_path=Path(base_path))
|
||||
|
||||
|
||||
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
|
||||
"""Parse Azure configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing Azure config.
|
||||
|
||||
Returns:
|
||||
AzureConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
connection_string = data.get("connection_string")
|
||||
container_name = data.get("container_name")
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("azure.connection_string is required")
|
||||
if not container_name:
|
||||
raise ValueError("azure.container_name is required")
|
||||
|
||||
return AzureConfig(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
create_container=data.get("create_container", False),
|
||||
)
|
||||
|
||||
|
||||
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
|
||||
"""Parse S3 configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing S3 config.
|
||||
|
||||
Returns:
|
||||
S3Config instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
bucket_name = data.get("bucket_name")
|
||||
|
||||
if not bucket_name:
|
||||
raise ValueError("s3.bucket_name is required")
|
||||
|
||||
return S3Config(
|
||||
bucket_name=bucket_name,
|
||||
region_name=data.get("region_name"),
|
||||
access_key_id=data.get("access_key_id"),
|
||||
secret_access_key=data.get("secret_access_key"),
|
||||
endpoint_url=data.get("endpoint_url"),
|
||||
create_bucket=data.get("create_bucket", False),
|
||||
)
|
||||
|
||||
|
||||
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
|
||||
"""Load storage configuration from YAML file.
|
||||
|
||||
Supports environment variable substitution using ${VAR_NAME} or
|
||||
${VAR_NAME:-default} syntax.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file.
|
||||
|
||||
Returns:
|
||||
Parsed StorageFileConfig.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist.
|
||||
ValueError: If config is invalid.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
try:
|
||||
raw_content = config_path.read_text(encoding="utf-8")
|
||||
data = yaml.safe_load(raw_content)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in config file: {e}") from e
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Config file must contain a YAML dictionary")
|
||||
|
||||
# Substitute environment variables
|
||||
data = _substitute_in_dict(data)
|
||||
|
||||
# Extract backend type
|
||||
backend_type = data.get("backend")
|
||||
if not backend_type:
|
||||
raise ValueError("'backend' field is required in config file")
|
||||
|
||||
# Parse presigned URL expiry
|
||||
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
|
||||
|
||||
# Parse backend-specific configurations
|
||||
local_config = None
|
||||
azure_config = None
|
||||
s3_config = None
|
||||
|
||||
if "local" in data:
|
||||
local_config = _parse_local_config(data["local"])
|
||||
|
||||
if "azure" in data:
|
||||
azure_config = _parse_azure_config(data["azure"])
|
||||
|
||||
if "s3" in data:
|
||||
s3_config = _parse_s3_config(data["s3"])
|
||||
|
||||
return StorageFileConfig(
|
||||
backend_type=backend_type,
|
||||
local=local_config,
|
||||
azure=azure_config,
|
||||
s3=s3_config,
|
||||
presigned_url_expiry=presigned_url_expiry,
|
||||
)
|
||||
296
packages/shared/shared/storage/factory.py
Normal file
296
packages/shared/shared/storage/factory.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Factory functions for creating storage backends.
|
||||
|
||||
Provides convenient functions for creating storage backends from
|
||||
configuration or environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from shared.storage.base import StorageBackend, StorageConfig
|
||||
|
||||
|
||||
def create_storage_backend(config: StorageConfig) -> StorageBackend:
|
||||
"""Create a storage backend from configuration.
|
||||
|
||||
Args:
|
||||
config: Storage configuration.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid.
|
||||
"""
|
||||
if config.backend_type == "local":
|
||||
if config.base_path is None:
|
||||
raise ValueError("base_path is required for local storage backend")
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=config.base_path)
|
||||
|
||||
elif config.backend_type == "azure_blob":
|
||||
if config.connection_string is None:
|
||||
raise ValueError(
|
||||
"connection_string is required for Azure blob storage backend"
|
||||
)
|
||||
if config.container_name is None:
|
||||
raise ValueError(
|
||||
"container_name is required for Azure blob storage backend"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=config.connection_string,
|
||||
container_name=config.container_name,
|
||||
)
|
||||
|
||||
elif config.backend_type == "s3":
|
||||
if config.bucket_name is None:
|
||||
raise ValueError("bucket_name is required for S3 storage backend")
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=config.bucket_name,
|
||||
region_name=config.region_name,
|
||||
access_key_id=config.access_key_id,
|
||||
secret_access_key=config.secret_access_key,
|
||||
endpoint_url=config.endpoint_url,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
|
||||
|
||||
|
||||
def get_default_storage_config() -> StorageConfig:
|
||||
"""Get storage configuration from environment variables.
|
||||
|
||||
Environment variables:
|
||||
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||
STORAGE_BASE_PATH: Base path for local storage.
|
||||
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||
AWS_S3_BUCKET: S3 bucket name.
|
||||
AWS_REGION: AWS region name.
|
||||
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||
|
||||
Returns:
|
||||
StorageConfig from environment.
|
||||
"""
|
||||
backend_type = os.environ.get("STORAGE_BACKEND", "local")
|
||||
|
||||
if backend_type == "local":
|
||||
base_path_str = os.environ.get("STORAGE_BASE_PATH")
|
||||
# Expand ~ to home directory
|
||||
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
|
||||
|
||||
return StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=base_path,
|
||||
)
|
||||
|
||||
elif backend_type == "azure_blob":
|
||||
return StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
|
||||
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
|
||||
)
|
||||
|
||||
elif backend_type == "s3":
|
||||
return StorageConfig(
|
||||
backend_type="s3",
|
||||
bucket_name=os.environ.get("AWS_S3_BUCKET"),
|
||||
region_name=os.environ.get("AWS_REGION"),
|
||||
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
|
||||
)
|
||||
|
||||
else:
|
||||
return StorageConfig(backend_type=backend_type)
|
||||
|
||||
|
||||
def create_storage_backend_from_env() -> StorageBackend:
|
||||
"""Create a storage backend from environment variables.
|
||||
|
||||
Environment variables:
|
||||
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||
STORAGE_BASE_PATH: Base path for local storage.
|
||||
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||
AWS_S3_BUCKET: S3 bucket name.
|
||||
AWS_REGION: AWS region name.
|
||||
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or empty.
|
||||
"""
|
||||
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
|
||||
|
||||
if backend_type == "local":
|
||||
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
|
||||
if not base_path:
|
||||
raise ValueError(
|
||||
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Expand ~ to home directory
|
||||
base_path_expanded = os.path.expanduser(base_path)
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=Path(base_path_expanded))
|
||||
|
||||
elif backend_type == "azure_blob":
|
||||
connection_string = os.environ.get(
|
||||
"AZURE_STORAGE_CONNECTION_STRING", ""
|
||||
).strip()
|
||||
if not connection_string:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
|
||||
"and cannot be empty"
|
||||
)
|
||||
|
||||
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
|
||||
if not container_name:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONTAINER environment variable is required "
|
||||
"and cannot be empty"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
elif backend_type == "s3":
|
||||
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
|
||||
if not bucket_name:
|
||||
raise ValueError(
|
||||
"AWS_S3_BUCKET environment variable is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=bucket_name,
|
||||
region_name=os.environ.get("AWS_REGION", "").strip() or None,
|
||||
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
|
||||
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
|
||||
or None,
|
||||
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {backend_type}")
|
||||
|
||||
|
||||
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
|
||||
"""Create a storage backend from a configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist.
|
||||
ValueError: If configuration is invalid.
|
||||
"""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
file_config = load_storage_config(config_path)
|
||||
|
||||
if file_config.backend_type == "local":
|
||||
if file_config.local is None:
|
||||
raise ValueError("local configuration section is required")
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=file_config.local.base_path)
|
||||
|
||||
elif file_config.backend_type == "azure_blob":
|
||||
if file_config.azure is None:
|
||||
raise ValueError("azure configuration section is required")
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=file_config.azure.connection_string,
|
||||
container_name=file_config.azure.container_name,
|
||||
create_container=file_config.azure.create_container,
|
||||
)
|
||||
|
||||
elif file_config.backend_type == "s3":
|
||||
if file_config.s3 is None:
|
||||
raise ValueError("s3 configuration section is required")
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=file_config.s3.bucket_name,
|
||||
region_name=file_config.s3.region_name,
|
||||
access_key_id=file_config.s3.access_key_id,
|
||||
secret_access_key=file_config.s3.secret_access_key,
|
||||
endpoint_url=file_config.s3.endpoint_url,
|
||||
create_bucket=file_config.s3.create_bucket,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
|
||||
|
||||
|
||||
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
|
||||
"""Get storage backend with fallback chain.
|
||||
|
||||
Priority:
|
||||
1. Config file (if provided)
|
||||
2. Environment variables
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid.
|
||||
FileNotFoundError: If specified config file doesn't exist.
|
||||
"""
|
||||
if config_path:
|
||||
return create_storage_backend_from_file(config_path)
|
||||
|
||||
# Fall back to environment variables
|
||||
return create_storage_backend_from_env()
|
||||
262
packages/shared/shared/storage/local.py
Normal file
262
packages/shared/shared/storage/local.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Local filesystem storage backend.
|
||||
|
||||
Provides storage operations using the local filesystem.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class LocalStorageBackend(StorageBackend):
|
||||
"""Storage backend using local filesystem.
|
||||
|
||||
Files are stored relative to a base path on the local filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str | Path) -> None:
|
||||
"""Initialize local storage backend.
|
||||
|
||||
Args:
|
||||
base_path: Base directory for all storage operations.
|
||||
Will be created if it doesn't exist.
|
||||
"""
|
||||
self._base_path = Path(base_path)
|
||||
self._base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def base_path(self) -> Path:
|
||||
"""Get the base path for this storage backend."""
|
||||
return self._base_path
|
||||
|
||||
def _get_full_path(self, remote_path: str) -> Path:
|
||||
"""Convert a remote path to a full local path with security validation.
|
||||
|
||||
Args:
|
||||
remote_path: The remote path to resolve.
|
||||
|
||||
Returns:
|
||||
The full local path.
|
||||
|
||||
Raises:
|
||||
StorageError: If the path attempts to escape the base directory.
|
||||
"""
|
||||
# Reject absolute paths
|
||||
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
|
||||
raise StorageError(f"Absolute paths not allowed: {remote_path}")
|
||||
|
||||
# Resolve to prevent path traversal attacks
|
||||
full_path = (self._base_path / remote_path).resolve()
|
||||
base_resolved = self._base_path.resolve()
|
||||
|
||||
# Verify the resolved path is within base_path
|
||||
try:
|
||||
full_path.relative_to(base_resolved)
|
||||
except ValueError:
|
||||
raise StorageError(f"Path traversal not allowed: {remote_path}")
|
||||
|
||||
return full_path
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to local storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If file exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
dest_path = self._get_full_path(remote_path)
|
||||
|
||||
if dest_path.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(local_path, dest_path)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from local storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
source_path = self._get_full_path(remote_path)
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_path, local_path)
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to check in storage.
|
||||
|
||||
Returns:
|
||||
True if the file exists, False otherwise.
|
||||
"""
|
||||
return self._get_full_path(remote_path).exists()
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Path prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Sorted list of file paths matching the prefix.
|
||||
"""
|
||||
if prefix:
|
||||
search_path = self._get_full_path(prefix)
|
||||
if not search_path.exists():
|
||||
return []
|
||||
base_for_relative = self._base_path
|
||||
else:
|
||||
search_path = self._base_path
|
||||
base_for_relative = self._base_path
|
||||
|
||||
files: list[str] = []
|
||||
if search_path.is_file():
|
||||
files.append(str(search_path.relative_to(self._base_path)))
|
||||
elif search_path.is_dir():
|
||||
for file_path in search_path.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(self._base_path)
|
||||
files.append(str(relative_path).replace("\\", "/"))
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if it didn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
return False
|
||||
|
||||
file_path.unlink()
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a file:// URL to access a file.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
file:// URL to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.as_uri()
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to storage.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
dest_path = self._get_full_path(remote_path)
|
||||
|
||||
if dest_path.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest_path.write_bytes(data)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a file as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
The file contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.read_bytes()
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a file:// URL for local file access.
|
||||
|
||||
For local storage, this returns a file:// URI.
|
||||
Note: Local file:// URLs don't actually expire.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
expires_in_seconds: Ignored for local storage (URLs don't expire).
|
||||
|
||||
Returns:
|
||||
file:// URL to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.as_uri()
|
||||
158
packages/shared/shared/storage/prefixes.py
Normal file
158
packages/shared/shared/storage/prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Storage path prefixes for unified file organization.
|
||||
|
||||
Provides standardized path prefixes for organizing files within
|
||||
the storage backend, ensuring consistent structure across
|
||||
local, Azure Blob, and S3 storage.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StoragePrefixes:
|
||||
"""Standardized storage path prefixes.
|
||||
|
||||
All paths are relative to the storage backend root.
|
||||
These prefixes ensure consistent file organization across
|
||||
all storage backends (local, Azure, S3).
|
||||
|
||||
Usage:
|
||||
from shared.storage.prefixes import PREFIXES
|
||||
|
||||
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
|
||||
storage.upload_bytes(content, path)
|
||||
"""
|
||||
|
||||
# Document storage
|
||||
DOCUMENTS: str = "documents"
|
||||
"""Original document files (PDFs, etc.)"""
|
||||
|
||||
IMAGES: str = "images"
|
||||
"""Page images extracted from documents"""
|
||||
|
||||
# Processing directories
|
||||
UPLOADS: str = "uploads"
|
||||
"""Temporary upload staging area"""
|
||||
|
||||
RESULTS: str = "results"
|
||||
"""Inference results and visualizations"""
|
||||
|
||||
EXPORTS: str = "exports"
|
||||
"""Exported datasets and annotations"""
|
||||
|
||||
# Training data
|
||||
DATASETS: str = "datasets"
|
||||
"""Training dataset files"""
|
||||
|
||||
MODELS: str = "models"
|
||||
"""Trained model weights and checkpoints"""
|
||||
|
||||
# Data pipeline directories (legacy compatibility)
|
||||
RAW_PDFS: str = "raw_pdfs"
|
||||
"""Raw PDF files for auto-labeling pipeline"""
|
||||
|
||||
STRUCTURED_DATA: str = "structured_data"
|
||||
"""CSV/structured data for matching"""
|
||||
|
||||
ADMIN_IMAGES: str = "admin_images"
|
||||
"""Admin UI page images"""
|
||||
|
||||
@staticmethod
|
||||
def document_path(document_id: str, extension: str = ".pdf") -> str:
|
||||
"""Get path for a document file.
|
||||
|
||||
Args:
|
||||
document_id: Unique document identifier.
|
||||
extension: File extension (include leading dot).
|
||||
|
||||
Returns:
|
||||
Storage path like "documents/abc123.pdf"
|
||||
"""
|
||||
ext = extension if extension.startswith(".") else f".{extension}"
|
||||
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
|
||||
"""Get path for a page image file.
|
||||
|
||||
Args:
|
||||
document_id: Unique document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
extension: File extension (include leading dot).
|
||||
|
||||
Returns:
|
||||
Storage path like "images/abc123/page_1.png"
|
||||
"""
|
||||
ext = extension if extension.startswith(".") else f".{extension}"
|
||||
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def upload_path(filename: str, subfolder: str | None = None) -> str:
|
||||
"""Get path for a temporary upload file.
|
||||
|
||||
Args:
|
||||
filename: Original filename.
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
|
||||
"""
|
||||
if subfolder:
|
||||
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
|
||||
return f"{PREFIXES.UPLOADS}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def result_path(filename: str) -> str:
|
||||
"""Get path for a result file.
|
||||
|
||||
Args:
|
||||
filename: Result filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "results/filename.json"
|
||||
"""
|
||||
return f"{PREFIXES.RESULTS}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def export_path(export_id: str, filename: str) -> str:
|
||||
"""Get path for an export file.
|
||||
|
||||
Args:
|
||||
export_id: Unique export identifier.
|
||||
filename: Export filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "exports/abc123/filename.zip"
|
||||
"""
|
||||
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def dataset_path(dataset_id: str, filename: str) -> str:
|
||||
"""Get path for a dataset file.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique dataset identifier.
|
||||
filename: Dataset filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "datasets/abc123/filename.yaml"
|
||||
"""
|
||||
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def model_path(version: str, filename: str) -> str:
|
||||
"""Get path for a model file.
|
||||
|
||||
Args:
|
||||
version: Model version string.
|
||||
filename: Model filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "models/v1.0.0/best.pt"
|
||||
"""
|
||||
return f"{PREFIXES.MODELS}/{version}/{filename}"
|
||||
|
||||
|
||||
# Default instance for convenient access
|
||||
PREFIXES = StoragePrefixes()
|
||||
309
packages/shared/shared/storage/s3.py
Normal file
309
packages/shared/shared/storage/s3.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
AWS S3 Storage backend.
|
||||
|
||||
Provides storage operations using AWS S3.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class S3StorageBackend(StorageBackend):
|
||||
"""Storage backend using AWS S3.
|
||||
|
||||
Files are stored as objects in an S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: str,
|
||||
region_name: str | None = None,
|
||||
access_key_id: str | None = None,
|
||||
secret_access_key: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
create_bucket: bool = False,
|
||||
) -> None:
|
||||
"""Initialize S3 storage backend.
|
||||
|
||||
Args:
|
||||
bucket_name: Name of the S3 bucket.
|
||||
region_name: AWS region name (optional, uses default if not set).
|
||||
access_key_id: AWS access key ID (optional, uses credentials chain).
|
||||
secret_access_key: AWS secret access key (optional).
|
||||
endpoint_url: Custom endpoint URL (for S3-compatible services).
|
||||
create_bucket: If True, create the bucket if it doesn't exist.
|
||||
"""
|
||||
import boto3
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._region_name = region_name
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs: dict[str, Any] = {}
|
||||
if region_name:
|
||||
client_kwargs["region_name"] = region_name
|
||||
if endpoint_url:
|
||||
client_kwargs["endpoint_url"] = endpoint_url
|
||||
if access_key_id and secret_access_key:
|
||||
client_kwargs["aws_access_key_id"] = access_key_id
|
||||
client_kwargs["aws_secret_access_key"] = secret_access_key
|
||||
|
||||
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
|
||||
|
||||
if create_bucket:
|
||||
self._ensure_bucket_exists()
|
||||
|
||||
def _ensure_bucket_exists(self) -> None:
|
||||
"""Create the bucket if it doesn't exist."""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self._s3.head_bucket(Bucket=self._bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchBucket"):
|
||||
# Bucket doesn't exist, create it
|
||||
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||
if self._region_name and self._region_name != "us-east-1":
|
||||
create_kwargs["CreateBucketConfiguration"] = {
|
||||
"LocationConstraint": self._region_name
|
||||
}
|
||||
self._s3.create_bucket(**create_kwargs)
|
||||
else:
|
||||
# Re-raise permission errors, network issues, etc.
|
||||
raise
|
||||
|
||||
def _object_exists(self, key: str) -> bool:
|
||||
"""Check if an object exists in S3.
|
||||
|
||||
Args:
|
||||
key: Object key to check.
|
||||
|
||||
Returns:
|
||||
True if object exists, False otherwise.
|
||||
"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self._s3.head_object(Bucket=self._bucket_name, Key=key)
|
||||
return True
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchKey"):
|
||||
return False
|
||||
raise
|
||||
|
||||
@property
|
||||
def bucket_name(self) -> str:
|
||||
"""Get the bucket name for this storage backend."""
|
||||
return self._bucket_name
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to S3.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination object key.
|
||||
overwrite: If True, overwrite existing object.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If object exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
if not overwrite and self._object_exists(remote_path):
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download an object from S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if an object exists in S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key to check.
|
||||
|
||||
Returns:
|
||||
True if the object exists, False otherwise.
|
||||
"""
|
||||
return self._object_exists(remote_path)
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List objects in S3 with given prefix.
|
||||
|
||||
Handles pagination to return all matching objects (S3 returns max 1000 per request).
|
||||
|
||||
Args:
|
||||
prefix: Object key prefix to filter.
|
||||
|
||||
Returns:
|
||||
List of object keys matching the prefix.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||
if prefix:
|
||||
kwargs["Prefix"] = prefix
|
||||
|
||||
all_keys: list[str] = []
|
||||
while True:
|
||||
response = self._s3.list_objects_v2(**kwargs)
|
||||
contents = response.get("Contents", [])
|
||||
all_keys.extend(obj["Key"] for obj in contents)
|
||||
|
||||
if not response.get("IsTruncated"):
|
||||
break
|
||||
kwargs["ContinuationToken"] = response["NextContinuationToken"]
|
||||
|
||||
return all_keys
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete an object from S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key to delete.
|
||||
|
||||
Returns:
|
||||
True if object was deleted, False if it didn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
return False
|
||||
|
||||
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a URL for an object.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
|
||||
Returns:
|
||||
URL to access the object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return self._s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a pre-signed URL for temporary object access.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
ValueError: If expires_in_seconds is out of valid range.
|
||||
"""
|
||||
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||
raise ValueError(
|
||||
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||
)
|
||||
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return self._s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||
ExpiresIn=expires_in_seconds,
|
||||
)
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to S3.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination object key.
|
||||
overwrite: If True, overwrite existing object.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
|
||||
Raises:
|
||||
StorageError: If object exists and overwrite is False.
|
||||
"""
|
||||
if not overwrite and self._object_exists(remote_path):
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download an object as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
|
||||
Returns:
|
||||
The object contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchKey"):
|
||||
raise FileNotFoundStorageError(remote_path) from e
|
||||
raise
|
||||
Reference in New Issue
Block a user