This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -120,7 +120,7 @@ def main() -> None:
logger.info("=" * 60)
# Create config
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
config = AppConfig(
model=ModelConfig(
@@ -136,7 +136,7 @@ def main() -> None:
reload=args.reload,
workers=args.workers,
),
storage=StorageConfig(),
file=FileConfig(),
)
# Create and run app

View File

@@ -112,6 +112,7 @@ class AdminDB:
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
category: str = "invoice",
admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str:
"""Create a new document record."""
@@ -125,6 +126,7 @@ class AdminDB:
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
category=category,
)
session.add(document)
session.flush()
@@ -154,6 +156,7 @@ class AdminDB:
has_annotations: bool | None = None,
auto_label_status: str | None = None,
batch_id: str | None = None,
category: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AdminDocument], int]:
@@ -171,6 +174,8 @@ class AdminDB:
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
if batch_id:
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
if category:
where_clauses.append(AdminDocument.category == category)
# Count query
count_stmt = select(func.count()).select_from(AdminDocument)
@@ -283,6 +288,32 @@ class AdminDB:
return True
return False
def get_document_categories(self) -> list[str]:
"""Get list of unique document categories."""
with get_session_context() as session:
statement = (
select(AdminDocument.category)
.distinct()
.order_by(AdminDocument.category)
)
categories = session.exec(statement).all()
return [c for c in categories if c is not None]
def update_document_category(
self, document_id: str, category: str
) -> AdminDocument | None:
"""Update document category."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.category = category
document.updated_at = datetime.utcnow()
session.add(document)
session.commit()
session.refresh(document)
return document
return None
# ==========================================================================
# Annotation Operations
# ==========================================================================
@@ -1292,6 +1323,36 @@ class AdminDB:
session.add(dataset)
session.commit()
def update_dataset_training_status(
self,
dataset_id: str | UUID,
training_status: str | None,
active_training_task_id: str | UUID | None = None,
update_main_status: bool = False,
) -> None:
"""Update dataset training status and optionally the main status.
Args:
dataset_id: Dataset UUID
training_status: Training status (pending, running, completed, failed, cancelled)
active_training_task_id: Currently active training task ID
update_main_status: If True and training_status is 'completed', set main status to 'trained'
"""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.training_status = training_status
dataset.active_training_task_id = (
UUID(str(active_training_task_id)) if active_training_task_id else None
)
dataset.updated_at = datetime.utcnow()
# Update main status to 'trained' when training completes
if update_main_status and training_status == "completed":
dataset.status = "trained"
session.add(dataset)
session.commit()
def add_dataset_documents(
self,
dataset_id: str | UUID,

View File

@@ -11,23 +11,8 @@ from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
# =============================================================================
# CSV to Field Class Mapping
# =============================================================================
CSV_TO_CLASS_MAPPING: dict[str, int] = {
"InvoiceNumber": 0, # invoice_number
"InvoiceDate": 1, # invoice_date
"InvoiceDueDate": 2, # invoice_due_date
"OCR": 3, # ocr_number
"Bankgiro": 4, # bankgiro
"Plusgiro": 5, # plusgiro
"Amount": 6, # amount
"supplier_organisation_number": 7, # supplier_organisation_number
# 8: payment_line (derived from OCR/Bankgiro/Amount)
"customer_number": 9, # customer_number
}
# Import field mappings from single source of truth
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
# =============================================================================
@@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True):
# Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization
category: str = Field(default="invoice", max_length=100, index=True)
# Document category for training different models (e.g., invoice, letter, receipt)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
@@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True):
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, training, archived, failed
# Status: building, ready, trained, archived, failed
training_status: str | None = Field(default=None, max_length=20, index=True)
# Training status: pending, scheduled, running, completed, failed, cancelled
active_training_task_id: UUID | None = Field(default=None, index=True)
train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1)
seed: int = Field(default=42)
@@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Field class mapping (same as src/cli/train.py)
FIELD_CLASSES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_organisation_number",
8: "payment_line",
9: "customer_number",
}
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
# This ensures consistency with the trained YOLO model
# Read-only models for API responses
@@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel):
status: str
auto_label_status: str | None
auto_label_error: str | None
category: str = "invoice"
created_at: datetime
updated_at: datetime

View File

@@ -141,6 +141,40 @@ def run_migrations() -> None:
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""",
),
# Migration 009: Add category to admin_documents
(
"admin_documents_category",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
""",
),
# Migration 010: Add training_status and active_training_task_id to training_datasets
(
"training_datasets_training_status",
"""
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
""",
),
# Migration 010b: Update existing datasets with completed training to 'trained' status
(
"training_datasets_update_trained_status",
"""
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
);
""",
),
]
with engine.connect() as conn:

View File

@@ -21,7 +21,8 @@ import re
import numpy as np
from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import Detection
# Import shared utilities for text cleaning and validation
from shared.utils.text_cleaner import TextCleaner

View File

@@ -10,7 +10,8 @@ from typing import Any
import time
import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser

View File

@@ -9,6 +9,9 @@ from pathlib import Path
from typing import Any
import numpy as np
# Import field mappings from single source of truth
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
@dataclass
class Detection:
@@ -72,33 +75,8 @@ class Detection:
return (x0, y0, x1, y1)
# Class names (must match training configuration)
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number', # Matches training class name
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
# Mapping from class name to field name
CLASS_TO_FIELD = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'ocr_number': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'amount': 'Amount',
'supplier_org_number': 'supplier_org_number',
'customer_number': 'customer_number',
'payment_line': 'payment_line',
}
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
# This ensures consistency with the trained YOLO model
class YOLODetector:

View File

@@ -4,18 +4,19 @@ Admin Annotation API Routes
FastAPI endpoints for annotation management.
"""
import io
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, StreamingResponse
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.autolabel import get_auto_label_service
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import (
AnnotationCreate,
AnnotationItem,
@@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Image storage directory
ADMIN_IMAGES_DIR = Path("data/admin_images")
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
@@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter:
@router.get(
"/{document_id}/images/{page_number}",
response_model=None,
responses={
200: {"content": {"image/png": {}}, "description": "Page image"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
@@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter:
page_number: int,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> FileResponse:
) -> FileResponse | StreamingResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
@@ -91,18 +91,33 @@ def create_annotation_router() -> APIRouter:
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
)
# Find image file
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
if not image_path.exists():
# Get storage helper
storage = get_storage_helper()
# Check if image exists
if not storage.admin_image_exists(document_id, page_number):
raise HTTPException(
status_code=404,
detail=f"Image for page {page_number} not found",
)
return FileResponse(
path=str(image_path),
# Try to get local path for efficient file serving
local_path = storage.get_admin_image_local_path(document_id, page_number)
if local_path is not None:
return FileResponse(
path=str(local_path),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
)
# Fall back to streaming for cloud storage
image_content = storage.get_admin_image(document_id, page_number)
return StreamingResponse(
io.BytesIO(image_content),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
headers={
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
},
)
# =========================================================================
@@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter:
)
# Get image dimensions for normalization
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
if not image_path.exists():
storage = get_storage_helper()
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
if dimensions is None:
raise HTTPException(
status_code=400,
detail=f"Image for page {request.page_number} not available",
)
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
image_width, image_height = dimensions
# Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
@@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter:
if request.bbox is not None:
# Get image dimensions
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
storage = get_storage_helper()
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
if dimensions is None:
raise HTTPException(
status_code=400,
detail=f"Image for page {annotation.page_number} not available",
)
image_width, image_height = dimensions
# Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width

View File

@@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.config import DEFAULT_DPI, StorageConfig
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import (
AnnotationItem,
AnnotationSource,
AutoLabelStatus,
BoundingBox,
DocumentCategoriesResponse,
DocumentDetailResponse,
DocumentItem,
DocumentListResponse,
DocumentStatus,
DocumentStatsResponse,
DocumentUpdateRequest,
DocumentUploadResponse,
ModelMetrics,
TrainingHistoryItem,
@@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None:
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
document_id: str, content: bytes, page_count: int, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
"""Convert PDF pages to images for annotation using StorageHelper."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
storage = get_storage_helper()
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
@@ -60,8 +61,9 @@ def _convert_pdf_to_images(
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
# Save to storage using StorageHelper
image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close()
@@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
str | None,
Query(description="Optional group key for document organization", max_length=255),
] = None,
category: Annotated[
str,
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
] = "invoice",
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate group_key length
@@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_path="", # Will update after saving
page_count=page_count,
group_key=group_key,
category=category,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
# Save file to storage using StorageHelper
storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
storage_path = storage.save_raw_pdf(content, filename)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database
# Update file path in database (using storage path for reference)
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if doc:
doc.file_path = str(file_path)
# Store the storage path (relative path within storage)
doc.file_path = storage_path
session.add(doc)
# Convert PDF to images for annotation
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
document_id, content, page_count, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
@@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
category=category,
group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
@@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
str | None,
Query(description="Filter by batch ID"),
] = None,
category: Annotated[
str | None,
Query(description="Filter by document category"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
@@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
has_annotations=has_annotations,
auto_label_status=auto_label_status,
batch_id=batch_id,
category=category,
limit=limit,
offset=offset,
)
@@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
category=doc.category if hasattr(doc, 'category') else "invoice",
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
@@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None,
category=document.category if hasattr(document, 'category') else "invoice",
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
@@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
detail="Document not found or does not belong to this token",
)
# Delete file
file_path = Path(document.file_path)
if file_path.exists():
file_path.unlink()
# Delete file using StorageHelper
storage = get_storage_helper()
# Delete images
images_dir = ADMIN_IMAGES_DIR / document_id
if images_dir.exists():
import shutil
shutil.rmtree(images_dir)
# Delete the raw PDF
filename = Path(document.file_path).name
if filename:
try:
storage._storage.delete(document.file_path)
except Exception as e:
logger.warning(f"Failed to delete PDF file: {e}")
# Delete admin images
try:
storage.delete_admin_images(document_id)
except Exception as e:
logger.warning(f"Failed to delete admin images: {e}")
# Delete from database
db.delete_document(document_id)
@@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"message": "Document group key updated",
}
@router.get(
"/categories",
response_model=DocumentCategoriesResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get available categories",
description="Get list of all available document categories.",
)
async def get_categories(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentCategoriesResponse:
"""Get all available document categories."""
categories = db.get_document_categories()
return DocumentCategoriesResponse(
categories=categories,
total=len(categories),
)
@router.patch(
"/{document_id}/category",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document category",
description="Update the category for a document.",
)
async def update_document_category(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: DocumentUpdateRequest,
) -> dict:
"""Update document category."""
_validate_uuid(document_id, "document_id")
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update category if provided
if request.category is not None:
db.update_document_category(document_id, request.category)
return {
"status": "updated",
"document_id": document_id,
"category": request.category,
"message": "Document category updated",
}
return router

View File

@@ -17,6 +17,7 @@ from inference.web.schemas.admin import (
TrainingStatus,
TrainingTaskResponse,
)
from inference.web.services.storage_helpers import get_storage_helper
from ._utils import _validate_uuid
@@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None:
db: AdminDBDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split
@@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None:
seed=request.seed,
)
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
# Get storage paths from StorageHelper
storage = get_storage_helper()
datasets_dir = storage.get_datasets_base_path()
admin_images_dir = storage.get_admin_images_base_path()
if datasets_dir is None or admin_images_dir is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
@@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None:
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
admin_images_dir=Path("data/admin_images"),
admin_images_dir=admin_images_dir,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None:
name=dataset.name,
description=dataset.description,
status=dataset.status,
training_status=dataset.training_status,
active_training_task_id=(
str(dataset.active_training_task_id)
if dataset.active_training_task_id
else None
),
train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio,
seed=dataset.seed,

View File

@@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None:
db: AdminDBDep,
) -> ExportResponse:
"""Export annotations for training."""
from pathlib import Path
import shutil
from inference.web.services.storage_helpers import get_storage_helper
# Get storage helper for reading images and exports directory
storage = get_storage_helper()
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
@@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None:
detail="No labeled documents available for export",
)
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
# Get exports directory from StorageHelper
exports_base = storage.get_exports_base_path()
if exports_base is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
@@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None:
if not page_annotations and not request.include_images:
continue
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
# Get image from storage
doc_id = str(doc.document_id)
if not storage.admin_image_exists(doc_id, page_num):
continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt"
@@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None:
f.write(line)
total_annotations += 1
from inference.data.admin_models import FIELD_CLASSES
from shared.fields import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}

View File

@@ -22,6 +22,7 @@ from inference.web.schemas.inference import (
InferenceResult,
)
from inference.web.schemas.common import ErrorResponse
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from inference.web.services import InferenceService
@@ -90,8 +91,17 @@ def create_inference_router(
# Generate document ID
doc_id = str(uuid.uuid4())[:8]
# Save uploaded file
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
# Get storage helper and uploads directory
storage = get_storage_helper()
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
if uploads_dir is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Storage not configured for local access",
)
# Save uploaded file to temporary location for processing
upload_path = uploads_dir / f"{doc_id}{file_ext}"
try:
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
@@ -149,12 +159,13 @@ def create_inference_router(
# Cleanup uploaded file
upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}")
@router.get("/results/{filename}", response_model=None)
async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image."""
file_path = storage_config.result_dir / filename
storage = get_storage_helper()
file_path = storage.get_result_local_path(filename)
if not file_path.exists():
if file_path is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
@@ -169,15 +180,15 @@ def create_inference_router(
@router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict:
"""Delete a result file."""
file_path = storage_config.result_dir / filename
storage = get_storage_helper()
if not file_path.exists():
if not storage.result_exists(filename):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
file_path.unlink()
storage.delete_result(filename)
return {"status": "deleted", "filename": filename}
return router

View File

@@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
from inference.data.admin_db import AdminDB
from inference.web.schemas.labeling import PreLabelResponse
from inference.web.schemas.common import ErrorResponse
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from inference.web.services import InferenceService
@@ -23,19 +24,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Storage directory for pre-label uploads (legacy, now uses storage_config)
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
document_id: str, content: bytes, page_count: int, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
"""Convert PDF pages to images for annotation using StorageHelper."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
storage = get_storage_helper()
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
@@ -43,8 +39,9 @@ def _convert_pdf_to_images(
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
# Save to storage using StorageHelper
image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close()
@@ -70,9 +67,6 @@ def create_labeling_router(
"""
router = APIRouter(prefix="/api/v1", tags=["labeling"])
# Ensure upload directory exists
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@router.post(
"/pre-label",
response_model=PreLabelResponse,
@@ -165,10 +159,11 @@ def create_labeling_router(
csv_field_values=expected_values,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
# Save file to storage using StorageHelper
storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
storage_path = storage.save_raw_pdf(content, filename)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(
@@ -176,15 +171,14 @@ def create_labeling_router(
detail="Failed to save file",
)
# Update file path in database
db.update_document_file_path(document_id, str(file_path))
# Update file path in database (using storage path)
db.update_document_file_path(document_id, storage_path)
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
document_id, content, page_count, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")

View File

@@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from inference.web.services import InferenceService
from inference.web.services.storage_helpers import get_storage_helper
# Public API imports
from inference.web.api.v1.public import (
@@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
allow_headers=["*"],
)
# Mount static files for results
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
app.mount(
"/static/results",
StaticFiles(directory=str(config.storage.result_dir)),
name="results",
)
# Mount static files for results using StorageHelper
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir:
app.mount(
"/static/results",
StaticFiles(directory=str(results_dir)),
name="results",
)
else:
logger.warning("Could not mount static results directory: local storage not available")
# Include public API routes
inference_router = create_inference_router(inference_service, config.storage)

View File

@@ -4,16 +4,49 @@ Web Application Configuration
Centralized configuration for the web application.
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from shared.config import DEFAULT_DPI, PATHS
from shared.config import DEFAULT_DPI
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_storage_backend(
config_path: Path | str | None = None,
) -> "StorageBackend":
"""Get storage backend for file operations.
Args:
config_path: Optional path to storage configuration file.
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
Returns:
Configured StorageBackend instance.
"""
from shared.storage import get_storage_backend as _get_storage_backend
# Check for config file path
if config_path is None:
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
if config_path_str:
config_path = Path(config_path_str)
return _get_storage_backend(config_path=config_path)
@dataclass(frozen=True)
class ModelConfig:
"""YOLO model configuration."""
"""YOLO model configuration.
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
- Models need to be accessible by inference service on any platform
- Models may be version-controlled or deployed separately
- Models are part of the application, not user data
"""
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5
@@ -33,24 +66,39 @@ class ServerConfig:
@dataclass(frozen=True)
class StorageConfig:
"""File storage configuration.
class FileConfig:
"""File handling configuration.
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
and avoids storing duplicate files.
This config holds file handling settings. For file operations,
use the storage backend with PREFIXES from shared.storage.prefixes.
Example:
from shared.storage import PREFIXES, get_storage_backend
storage = get_storage_backend()
path = PREFIXES.document_path(document_id)
storage.upload_bytes(content, path)
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
They are kept for backward compatibility with existing code and tests.
New code should use the storage backend with PREFIXES instead.
"""
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
admin_images_dir: Path = Path("data/admin_images")
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI
presigned_url_expiry_seconds: int = 3600
# Deprecated path fields - kept for backward compatibility
# New code should use storage backend with PREFIXES instead
# All paths are now under data/ to match WSL storage layout
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
result_dir: Path = field(default_factory=lambda: Path("data/results"))
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
"""Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
@@ -61,9 +109,17 @@ class StorageConfig:
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
# Backward compatibility alias
StorageConfig = FileConfig
@dataclass(frozen=True)
class AsyncConfig:
"""Async processing configuration."""
"""Async processing configuration.
Note: For file paths, use the storage backend with PREFIXES.
Example: PREFIXES.upload_path(filename, "async")
"""
# Queue settings
queue_max_size: int = 100
@@ -77,14 +133,17 @@ class AsyncConfig:
# Storage
result_retention_days: int = 7
temp_upload_dir: Path = Path("uploads/async")
max_file_size_mb: int = 50
# Deprecated: kept for backward compatibility
# Path under data/ to match WSL storage layout
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
# Cleanup
cleanup_interval_hours: int = 1
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
"""Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@@ -95,19 +154,41 @@ class AppConfig:
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
file: FileConfig = field(default_factory=FileConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
storage_backend: "StorageBackend | None" = None
@property
def storage(self) -> FileConfig:
"""Backward compatibility alias for file config."""
return self.file
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary."""
file_config = config_dict.get("file", config_dict.get("storage", {}))
return cls(
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
file=FileConfig(**file_config),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
)
def create_app_config(
storage_config_path: Path | str | None = None,
) -> AppConfig:
"""Create application configuration with storage backend.
Args:
storage_config_path: Optional path to storage configuration file.
Returns:
Configured AppConfig instance with storage backend initialized.
"""
storage_backend = get_storage_backend(config_path=storage_config_path)
return AppConfig(storage_backend=storage_backend)
# Default configuration instance
default_config = AppConfig()

View File

@@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
)
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
@@ -36,7 +37,13 @@ class AutoLabelScheduler:
"""
self._check_interval = check_interval_seconds
self._batch_size = batch_size
# Get output directory from StorageHelper
if output_dir is None:
storage = get_storage_helper()
output_dir = storage.get_autolabel_output_path()
self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any
from inference.data.admin_db import AdminDB
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
@@ -107,6 +108,14 @@ class TrainingScheduler:
self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started")
# Update dataset training status to running
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="running",
active_training_task_id=task_id,
)
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
@@ -192,6 +201,15 @@ class TrainingScheduler:
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
# Update dataset training status to completed and main status to trained
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="completed",
active_training_task_id=None,
update_main_status=True, # Set main status to 'trained'
)
# Auto-create model version for the completed training
self._create_model_version_from_training(
task_id=task_id,
@@ -203,6 +221,13 @@ class TrainingScheduler:
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
# Update dataset training status to failed
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="failed",
active_training_task_id=None,
)
raise
def _create_model_version_from_training(
@@ -268,9 +293,10 @@ class TrainingScheduler:
f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}"
)
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
self._db.add_training_log(
task_id, "INFO",
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
f"Model version {version} created (mAP: {mAP_display})",
)
except Exception as e:
@@ -283,8 +309,11 @@ class TrainingScheduler:
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
import shutil
from inference.data.admin_models import FIELD_CLASSES
from shared.fields import FIELD_CLASSES
from inference.web.services.storage_helpers import get_storage_helper
# Get storage helper for reading images
storage = get_storage_helper()
# Get all labeled documents
documents = self._db.get_labeled_documents_for_export()
@@ -293,8 +322,12 @@ class TrainingScheduler:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory
export_dir = Path("data/training") / task_id
# Create export directory using StorageHelper
training_base = storage.get_training_data_path()
if training_base is None:
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
return None
export_dir = training_base / task_id
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
@@ -323,14 +356,16 @@ class TrainingScheduler:
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
# Copy image
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
# Get image from storage
doc_id = str(doc.document_id)
if not storage.admin_image_exists(doc_id, page_num):
continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1
# Write YOLO label
@@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())}
self._db.add_training_log(task_id, level, message)
# Create shared training config
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
# because models need to be accessible by inference service on any platform
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig(
model_path=model_name,

View File

@@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel):
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
category: str | None = Field(None, description="Filter documents by category (optional)")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split")
@@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel):
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
train_ratio: float
val_ratio: float
seed: int

View File

@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
@@ -44,6 +45,7 @@ class DocumentItem(BaseModel):
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel):
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
@@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel):
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents")
class DocumentUpdateRequest(BaseModel):
"""Request for updating document metadata."""
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key")
class DocumentCategoriesResponse(BaseModel):
"""Response for available document categories."""
categories: list[str] = Field(..., description="List of available categories")
total: int = Field(..., ge=0, description="Total number of categories")

View File

@@ -5,6 +5,7 @@ Manages async request lifecycle and background processing.
"""
import logging
import re
import shutil
import time
import uuid
@@ -17,6 +18,7 @@ from typing import TYPE_CHECKING
from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.core.rate_limiter import RateLimiter
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from inference.web.config import AsyncConfig, StorageConfig
@@ -189,9 +191,7 @@ class AsyncProcessingService:
filename: str,
content: bytes,
) -> Path:
"""Save uploaded file to temp storage."""
import re
"""Save uploaded file to temp storage using StorageHelper."""
# Extract extension from filename
ext = Path(filename).suffix.lower()
@@ -203,9 +203,11 @@ class AsyncProcessingService:
if ext not in self.ALLOWED_EXTENSIONS:
ext = ".pdf"
# Create async upload directory
upload_dir = self._async_config.temp_upload_dir
upload_dir.mkdir(parents=True, exist_ok=True)
# Get upload directory from StorageHelper
storage = get_storage_helper()
upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None:
raise ValueError("Storage not configured for local access")
# Build file path - request_id is a UUID so it's safe
file_path = upload_dir / f"{request_id}{ext}"
@@ -355,8 +357,9 @@ class AsyncProcessingService:
def _cleanup_orphan_files(self) -> int:
"""Clean up upload files that don't have matching requests."""
upload_dir = self._async_config.temp_upload_dir
if not upload_dir.exists():
storage = get_storage_helper()
upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None or not upload_dir.exists():
return 0
count = 0

View File

@@ -13,7 +13,7 @@ from PIL import Image
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken

View File

@@ -16,7 +16,7 @@ from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from inference.data.admin_db import AdminDB
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
from shared.fields import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)

View File

@@ -12,7 +12,7 @@ from pathlib import Path
import yaml
from inference.data.admin_models import FIELD_CLASSES
from shared.fields import FIELD_CLASSES
logger = logging.getLogger(__name__)

View File

@@ -13,9 +13,10 @@ from typing import Any
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
from shared.fields import CSV_TO_CLASS_MAPPING
from inference.data.admin_models import AdminDocument
from shared.data.db import DocumentDB
from inference.web.config import StorageConfig
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
@@ -122,8 +123,12 @@ def process_document_autolabel(
document_id = str(document.document_id)
file_path = Path(document.file_path)
# Get output directory from StorageHelper
storage = get_storage_helper()
if output_dir is None:
output_dir = Path("data/autolabel_output")
output_dir = storage.get_autolabel_output_path()
if output_dir is None:
output_dir = Path("data/autolabel_output")
output_dir.mkdir(parents=True, exist_ok=True)
# Mark as processing
@@ -152,10 +157,12 @@ def process_document_autolabel(
is_scanned = len(tokens) < 10 # Threshold for "no text"
# Build task data
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
# Use raw_pdfs base path for pdf_path
# This ensures consistency with CLI autolabel for reprocess_failed.py
storage_config = StorageConfig()
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
if raw_pdfs_dir is None:
raise ValueError("Storage not configured for local access")
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
task_data = {
"row_dict": row_dict,
@@ -246,8 +253,8 @@ def _save_annotations_to_db(
Returns:
Number of annotations saved
"""
from PIL import Image
from inference.data.admin_models import FIELD_CLASS_IDS
from shared.fields import FIELD_CLASS_IDS
from inference.web.services.storage_helpers import get_storage_helper
# Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
@@ -266,6 +273,9 @@ def _save_annotations_to_db(
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
scale = dpi / 72.0
# Get storage helper for image dimensions
storage = get_storage_helper()
# Cache for image dimensions per page
image_dimensions: dict[int, tuple[int, int]] = {}
@@ -274,18 +284,11 @@ def _save_annotations_to_db(
if page_no in image_dimensions:
return image_dimensions[page_no]
# Try to load from admin_images
admin_images_dir = Path("data/admin_images") / document_id
image_path = admin_images_dir / f"page_{page_no}.png"
if image_path.exists():
try:
with Image.open(image_path) as img:
dims = img.size # (width, height)
image_dimensions[page_no] = dims
return dims
except Exception as e:
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
# Get dimensions from storage helper
dims = storage.get_admin_image_dimensions(document_id, page_no)
if dims:
image_dimensions[page_no] = dims
return dims
return None
@@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db(
from datetime import datetime
document_id = str(document.document_id)
storage_config = StorageConfig()
# Build pdf_path using admin_upload_dir (same as auto-label)
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
# Build pdf_path using raw_pdfs base path (same as auto-label)
storage = get_storage_helper()
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
if raw_pdfs_dir is None:
return {
"success": False,
"document_id": document_id,
"error": "Storage not configured for local access",
}
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
# Build report dict compatible with DocumentDB.save_document()
field_results = []

View File

@@ -0,0 +1,217 @@
"""
Document Service for storage-backed file operations.
Provides a unified interface for document upload, download, and serving
using the storage abstraction layer.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from uuid import uuid4
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
@dataclass
class DocumentResult:
"""Result of document operation."""
id: str
file_path: str
filename: str | None = None
class DocumentService:
"""Service for document file operations using storage backend.
Provides upload, download, and URL generation for documents and images.
"""
# Storage path prefixes
DOCUMENTS_PREFIX = "documents"
IMAGES_PREFIX = "images"
def __init__(
self,
storage_backend: "StorageBackend",
admin_db: Any | None = None,
) -> None:
"""Initialize document service.
Args:
storage_backend: Storage backend for file operations.
admin_db: Optional AdminDB instance for database operations.
"""
self._storage = storage_backend
self._admin_db = admin_db
def upload_document(
self,
content: bytes,
filename: str,
dataset_id: str | None = None,
document_id: str | None = None,
) -> DocumentResult:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename.
dataset_id: Optional dataset ID for organization.
document_id: Optional document ID (generated if not provided).
Returns:
DocumentResult with ID and storage path.
"""
if document_id is None:
document_id = str(uuid4())
# Extract extension from filename
ext = ""
if "." in filename:
ext = "." + filename.rsplit(".", 1)[-1].lower()
# Build logical path
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
# Upload via storage backend
self._storage.upload_bytes(content, remote_path, overwrite=True)
return DocumentResult(
id=document_id,
file_path=remote_path,
filename=filename,
)
def download_document(self, remote_path: str) -> bytes:
"""Download a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
Document content as bytes.
"""
return self._storage.download_bytes(remote_path)
def get_document_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a document.
Args:
remote_path: Logical path to the document.
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for document access.
"""
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def document_exists(self, remote_path: str) -> bool:
"""Check if a document exists in storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document exists.
"""
return self._storage.exists(remote_path)
def delete_document_files(self, remote_path: str) -> bool:
"""Delete a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document was deleted.
"""
return self._storage.delete(remote_path)
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Logical path where image was stored.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
self._storage.upload_bytes(content, remote_path, overwrite=True)
return remote_path
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a page image.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for image access.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image from storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.download_bytes(remote_path)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document ID.
Returns:
Number of images deleted.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
image_paths = self._storage.list_files(prefix)
deleted_count = 0
for path in image_paths:
if self._storage.delete(path):
deleted_count += 1
return deleted_count
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document ID.
Returns:
List of image paths.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
return self._storage.list_files(prefix)

View File

@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
@@ -303,12 +305,19 @@ class InferenceService:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
@@ -320,19 +329,26 @@ class InferenceService:
from ultralytics import YOLO
import io
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))

View File

@@ -0,0 +1,830 @@
"""
Storage helpers for web services.
Provides convenience functions for common storage operations,
wrapping the storage backend with proper path handling using prefixes.
"""
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
from shared.storage import PREFIXES, get_storage_backend
from shared.storage.local import LocalStorageBackend
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_default_storage() -> "StorageBackend":
"""Get the default storage backend.
Returns:
Configured StorageBackend instance.
"""
return get_storage_backend()
class StorageHelper:
"""Helper class for storage operations with prefixes.
Provides high-level operations for document storage, including
upload, download, and URL generation with proper path prefixes.
"""
def __init__(self, storage: "StorageBackend | None" = None) -> None:
"""Initialize storage helper.
Args:
storage: Storage backend to use. If None, creates default.
"""
self._storage = storage or get_default_storage()
@property
def storage(self) -> "StorageBackend":
"""Get the underlying storage backend."""
return self._storage
# Document operations
def upload_document(
self,
content: bytes,
filename: str,
document_id: str | None = None,
) -> tuple[str, str]:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename (used for extension).
document_id: Optional document ID. Generated if not provided.
Returns:
Tuple of (document_id, storage_path).
"""
if document_id is None:
document_id = str(uuid4())
ext = Path(filename).suffix.lower() or ".pdf"
path = PREFIXES.document_path(document_id, ext)
self._storage.upload_bytes(content, path, overwrite=True)
return document_id, path
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
"""Download a document from storage.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
Document content as bytes.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.download_bytes(path)
def get_document_url(
self,
document_id: str,
extension: str = ".pdf",
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a document.
Args:
document_id: Document identifier.
extension: File extension.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.get_presigned_url(path, expires_in_seconds)
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
"""Check if a document exists.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document exists.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.exists(path)
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
"""Delete a document.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document was deleted.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.delete(path)
# Image operations
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = PREFIXES.image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
# Upload staging operations
def save_upload(
self,
content: bytes,
filename: str,
subfolder: str | None = None,
) -> str:
"""Save a file to upload staging area.
Args:
content: File content as bytes.
filename: Filename to save as.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path where file was saved.
"""
path = PREFIXES.upload_path(filename, subfolder)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
"""Get a file from upload staging area.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
File content as bytes.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.download_bytes(path)
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
"""Delete a file from upload staging area.
Args:
filename: Filename to delete.
subfolder: Optional subfolder.
Returns:
True if file was deleted.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.delete(path)
# Result operations
def save_result(self, content: bytes, filename: str) -> str:
"""Save a result file.
Args:
content: File content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.result_path(filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_result(self, filename: str) -> bytes:
"""Get a result file.
Args:
filename: Filename to retrieve.
Returns:
File content as bytes.
"""
path = PREFIXES.result_path(filename)
return self._storage.download_bytes(path)
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
"""Get presigned URL for a result file.
Args:
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.result_path(filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
def result_exists(self, filename: str) -> bool:
"""Check if a result file exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = PREFIXES.result_path(filename)
return self._storage.exists(path)
def delete_result(self, filename: str) -> bool:
"""Delete a result file.
Args:
filename: Filename to delete.
Returns:
True if file was deleted.
"""
path = PREFIXES.result_path(filename)
return self._storage.delete(path)
# Export operations
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
"""Save an export file.
Args:
content: File content as bytes.
export_id: Export identifier.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.export_path(export_id, filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_export_url(
self,
export_id: str,
filename: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an export file.
Args:
export_id: Export identifier.
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.export_path(export_id, filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
# Admin image operations
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
"""Get the storage path for an admin image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Storage path like "admin_images/doc123/page_1.png"
"""
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
def save_admin_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save an admin page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = self.get_admin_image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
"""Download an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_admin_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
"""Check if an admin page image exists.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
True if image exists.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.exists(path)
def list_admin_images(self, document_id: str) -> list[str]:
"""List all admin images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
def delete_admin_images(self, document_id: str) -> int:
"""Delete all admin images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def get_admin_image_local_path(
self, document_id: str, page_num: int
) -> Path | None:
"""Get the local filesystem path for an admin image.
This method is useful for serving files via FileResponse.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
# Cloud storage - cannot get local path
return None
remote_path = self.get_admin_image_path(document_id, page_num)
try:
full_path = self._storage._get_full_path(remote_path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_admin_image_dimensions(
self, document_id: str, page_num: int
) -> tuple[int, int] | None:
"""Get the dimensions (width, height) of an admin image.
This method is useful for normalizing bounding box coordinates.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Tuple of (width, height) if image exists, None otherwise.
"""
from PIL import Image
# Try local path first for efficiency
local_path = self.get_admin_image_local_path(document_id, page_num)
if local_path is not None:
with Image.open(local_path) as img:
return img.size
# Fall back to downloading for cloud storage
if not self.admin_image_exists(document_id, page_num):
return None
try:
import io
image_bytes = self.get_admin_image(document_id, page_num)
with Image.open(io.BytesIO(image_bytes)) as img:
return img.size
except Exception:
return None
# Raw PDF operations (legacy compatibility)
def save_raw_pdf(self, content: bytes, filename: str) -> str:
"""Save a raw PDF for auto-labeling pipeline.
Args:
content: PDF content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_raw_pdf(self, filename: str) -> bytes:
"""Get a raw PDF from storage.
Args:
filename: Filename to retrieve.
Returns:
PDF content as bytes.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.download_bytes(path)
def raw_pdf_exists(self, filename: str) -> bool:
"""Check if a raw PDF exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.exists(path)
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a raw PDF.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = f"{PREFIXES.RAW_PDFS}/{filename}"
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_raw_pdf_path(self, filename: str) -> str:
"""Get the storage path for a raw PDF (not the local filesystem path).
Args:
filename: Filename.
Returns:
Storage path like "raw_pdfs/filename.pdf"
"""
return f"{PREFIXES.RAW_PDFS}/{filename}"
# Result local path operations
def get_result_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a result file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.result_path(filename)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_results_base_path(self) -> Path | None:
"""Get the base directory path for results (local storage only).
Used for mounting static file directories.
Returns:
Path to results directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Upload local path operations
def get_upload_local_path(
self, filename: str, subfolder: str | None = None
) -> Path | None:
"""Get the local filesystem path for an upload file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.upload_path(filename, subfolder)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
"""Get the base directory path for uploads (local storage only).
Args:
subfolder: Optional subfolder (e.g., "async").
Returns:
Path to uploads directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
if subfolder:
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
else:
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
"""Check if an upload file exists.
Args:
filename: Filename to check.
subfolder: Optional subfolder.
Returns:
True if file exists.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.exists(path)
# Dataset operations
def get_datasets_base_path(self) -> Path | None:
"""Get the base directory path for datasets (local storage only).
Returns:
Path to datasets directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_admin_images_base_path(self) -> Path | None:
"""Get the base directory path for admin images (local storage only).
Returns:
Path to admin_images directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_raw_pdfs_base_path(self) -> Path | None:
"""Get the base directory path for raw PDFs (local storage only).
Returns:
Path to raw_pdfs directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_autolabel_output_path(self) -> Path | None:
"""Get the directory path for autolabel output (local storage only).
Returns:
Path to autolabel_output directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
# Use a subfolder under results for autolabel output
base_path = self._storage._get_full_path("autolabel_output")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_training_data_path(self) -> Path | None:
"""Get the directory path for training data exports (local storage only).
Returns:
Path to training directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path("training")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_exports_base_path(self) -> Path | None:
"""Get the base directory path for exports (local storage only).
Returns:
Path to exports directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Default instance for convenience
_default_helper: StorageHelper | None = None
def get_storage_helper() -> StorageHelper:
"""Get the default storage helper instance.
Creates the helper on first call with default storage backend.
Returns:
Default StorageHelper instance.
"""
global _default_helper
if _default_helper is None:
_default_helper = StorageHelper()
return _default_helper

205
packages/shared/README.md Normal file
View File

@@ -0,0 +1,205 @@
# Shared Package
Shared utilities and abstractions for the Invoice Master system.
## Storage Abstraction Layer
A unified storage abstraction supporting multiple backends:
- **Local filesystem** - Development and testing
- **Azure Blob Storage** - Azure cloud deployments
- **AWS S3** - AWS cloud deployments
### Installation
```bash
# Basic installation (local storage only)
pip install -e packages/shared
# With Azure support
pip install -e "packages/shared[azure]"
# With S3 support
pip install -e "packages/shared[s3]"
# All cloud providers
pip install -e "packages/shared[all]"
```
### Quick Start
```python
from shared.storage import get_storage_backend
# Option 1: From configuration file
storage = get_storage_backend("storage.yaml")
# Option 2: From environment variables
from shared.storage import create_storage_backend_from_env
storage = create_storage_backend_from_env()
# Upload a file
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
# Download a file
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
# Get pre-signed URL for frontend access
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
```
### Configuration File Format
Create a `storage.yaml` file with environment variable substitution support:
```yaml
# Backend selection: local, azure_blob, or s3
backend: ${STORAGE_BACKEND:-local}
# Default pre-signed URL expiry (seconds)
presigned_url_expiry: 3600
# Local storage configuration
local:
base_path: ${STORAGE_BASE_PATH:-./data/storage}
# Azure Blob Storage configuration
azure:
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
create_container: false
# AWS S3 configuration
s3:
bucket_name: ${AWS_S3_BUCKET}
region_name: ${AWS_REGION:-us-east-1}
access_key_id: ${AWS_ACCESS_KEY_ID}
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
create_bucket: false
```
### Environment Variables
| Variable | Backend | Description |
|----------|---------|-------------|
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
| `STORAGE_BASE_PATH` | Local | Base directory path |
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
| `AWS_S3_BUCKET` | S3 | Bucket name |
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
### API Reference
#### StorageBackend Interface
```python
class StorageBackend(ABC):
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
"""Upload a file to storage."""
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage."""
def exists(self, remote_path: str) -> bool:
"""Check if a file exists."""
def list_files(self, prefix: str) -> list[str]:
"""List files with given prefix."""
def delete(self, remote_path: str) -> bool:
"""Delete a file."""
def get_url(self, remote_path: str) -> str:
"""Get URL for a file."""
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
"""Upload bytes directly."""
def download_bytes(self, remote_path: str) -> bytes:
"""Download file as bytes."""
```
#### Factory Functions
```python
# Create from configuration file
storage = create_storage_backend_from_file("storage.yaml")
# Create from environment variables
storage = create_storage_backend_from_env()
# Create from StorageConfig object
config = StorageConfig(backend_type="local", base_path=Path("./data"))
storage = create_storage_backend(config)
# Convenience function with fallback chain: config file -> env vars -> local default
storage = get_storage_backend("storage.yaml") # or None for env-only
```
### Pre-signed URLs
Pre-signed URLs provide temporary access to files without exposing credentials:
```python
# Generate URL valid for 1 hour (default)
url = storage.get_presigned_url("documents/invoice.pdf")
# Generate URL valid for 24 hours
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
# Maximum expiry: 7 days (604800 seconds)
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
```
**Note:** Local storage returns `file://` URLs that don't actually expire.
### Error Handling
```python
from shared.storage import (
StorageError,
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
)
try:
storage.download("nonexistent.pdf", Path("local.pdf"))
except FileNotFoundStorageError as e:
print(f"File not found: {e}")
except StorageError as e:
print(f"Storage error: {e}")
```
### Testing with MinIO (S3-compatible)
```bash
# Start MinIO locally
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
# Configure environment
export STORAGE_BACKEND=s3
export AWS_S3_BUCKET=test-bucket
export AWS_ENDPOINT_URL=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
```
### Module Structure
```
shared/storage/
├── __init__.py # Public exports
├── base.py # Abstract interface and exceptions
├── local.py # Local filesystem backend
├── azure.py # Azure Blob Storage backend
├── s3.py # AWS S3 backend
├── config_loader.py # YAML configuration loader
└── factory.py # Backend factory functions
```

View File

@@ -16,4 +16,18 @@ setup(
"pyyaml>=6.0",
"thefuzz>=0.20.0",
],
extras_require={
"azure": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
],
"s3": [
"boto3>=1.34.0",
],
"all": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
"boto3>=1.34.0",
],
},
)

View File

@@ -58,23 +58,16 @@ def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows
if _is_wsl():
# WSL: use native Linux filesystem for better I/O performance
PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
'reports_dir': 'reports', # Keep reports in project directory
}
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Paths Configuration - uses STORAGE_BASE_PATH for consistency
# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
PATHS = {
'csv_dir': f'{_storage_base}/structured_data',
'pdf_dir': f'{_storage_base}/raw_pdfs',
'output_dir': f'{_storage_base}/datasets',
'reports_dir': 'reports', # Keep reports in project directory
}
# Auto-labeling Configuration
AUTOLABEL = {

View File

@@ -0,0 +1,46 @@
"""
Shared Field Definitions - Single Source of Truth.
This module provides centralized field class definitions used throughout
the invoice extraction system. All field mappings are derived from
FIELD_DEFINITIONS to ensure consistency.
Usage:
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
Available exports:
- FieldDefinition: Dataclass for field definition
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
- NUM_CLASSES: Total number of field classes (10)
- CLASS_NAMES: List of class names in order [0..9]
- FIELD_CLASSES: dict[int, str] - class_id to class_name
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
"""
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
from .mappings import (
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
ACCOUNT_FIELD_MAPPING,
)
__all__ = [
"FieldDefinition",
"FIELD_DEFINITIONS",
"NUM_CLASSES",
"CLASS_NAMES",
"FIELD_CLASSES",
"FIELD_CLASS_IDS",
"CLASS_TO_FIELD",
"CSV_TO_CLASS_MAPPING",
"TRAINING_FIELD_CLASSES",
"ACCOUNT_FIELD_MAPPING",
]

View File

@@ -0,0 +1,58 @@
"""
Field Configuration - Single Source of Truth
This module defines all invoice field classes used throughout the system.
The class IDs are verified against the trained YOLO model (best.pt).
IMPORTANT: Do not modify class_id values without retraining the model!
"""
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class FieldDefinition:
"""Immutable field definition for invoice extraction.
Attributes:
class_id: YOLO class ID (0-9), must match trained model
class_name: YOLO class name (lowercase_underscore)
field_name: Business field name used in API responses
csv_name: CSV column name for data import/export
is_derived: True if field is derived from other fields (not in CSV)
"""
class_id: int
class_name: str
field_name: str
csv_name: str
is_derived: bool = False
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
#
# DO NOT CHANGE THE ORDER - it must match the trained model!
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
FieldDefinition(6, "amount", "Amount", "Amount"),
FieldDefinition(
7,
"supplier_org_number",
"supplier_organisation_number",
"supplier_organisation_number",
),
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
FieldDefinition(
9, "payment_line", "payment_line", "payment_line", is_derived=True
),
)
# Total number of field classes
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)

View File

@@ -0,0 +1,57 @@
"""
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
This ensures consistency across the entire codebase.
DO NOT hardcode field mappings elsewhere - always import from this module.
"""
from typing import Final
from .field_config import FIELD_DEFINITIONS
# List of class names in order (for YOLO classes.txt generation)
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
# class_id -> class_name mapping
# Example: {0: "invoice_number", 1: "invoice_date", ...}
FIELD_CLASSES: Final[dict[int, str]] = {
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
}
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
FIELD_CLASS_IDS: Final[dict[str, int]] = {
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# class_name -> field_name mapping (for API responses)
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
CLASS_TO_FIELD: Final[dict[str, str]] = {
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
}
# field_name -> class_id mapping (for CSV import)
# Excludes derived fields like payment_line
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
}
# field_name -> class_id mapping (for training, includes all fields)
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# Account field mapping for supplier_accounts special handling
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
"supplier_accounts": {
"BG": "Bankgiro",
"PG": "Plusgiro",
}
}

View File

@@ -0,0 +1,59 @@
"""
Storage abstraction layer for training data.
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
"""
from shared.storage.base import (
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
StorageBackend,
StorageConfig,
StorageError,
)
from shared.storage.factory import (
create_storage_backend,
create_storage_backend_from_env,
create_storage_backend_from_file,
get_default_storage_config,
get_storage_backend,
)
from shared.storage.local import LocalStorageBackend
from shared.storage.prefixes import PREFIXES, StoragePrefixes
__all__ = [
# Base classes and exceptions
"StorageBackend",
"StorageConfig",
"StorageError",
"FileNotFoundStorageError",
"PresignedUrlNotSupportedError",
# Backends
"LocalStorageBackend",
# Factory functions
"create_storage_backend",
"create_storage_backend_from_env",
"create_storage_backend_from_file",
"get_default_storage_config",
"get_storage_backend",
# Path prefixes
"PREFIXES",
"StoragePrefixes",
]
# Lazy imports to avoid dependencies when not using specific backends
def __getattr__(name: str):
if name == "AzureBlobStorageBackend":
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend
if name == "S3StorageBackend":
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend
if name == "load_storage_config":
from shared.storage.config_loader import load_storage_config
return load_storage_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,335 @@
"""
Azure Blob Storage backend.
Provides storage operations using Azure Blob Storage.
"""
from pathlib import Path
from azure.storage.blob import (
BlobSasPermissions,
BlobServiceClient,
ContainerClient,
generate_blob_sas,
)
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class AzureBlobStorageBackend(StorageBackend):
"""Storage backend using Azure Blob Storage.
Files are stored as blobs in an Azure Blob Storage container.
"""
def __init__(
self,
connection_string: str,
container_name: str,
create_container: bool = False,
) -> None:
"""Initialize Azure Blob Storage backend.
Args:
connection_string: Azure Storage connection string.
container_name: Name of the blob container.
create_container: If True, create the container if it doesn't exist.
"""
self._connection_string = connection_string
self._container_name = container_name
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
self._container = self._blob_service.get_container_client(container_name)
# Extract account key from connection string for SAS token generation
self._account_key = self._extract_account_key(connection_string)
if create_container and not self._container.exists():
self._container.create_container()
@staticmethod
def _extract_account_key(connection_string: str) -> str | None:
"""Extract account key from connection string.
Args:
connection_string: Azure Storage connection string.
Returns:
Account key if found, None otherwise.
"""
for part in connection_string.split(";"):
if part.startswith("AccountKey="):
return part[len("AccountKey=") :]
return None
@property
def container_name(self) -> str:
"""Get the container name for this storage backend."""
return self._container_name
@property
def container_client(self) -> ContainerClient:
"""Get the Azure container client."""
return self._container
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to Azure Blob Storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If blob exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
with open(local_path, "rb") as f:
blob_client.upload_blob(f, overwrite=overwrite)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a blob from Azure Blob Storage.
Args:
remote_path: Blob path in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
stream = blob_client.download_blob()
local_path.write_bytes(stream.readall())
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a blob exists in storage.
Args:
remote_path: Blob path to check.
Returns:
True if the blob exists, False otherwise.
"""
blob_client = self._container.get_blob_client(remote_path)
return blob_client.exists()
def list_files(self, prefix: str) -> list[str]:
"""List blobs in storage with given prefix.
Args:
prefix: Blob path prefix to filter.
Returns:
List of blob paths matching the prefix.
"""
if prefix:
blobs = self._container.list_blobs(name_starts_with=prefix)
else:
blobs = self._container.list_blobs()
return [blob.name for blob in blobs]
def delete(self, remote_path: str) -> bool:
"""Delete a blob from storage.
Args:
remote_path: Blob path to delete.
Returns:
True if blob was deleted, False if it didn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
return False
blob_client.delete_blob()
return True
def get_url(self, remote_path: str) -> str:
"""Get the URL for a blob.
Args:
remote_path: Blob path in storage.
Returns:
URL to access the blob.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
return blob_client.url
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to Azure Blob Storage.
Args:
data: Bytes to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the data was stored.
"""
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
blob_client.upload_blob(data, overwrite=overwrite)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a blob as bytes.
Args:
remote_path: Blob path in storage.
Returns:
The blob contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
stream = blob_client.download_blob()
return stream.readall()
def upload_directory(
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
) -> list[str]:
"""Upload all files in a directory to Azure Blob Storage.
Args:
local_dir: Local directory to upload.
remote_prefix: Prefix for remote blob paths.
overwrite: If True, overwrite existing blobs.
Returns:
List of remote paths where files were stored.
"""
uploaded: list[str] = []
for file_path in local_dir.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(local_dir)
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
self.upload(file_path, remote_path, overwrite=overwrite)
uploaded.append(remote_path)
return uploaded
def download_directory(
self, remote_prefix: str, local_dir: Path
) -> list[Path]:
"""Download all blobs with a prefix to a local directory.
Args:
remote_prefix: Blob path prefix to download.
local_dir: Local directory to download to.
Returns:
List of local paths where files were downloaded.
"""
downloaded: list[Path] = []
blobs = self.list_files(remote_prefix)
for blob_path in blobs:
# Remove prefix to get relative path
if remote_prefix:
relative_path = blob_path[len(remote_prefix):]
if relative_path.startswith("/"):
relative_path = relative_path[1:]
else:
relative_path = blob_path
local_path = local_dir / relative_path
self.download(blob_path, local_path)
downloaded.append(local_path)
return downloaded
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a SAS URL for temporary blob access.
Args:
remote_path: Blob path in storage.
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
Returns:
Blob URL with SAS token.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
from datetime import datetime, timedelta, timezone
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
# Generate SAS token
sas_token = generate_blob_sas(
account_name=self._blob_service.account_name,
container_name=self._container_name,
blob_name=remote_path,
account_key=self._account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
)
return f"{blob_client.url}?{sas_token}"

View File

@@ -0,0 +1,229 @@
"""
Base classes and interfaces for storage backends.
Defines the abstract StorageBackend interface and common exceptions.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
class StorageError(Exception):
"""Base exception for storage operations."""
pass
class FileNotFoundStorageError(StorageError):
"""Raised when a file is not found in storage."""
def __init__(self, path: str) -> None:
self.path = path
super().__init__(f"File not found in storage: {path}")
class PresignedUrlNotSupportedError(StorageError):
"""Raised when pre-signed URLs are not supported by a backend."""
def __init__(self, backend_type: str) -> None:
self.backend_type = backend_type
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
@dataclass(frozen=True)
class StorageConfig:
"""Configuration for storage backend.
Attributes:
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
connection_string: Azure Blob Storage connection string (for azure_blob).
container_name: Azure Blob Storage container name (for azure_blob).
base_path: Base path for local storage (for local).
bucket_name: S3 bucket name (for s3).
region_name: AWS region name (for s3).
access_key_id: AWS access key ID (for s3).
secret_access_key: AWS secret access key (for s3).
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
connection_string: str | None = None
container_name: str | None = None
base_path: Path | None = None
bucket_name: str | None = None
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
presigned_url_expiry: int = 3600
class StorageBackend(ABC):
"""Abstract base class for storage backends.
Provides a unified interface for storing and retrieving files
from different storage systems (local filesystem, Azure Blob, etc.).
"""
@abstractmethod
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
pass
@abstractmethod
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
pass
@abstractmethod
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
List of file paths matching the prefix.
"""
pass
@abstractmethod
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
pass
@abstractmethod
def get_url(self, remote_path: str) -> str:
"""Get a URL or path to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
URL or path to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary access.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: URL validity duration (default 1 hour).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
"""
pass
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Default implementation writes to temp file then uploads.
Subclasses may override for more efficient implementation.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
try:
return self.upload(temp_path, remote_path, overwrite=overwrite)
finally:
temp_path.unlink(missing_ok=True)
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Default implementation downloads to temp file then reads.
Subclasses may override for more efficient implementation.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_path = Path(f.name)
try:
self.download(remote_path, temp_path)
return temp_path.read_bytes()
finally:
temp_path.unlink(missing_ok=True)

View File

@@ -0,0 +1,242 @@
"""
Configuration file loader for storage backends.
Supports YAML configuration files with environment variable substitution.
"""
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LocalConfig:
"""Local storage backend configuration."""
base_path: Path
@dataclass(frozen=True)
class AzureConfig:
"""Azure Blob Storage configuration."""
connection_string: str
container_name: str
create_container: bool = False
@dataclass(frozen=True)
class S3Config:
"""AWS S3 configuration."""
bucket_name: str
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
create_bucket: bool = False
@dataclass(frozen=True)
class StorageFileConfig:
"""Extended storage configuration from file.
Attributes:
backend_type: Type of storage backend.
local: Local backend configuration.
azure: Azure Blob configuration.
s3: S3 configuration.
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
local: LocalConfig | None = None
azure: AzureConfig | None = None
s3: S3Config | None = None
presigned_url_expiry: int = 3600
def substitute_env_vars(value: str) -> str:
"""Substitute environment variables in a string.
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
Args:
value: String potentially containing env var references.
Returns:
String with env vars substituted.
"""
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
def replace(match: re.Match[str]) -> str:
var_name = match.group(1)
default = match.group(2)
return os.environ.get(var_name, default or "")
return re.sub(pattern, replace, value)
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
"""Recursively substitute env vars in a dictionary.
Args:
data: Dictionary to process.
Returns:
New dictionary with substitutions applied.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, str):
result[key] = substitute_env_vars(value)
elif isinstance(value, dict):
result[key] = _substitute_in_dict(value)
elif isinstance(value, list):
result[key] = [
substitute_env_vars(item) if isinstance(item, str) else item
for item in value
]
else:
result[key] = value
return result
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
"""Parse local configuration section.
Args:
data: Dictionary containing local config.
Returns:
LocalConfig instance.
Raises:
ValueError: If required fields are missing.
"""
base_path = data.get("base_path")
if not base_path:
raise ValueError("local.base_path is required")
return LocalConfig(base_path=Path(base_path))
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
"""Parse Azure configuration section.
Args:
data: Dictionary containing Azure config.
Returns:
AzureConfig instance.
Raises:
ValueError: If required fields are missing.
"""
connection_string = data.get("connection_string")
container_name = data.get("container_name")
if not connection_string:
raise ValueError("azure.connection_string is required")
if not container_name:
raise ValueError("azure.container_name is required")
return AzureConfig(
connection_string=connection_string,
container_name=container_name,
create_container=data.get("create_container", False),
)
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
"""Parse S3 configuration section.
Args:
data: Dictionary containing S3 config.
Returns:
S3Config instance.
Raises:
ValueError: If required fields are missing.
"""
bucket_name = data.get("bucket_name")
if not bucket_name:
raise ValueError("s3.bucket_name is required")
return S3Config(
bucket_name=bucket_name,
region_name=data.get("region_name"),
access_key_id=data.get("access_key_id"),
secret_access_key=data.get("secret_access_key"),
endpoint_url=data.get("endpoint_url"),
create_bucket=data.get("create_bucket", False),
)
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
"""Load storage configuration from YAML file.
Supports environment variable substitution using ${VAR_NAME} or
${VAR_NAME:-default} syntax.
Args:
config_path: Path to configuration file.
Returns:
Parsed StorageFileConfig.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If config is invalid.
"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
raw_content = config_path.read_text(encoding="utf-8")
data = yaml.safe_load(raw_content)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in config file: {e}") from e
if not isinstance(data, dict):
raise ValueError("Config file must contain a YAML dictionary")
# Substitute environment variables
data = _substitute_in_dict(data)
# Extract backend type
backend_type = data.get("backend")
if not backend_type:
raise ValueError("'backend' field is required in config file")
# Parse presigned URL expiry
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
# Parse backend-specific configurations
local_config = None
azure_config = None
s3_config = None
if "local" in data:
local_config = _parse_local_config(data["local"])
if "azure" in data:
azure_config = _parse_azure_config(data["azure"])
if "s3" in data:
s3_config = _parse_s3_config(data["s3"])
return StorageFileConfig(
backend_type=backend_type,
local=local_config,
azure=azure_config,
s3=s3_config,
presigned_url_expiry=presigned_url_expiry,
)

View File

@@ -0,0 +1,296 @@
"""
Factory functions for creating storage backends.
Provides convenient functions for creating storage backends from
configuration or environment variables.
"""
import os
from pathlib import Path
from shared.storage.base import StorageBackend, StorageConfig
def create_storage_backend(config: StorageConfig) -> StorageBackend:
"""Create a storage backend from configuration.
Args:
config: Storage configuration.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
"""
if config.backend_type == "local":
if config.base_path is None:
raise ValueError("base_path is required for local storage backend")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=config.base_path)
elif config.backend_type == "azure_blob":
if config.connection_string is None:
raise ValueError(
"connection_string is required for Azure blob storage backend"
)
if config.container_name is None:
raise ValueError(
"container_name is required for Azure blob storage backend"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=config.connection_string,
container_name=config.container_name,
)
elif config.backend_type == "s3":
if config.bucket_name is None:
raise ValueError("bucket_name is required for S3 storage backend")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=config.bucket_name,
region_name=config.region_name,
access_key_id=config.access_key_id,
secret_access_key=config.secret_access_key,
endpoint_url=config.endpoint_url,
)
else:
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
def get_default_storage_config() -> StorageConfig:
"""Get storage configuration from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
StorageConfig from environment.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local")
if backend_type == "local":
base_path_str = os.environ.get("STORAGE_BASE_PATH")
# Expand ~ to home directory
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
return StorageConfig(
backend_type="local",
base_path=base_path,
)
elif backend_type == "azure_blob":
return StorageConfig(
backend_type="azure_blob",
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
)
elif backend_type == "s3":
return StorageConfig(
backend_type="s3",
bucket_name=os.environ.get("AWS_S3_BUCKET"),
region_name=os.environ.get("AWS_REGION"),
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
)
else:
return StorageConfig(backend_type=backend_type)
def create_storage_backend_from_env() -> StorageBackend:
"""Create a storage backend from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
A configured storage backend.
Raises:
ValueError: If required environment variables are missing or empty.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
if backend_type == "local":
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
if not base_path:
raise ValueError(
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
)
# Expand ~ to home directory
base_path_expanded = os.path.expanduser(base_path)
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=Path(base_path_expanded))
elif backend_type == "azure_blob":
connection_string = os.environ.get(
"AZURE_STORAGE_CONNECTION_STRING", ""
).strip()
if not connection_string:
raise ValueError(
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
"and cannot be empty"
)
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
if not container_name:
raise ValueError(
"AZURE_STORAGE_CONTAINER environment variable is required "
"and cannot be empty"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=connection_string,
container_name=container_name,
)
elif backend_type == "s3":
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
if not bucket_name:
raise ValueError(
"AWS_S3_BUCKET environment variable is required and cannot be empty"
)
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=bucket_name,
region_name=os.environ.get("AWS_REGION", "").strip() or None,
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
or None,
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
)
else:
raise ValueError(f"Unknown storage backend type: {backend_type}")
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
"""Create a storage backend from a configuration file.
Args:
config_path: Path to YAML configuration file.
Returns:
A configured storage backend.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If configuration is invalid.
"""
from shared.storage.config_loader import load_storage_config
file_config = load_storage_config(config_path)
if file_config.backend_type == "local":
if file_config.local is None:
raise ValueError("local configuration section is required")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=file_config.local.base_path)
elif file_config.backend_type == "azure_blob":
if file_config.azure is None:
raise ValueError("azure configuration section is required")
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=file_config.azure.connection_string,
container_name=file_config.azure.container_name,
create_container=file_config.azure.create_container,
)
elif file_config.backend_type == "s3":
if file_config.s3 is None:
raise ValueError("s3 configuration section is required")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=file_config.s3.bucket_name,
region_name=file_config.s3.region_name,
access_key_id=file_config.s3.access_key_id,
secret_access_key=file_config.s3.secret_access_key,
endpoint_url=file_config.s3.endpoint_url,
create_bucket=file_config.s3.create_bucket,
)
else:
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
"""Get storage backend with fallback chain.
Priority:
1. Config file (if provided)
2. Environment variables
Args:
config_path: Optional path to config file.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
FileNotFoundError: If specified config file doesn't exist.
"""
if config_path:
return create_storage_backend_from_file(config_path)
# Fall back to environment variables
return create_storage_backend_from_env()

