Files
invoice-master-poc-v2/packages/shared/shared/storage/s3.py
Yaojia Wang a516de4320 WIP
2026-02-01 00:08:40 +01:00

310 lines
9.5 KiB
Python

"""
AWS S3 Storage backend.
Provides storage operations using AWS S3.
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class S3StorageBackend(StorageBackend):
"""Storage backend using AWS S3.
Files are stored as objects in an S3 bucket.
"""
def __init__(
self,
bucket_name: str,
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
endpoint_url: str | None = None,
create_bucket: bool = False,
) -> None:
"""Initialize S3 storage backend.
Args:
bucket_name: Name of the S3 bucket.
region_name: AWS region name (optional, uses default if not set).
access_key_id: AWS access key ID (optional, uses credentials chain).
secret_access_key: AWS secret access key (optional).
endpoint_url: Custom endpoint URL (for S3-compatible services).
create_bucket: If True, create the bucket if it doesn't exist.
"""
import boto3
self._bucket_name = bucket_name
self._region_name = region_name
# Build client kwargs
client_kwargs: dict[str, Any] = {}
if region_name:
client_kwargs["region_name"] = region_name
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
if access_key_id and secret_access_key:
client_kwargs["aws_access_key_id"] = access_key_id
client_kwargs["aws_secret_access_key"] = secret_access_key
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
if create_bucket:
self._ensure_bucket_exists()
def _ensure_bucket_exists(self) -> None:
"""Create the bucket if it doesn't exist."""
from botocore.exceptions import ClientError
try:
self._s3.head_bucket(Bucket=self._bucket_name)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchBucket"):
# Bucket doesn't exist, create it
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if self._region_name and self._region_name != "us-east-1":
create_kwargs["CreateBucketConfiguration"] = {
"LocationConstraint": self._region_name
}
self._s3.create_bucket(**create_kwargs)
else:
# Re-raise permission errors, network issues, etc.
raise
def _object_exists(self, key: str) -> bool:
"""Check if an object exists in S3.
Args:
key: Object key to check.
Returns:
True if object exists, False otherwise.
"""
from botocore.exceptions import ClientError
try:
self._s3.head_object(Bucket=self._bucket_name, Key=key)
return True
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
return False
raise
@property
def bucket_name(self) -> str:
"""Get the bucket name for this storage backend."""
return self._bucket_name
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to S3.
Args:
local_path: Path to the local file to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If object exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download an object from S3.
Args:
remote_path: Object key in S3.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if an object exists in S3.
Args:
remote_path: Object key to check.
Returns:
True if the object exists, False otherwise.
"""
return self._object_exists(remote_path)
def list_files(self, prefix: str) -> list[str]:
"""List objects in S3 with given prefix.
Handles pagination to return all matching objects (S3 returns max 1000 per request).
Args:
prefix: Object key prefix to filter.
Returns:
List of object keys matching the prefix.
"""
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if prefix:
kwargs["Prefix"] = prefix
all_keys: list[str] = []
while True:
response = self._s3.list_objects_v2(**kwargs)
contents = response.get("Contents", [])
all_keys.extend(obj["Key"] for obj in contents)
if not response.get("IsTruncated"):
break
kwargs["ContinuationToken"] = response["NextContinuationToken"]
return all_keys
def delete(self, remote_path: str) -> bool:
"""Delete an object from S3.
Args:
remote_path: Object key to delete.
Returns:
True if object was deleted, False if it didn't exist.
"""
if not self._object_exists(remote_path):
return False
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
return True
def get_url(self, remote_path: str) -> str:
"""Get a URL for an object.
Args:
remote_path: Object key in S3.
Returns:
URL to access the object.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=3600,
)
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary object access.
Args:
remote_path: Object key in S3.
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=expires_in_seconds,
)
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to S3.
Args:
data: Bytes to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the data was stored.
Raises:
StorageError: If object exists and overwrite is False.
"""
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download an object as bytes.
Args:
remote_path: Object key in S3.
Returns:
The object contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
from botocore.exceptions import ClientError
try:
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
raise FileNotFoundStorageError(remote_path) from e
raise