WIP
This commit is contained in:
@@ -120,7 +120,7 @@ def main() -> None:
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
@@ -136,7 +136,7 @@ def main() -> None:
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
storage=StorageConfig(),
|
||||
file=FileConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
|
||||
@@ -112,6 +112,7 @@ class AdminDB:
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
category: str = "invoice",
|
||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||
) -> str:
|
||||
"""Create a new document record."""
|
||||
@@ -125,6 +126,7 @@ class AdminDB:
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
@@ -154,6 +156,7 @@ class AdminDB:
|
||||
has_annotations: bool | None = None,
|
||||
auto_label_status: str | None = None,
|
||||
batch_id: str | None = None,
|
||||
category: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
@@ -171,6 +174,8 @@ class AdminDB:
|
||||
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
||||
if batch_id:
|
||||
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
||||
if category:
|
||||
where_clauses.append(AdminDocument.category == category)
|
||||
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(AdminDocument)
|
||||
@@ -283,6 +288,32 @@ class AdminDB:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_document_categories(self) -> list[str]:
|
||||
"""Get list of unique document categories."""
|
||||
with get_session_context() as session:
|
||||
statement = (
|
||||
select(AdminDocument.category)
|
||||
.distinct()
|
||||
.order_by(AdminDocument.category)
|
||||
)
|
||||
categories = session.exec(statement).all()
|
||||
return [c for c in categories if c is not None]
|
||||
|
||||
def update_document_category(
|
||||
self, document_id: str, category: str
|
||||
) -> AdminDocument | None:
|
||||
"""Update document category."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.category = category
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
session.refresh(document)
|
||||
return document
|
||||
return None
|
||||
|
||||
# ==========================================================================
|
||||
# Annotation Operations
|
||||
# ==========================================================================
|
||||
@@ -1292,6 +1323,36 @@ class AdminDB:
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def update_dataset_training_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
training_status: str | None,
|
||||
active_training_task_id: str | UUID | None = None,
|
||||
update_main_status: bool = False,
|
||||
) -> None:
|
||||
"""Update dataset training status and optionally the main status.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
training_status: Training status (pending, running, completed, failed, cancelled)
|
||||
active_training_task_id: Currently active training task ID
|
||||
update_main_status: If True and training_status is 'completed', set main status to 'trained'
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.training_status = training_status
|
||||
dataset.active_training_task_id = (
|
||||
UUID(str(active_training_task_id)) if active_training_task_id else None
|
||||
)
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
# Update main status to 'trained' when training completes
|
||||
if update_main_status and training_status == "completed":
|
||||
dataset.status = "trained"
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def add_dataset_documents(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
|
||||
@@ -11,23 +11,8 @@ from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CSV to Field Class Mapping
|
||||
# =============================================================================
|
||||
|
||||
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
||||
"InvoiceNumber": 0, # invoice_number
|
||||
"InvoiceDate": 1, # invoice_date
|
||||
"InvoiceDueDate": 2, # invoice_due_date
|
||||
"OCR": 3, # ocr_number
|
||||
"Bankgiro": 4, # bankgiro
|
||||
"Plusgiro": 5, # plusgiro
|
||||
"Amount": 6, # amount
|
||||
"supplier_organisation_number": 7, # supplier_organisation_number
|
||||
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
||||
"customer_number": 9, # customer_number
|
||||
}
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True):
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||
# User-defined grouping key for document organization
|
||||
category: str = Field(default="invoice", max_length=100, index=True)
|
||||
# Document category for training different models (e.g., invoice, letter, receipt)
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
@@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True):
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="building", max_length=20, index=True)
|
||||
# Status: building, ready, training, archived, failed
|
||||
# Status: building, ready, trained, archived, failed
|
||||
training_status: str | None = Field(default=None, max_length=20, index=True)
|
||||
# Training status: pending, scheduled, running, completed, failed, cancelled
|
||||
active_training_task_id: UUID | None = Field(default=None, index=True)
|
||||
train_ratio: float = Field(default=0.8)
|
||||
val_ratio: float = Field(default=0.1)
|
||||
seed: int = Field(default=42)
|
||||
@@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Field class mapping (same as src/cli/train.py)
|
||||
FIELD_CLASSES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_organisation_number",
|
||||
8: "payment_line",
|
||||
9: "customer_number",
|
||||
}
|
||||
|
||||
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
||||
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
# Read-only models for API responses
|
||||
@@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel):
|
||||
status: str
|
||||
auto_label_status: str | None
|
||||
auto_label_error: str | None
|
||||
category: str = "invoice"
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@@ -141,6 +141,40 @@ def run_migrations() -> None:
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 009: Add category to admin_documents
|
||||
(
|
||||
"admin_documents_category",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||
""",
|
||||
),
|
||||
# Migration 010: Add training_status and active_training_task_id to training_datasets
|
||||
(
|
||||
"training_datasets_training_status",
|
||||
"""
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
|
||||
""",
|
||||
),
|
||||
# Migration 010b: Update existing datasets with completed training to 'trained' status
|
||||
(
|
||||
"training_datasets_update_trained_status",
|
||||
"""
|
||||
UPDATE training_datasets d
|
||||
SET status = 'trained'
|
||||
WHERE d.status = 'ready'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM training_tasks t
|
||||
WHERE t.dataset_id = d.dataset_id
|
||||
AND t.status = 'completed'
|
||||
);
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
|
||||
@@ -21,7 +21,8 @@ import re
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import Detection
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
@@ -10,7 +10,8 @@ from typing import Any
|
||||
import time
|
||||
import re
|
||||
|
||||
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor, ExtractedField
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
|
||||
|
||||
@@ -9,6 +9,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
@@ -72,33 +75,8 @@ class Detection:
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
# Class names (must match training configuration)
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number', # Matches training class name
|
||||
'customer_number',
|
||||
'payment_line', # Machine code payment line at bottom of invoice
|
||||
]
|
||||
|
||||
# Mapping from class name to field name
|
||||
CLASS_TO_FIELD = {
|
||||
'invoice_number': 'InvoiceNumber',
|
||||
'invoice_date': 'InvoiceDate',
|
||||
'invoice_due_date': 'InvoiceDueDate',
|
||||
'ocr_number': 'OCR',
|
||||
'bankgiro': 'Bankgiro',
|
||||
'plusgiro': 'Plusgiro',
|
||||
'amount': 'Amount',
|
||||
'supplier_org_number': 'supplier_org_number',
|
||||
'customer_number': 'customer_number',
|
||||
'payment_line': 'payment_line',
|
||||
}
|
||||
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
|
||||
@@ -4,18 +4,19 @@ Admin Annotation API Routes
|
||||
FastAPI endpoints for annotation management.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.autolabel import get_auto_label_service
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
@@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Image storage directory
|
||||
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
@@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter:
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/images/{page_number}",
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {"content": {"image/png": {}}, "description": "Page image"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
@@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter:
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> FileResponse:
|
||||
) -> FileResponse | StreamingResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
@@ -91,18 +91,33 @@ def create_annotation_router() -> APIRouter:
|
||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||
)
|
||||
|
||||
# Find image file
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
||||
if not image_path.exists():
|
||||
# Get storage helper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Check if image exists
|
||||
if not storage.admin_image_exists(document_id, page_number):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Image for page {page_number} not found",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
# Try to get local path for efficient file serving
|
||||
local_path = storage.get_admin_image_local_path(document_id, page_number)
|
||||
if local_path is not None:
|
||||
return FileResponse(
|
||||
path=str(local_path),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
)
|
||||
|
||||
# Fall back to streaming for cloud storage
|
||||
image_content = storage.get_admin_image(document_id, page_number)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(image_content),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
@@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Get image dimensions for normalization
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
||||
if not image_path.exists():
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {request.page_number} not available",
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
@@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter:
|
||||
|
||||
if request.bbox is not None:
|
||||
# Get image dimensions
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {annotation.page_number} not available",
|
||||
)
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
|
||||
@@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
BoundingBox,
|
||||
DocumentCategoriesResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
DocumentListResponse,
|
||||
DocumentStatus,
|
||||
DocumentStatsResponse,
|
||||
DocumentUpdateRequest,
|
||||
DocumentUploadResponse,
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
@@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
@@ -60,8 +61,9 @@ def _convert_pdf_to_images(
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
@@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
str | None,
|
||||
Query(description="Optional group key for document organization", max_length=255),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
|
||||
] = "invoice",
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate group_key length
|
||||
@@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database
|
||||
# Update file path in database (using storage path for reference)
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
doc.file_path = str(file_path)
|
||||
# Store the storage path (relative path within storage)
|
||||
doc.file_path = storage_path
|
||||
session.add(doc)
|
||||
|
||||
# Convert PDF to images for annotation
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
@@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
category=category,
|
||||
group_key=group_key,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
@@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
str | None,
|
||||
Query(description="Filter by batch ID"),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by document category"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
@@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
has_annotations=has_annotations,
|
||||
auto_label_status=auto_label_status,
|
||||
batch_id=batch_id,
|
||||
category=category,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
@@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
category=doc.category if hasattr(doc, 'category') else "invoice",
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
@@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||
category=document.category if hasattr(document, 'category') else "invoice",
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
@@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete file
|
||||
file_path = Path(document.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
# Delete file using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Delete images
|
||||
images_dir = ADMIN_IMAGES_DIR / document_id
|
||||
if images_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(images_dir)
|
||||
# Delete the raw PDF
|
||||
filename = Path(document.file_path).name
|
||||
if filename:
|
||||
try:
|
||||
storage._storage.delete(document.file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete PDF file: {e}")
|
||||
|
||||
# Delete admin images
|
||||
try:
|
||||
storage.delete_admin_images(document_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete admin images: {e}")
|
||||
|
||||
# Delete from database
|
||||
db.delete_document(document_id)
|
||||
@@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
response_model=DocumentCategoriesResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get available categories",
|
||||
description="Get list of all available document categories.",
|
||||
)
|
||||
async def get_categories(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentCategoriesResponse:
|
||||
"""Get all available document categories."""
|
||||
categories = db.get_document_categories()
|
||||
return DocumentCategoriesResponse(
|
||||
categories=categories,
|
||||
total=len(categories),
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/category",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document category",
|
||||
description="Update the category for a document.",
|
||||
)
|
||||
async def update_document_category(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: DocumentUpdateRequest,
|
||||
) -> dict:
|
||||
"""Update document category."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update category if provided
|
||||
if request.category is not None:
|
||||
db.update_document_category(document_id, request.category)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"category": request.category,
|
||||
"message": "Document category updated",
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
@@ -17,6 +17,7 @@ from inference.web.schemas.admin import (
|
||||
TrainingStatus,
|
||||
TrainingTaskResponse,
|
||||
)
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
@@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
db: AdminDBDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Validate minimum document count for proper train/val/test split
|
||||
@@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
|
||||
# Get storage paths from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
datasets_dir = storage.get_datasets_base_path()
|
||||
admin_images_dir = storage.get_admin_images_base_path()
|
||||
|
||||
if datasets_dir is None or admin_images_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
@@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
admin_images_dir=Path("data/admin_images"),
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
status=dataset.status,
|
||||
training_status=dataset.training_status,
|
||||
active_training_task_id=(
|
||||
str(dataset.active_training_task_id)
|
||||
if dataset.active_training_task_id
|
||||
else None
|
||||
),
|
||||
train_ratio=dataset.train_ratio,
|
||||
val_ratio=dataset.val_ratio,
|
||||
seed=dataset.seed,
|
||||
|
||||
@@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images and exports directory
|
||||
storage = get_storage_helper()
|
||||
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
@@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
# Get exports directory from StorageHelper
|
||||
exports_base = storage.get_exports_base_path()
|
||||
if exports_base is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
@@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
@@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
|
||||
@@ -22,6 +22,7 @@ from inference.web.schemas.inference import (
|
||||
InferenceResult,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
@@ -90,8 +91,17 @@ def create_inference_router(
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Save uploaded file
|
||||
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
||||
# Get storage helper and uploads directory
|
||||
storage = get_storage_helper()
|
||||
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
|
||||
if uploads_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
# Save uploaded file to temporary location for processing
|
||||
upload_path = uploads_dir / f"{doc_id}{file_ext}"
|
||||
try:
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
@@ -149,12 +159,13 @@ def create_inference_router(
|
||||
# Cleanup uploaded file
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
@router.get("/results/{filename}", response_model=None)
|
||||
async def get_result_image(filename: str) -> FileResponse:
|
||||
"""Get visualization result image."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
storage = get_storage_helper()
|
||||
file_path = storage.get_result_local_path(filename)
|
||||
|
||||
if not file_path.exists():
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
@@ -169,15 +180,15 @@ def create_inference_router(
|
||||
@router.delete("/results/{filename}")
|
||||
async def delete_result(filename: str) -> dict:
|
||||
"""Delete a result file."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
storage = get_storage_helper()
|
||||
|
||||
if not file_path.exists():
|
||||
if not storage.result_exists(filename):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
file_path.unlink()
|
||||
storage.delete_result(filename)
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
return router
|
||||
|
||||
@@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.labeling import PreLabelResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
@@ -23,19 +24,14 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
||||
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
@@ -43,8 +39,9 @@ def _convert_pdf_to_images(
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
@@ -70,9 +67,6 @@ def create_labeling_router(
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||
|
||||
# Ensure upload directory exists
|
||||
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@router.post(
|
||||
"/pre-label",
|
||||
response_model=PreLabelResponse,
|
||||
@@ -165,10 +159,11 @@ def create_labeling_router(
|
||||
csv_field_values=expected_values,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(
|
||||
@@ -176,15 +171,14 @@ def create_labeling_router(
|
||||
detail="Failed to save file",
|
||||
)
|
||||
|
||||
# Update file path in database
|
||||
db.update_document_file_path(document_id, str(file_path))
|
||||
# Update file path in database (using storage path)
|
||||
db.update_document_file_path(document_id, storage_path)
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Public API imports
|
||||
from inference.web.api.v1.public import (
|
||||
@@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount static files for results
|
||||
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(config.storage.result_dir)),
|
||||
name="results",
|
||||
)
|
||||
# Mount static files for results using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir:
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(results_dir)),
|
||||
name="results",
|
||||
)
|
||||
else:
|
||||
logger.warning("Could not mount static results directory: local storage not available")
|
||||
|
||||
# Include public API routes
|
||||
inference_router = create_inference_router(inference_service, config.storage)
|
||||
|
||||
@@ -4,16 +4,49 @@ Web Application Configuration
|
||||
Centralized configuration for the web application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_storage_backend(
|
||||
config_path: Path | str | None = None,
|
||||
) -> "StorageBackend":
|
||||
"""Get storage backend for file operations.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to storage configuration file.
|
||||
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
from shared.storage import get_storage_backend as _get_storage_backend
|
||||
|
||||
# Check for config file path
|
||||
if config_path is None:
|
||||
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
|
||||
if config_path_str:
|
||||
config_path = Path(config_path_str)
|
||||
|
||||
return _get_storage_backend(config_path=config_path)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""YOLO model configuration."""
|
||||
"""YOLO model configuration.
|
||||
|
||||
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
|
||||
- Models need to be accessible by inference service on any platform
|
||||
- Models may be version-controlled or deployed separately
|
||||
- Models are part of the application, not user data
|
||||
"""
|
||||
|
||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||
confidence_threshold: float = 0.5
|
||||
@@ -33,24 +66,39 @@ class ServerConfig:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""File storage configuration.
|
||||
class FileConfig:
|
||||
"""File handling configuration.
|
||||
|
||||
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
||||
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
||||
and avoids storing duplicate files.
|
||||
This config holds file handling settings. For file operations,
|
||||
use the storage backend with PREFIXES from shared.storage.prefixes.
|
||||
|
||||
Example:
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
|
||||
storage = get_storage_backend()
|
||||
path = PREFIXES.document_path(document_id)
|
||||
storage.upload_bytes(content, path)
|
||||
|
||||
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
|
||||
They are kept for backward compatibility with existing code and tests.
|
||||
New code should use the storage backend with PREFIXES instead.
|
||||
"""
|
||||
|
||||
upload_dir: Path = Path("uploads")
|
||||
result_dir: Path = Path("results")
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
||||
admin_images_dir: Path = Path("data/admin_images")
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
dpi: int = DEFAULT_DPI
|
||||
presigned_url_expiry_seconds: int = 3600
|
||||
|
||||
# Deprecated path fields - kept for backward compatibility
|
||||
# New code should use storage backend with PREFIXES instead
|
||||
# All paths are now under data/ to match WSL storage layout
|
||||
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
|
||||
result_dir: Path = field(default_factory=lambda: Path("data/results"))
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
|
||||
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||
@@ -61,9 +109,17 @@ class StorageConfig:
|
||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
StorageConfig = FileConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncConfig:
|
||||
"""Async processing configuration."""
|
||||
"""Async processing configuration.
|
||||
|
||||
Note: For file paths, use the storage backend with PREFIXES.
|
||||
Example: PREFIXES.upload_path(filename, "async")
|
||||
"""
|
||||
|
||||
# Queue settings
|
||||
queue_max_size: int = 100
|
||||
@@ -77,14 +133,17 @@ class AsyncConfig:
|
||||
|
||||
# Storage
|
||||
result_retention_days: int = 7
|
||||
temp_upload_dir: Path = Path("uploads/async")
|
||||
max_file_size_mb: int = 50
|
||||
|
||||
# Deprecated: kept for backward compatibility
|
||||
# Path under data/ to match WSL storage layout
|
||||
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
|
||||
|
||||
# Cleanup
|
||||
cleanup_interval_hours: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -95,19 +154,41 @@ class AppConfig:
|
||||
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
file: FileConfig = field(default_factory=FileConfig)
|
||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||
storage_backend: "StorageBackend | None" = None
|
||||
|
||||
@property
|
||||
def storage(self) -> FileConfig:
|
||||
"""Backward compatibility alias for file config."""
|
||||
return self.file
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
"""Create config from dictionary."""
|
||||
file_config = config_dict.get("file", config_dict.get("storage", {}))
|
||||
return cls(
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
file=FileConfig(**file_config),
|
||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||
)
|
||||
|
||||
|
||||
def create_app_config(
|
||||
storage_config_path: Path | str | None = None,
|
||||
) -> AppConfig:
|
||||
"""Create application configuration with storage backend.
|
||||
|
||||
Args:
|
||||
storage_config_path: Optional path to storage configuration file.
|
||||
|
||||
Returns:
|
||||
Configured AppConfig instance with storage backend initialized.
|
||||
"""
|
||||
storage_backend = get_storage_backend(config_path=storage_config_path)
|
||||
return AppConfig(storage_backend=storage_backend)
|
||||
|
||||
|
||||
# Default configuration instance
|
||||
default_config = AppConfig()
|
||||
|
||||
@@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,13 @@ class AutoLabelScheduler:
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
if output_dir is None:
|
||||
storage = get_storage_helper()
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,6 +108,14 @@ class TrainingScheduler:
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||
|
||||
# Update dataset training status to running
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
@@ -192,6 +201,15 @@ class TrainingScheduler:
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Update dataset training status to completed and main status to trained
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
update_main_status=True, # Set main status to 'trained'
|
||||
)
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
@@ -203,6 +221,13 @@ class TrainingScheduler:
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
# Update dataset training status to failed
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="failed",
|
||||
active_training_task_id=None,
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_model_version_from_training(
|
||||
@@ -268,9 +293,10 @@ class TrainingScheduler:
|
||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
||||
f"Model version {version} created (mAP: {mAP_display})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -283,8 +309,11 @@ class TrainingScheduler:
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
@@ -293,8 +322,12 @@ class TrainingScheduler:
|
||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/training") / task_id
|
||||
# Create export directory using StorageHelper
|
||||
training_base = storage.get_training_data_path()
|
||||
if training_base is None:
|
||||
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
|
||||
return None
|
||||
export_dir = training_base / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
@@ -323,14 +356,16 @@ class TrainingScheduler:
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label
|
||||
@@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())}
|
||||
self._db.add_training_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
|
||||
# because models need to be accessible by inference service on any platform
|
||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||
config = SharedTrainingConfig(
|
||||
model_path=model_name,
|
||||
|
||||
@@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||
category: str | None = Field(None, description="Filter documents by category (optional)")
|
||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||
seed: int = Field(42, description="Random seed for split")
|
||||
@@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel):
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
|
||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
@@ -44,6 +45,7 @@ class DocumentItem(BaseModel):
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
@@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel):
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
@@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel):
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
|
||||
|
||||
class DocumentUpdateRequest(BaseModel):
|
||||
"""Request for updating document metadata."""
|
||||
|
||||
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
|
||||
|
||||
class DocumentCategoriesResponse(BaseModel):
|
||||
"""Response for available document categories."""
|
||||
|
||||
categories: list[str] = Field(..., description="List of available categories")
|
||||
total: int = Field(..., ge=0, description="Total number of categories")
|
||||
|
||||
@@ -5,6 +5,7 @@ Manages async request lifecycle and background processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
@@ -17,6 +18,7 @@ from typing import TYPE_CHECKING
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
@@ -189,9 +191,7 @@ class AsyncProcessingService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> Path:
|
||||
"""Save uploaded file to temp storage."""
|
||||
import re
|
||||
|
||||
"""Save uploaded file to temp storage using StorageHelper."""
|
||||
# Extract extension from filename
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
@@ -203,9 +203,11 @@ class AsyncProcessingService:
|
||||
if ext not in self.ALLOWED_EXTENSIONS:
|
||||
ext = ".pdf"
|
||||
|
||||
# Create async upload directory
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Get upload directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
|
||||
# Build file path - request_id is a UUID so it's safe
|
||||
file_path = upload_dir / f"{request_id}{ext}"
|
||||
@@ -355,8 +357,9 @@ class AsyncProcessingService:
|
||||
|
||||
def _cleanup_orphan_files(self) -> int:
|
||||
"""Clean up upload files that don't have matching requests."""
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
if not upload_dir.exists():
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None or not upload_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
|
||||
@@ -13,7 +13,7 @@ from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@ from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
from inference.data.admin_models import AdminDocument
|
||||
from shared.data.db import DocumentDB
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,8 +123,12 @@ def process_document_autolabel(
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mark as processing
|
||||
@@ -152,10 +157,12 @@ def process_document_autolabel(
|
||||
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
||||
|
||||
# Build task data
|
||||
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
|
||||
# Use raw_pdfs base path for pdf_path
|
||||
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
||||
storage_config = StorageConfig()
|
||||
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
task_data = {
|
||||
"row_dict": row_dict,
|
||||
@@ -246,8 +253,8 @@ def _save_annotations_to_db(
|
||||
Returns:
|
||||
Number of annotations saved
|
||||
"""
|
||||
from PIL import Image
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS
|
||||
from shared.fields import FIELD_CLASS_IDS
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Mapping from CSV field names to internal field names
|
||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||
@@ -266,6 +273,9 @@ def _save_annotations_to_db(
|
||||
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
||||
scale = dpi / 72.0
|
||||
|
||||
# Get storage helper for image dimensions
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Cache for image dimensions per page
|
||||
image_dimensions: dict[int, tuple[int, int]] = {}
|
||||
|
||||
@@ -274,18 +284,11 @@ def _save_annotations_to_db(
|
||||
if page_no in image_dimensions:
|
||||
return image_dimensions[page_no]
|
||||
|
||||
# Try to load from admin_images
|
||||
admin_images_dir = Path("data/admin_images") / document_id
|
||||
image_path = admin_images_dir / f"page_{page_no}.png"
|
||||
|
||||
if image_path.exists():
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
dims = img.size # (width, height)
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
|
||||
# Get dimensions from storage helper
|
||||
dims = storage.get_admin_image_dimensions(document_id, page_no)
|
||||
if dims:
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
|
||||
return None
|
||||
|
||||
@@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db(
|
||||
from datetime import datetime
|
||||
|
||||
document_id = str(document.document_id)
|
||||
storage_config = StorageConfig()
|
||||
|
||||
# Build pdf_path using admin_upload_dir (same as auto-label)
|
||||
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
# Build pdf_path using raw_pdfs base path (same as auto-label)
|
||||
storage = get_storage_helper()
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
return {
|
||||
"success": False,
|
||||
"document_id": document_id,
|
||||
"error": "Storage not configured for local access",
|
||||
}
|
||||
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
# Build report dict compatible with DocumentDB.save_document()
|
||||
field_results = []
|
||||
|
||||
217
packages/inference/inference/web/services/document_service.py
Normal file
217
packages/inference/inference/web/services/document_service.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Document Service for storage-backed file operations.
|
||||
|
||||
Provides a unified interface for document upload, download, and serving
|
||||
using the storage abstraction layer.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentResult:
|
||||
"""Result of document operation."""
|
||||
|
||||
id: str
|
||||
file_path: str
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document file operations using storage backend.
|
||||
|
||||
Provides upload, download, and URL generation for documents and images.
|
||||
"""
|
||||
|
||||
# Storage path prefixes
|
||||
DOCUMENTS_PREFIX = "documents"
|
||||
IMAGES_PREFIX = "images"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_backend: "StorageBackend",
|
||||
admin_db: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize document service.
|
||||
|
||||
Args:
|
||||
storage_backend: Storage backend for file operations.
|
||||
admin_db: Optional AdminDB instance for database operations.
|
||||
"""
|
||||
self._storage = storage_backend
|
||||
self._admin_db = admin_db
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
dataset_id: str | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> DocumentResult:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename.
|
||||
dataset_id: Optional dataset ID for organization.
|
||||
document_id: Optional document ID (generated if not provided).
|
||||
|
||||
Returns:
|
||||
DocumentResult with ID and storage path.
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
# Extract extension from filename
|
||||
ext = ""
|
||||
if "." in filename:
|
||||
ext = "." + filename.rsplit(".", 1)[-1].lower()
|
||||
|
||||
# Build logical path
|
||||
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
|
||||
|
||||
# Upload via storage backend
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
|
||||
return DocumentResult(
|
||||
id=document_id,
|
||||
file_path=remote_path,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
def download_document(self, remote_path: str) -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a document.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for document access.
|
||||
"""
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, remote_path: str) -> bool:
|
||||
"""Check if a document exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
return self._storage.exists(remote_path)
|
||||
|
||||
def delete_document_files(self, remote_path: str) -> bool:
|
||||
"""Delete a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
return self._storage.delete(remote_path)
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Logical path where image was stored.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
return remote_path
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for image access.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
image_paths = self._storage.list_files(prefix)
|
||||
|
||||
deleted_count = 0
|
||||
for path in image_paths:
|
||||
if self._storage.delete(path):
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
@@ -303,12 +305,19 @@ class InferenceService:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
@@ -320,19 +329,26 @@ class InferenceService:
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
|
||||
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
Storage helpers for web services.
|
||||
|
||||
Provides convenience functions for common storage operations,
|
||||
wrapping the storage backend with proper path handling using prefixes.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_default_storage() -> "StorageBackend":
|
||||
"""Get the default storage backend.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
return get_storage_backend()
|
||||
|
||||
|
||||
class StorageHelper:
|
||||
"""Helper class for storage operations with prefixes.
|
||||
|
||||
Provides high-level operations for document storage, including
|
||||
upload, download, and URL generation with proper path prefixes.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: "StorageBackend | None" = None) -> None:
|
||||
"""Initialize storage helper.
|
||||
|
||||
Args:
|
||||
storage: Storage backend to use. If None, creates default.
|
||||
"""
|
||||
self._storage = storage or get_default_storage()
|
||||
|
||||
@property
|
||||
def storage(self) -> "StorageBackend":
|
||||
"""Get the underlying storage backend."""
|
||||
return self._storage
|
||||
|
||||
# Document operations
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
document_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename (used for extension).
|
||||
document_id: Optional document ID. Generated if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple of (document_id, storage_path).
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
ext = Path(filename).suffix.lower() or ".pdf"
|
||||
path = PREFIXES.document_path(document_id, ext)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
|
||||
return document_id, path
|
||||
|
||||
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
document_id: str,
|
||||
extension: str = ".pdf",
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Check if a document exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Image operations
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
# Upload staging operations
|
||||
|
||||
def save_upload(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
subfolder: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file to upload staging area.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
|
||||
"""Get a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Delete a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Result operations
|
||||
|
||||
def save_result(self, content: bytes, filename: str) -> str:
|
||||
"""Save a result file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_result(self, filename: str) -> bytes:
|
||||
"""Get a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
|
||||
"""Get presigned URL for a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def result_exists(self, filename: str) -> bool:
|
||||
"""Check if a result file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_result(self, filename: str) -> bool:
|
||||
"""Delete a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Export operations
|
||||
|
||||
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
|
||||
"""Save an export file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
export_id: Export identifier.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_export_url(
|
||||
self,
|
||||
export_id: str,
|
||||
filename: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an export file.
|
||||
|
||||
Args:
|
||||
export_id: Export identifier.
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
# Admin image operations
|
||||
|
||||
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
|
||||
"""Get the storage path for an admin image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Storage path like "admin_images/doc123/page_1.png"
|
||||
"""
|
||||
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
|
||||
|
||||
def save_admin_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save an admin page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_admin_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
|
||||
"""Check if an admin page image exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
True if image exists.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def list_admin_images(self, document_id: str) -> list[str]:
|
||||
"""List all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
def delete_admin_images(self, document_id: str) -> int:
|
||||
"""Delete all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def get_admin_image_local_path(
|
||||
self, document_id: str, page_num: int
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an admin image.
|
||||
|
||||
This method is useful for serving files via FileResponse.
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
# Cloud storage - cannot get local path
|
||||
return None
|
||||
|
||||
remote_path = self.get_admin_image_path(document_id, page_num)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(remote_path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_image_dimensions(
|
||||
self, document_id: str, page_num: int
|
||||
) -> tuple[int, int] | None:
|
||||
"""Get the dimensions (width, height) of an admin image.
|
||||
|
||||
This method is useful for normalizing bounding box coordinates.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if image exists, None otherwise.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
# Try local path first for efficiency
|
||||
local_path = self.get_admin_image_local_path(document_id, page_num)
|
||||
if local_path is not None:
|
||||
with Image.open(local_path) as img:
|
||||
return img.size
|
||||
|
||||
# Fall back to downloading for cloud storage
|
||||
if not self.admin_image_exists(document_id, page_num):
|
||||
return None
|
||||
|
||||
try:
|
||||
import io
|
||||
image_bytes = self.get_admin_image(document_id, page_num)
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return img.size
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Raw PDF operations (legacy compatibility)
|
||||
|
||||
def save_raw_pdf(self, content: bytes, filename: str) -> str:
|
||||
"""Save a raw PDF for auto-labeling pipeline.
|
||||
|
||||
Args:
|
||||
content: PDF content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_raw_pdf(self, filename: str) -> bytes:
|
||||
"""Get a raw PDF from storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
PDF content as bytes.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def raw_pdf_exists(self, filename: str) -> bool:
|
||||
"""Check if a raw PDF exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.exists(path)
|
||||
|
||||
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a raw PDF.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdf_path(self, filename: str) -> str:
|
||||
"""Get the storage path for a raw PDF (not the local filesystem path).
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "raw_pdfs/filename.pdf"
|
||||
"""
|
||||
return f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
|
||||
# Result local path operations
|
||||
|
||||
def get_result_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a result file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.result_path(filename)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_results_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for results (local storage only).
|
||||
|
||||
Used for mounting static file directories.
|
||||
|
||||
Returns:
|
||||
Path to results directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Upload local path operations
|
||||
|
||||
def get_upload_local_path(
|
||||
self, filename: str, subfolder: str | None = None
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an upload file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
|
||||
"""Get the base directory path for uploads (local storage only).
|
||||
|
||||
Args:
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Path to uploads directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
if subfolder:
|
||||
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
|
||||
else:
|
||||
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Check if an upload file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.exists(path)
|
||||
|
||||
# Dataset operations
|
||||
|
||||
def get_datasets_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for datasets (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to datasets directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_images_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for admin images (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to admin_images directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdfs_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for raw PDFs (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to raw_pdfs directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_autolabel_output_path(self) -> Path | None:
|
||||
"""Get the directory path for autolabel output (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to autolabel_output directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use a subfolder under results for autolabel output
|
||||
base_path = self._storage._get_full_path("autolabel_output")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_training_data_path(self) -> Path | None:
|
||||
"""Get the directory path for training data exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to training directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path("training")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_exports_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to exports directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Default instance for convenience
|
||||
_default_helper: StorageHelper | None = None
|
||||
|
||||
|
||||
def get_storage_helper() -> StorageHelper:
|
||||
"""Get the default storage helper instance.
|
||||
|
||||
Creates the helper on first call with default storage backend.
|
||||
|
||||
Returns:
|
||||
Default StorageHelper instance.
|
||||
"""
|
||||
global _default_helper
|
||||
if _default_helper is None:
|
||||
_default_helper = StorageHelper()
|
||||
return _default_helper
|
||||
205
packages/shared/README.md
Normal file
205
packages/shared/README.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# Shared Package
|
||||
|
||||
Shared utilities and abstractions for the Invoice Master system.
|
||||
|
||||
## Storage Abstraction Layer
|
||||
|
||||
A unified storage abstraction supporting multiple backends:
|
||||
- **Local filesystem** - Development and testing
|
||||
- **Azure Blob Storage** - Azure cloud deployments
|
||||
- **AWS S3** - AWS cloud deployments
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation (local storage only)
|
||||
pip install -e packages/shared
|
||||
|
||||
# With Azure support
|
||||
pip install -e "packages/shared[azure]"
|
||||
|
||||
# With S3 support
|
||||
pip install -e "packages/shared[s3]"
|
||||
|
||||
# All cloud providers
|
||||
pip install -e "packages/shared[all]"
|
||||
```
|
||||
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
from shared.storage import get_storage_backend
|
||||
|
||||
# Option 1: From configuration file
|
||||
storage = get_storage_backend("storage.yaml")
|
||||
|
||||
# Option 2: From environment variables
|
||||
from shared.storage import create_storage_backend_from_env
|
||||
storage = create_storage_backend_from_env()
|
||||
|
||||
# Upload a file
|
||||
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
|
||||
|
||||
# Download a file
|
||||
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
|
||||
|
||||
# Get pre-signed URL for frontend access
|
||||
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
|
||||
```
|
||||
|
||||
### Configuration File Format
|
||||
|
||||
Create a `storage.yaml` file with environment variable substitution support:
|
||||
|
||||
```yaml
|
||||
# Backend selection: local, azure_blob, or s3
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
|
||||
# Default pre-signed URL expiry (seconds)
|
||||
presigned_url_expiry: 3600
|
||||
|
||||
# Local storage configuration
|
||||
local:
|
||||
base_path: ${STORAGE_BASE_PATH:-./data/storage}
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
azure:
|
||||
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
|
||||
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
|
||||
create_container: false
|
||||
|
||||
# AWS S3 configuration
|
||||
s3:
|
||||
bucket_name: ${AWS_S3_BUCKET}
|
||||
region_name: ${AWS_REGION:-us-east-1}
|
||||
access_key_id: ${AWS_ACCESS_KEY_ID}
|
||||
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
|
||||
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
|
||||
create_bucket: false
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Backend | Description |
|
||||
|----------|---------|-------------|
|
||||
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
|
||||
| `STORAGE_BASE_PATH` | Local | Base directory path |
|
||||
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
|
||||
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
|
||||
| `AWS_S3_BUCKET` | S3 | Bucket name |
|
||||
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
|
||||
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
|
||||
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
|
||||
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
|
||||
|
||||
### API Reference
|
||||
|
||||
#### StorageBackend Interface
|
||||
|
||||
```python
|
||||
class StorageBackend(ABC):
|
||||
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||
"""Upload a file to storage."""
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from storage."""
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists."""
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files with given prefix."""
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file."""
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get URL for a file."""
|
||||
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
|
||||
|
||||
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
|
||||
"""Upload bytes directly."""
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download file as bytes."""
|
||||
```
|
||||
|
||||
#### Factory Functions
|
||||
|
||||
```python
|
||||
# Create from configuration file
|
||||
storage = create_storage_backend_from_file("storage.yaml")
|
||||
|
||||
# Create from environment variables
|
||||
storage = create_storage_backend_from_env()
|
||||
|
||||
# Create from StorageConfig object
|
||||
config = StorageConfig(backend_type="local", base_path=Path("./data"))
|
||||
storage = create_storage_backend(config)
|
||||
|
||||
# Convenience function with fallback chain: config file -> env vars -> local default
|
||||
storage = get_storage_backend("storage.yaml") # or None for env-only
|
||||
```
|
||||
|
||||
### Pre-signed URLs
|
||||
|
||||
Pre-signed URLs provide temporary access to files without exposing credentials:
|
||||
|
||||
```python
|
||||
# Generate URL valid for 1 hour (default)
|
||||
url = storage.get_presigned_url("documents/invoice.pdf")
|
||||
|
||||
# Generate URL valid for 24 hours
|
||||
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
|
||||
|
||||
# Maximum expiry: 7 days (604800 seconds)
|
||||
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
|
||||
```
|
||||
|
||||
**Note:** Local storage returns `file://` URLs that don't actually expire.
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
from shared.storage import (
|
||||
StorageError,
|
||||
FileNotFoundStorageError,
|
||||
PresignedUrlNotSupportedError,
|
||||
)
|
||||
|
||||
try:
|
||||
storage.download("nonexistent.pdf", Path("local.pdf"))
|
||||
except FileNotFoundStorageError as e:
|
||||
print(f"File not found: {e}")
|
||||
except StorageError as e:
|
||||
print(f"Storage error: {e}")
|
||||
```
|
||||
|
||||
### Testing with MinIO (S3-compatible)
|
||||
|
||||
```bash
|
||||
# Start MinIO locally
|
||||
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
|
||||
|
||||
# Configure environment
|
||||
export STORAGE_BACKEND=s3
|
||||
export AWS_S3_BUCKET=test-bucket
|
||||
export AWS_ENDPOINT_URL=http://localhost:9000
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
```
|
||||
|
||||
### Module Structure
|
||||
|
||||
```
|
||||
shared/storage/
|
||||
├── __init__.py # Public exports
|
||||
├── base.py # Abstract interface and exceptions
|
||||
├── local.py # Local filesystem backend
|
||||
├── azure.py # Azure Blob Storage backend
|
||||
├── s3.py # AWS S3 backend
|
||||
├── config_loader.py # YAML configuration loader
|
||||
└── factory.py # Backend factory functions
|
||||
```
|
||||
@@ -16,4 +16,18 @@ setup(
|
||||
"pyyaml>=6.0",
|
||||
"thefuzz>=0.20.0",
|
||||
],
|
||||
extras_require={
|
||||
"azure": [
|
||||
"azure-storage-blob>=12.19.0",
|
||||
"azure-identity>=1.15.0",
|
||||
],
|
||||
"s3": [
|
||||
"boto3>=1.34.0",
|
||||
],
|
||||
"all": [
|
||||
"azure-storage-blob>=12.19.0",
|
||||
"azure-identity>=1.15.0",
|
||||
"boto3>=1.34.0",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -58,23 +58,16 @@ def get_db_connection_string():
|
||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||
|
||||
|
||||
# Paths Configuration - auto-detect WSL vs Windows
|
||||
if _is_wsl():
|
||||
# WSL: use native Linux filesystem for better I/O performance
|
||||
PATHS = {
|
||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
else:
|
||||
# Windows or native Linux: use relative paths
|
||||
PATHS = {
|
||||
'csv_dir': 'data/structured_data',
|
||||
'pdf_dir': 'data/raw_pdfs',
|
||||
'output_dir': 'data/dataset',
|
||||
'reports_dir': 'reports',
|
||||
}
|
||||
# Paths Configuration - uses STORAGE_BASE_PATH for consistency
|
||||
# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
|
||||
_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
|
||||
|
||||
PATHS = {
|
||||
'csv_dir': f'{_storage_base}/structured_data',
|
||||
'pdf_dir': f'{_storage_base}/raw_pdfs',
|
||||
'output_dir': f'{_storage_base}/datasets',
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
|
||||
46
packages/shared/shared/fields/__init__.py
Normal file
46
packages/shared/shared/fields/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Shared Field Definitions - Single Source of Truth.
|
||||
|
||||
This module provides centralized field class definitions used throughout
|
||||
the invoice extraction system. All field mappings are derived from
|
||||
FIELD_DEFINITIONS to ensure consistency.
|
||||
|
||||
Usage:
|
||||
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
Available exports:
|
||||
- FieldDefinition: Dataclass for field definition
|
||||
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
|
||||
- NUM_CLASSES: Total number of field classes (10)
|
||||
- CLASS_NAMES: List of class names in order [0..9]
|
||||
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||
"""
|
||||
|
||||
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
|
||||
from .mappings import (
|
||||
CLASS_NAMES,
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FieldDefinition",
|
||||
"FIELD_DEFINITIONS",
|
||||
"NUM_CLASSES",
|
||||
"CLASS_NAMES",
|
||||
"FIELD_CLASSES",
|
||||
"FIELD_CLASS_IDS",
|
||||
"CLASS_TO_FIELD",
|
||||
"CSV_TO_CLASS_MAPPING",
|
||||
"TRAINING_FIELD_CLASSES",
|
||||
"ACCOUNT_FIELD_MAPPING",
|
||||
]
|
||||
58
packages/shared/shared/fields/field_config.py
Normal file
58
packages/shared/shared/fields/field_config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Field Configuration - Single Source of Truth
|
||||
|
||||
This module defines all invoice field classes used throughout the system.
|
||||
The class IDs are verified against the trained YOLO model (best.pt).
|
||||
|
||||
IMPORTANT: Do not modify class_id values without retraining the model!
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FieldDefinition:
|
||||
"""Immutable field definition for invoice extraction.
|
||||
|
||||
Attributes:
|
||||
class_id: YOLO class ID (0-9), must match trained model
|
||||
class_name: YOLO class name (lowercase_underscore)
|
||||
field_name: Business field name used in API responses
|
||||
csv_name: CSV column name for data import/export
|
||||
is_derived: True if field is derived from other fields (not in CSV)
|
||||
"""
|
||||
|
||||
class_id: int
|
||||
class_name: str
|
||||
field_name: str
|
||||
csv_name: str
|
||||
is_derived: bool = False
|
||||
|
||||
|
||||
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
|
||||
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
|
||||
#
|
||||
# DO NOT CHANGE THE ORDER - it must match the trained model!
|
||||
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
|
||||
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
|
||||
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
|
||||
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
|
||||
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
|
||||
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
|
||||
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
|
||||
FieldDefinition(6, "amount", "Amount", "Amount"),
|
||||
FieldDefinition(
|
||||
7,
|
||||
"supplier_org_number",
|
||||
"supplier_organisation_number",
|
||||
"supplier_organisation_number",
|
||||
),
|
||||
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
|
||||
FieldDefinition(
|
||||
9, "payment_line", "payment_line", "payment_line", is_derived=True
|
||||
),
|
||||
)
|
||||
|
||||
# Total number of field classes
|
||||
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)
|
||||
57
packages/shared/shared/fields/mappings.py
Normal file
57
packages/shared/shared/fields/mappings.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
|
||||
|
||||
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
|
||||
This ensures consistency across the entire codebase.
|
||||
|
||||
DO NOT hardcode field mappings elsewhere - always import from this module.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
|
||||
from .field_config import FIELD_DEFINITIONS
|
||||
|
||||
|
||||
# List of class names in order (for YOLO classes.txt generation)
|
||||
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
|
||||
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||
|
||||
# class_id -> class_name mapping
|
||||
# Example: {0: "invoice_number", 1: "invoice_date", ...}
|
||||
FIELD_CLASSES: Final[dict[int, str]] = {
|
||||
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
|
||||
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
|
||||
FIELD_CLASS_IDS: Final[dict[str, int]] = {
|
||||
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# class_name -> field_name mapping (for API responses)
|
||||
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
|
||||
CLASS_TO_FIELD: Final[dict[str, str]] = {
|
||||
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# field_name -> class_id mapping (for CSV import)
|
||||
# Excludes derived fields like payment_line
|
||||
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
|
||||
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
|
||||
}
|
||||
|
||||
# field_name -> class_id mapping (for training, includes all fields)
|
||||
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
|
||||
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# Account field mapping for supplier_accounts special handling
|
||||
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||
"supplier_accounts": {
|
||||
"BG": "Bankgiro",
|
||||
"PG": "Plusgiro",
|
||||
}
|
||||
}
|
||||
59
packages/shared/shared/storage/__init__.py
Normal file
59
packages/shared/shared/storage/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Storage abstraction layer for training data.
|
||||
|
||||
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
|
||||
"""
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
PresignedUrlNotSupportedError,
|
||||
StorageBackend,
|
||||
StorageConfig,
|
||||
StorageError,
|
||||
)
|
||||
from shared.storage.factory import (
|
||||
create_storage_backend,
|
||||
create_storage_backend_from_env,
|
||||
create_storage_backend_from_file,
|
||||
get_default_storage_config,
|
||||
get_storage_backend,
|
||||
)
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||
|
||||
__all__ = [
|
||||
# Base classes and exceptions
|
||||
"StorageBackend",
|
||||
"StorageConfig",
|
||||
"StorageError",
|
||||
"FileNotFoundStorageError",
|
||||
"PresignedUrlNotSupportedError",
|
||||
# Backends
|
||||
"LocalStorageBackend",
|
||||
# Factory functions
|
||||
"create_storage_backend",
|
||||
"create_storage_backend_from_env",
|
||||
"create_storage_backend_from_file",
|
||||
"get_default_storage_config",
|
||||
"get_storage_backend",
|
||||
# Path prefixes
|
||||
"PREFIXES",
|
||||
"StoragePrefixes",
|
||||
]
|
||||
|
||||
|
||||
# Lazy imports to avoid dependencies when not using specific backends
|
||||
def __getattr__(name: str):
|
||||
if name == "AzureBlobStorageBackend":
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend
|
||||
if name == "S3StorageBackend":
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend
|
||||
if name == "load_storage_config":
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
return load_storage_config
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
335
packages/shared/shared/storage/azure.py
Normal file
335
packages/shared/shared/storage/azure.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Azure Blob Storage backend.
|
||||
|
||||
Provides storage operations using Azure Blob Storage.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from azure.storage.blob import (
|
||||
BlobSasPermissions,
|
||||
BlobServiceClient,
|
||||
ContainerClient,
|
||||
generate_blob_sas,
|
||||
)
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class AzureBlobStorageBackend(StorageBackend):
|
||||
"""Storage backend using Azure Blob Storage.
|
||||
|
||||
Files are stored as blobs in an Azure Blob Storage container.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
container_name: str,
|
||||
create_container: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Azure Blob Storage backend.
|
||||
|
||||
Args:
|
||||
connection_string: Azure Storage connection string.
|
||||
container_name: Name of the blob container.
|
||||
create_container: If True, create the container if it doesn't exist.
|
||||
"""
|
||||
self._connection_string = connection_string
|
||||
self._container_name = container_name
|
||||
|
||||
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
|
||||
self._container = self._blob_service.get_container_client(container_name)
|
||||
|
||||
# Extract account key from connection string for SAS token generation
|
||||
self._account_key = self._extract_account_key(connection_string)
|
||||
|
||||
if create_container and not self._container.exists():
|
||||
self._container.create_container()
|
||||
|
||||
@staticmethod
|
||||
def _extract_account_key(connection_string: str) -> str | None:
|
||||
"""Extract account key from connection string.
|
||||
|
||||
Args:
|
||||
connection_string: Azure Storage connection string.
|
||||
|
||||
Returns:
|
||||
Account key if found, None otherwise.
|
||||
"""
|
||||
for part in connection_string.split(";"):
|
||||
if part.startswith("AccountKey="):
|
||||
return part[len("AccountKey=") :]
|
||||
return None
|
||||
|
||||
@property
|
||||
def container_name(self) -> str:
|
||||
"""Get the container name for this storage backend."""
|
||||
return self._container_name
|
||||
|
||||
@property
|
||||
def container_client(self) -> ContainerClient:
|
||||
"""Get the Azure container client."""
|
||||
return self._container
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination blob path.
|
||||
overwrite: If True, overwrite existing blob.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If blob exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if blob_client.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
blob_client.upload_blob(f, overwrite=overwrite)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a blob from Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stream = blob_client.download_blob()
|
||||
local_path.write_bytes(stream.readall())
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a blob exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path to check.
|
||||
|
||||
Returns:
|
||||
True if the blob exists, False otherwise.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
return blob_client.exists()
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List blobs in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Blob path prefix to filter.
|
||||
|
||||
Returns:
|
||||
List of blob paths matching the prefix.
|
||||
"""
|
||||
if prefix:
|
||||
blobs = self._container.list_blobs(name_starts_with=prefix)
|
||||
else:
|
||||
blobs = self._container.list_blobs()
|
||||
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a blob from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path to delete.
|
||||
|
||||
Returns:
|
||||
True if blob was deleted, False if it didn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
return False
|
||||
|
||||
blob_client.delete_blob()
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get the URL for a blob.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
|
||||
Returns:
|
||||
URL to access the blob.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return blob_client.url
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination blob path.
|
||||
overwrite: If True, overwrite existing blob.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if blob_client.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
blob_client.upload_blob(data, overwrite=overwrite)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a blob as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
|
||||
Returns:
|
||||
The blob contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
stream = blob_client.download_blob()
|
||||
return stream.readall()
|
||||
|
||||
def upload_directory(
|
||||
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
|
||||
) -> list[str]:
|
||||
"""Upload all files in a directory to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory to upload.
|
||||
remote_prefix: Prefix for remote blob paths.
|
||||
overwrite: If True, overwrite existing blobs.
|
||||
|
||||
Returns:
|
||||
List of remote paths where files were stored.
|
||||
"""
|
||||
uploaded: list[str] = []
|
||||
|
||||
for file_path in local_dir.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(local_dir)
|
||||
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
|
||||
self.upload(file_path, remote_path, overwrite=overwrite)
|
||||
uploaded.append(remote_path)
|
||||
|
||||
return uploaded
|
||||
|
||||
def download_directory(
|
||||
self, remote_prefix: str, local_dir: Path
|
||||
) -> list[Path]:
|
||||
"""Download all blobs with a prefix to a local directory.
|
||||
|
||||
Args:
|
||||
remote_prefix: Blob path prefix to download.
|
||||
local_dir: Local directory to download to.
|
||||
|
||||
Returns:
|
||||
List of local paths where files were downloaded.
|
||||
"""
|
||||
downloaded: list[Path] = []
|
||||
|
||||
blobs = self.list_files(remote_prefix)
|
||||
|
||||
for blob_path in blobs:
|
||||
# Remove prefix to get relative path
|
||||
if remote_prefix:
|
||||
relative_path = blob_path[len(remote_prefix):]
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
else:
|
||||
relative_path = blob_path
|
||||
|
||||
local_path = local_dir / relative_path
|
||||
self.download(blob_path, local_path)
|
||||
downloaded.append(local_path)
|
||||
|
||||
return downloaded
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a SAS URL for temporary blob access.
|
||||
|
||||
Args:
|
||||
remote_path: Blob path in storage.
|
||||
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
|
||||
|
||||
Returns:
|
||||
Blob URL with SAS token.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
ValueError: If expires_in_seconds is out of valid range.
|
||||
"""
|
||||
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||
raise ValueError(
|
||||
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||
)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
blob_client = self._container.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
# Generate SAS token
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=self._blob_service.account_name,
|
||||
container_name=self._container_name,
|
||||
blob_name=remote_path,
|
||||
account_key=self._account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
|
||||
)
|
||||
|
||||
return f"{blob_client.url}?{sas_token}"
|
||||
229
packages/shared/shared/storage/base.py
Normal file
229
packages/shared/shared/storage/base.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Base classes and interfaces for storage backends.
|
||||
|
||||
Defines the abstract StorageBackend interface and common exceptions.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StorageError(Exception):
|
||||
"""Base exception for storage operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileNotFoundStorageError(StorageError):
|
||||
"""Raised when a file is not found in storage."""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
super().__init__(f"File not found in storage: {path}")
|
||||
|
||||
|
||||
class PresignedUrlNotSupportedError(StorageError):
|
||||
"""Raised when pre-signed URLs are not supported by a backend."""
|
||||
|
||||
def __init__(self, backend_type: str) -> None:
|
||||
self.backend_type = backend_type
|
||||
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""Configuration for storage backend.
|
||||
|
||||
Attributes:
|
||||
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
|
||||
connection_string: Azure Blob Storage connection string (for azure_blob).
|
||||
container_name: Azure Blob Storage container name (for azure_blob).
|
||||
base_path: Base path for local storage (for local).
|
||||
bucket_name: S3 bucket name (for s3).
|
||||
region_name: AWS region name (for s3).
|
||||
access_key_id: AWS access key ID (for s3).
|
||||
secret_access_key: AWS secret access key (for s3).
|
||||
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
|
||||
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||
"""
|
||||
|
||||
backend_type: str
|
||||
connection_string: str | None = None
|
||||
container_name: str | None = None
|
||||
base_path: Path | None = None
|
||||
bucket_name: str | None = None
|
||||
region_name: str | None = None
|
||||
access_key_id: str | None = None
|
||||
secret_access_key: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
presigned_url_expiry: int = 3600
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Abstract base class for storage backends.
|
||||
|
||||
Provides a unified interface for storing and retrieving files
|
||||
from different storage systems (local filesystem, Azure Blob, etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If file exists and overwrite is False.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to check in storage.
|
||||
|
||||
Returns:
|
||||
True if the file exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Path prefix to filter files.
|
||||
|
||||
Returns:
|
||||
List of file paths matching the prefix.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if it didn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a URL or path to access a file.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
URL or path to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a pre-signed URL for temporary access.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
expires_in_seconds: URL validity duration (default 1 hour).
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
|
||||
"""
|
||||
pass
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to storage.
|
||||
|
||||
Default implementation writes to temp file then uploads.
|
||||
Subclasses may override for more efficient implementation.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(data)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
return self.upload(temp_path, remote_path, overwrite=overwrite)
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a file as bytes.
|
||||
|
||||
Default implementation downloads to temp file then reads.
|
||||
Subclasses may override for more efficient implementation.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
The file contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
self.download(remote_path, temp_path)
|
||||
return temp_path.read_bytes()
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
242
packages/shared/shared/storage/config_loader.py
Normal file
242
packages/shared/shared/storage/config_loader.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Configuration file loader for storage backends.
|
||||
|
||||
Supports YAML configuration files with environment variable substitution.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocalConfig:
|
||||
"""Local storage backend configuration."""
|
||||
|
||||
base_path: Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureConfig:
|
||||
"""Azure Blob Storage configuration."""
|
||||
|
||||
connection_string: str
|
||||
container_name: str
|
||||
create_container: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class S3Config:
|
||||
"""AWS S3 configuration."""
|
||||
|
||||
bucket_name: str
|
||||
region_name: str | None = None
|
||||
access_key_id: str | None = None
|
||||
secret_access_key: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
create_bucket: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageFileConfig:
|
||||
"""Extended storage configuration from file.
|
||||
|
||||
Attributes:
|
||||
backend_type: Type of storage backend.
|
||||
local: Local backend configuration.
|
||||
azure: Azure Blob configuration.
|
||||
s3: S3 configuration.
|
||||
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||
"""
|
||||
|
||||
backend_type: str
|
||||
local: LocalConfig | None = None
|
||||
azure: AzureConfig | None = None
|
||||
s3: S3Config | None = None
|
||||
presigned_url_expiry: int = 3600
|
||||
|
||||
|
||||
def substitute_env_vars(value: str) -> str:
|
||||
"""Substitute environment variables in a string.
|
||||
|
||||
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
|
||||
|
||||
Args:
|
||||
value: String potentially containing env var references.
|
||||
|
||||
Returns:
|
||||
String with env vars substituted.
|
||||
"""
|
||||
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
var_name = match.group(1)
|
||||
default = match.group(2)
|
||||
return os.environ.get(var_name, default or "")
|
||||
|
||||
return re.sub(pattern, replace, value)
|
||||
|
||||
|
||||
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively substitute env vars in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to process.
|
||||
|
||||
Returns:
|
||||
New dictionary with substitutions applied.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
result[key] = substitute_env_vars(value)
|
||||
elif isinstance(value, dict):
|
||||
result[key] = _substitute_in_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
substitute_env_vars(item) if isinstance(item, str) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
|
||||
"""Parse local configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing local config.
|
||||
|
||||
Returns:
|
||||
LocalConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
base_path = data.get("base_path")
|
||||
if not base_path:
|
||||
raise ValueError("local.base_path is required")
|
||||
return LocalConfig(base_path=Path(base_path))
|
||||
|
||||
|
||||
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
|
||||
"""Parse Azure configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing Azure config.
|
||||
|
||||
Returns:
|
||||
AzureConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
connection_string = data.get("connection_string")
|
||||
container_name = data.get("container_name")
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("azure.connection_string is required")
|
||||
if not container_name:
|
||||
raise ValueError("azure.container_name is required")
|
||||
|
||||
return AzureConfig(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
create_container=data.get("create_container", False),
|
||||
)
|
||||
|
||||
|
||||
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
|
||||
"""Parse S3 configuration section.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing S3 config.
|
||||
|
||||
Returns:
|
||||
S3Config instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing.
|
||||
"""
|
||||
bucket_name = data.get("bucket_name")
|
||||
|
||||
if not bucket_name:
|
||||
raise ValueError("s3.bucket_name is required")
|
||||
|
||||
return S3Config(
|
||||
bucket_name=bucket_name,
|
||||
region_name=data.get("region_name"),
|
||||
access_key_id=data.get("access_key_id"),
|
||||
secret_access_key=data.get("secret_access_key"),
|
||||
endpoint_url=data.get("endpoint_url"),
|
||||
create_bucket=data.get("create_bucket", False),
|
||||
)
|
||||
|
||||
|
||||
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
|
||||
"""Load storage configuration from YAML file.
|
||||
|
||||
Supports environment variable substitution using ${VAR_NAME} or
|
||||
${VAR_NAME:-default} syntax.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file.
|
||||
|
||||
Returns:
|
||||
Parsed StorageFileConfig.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist.
|
||||
ValueError: If config is invalid.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
try:
|
||||
raw_content = config_path.read_text(encoding="utf-8")
|
||||
data = yaml.safe_load(raw_content)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in config file: {e}") from e
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Config file must contain a YAML dictionary")
|
||||
|
||||
# Substitute environment variables
|
||||
data = _substitute_in_dict(data)
|
||||
|
||||
# Extract backend type
|
||||
backend_type = data.get("backend")
|
||||
if not backend_type:
|
||||
raise ValueError("'backend' field is required in config file")
|
||||
|
||||
# Parse presigned URL expiry
|
||||
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
|
||||
|
||||
# Parse backend-specific configurations
|
||||
local_config = None
|
||||
azure_config = None
|
||||
s3_config = None
|
||||
|
||||
if "local" in data:
|
||||
local_config = _parse_local_config(data["local"])
|
||||
|
||||
if "azure" in data:
|
||||
azure_config = _parse_azure_config(data["azure"])
|
||||
|
||||
if "s3" in data:
|
||||
s3_config = _parse_s3_config(data["s3"])
|
||||
|
||||
return StorageFileConfig(
|
||||
backend_type=backend_type,
|
||||
local=local_config,
|
||||
azure=azure_config,
|
||||
s3=s3_config,
|
||||
presigned_url_expiry=presigned_url_expiry,
|
||||
)
|
||||
296
packages/shared/shared/storage/factory.py
Normal file
296
packages/shared/shared/storage/factory.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Factory functions for creating storage backends.
|
||||
|
||||
Provides convenient functions for creating storage backends from
|
||||
configuration or environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from shared.storage.base import StorageBackend, StorageConfig
|
||||
|
||||
|
||||
def create_storage_backend(config: StorageConfig) -> StorageBackend:
|
||||
"""Create a storage backend from configuration.
|
||||
|
||||
Args:
|
||||
config: Storage configuration.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid.
|
||||
"""
|
||||
if config.backend_type == "local":
|
||||
if config.base_path is None:
|
||||
raise ValueError("base_path is required for local storage backend")
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=config.base_path)
|
||||
|
||||
elif config.backend_type == "azure_blob":
|
||||
if config.connection_string is None:
|
||||
raise ValueError(
|
||||
"connection_string is required for Azure blob storage backend"
|
||||
)
|
||||
if config.container_name is None:
|
||||
raise ValueError(
|
||||
"container_name is required for Azure blob storage backend"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=config.connection_string,
|
||||
container_name=config.container_name,
|
||||
)
|
||||
|
||||
elif config.backend_type == "s3":
|
||||
if config.bucket_name is None:
|
||||
raise ValueError("bucket_name is required for S3 storage backend")
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=config.bucket_name,
|
||||
region_name=config.region_name,
|
||||
access_key_id=config.access_key_id,
|
||||
secret_access_key=config.secret_access_key,
|
||||
endpoint_url=config.endpoint_url,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
|
||||
|
||||
|
||||
def get_default_storage_config() -> StorageConfig:
|
||||
"""Get storage configuration from environment variables.
|
||||
|
||||
Environment variables:
|
||||
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||
STORAGE_BASE_PATH: Base path for local storage.
|
||||
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||
AWS_S3_BUCKET: S3 bucket name.
|
||||
AWS_REGION: AWS region name.
|
||||
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||
|
||||
Returns:
|
||||
StorageConfig from environment.
|
||||
"""
|
||||
backend_type = os.environ.get("STORAGE_BACKEND", "local")
|
||||
|
||||
if backend_type == "local":
|
||||
base_path_str = os.environ.get("STORAGE_BASE_PATH")
|
||||
# Expand ~ to home directory
|
||||
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
|
||||
|
||||
return StorageConfig(
|
||||
backend_type="local",
|
||||
base_path=base_path,
|
||||
)
|
||||
|
||||
elif backend_type == "azure_blob":
|
||||
return StorageConfig(
|
||||
backend_type="azure_blob",
|
||||
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
|
||||
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
|
||||
)
|
||||
|
||||
elif backend_type == "s3":
|
||||
return StorageConfig(
|
||||
backend_type="s3",
|
||||
bucket_name=os.environ.get("AWS_S3_BUCKET"),
|
||||
region_name=os.environ.get("AWS_REGION"),
|
||||
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
|
||||
)
|
||||
|
||||
else:
|
||||
return StorageConfig(backend_type=backend_type)
|
||||
|
||||
|
||||
def create_storage_backend_from_env() -> StorageBackend:
|
||||
"""Create a storage backend from environment variables.
|
||||
|
||||
Environment variables:
|
||||
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||
STORAGE_BASE_PATH: Base path for local storage.
|
||||
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||
AWS_S3_BUCKET: S3 bucket name.
|
||||
AWS_REGION: AWS region name.
|
||||
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or empty.
|
||||
"""
|
||||
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
|
||||
|
||||
if backend_type == "local":
|
||||
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
|
||||
if not base_path:
|
||||
raise ValueError(
|
||||
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Expand ~ to home directory
|
||||
base_path_expanded = os.path.expanduser(base_path)
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=Path(base_path_expanded))
|
||||
|
||||
elif backend_type == "azure_blob":
|
||||
connection_string = os.environ.get(
|
||||
"AZURE_STORAGE_CONNECTION_STRING", ""
|
||||
).strip()
|
||||
if not connection_string:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
|
||||
"and cannot be empty"
|
||||
)
|
||||
|
||||
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
|
||||
if not container_name:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONTAINER environment variable is required "
|
||||
"and cannot be empty"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
elif backend_type == "s3":
|
||||
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
|
||||
if not bucket_name:
|
||||
raise ValueError(
|
||||
"AWS_S3_BUCKET environment variable is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=bucket_name,
|
||||
region_name=os.environ.get("AWS_REGION", "").strip() or None,
|
||||
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
|
||||
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
|
||||
or None,
|
||||
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {backend_type}")
|
||||
|
||||
|
||||
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
|
||||
"""Create a storage backend from a configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist.
|
||||
ValueError: If configuration is invalid.
|
||||
"""
|
||||
from shared.storage.config_loader import load_storage_config
|
||||
|
||||
file_config = load_storage_config(config_path)
|
||||
|
||||
if file_config.backend_type == "local":
|
||||
if file_config.local is None:
|
||||
raise ValueError("local configuration section is required")
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
return LocalStorageBackend(base_path=file_config.local.base_path)
|
||||
|
||||
elif file_config.backend_type == "azure_blob":
|
||||
if file_config.azure is None:
|
||||
raise ValueError("azure configuration section is required")
|
||||
|
||||
# Import here to allow lazy loading of Azure SDK
|
||||
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||
|
||||
from shared.storage.azure import AzureBlobStorageBackend
|
||||
|
||||
return AzureBlobStorageBackend(
|
||||
connection_string=file_config.azure.connection_string,
|
||||
container_name=file_config.azure.container_name,
|
||||
create_container=file_config.azure.create_container,
|
||||
)
|
||||
|
||||
elif file_config.backend_type == "s3":
|
||||
if file_config.s3 is None:
|
||||
raise ValueError("s3 configuration section is required")
|
||||
|
||||
# Import here to allow lazy loading of boto3
|
||||
import boto3 # noqa: F401
|
||||
|
||||
from shared.storage.s3 import S3StorageBackend
|
||||
|
||||
return S3StorageBackend(
|
||||
bucket_name=file_config.s3.bucket_name,
|
||||
region_name=file_config.s3.region_name,
|
||||
access_key_id=file_config.s3.access_key_id,
|
||||
secret_access_key=file_config.s3.secret_access_key,
|
||||
endpoint_url=file_config.s3.endpoint_url,
|
||||
create_bucket=file_config.s3.create_bucket,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
|
||||
|
||||
|
||||
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
|
||||
"""Get storage backend with fallback chain.
|
||||
|
||||
Priority:
|
||||
1. Config file (if provided)
|
||||
2. Environment variables
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file.
|
||||
|
||||
Returns:
|
||||
A configured storage backend.
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid.
|
||||
FileNotFoundError: If specified config file doesn't exist.
|
||||
"""
|
||||
if config_path:
|
||||
return create_storage_backend_from_file(config_path)
|
||||
|
||||
# Fall back to environment variables
|
||||
return create_storage_backend_from_env()
|
||||
262
packages/shared/shared/storage/local.py
Normal file
262
packages/shared/shared/storage/local.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Local filesystem storage backend.
|
||||
|
||||
Provides storage operations using the local filesystem.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class LocalStorageBackend(StorageBackend):
|
||||
"""Storage backend using local filesystem.
|
||||
|
||||
Files are stored relative to a base path on the local filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str | Path) -> None:
|
||||
"""Initialize local storage backend.
|
||||
|
||||
Args:
|
||||
base_path: Base directory for all storage operations.
|
||||
Will be created if it doesn't exist.
|
||||
"""
|
||||
self._base_path = Path(base_path)
|
||||
self._base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def base_path(self) -> Path:
|
||||
"""Get the base path for this storage backend."""
|
||||
return self._base_path
|
||||
|
||||
def _get_full_path(self, remote_path: str) -> Path:
|
||||
"""Convert a remote path to a full local path with security validation.
|
||||
|
||||
Args:
|
||||
remote_path: The remote path to resolve.
|
||||
|
||||
Returns:
|
||||
The full local path.
|
||||
|
||||
Raises:
|
||||
StorageError: If the path attempts to escape the base directory.
|
||||
"""
|
||||
# Reject absolute paths
|
||||
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
|
||||
raise StorageError(f"Absolute paths not allowed: {remote_path}")
|
||||
|
||||
# Resolve to prevent path traversal attacks
|
||||
full_path = (self._base_path / remote_path).resolve()
|
||||
base_resolved = self._base_path.resolve()
|
||||
|
||||
# Verify the resolved path is within base_path
|
||||
try:
|
||||
full_path.relative_to(base_resolved)
|
||||
except ValueError:
|
||||
raise StorageError(f"Path traversal not allowed: {remote_path}")
|
||||
|
||||
return full_path
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to local storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If file exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
dest_path = self._get_full_path(remote_path)
|
||||
|
||||
if dest_path.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(local_path, dest_path)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download a file from local storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
source_path = self._get_full_path(remote_path)
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_path, local_path)
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if a file exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to check in storage.
|
||||
|
||||
Returns:
|
||||
True if the file exists, False otherwise.
|
||||
"""
|
||||
return self._get_full_path(remote_path).exists()
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List files in storage with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Path prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Sorted list of file paths matching the prefix.
|
||||
"""
|
||||
if prefix:
|
||||
search_path = self._get_full_path(prefix)
|
||||
if not search_path.exists():
|
||||
return []
|
||||
base_for_relative = self._base_path
|
||||
else:
|
||||
search_path = self._base_path
|
||||
base_for_relative = self._base_path
|
||||
|
||||
files: list[str] = []
|
||||
if search_path.is_file():
|
||||
files.append(str(search_path.relative_to(self._base_path)))
|
||||
elif search_path.is_dir():
|
||||
for file_path in search_path.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(self._base_path)
|
||||
files.append(str(relative_path).replace("\\", "/"))
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete a file from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if it didn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
return False
|
||||
|
||||
file_path.unlink()
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a file:// URL to access a file.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
file:// URL to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.as_uri()
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to storage.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination path in storage.
|
||||
overwrite: If True, overwrite existing file.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
"""
|
||||
dest_path = self._get_full_path(remote_path)
|
||||
|
||||
if dest_path.exists() and not overwrite:
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest_path.write_bytes(data)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download a file as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
|
||||
Returns:
|
||||
The file contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.read_bytes()
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a file:// URL for local file access.
|
||||
|
||||
For local storage, this returns a file:// URI.
|
||||
Note: Local file:// URLs don't actually expire.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file in storage.
|
||||
expires_in_seconds: Ignored for local storage (URLs don't expire).
|
||||
|
||||
Returns:
|
||||
file:// URL to access the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
file_path = self._get_full_path(remote_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return file_path.as_uri()
|
||||
158
packages/shared/shared/storage/prefixes.py
Normal file
158
packages/shared/shared/storage/prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Storage path prefixes for unified file organization.
|
||||
|
||||
Provides standardized path prefixes for organizing files within
|
||||
the storage backend, ensuring consistent structure across
|
||||
local, Azure Blob, and S3 storage.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StoragePrefixes:
|
||||
"""Standardized storage path prefixes.
|
||||
|
||||
All paths are relative to the storage backend root.
|
||||
These prefixes ensure consistent file organization across
|
||||
all storage backends (local, Azure, S3).
|
||||
|
||||
Usage:
|
||||
from shared.storage.prefixes import PREFIXES
|
||||
|
||||
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
|
||||
storage.upload_bytes(content, path)
|
||||
"""
|
||||
|
||||
# Document storage
|
||||
DOCUMENTS: str = "documents"
|
||||
"""Original document files (PDFs, etc.)"""
|
||||
|
||||
IMAGES: str = "images"
|
||||
"""Page images extracted from documents"""
|
||||
|
||||
# Processing directories
|
||||
UPLOADS: str = "uploads"
|
||||
"""Temporary upload staging area"""
|
||||
|
||||
RESULTS: str = "results"
|
||||
"""Inference results and visualizations"""
|
||||
|
||||
EXPORTS: str = "exports"
|
||||
"""Exported datasets and annotations"""
|
||||
|
||||
# Training data
|
||||
DATASETS: str = "datasets"
|
||||
"""Training dataset files"""
|
||||
|
||||
MODELS: str = "models"
|
||||
"""Trained model weights and checkpoints"""
|
||||
|
||||
# Data pipeline directories (legacy compatibility)
|
||||
RAW_PDFS: str = "raw_pdfs"
|
||||
"""Raw PDF files for auto-labeling pipeline"""
|
||||
|
||||
STRUCTURED_DATA: str = "structured_data"
|
||||
"""CSV/structured data for matching"""
|
||||
|
||||
ADMIN_IMAGES: str = "admin_images"
|
||||
"""Admin UI page images"""
|
||||
|
||||
@staticmethod
|
||||
def document_path(document_id: str, extension: str = ".pdf") -> str:
|
||||
"""Get path for a document file.
|
||||
|
||||
Args:
|
||||
document_id: Unique document identifier.
|
||||
extension: File extension (include leading dot).
|
||||
|
||||
Returns:
|
||||
Storage path like "documents/abc123.pdf"
|
||||
"""
|
||||
ext = extension if extension.startswith(".") else f".{extension}"
|
||||
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
|
||||
"""Get path for a page image file.
|
||||
|
||||
Args:
|
||||
document_id: Unique document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
extension: File extension (include leading dot).
|
||||
|
||||
Returns:
|
||||
Storage path like "images/abc123/page_1.png"
|
||||
"""
|
||||
ext = extension if extension.startswith(".") else f".{extension}"
|
||||
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def upload_path(filename: str, subfolder: str | None = None) -> str:
|
||||
"""Get path for a temporary upload file.
|
||||
|
||||
Args:
|
||||
filename: Original filename.
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
|
||||
"""
|
||||
if subfolder:
|
||||
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
|
||||
return f"{PREFIXES.UPLOADS}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def result_path(filename: str) -> str:
|
||||
"""Get path for a result file.
|
||||
|
||||
Args:
|
||||
filename: Result filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "results/filename.json"
|
||||
"""
|
||||
return f"{PREFIXES.RESULTS}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def export_path(export_id: str, filename: str) -> str:
|
||||
"""Get path for an export file.
|
||||
|
||||
Args:
|
||||
export_id: Unique export identifier.
|
||||
filename: Export filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "exports/abc123/filename.zip"
|
||||
"""
|
||||
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def dataset_path(dataset_id: str, filename: str) -> str:
|
||||
"""Get path for a dataset file.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique dataset identifier.
|
||||
filename: Dataset filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "datasets/abc123/filename.yaml"
|
||||
"""
|
||||
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def model_path(version: str, filename: str) -> str:
|
||||
"""Get path for a model file.
|
||||
|
||||
Args:
|
||||
version: Model version string.
|
||||
filename: Model filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "models/v1.0.0/best.pt"
|
||||
"""
|
||||
return f"{PREFIXES.MODELS}/{version}/{filename}"
|
||||
|
||||
|
||||
# Default instance for convenient access
|
||||
PREFIXES = StoragePrefixes()
|
||||
309
packages/shared/shared/storage/s3.py
Normal file
309
packages/shared/shared/storage/s3.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
AWS S3 Storage backend.
|
||||
|
||||
Provides storage operations using AWS S3.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from shared.storage.base import (
|
||||
FileNotFoundStorageError,
|
||||
StorageBackend,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
class S3StorageBackend(StorageBackend):
|
||||
"""Storage backend using AWS S3.
|
||||
|
||||
Files are stored as objects in an S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: str,
|
||||
region_name: str | None = None,
|
||||
access_key_id: str | None = None,
|
||||
secret_access_key: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
create_bucket: bool = False,
|
||||
) -> None:
|
||||
"""Initialize S3 storage backend.
|
||||
|
||||
Args:
|
||||
bucket_name: Name of the S3 bucket.
|
||||
region_name: AWS region name (optional, uses default if not set).
|
||||
access_key_id: AWS access key ID (optional, uses credentials chain).
|
||||
secret_access_key: AWS secret access key (optional).
|
||||
endpoint_url: Custom endpoint URL (for S3-compatible services).
|
||||
create_bucket: If True, create the bucket if it doesn't exist.
|
||||
"""
|
||||
import boto3
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._region_name = region_name
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs: dict[str, Any] = {}
|
||||
if region_name:
|
||||
client_kwargs["region_name"] = region_name
|
||||
if endpoint_url:
|
||||
client_kwargs["endpoint_url"] = endpoint_url
|
||||
if access_key_id and secret_access_key:
|
||||
client_kwargs["aws_access_key_id"] = access_key_id
|
||||
client_kwargs["aws_secret_access_key"] = secret_access_key
|
||||
|
||||
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
|
||||
|
||||
if create_bucket:
|
||||
self._ensure_bucket_exists()
|
||||
|
||||
def _ensure_bucket_exists(self) -> None:
|
||||
"""Create the bucket if it doesn't exist."""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self._s3.head_bucket(Bucket=self._bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchBucket"):
|
||||
# Bucket doesn't exist, create it
|
||||
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||
if self._region_name and self._region_name != "us-east-1":
|
||||
create_kwargs["CreateBucketConfiguration"] = {
|
||||
"LocationConstraint": self._region_name
|
||||
}
|
||||
self._s3.create_bucket(**create_kwargs)
|
||||
else:
|
||||
# Re-raise permission errors, network issues, etc.
|
||||
raise
|
||||
|
||||
def _object_exists(self, key: str) -> bool:
|
||||
"""Check if an object exists in S3.
|
||||
|
||||
Args:
|
||||
key: Object key to check.
|
||||
|
||||
Returns:
|
||||
True if object exists, False otherwise.
|
||||
"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self._s3.head_object(Bucket=self._bucket_name, Key=key)
|
||||
return True
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchKey"):
|
||||
return False
|
||||
raise
|
||||
|
||||
@property
|
||||
def bucket_name(self) -> str:
|
||||
"""Get the bucket name for this storage backend."""
|
||||
return self._bucket_name
|
||||
|
||||
def upload(
|
||||
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload a file to S3.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to upload.
|
||||
remote_path: Destination object key.
|
||||
overwrite: If True, overwrite existing object.
|
||||
|
||||
Returns:
|
||||
The remote path where the file was stored.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If local_path doesn't exist.
|
||||
StorageError: If object exists and overwrite is False.
|
||||
"""
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundStorageError(str(local_path))
|
||||
|
||||
if not overwrite and self._object_exists(remote_path):
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||
"""Download an object from S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
local_path: Local destination path.
|
||||
|
||||
Returns:
|
||||
The local path where the file was downloaded.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
|
||||
|
||||
return local_path
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if an object exists in S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key to check.
|
||||
|
||||
Returns:
|
||||
True if the object exists, False otherwise.
|
||||
"""
|
||||
return self._object_exists(remote_path)
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
"""List objects in S3 with given prefix.
|
||||
|
||||
Handles pagination to return all matching objects (S3 returns max 1000 per request).
|
||||
|
||||
Args:
|
||||
prefix: Object key prefix to filter.
|
||||
|
||||
Returns:
|
||||
List of object keys matching the prefix.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||
if prefix:
|
||||
kwargs["Prefix"] = prefix
|
||||
|
||||
all_keys: list[str] = []
|
||||
while True:
|
||||
response = self._s3.list_objects_v2(**kwargs)
|
||||
contents = response.get("Contents", [])
|
||||
all_keys.extend(obj["Key"] for obj in contents)
|
||||
|
||||
if not response.get("IsTruncated"):
|
||||
break
|
||||
kwargs["ContinuationToken"] = response["NextContinuationToken"]
|
||||
|
||||
return all_keys
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
"""Delete an object from S3.
|
||||
|
||||
Args:
|
||||
remote_path: Object key to delete.
|
||||
|
||||
Returns:
|
||||
True if object was deleted, False if it didn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
return False
|
||||
|
||||
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
|
||||
return True
|
||||
|
||||
def get_url(self, remote_path: str) -> str:
|
||||
"""Get a URL for an object.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
|
||||
Returns:
|
||||
URL to access the object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return self._s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
def get_presigned_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Generate a pre-signed URL for temporary object access.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
ValueError: If expires_in_seconds is out of valid range.
|
||||
"""
|
||||
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||
raise ValueError(
|
||||
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||
)
|
||||
|
||||
if not self._object_exists(remote_path):
|
||||
raise FileNotFoundStorageError(remote_path)
|
||||
|
||||
return self._s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||
ExpiresIn=expires_in_seconds,
|
||||
)
|
||||
|
||||
def upload_bytes(
|
||||
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Upload bytes directly to S3.
|
||||
|
||||
Args:
|
||||
data: Bytes to upload.
|
||||
remote_path: Destination object key.
|
||||
overwrite: If True, overwrite existing object.
|
||||
|
||||
Returns:
|
||||
The remote path where the data was stored.
|
||||
|
||||
Raises:
|
||||
StorageError: If object exists and overwrite is False.
|
||||
"""
|
||||
if not overwrite and self._object_exists(remote_path):
|
||||
raise StorageError(f"File already exists: {remote_path}")
|
||||
|
||||
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
|
||||
|
||||
return remote_path
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
"""Download an object as bytes.
|
||||
|
||||
Args:
|
||||
remote_path: Object key in S3.
|
||||
|
||||
Returns:
|
||||
The object contents as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||
"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code in ("404", "NoSuchKey"):
|
||||
raise FileNotFoundStorageError(remote_path) from e
|
||||
raise
|
||||
@@ -20,7 +20,7 @@ from shared.config import get_db_connection_string
|
||||
from shared.normalize import normalize_field
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.pdf import is_text_pdf, extract_text_tokens
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ def process_single_document(args_tuple):
|
||||
# Import inside worker to avoid pickling issues
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
start_time = time.time()
|
||||
@@ -342,7 +342,8 @@ def main():
|
||||
from shared.ocr import OCREngine
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.normalize import normalize_field
|
||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
# Handle comma-separated CSV paths
|
||||
csv_input = args.csv
|
||||
|
||||
@@ -90,7 +90,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import shutil
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
@@ -208,7 +208,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import shutil
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
|
||||
@@ -15,7 +15,8 @@ from training.data.autolabel_report import FieldMatchResult
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.normalize import normalize_field
|
||||
from shared.ocr.machine_code_parser import MachineCodeParser
|
||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
|
||||
def match_supplier_accounts(
|
||||
|
||||
@@ -9,43 +9,12 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
import csv
|
||||
|
||||
|
||||
# Field class mapping for YOLO
|
||||
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
|
||||
FIELD_CLASSES = {
|
||||
'InvoiceNumber': 0,
|
||||
'InvoiceDate': 1,
|
||||
'InvoiceDueDate': 2,
|
||||
'OCR': 3,
|
||||
'Bankgiro': 4,
|
||||
'Plusgiro': 5,
|
||||
'Amount': 6,
|
||||
'supplier_organisation_number': 7,
|
||||
'customer_number': 8,
|
||||
'payment_line': 9, # Machine code payment line at bottom of invoice
|
||||
}
|
||||
|
||||
# Fields that need matching but map to other YOLO classes
|
||||
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
|
||||
ACCOUNT_FIELD_MAPPING = {
|
||||
'supplier_accounts': {
|
||||
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
|
||||
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
|
||||
}
|
||||
}
|
||||
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number',
|
||||
'customer_number',
|
||||
'payment_line', # Machine code payment line at bottom of invoice
|
||||
]
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import (
|
||||
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
||||
CLASS_NAMES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -101,7 +101,8 @@ class DatasetBuilder:
|
||||
Returns:
|
||||
DatasetStats with build results
|
||||
"""
|
||||
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
|
||||
from shared.fields import CLASS_NAMES
|
||||
from .annotation_generator import AnnotationGenerator
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from .annotation_generator import YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user