View File

@@ -0,0 +1,262 @@
"""
Local filesystem storage backend.
Provides storage operations using the local filesystem.
"""
import shutil
from pathlib import Path
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class LocalStorageBackend(StorageBackend):
"""Storage backend using local filesystem.
Files are stored relative to a base path on the local filesystem.
"""
def __init__(self, base_path: str | Path) -> None:
"""Initialize local storage backend.
Args:
base_path: Base directory for all storage operations.
Will be created if it doesn't exist.
"""
self._base_path = Path(base_path)
self._base_path.mkdir(parents=True, exist_ok=True)
@property
def base_path(self) -> Path:
"""Get the base path for this storage backend."""
return self._base_path
def _get_full_path(self, remote_path: str) -> Path:
"""Convert a remote path to a full local path with security validation.
Args:
remote_path: The remote path to resolve.
Returns:
The full local path.
Raises:
StorageError: If the path attempts to escape the base directory.
"""
# Reject absolute paths
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
raise StorageError(f"Absolute paths not allowed: {remote_path}")
# Resolve to prevent path traversal attacks
full_path = (self._base_path / remote_path).resolve()
base_resolved = self._base_path.resolve()
# Verify the resolved path is within base_path
try:
full_path.relative_to(base_resolved)
except ValueError:
raise StorageError(f"Path traversal not allowed: {remote_path}")
return full_path
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to local storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(local_path, dest_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from local storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
source_path = self._get_full_path(remote_path)
if not source_path.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_path, local_path)
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
return self._get_full_path(remote_path).exists()
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
Sorted list of file paths matching the prefix.
"""
if prefix:
search_path = self._get_full_path(prefix)
if not search_path.exists():
return []
base_for_relative = self._base_path
else:
search_path = self._base_path
base_for_relative = self._base_path
files: list[str] = []
if search_path.is_file():
files.append(str(search_path.relative_to(self._base_path)))
elif search_path.is_dir():
for file_path in search_path.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(self._base_path)
files.append(str(relative_path).replace("\\", "/"))
return sorted(files)
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
return False
file_path.unlink()
return True
def get_url(self, remote_path: str) -> str:
"""Get a file:// URL to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
dest_path.write_bytes(data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.read_bytes()
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a file:// URL for local file access.
For local storage, this returns a file:// URI.
Note: Local file:// URLs don't actually expire.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: Ignored for local storage (URLs don't expire).
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()

View File

@@ -0,0 +1,158 @@
"""
Storage path prefixes for unified file organization.
Provides standardized path prefixes for organizing files within
the storage backend, ensuring consistent structure across
local, Azure Blob, and S3 storage.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class StoragePrefixes:
"""Standardized storage path prefixes.
All paths are relative to the storage backend root.
These prefixes ensure consistent file organization across
all storage backends (local, Azure, S3).
Usage:
from shared.storage.prefixes import PREFIXES
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
storage.upload_bytes(content, path)
"""
# Document storage
DOCUMENTS: str = "documents"
"""Original document files (PDFs, etc.)"""
IMAGES: str = "images"
"""Page images extracted from documents"""
# Processing directories
UPLOADS: str = "uploads"
"""Temporary upload staging area"""
RESULTS: str = "results"
"""Inference results and visualizations"""
EXPORTS: str = "exports"
"""Exported datasets and annotations"""
# Training data
DATASETS: str = "datasets"
"""Training dataset files"""
MODELS: str = "models"
"""Trained model weights and checkpoints"""
# Data pipeline directories (legacy compatibility)
RAW_PDFS: str = "raw_pdfs"
"""Raw PDF files for auto-labeling pipeline"""
STRUCTURED_DATA: str = "structured_data"
"""CSV/structured data for matching"""
ADMIN_IMAGES: str = "admin_images"
"""Admin UI page images"""
@staticmethod
def document_path(document_id: str, extension: str = ".pdf") -> str:
"""Get path for a document file.
Args:
document_id: Unique document identifier.
extension: File extension (include leading dot).
Returns:
Storage path like "documents/abc123.pdf"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
@staticmethod
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
"""Get path for a page image file.
Args:
document_id: Unique document identifier.
page_num: Page number (1-indexed).
extension: File extension (include leading dot).
Returns:
Storage path like "images/abc123/page_1.png"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
@staticmethod
def upload_path(filename: str, subfolder: str | None = None) -> str:
"""Get path for a temporary upload file.
Args:
filename: Original filename.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
"""
if subfolder:
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
return f"{PREFIXES.UPLOADS}/{filename}"
@staticmethod
def result_path(filename: str) -> str:
"""Get path for a result file.
Args:
filename: Result filename.
Returns:
Storage path like "results/filename.json"
"""
return f"{PREFIXES.RESULTS}/{filename}"
@staticmethod
def export_path(export_id: str, filename: str) -> str:
"""Get path for an export file.
Args:
export_id: Unique export identifier.
filename: Export filename.
Returns:
Storage path like "exports/abc123/filename.zip"
"""
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
@staticmethod
def dataset_path(dataset_id: str, filename: str) -> str:
"""Get path for a dataset file.
Args:
dataset_id: Unique dataset identifier.
filename: Dataset filename.
Returns:
Storage path like "datasets/abc123/filename.yaml"
"""
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
@staticmethod
def model_path(version: str, filename: str) -> str:
"""Get path for a model file.
Args:
version: Model version string.
filename: Model filename.
Returns:
Storage path like "models/v1.0.0/best.pt"
"""
return f"{PREFIXES.MODELS}/{version}/{filename}"
# Default instance for convenient access
PREFIXES = StoragePrefixes()

View File

@@ -0,0 +1,309 @@
"""
AWS S3 Storage backend.
Provides storage operations using AWS S3.
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class S3StorageBackend(StorageBackend):
"""Storage backend using AWS S3.
Files are stored as objects in an S3 bucket.
"""
def __init__(
self,
bucket_name: str,
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
endpoint_url: str | None = None,
create_bucket: bool = False,
) -> None:
"""Initialize S3 storage backend.
Args:
bucket_name: Name of the S3 bucket.
region_name: AWS region name (optional, uses default if not set).
access_key_id: AWS access key ID (optional, uses credentials chain).
secret_access_key: AWS secret access key (optional).
endpoint_url: Custom endpoint URL (for S3-compatible services).
create_bucket: If True, create the bucket if it doesn't exist.
"""
import boto3
self._bucket_name = bucket_name
self._region_name = region_name
# Build client kwargs
client_kwargs: dict[str, Any] = {}
if region_name:
client_kwargs["region_name"] = region_name
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
if access_key_id and secret_access_key:
client_kwargs["aws_access_key_id"] = access_key_id
client_kwargs["aws_secret_access_key"] = secret_access_key
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
if create_bucket:
self._ensure_bucket_exists()
def _ensure_bucket_exists(self) -> None:
"""Create the bucket if it doesn't exist."""
from botocore.exceptions import ClientError
try:
self._s3.head_bucket(Bucket=self._bucket_name)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchBucket"):
# Bucket doesn't exist, create it
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if self._region_name and self._region_name != "us-east-1":
create_kwargs["CreateBucketConfiguration"] = {
"LocationConstraint": self._region_name
}
self._s3.create_bucket(**create_kwargs)
else:
# Re-raise permission errors, network issues, etc.
raise
def _object_exists(self, key: str) -> bool:
"""Check if an object exists in S3.
Args:
key: Object key to check.
Returns:
True if object exists, False otherwise.
"""
from botocore.exceptions import ClientError
try:
self._s3.head_object(Bucket=self._bucket_name, Key=key)
return True
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
return False
raise
@property
def bucket_name(self) -> str:
"""Get the bucket name for this storage backend."""
return self._bucket_name
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to S3.
Args:
local_path: Path to the local file to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If object exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download an object from S3.
Args:
remote_path: Object key in S3.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if an object exists in S3.
Args:
remote_path: Object key to check.
Returns:
True if the object exists, False otherwise.
"""
return self._object_exists(remote_path)
def list_files(self, prefix: str) -> list[str]:
"""List objects in S3 with given prefix.
Handles pagination to return all matching objects (S3 returns max 1000 per request).
Args:
prefix: Object key prefix to filter.
Returns:
List of object keys matching the prefix.
"""
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if prefix:
kwargs["Prefix"] = prefix
all_keys: list[str] = []
while True:
response = self._s3.list_objects_v2(**kwargs)
contents = response.get("Contents", [])
all_keys.extend(obj["Key"] for obj in contents)
if not response.get("IsTruncated"):
break
kwargs["ContinuationToken"] = response["NextContinuationToken"]
return all_keys
def delete(self, remote_path: str) -> bool:
"""Delete an object from S3.
Args:
remote_path: Object key to delete.
Returns:
True if object was deleted, False if it didn't exist.
"""
if not self._object_exists(remote_path):
return False
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
return True
def get_url(self, remote_path: str) -> str:
"""Get a URL for an object.
Args:
remote_path: Object key in S3.
Returns:
URL to access the object.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=3600,
)
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary object access.
Args:
remote_path: Object key in S3.
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=expires_in_seconds,
)
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to S3.
Args:
data: Bytes to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the data was stored.
Raises:
StorageError: If object exists and overwrite is False.
"""
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download an object as bytes.
Args:
remote_path: Object key in S3.
Returns:
The object contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
from botocore.exceptions import ClientError
try:
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
raise FileNotFoundStorageError(remote_path) from e
raise

View File

@@ -20,7 +20,7 @@ from shared.config import get_db_connection_string
from shared.normalize import normalize_field
from shared.matcher import FieldMatcher
from shared.pdf import is_text_pdf, extract_text_tokens
from training.yolo.annotation_generator import FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from shared.data.db import DocumentDB

View File

@@ -113,7 +113,7 @@ def process_single_document(args_tuple):
# Import inside worker to avoid pickling issues
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
start_time = time.time()
@@ -342,7 +342,8 @@ def main():
from shared.ocr import OCREngine
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.yolo.annotation_generator import AnnotationGenerator
# Handle comma-separated CSV paths
csv_input = args.csv

View File

@@ -90,7 +90,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
import shutil
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"]
@@ -208,7 +208,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
import shutil
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"]

View File

@@ -15,7 +15,8 @@ from training.data.autolabel_report import FieldMatchResult
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
from shared.ocr.machine_code_parser import MachineCodeParser
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.yolo.annotation_generator import AnnotationGenerator
def match_supplier_accounts(

View File

@@ -9,43 +9,12 @@ from pathlib import Path
from typing import Any
import csv
# Field class mapping for YOLO
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
FIELD_CLASSES = {
'InvoiceNumber': 0,
'InvoiceDate': 1,
'InvoiceDueDate': 2,
'OCR': 3,
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
'supplier_organisation_number': 7,
'customer_number': 8,
'payment_line': 9, # Machine code payment line at bottom of invoice
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
}
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number',
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
# Import field mappings from single source of truth
from shared.fields import (
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
CLASS_NAMES,
ACCOUNT_FIELD_MAPPING,
)
@dataclass

View File

@@ -101,7 +101,8 @@ class DatasetBuilder:
Returns:
DatasetStats with build results
"""
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
from shared.fields import CLASS_NAMES
from .annotation_generator import AnnotationGenerator
random.seed(seed)

View File

@@ -18,7 +18,8 @@ import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from .annotation_generator import YOLOAnnotation
logger = logging.getLogger(__name__)