WIP
This commit is contained in:
@@ -120,7 +120,7 @@ def main() -> None:
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
@@ -136,7 +136,7 @@ def main() -> None:
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
storage=StorageConfig(),
|
||||
file=FileConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
|
||||
@@ -112,6 +112,7 @@ class AdminDB:
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
category: str = "invoice",
|
||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||
) -> str:
|
||||
"""Create a new document record."""
|
||||
@@ -125,6 +126,7 @@ class AdminDB:
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
@@ -154,6 +156,7 @@ class AdminDB:
|
||||
has_annotations: bool | None = None,
|
||||
auto_label_status: str | None = None,
|
||||
batch_id: str | None = None,
|
||||
category: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
@@ -171,6 +174,8 @@ class AdminDB:
|
||||
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
||||
if batch_id:
|
||||
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
||||
if category:
|
||||
where_clauses.append(AdminDocument.category == category)
|
||||
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(AdminDocument)
|
||||
@@ -283,6 +288,32 @@ class AdminDB:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_document_categories(self) -> list[str]:
|
||||
"""Get list of unique document categories."""
|
||||
with get_session_context() as session:
|
||||
statement = (
|
||||
select(AdminDocument.category)
|
||||
.distinct()
|
||||
.order_by(AdminDocument.category)
|
||||
)
|
||||
categories = session.exec(statement).all()
|
||||
return [c for c in categories if c is not None]
|
||||
|
||||
def update_document_category(
|
||||
self, document_id: str, category: str
|
||||
) -> AdminDocument | None:
|
||||
"""Update document category."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.category = category
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
session.refresh(document)
|
||||
return document
|
||||
return None
|
||||
|
||||
# ==========================================================================
|
||||
# Annotation Operations
|
||||
# ==========================================================================
|
||||
@@ -1292,6 +1323,36 @@ class AdminDB:
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def update_dataset_training_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
training_status: str | None,
|
||||
active_training_task_id: str | UUID | None = None,
|
||||
update_main_status: bool = False,
|
||||
) -> None:
|
||||
"""Update dataset training status and optionally the main status.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
training_status: Training status (pending, running, completed, failed, cancelled)
|
||||
active_training_task_id: Currently active training task ID
|
||||
update_main_status: If True and training_status is 'completed', set main status to 'trained'
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.training_status = training_status
|
||||
dataset.active_training_task_id = (
|
||||
UUID(str(active_training_task_id)) if active_training_task_id else None
|
||||
)
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
# Update main status to 'trained' when training completes
|
||||
if update_main_status and training_status == "completed":
|
||||
dataset.status = "trained"
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def add_dataset_documents(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
|
||||
@@ -11,23 +11,8 @@ from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CSV to Field Class Mapping
|
||||
# =============================================================================
|
||||
|
||||
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
||||
"InvoiceNumber": 0, # invoice_number
|
||||
"InvoiceDate": 1, # invoice_date
|
||||
"InvoiceDueDate": 2, # invoice_due_date
|
||||
"OCR": 3, # ocr_number
|
||||
"Bankgiro": 4, # bankgiro
|
||||
"Plusgiro": 5, # plusgiro
|
||||
"Amount": 6, # amount
|
||||
"supplier_organisation_number": 7, # supplier_organisation_number
|
||||
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
||||
"customer_number": 9, # customer_number
|
||||
}
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True):
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||
# User-defined grouping key for document organization
|
||||
category: str = Field(default="invoice", max_length=100, index=True)
|
||||
# Document category for training different models (e.g., invoice, letter, receipt)
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
@@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True):
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="building", max_length=20, index=True)
|
||||
# Status: building, ready, training, archived, failed
|
||||
# Status: building, ready, trained, archived, failed
|
||||
training_status: str | None = Field(default=None, max_length=20, index=True)
|
||||
# Training status: pending, scheduled, running, completed, failed, cancelled
|
||||
active_training_task_id: UUID | None = Field(default=None, index=True)
|
||||
train_ratio: float = Field(default=0.8)
|
||||
val_ratio: float = Field(default=0.1)
|
||||
seed: int = Field(default=42)
|
||||
@@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Field class mapping (same as src/cli/train.py)
|
||||
FIELD_CLASSES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_organisation_number",
|
||||
8: "payment_line",
|
||||
9: "customer_number",
|
||||
}
|
||||
|
||||
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
||||
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
# Read-only models for API responses
|
||||
@@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel):
|
||||
status: str
|
||||
auto_label_status: str | None
|
||||
auto_label_error: str | None
|
||||
category: str = "invoice"
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@@ -141,6 +141,40 @@ def run_migrations() -> None:
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 009: Add category to admin_documents
|
||||
(
|
||||
"admin_documents_category",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||
""",
|
||||
),
|
||||
# Migration 010: Add training_status and active_training_task_id to training_datasets
|
||||
(
|
||||
"training_datasets_training_status",
|
||||
"""
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
|
||||
""",
|
||||
),
|
||||
# Migration 010b: Update existing datasets with completed training to 'trained' status
|
||||
(
|
||||
"training_datasets_update_trained_status",
|
||||
"""
|
||||
UPDATE training_datasets d
|
||||
SET status = 'trained'
|
||||
WHERE d.status = 'ready'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM training_tasks t
|
||||
WHERE t.dataset_id = d.dataset_id
|
||||
AND t.status = 'completed'
|
||||
);
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
|
||||
@@ -21,7 +21,8 @@ import re
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import Detection
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
@@ -10,7 +10,8 @@ from typing import Any
|
||||
import time
|
||||
import re
|
||||
|
||||
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor, ExtractedField
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
|
||||
|
||||
@@ -9,6 +9,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
@@ -72,33 +75,8 @@ class Detection:
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
# Class names (must match training configuration)
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number', # Matches training class name
|
||||
'customer_number',
|
||||
'payment_line', # Machine code payment line at bottom of invoice
|
||||
]
|
||||
|
||||
# Mapping from class name to field name
|
||||
CLASS_TO_FIELD = {
|
||||
'invoice_number': 'InvoiceNumber',
|
||||
'invoice_date': 'InvoiceDate',
|
||||
'invoice_due_date': 'InvoiceDueDate',
|
||||
'ocr_number': 'OCR',
|
||||
'bankgiro': 'Bankgiro',
|
||||
'plusgiro': 'Plusgiro',
|
||||
'amount': 'Amount',
|
||||
'supplier_org_number': 'supplier_org_number',
|
||||
'customer_number': 'customer_number',
|
||||
'payment_line': 'payment_line',
|
||||
}
|
||||
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
|
||||
@@ -4,18 +4,19 @@ Admin Annotation API Routes
|
||||
FastAPI endpoints for annotation management.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.autolabel import get_auto_label_service
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
@@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Image storage directory
|
||||
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
@@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter:
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/images/{page_number}",
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {"content": {"image/png": {}}, "description": "Page image"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
@@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter:
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> FileResponse:
|
||||
) -> FileResponse | StreamingResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
@@ -91,18 +91,33 @@ def create_annotation_router() -> APIRouter:
|
||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||
)
|
||||
|
||||
# Find image file
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
||||
if not image_path.exists():
|
||||
# Get storage helper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Check if image exists
|
||||
if not storage.admin_image_exists(document_id, page_number):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Image for page {page_number} not found",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
# Try to get local path for efficient file serving
|
||||
local_path = storage.get_admin_image_local_path(document_id, page_number)
|
||||
if local_path is not None:
|
||||
return FileResponse(
|
||||
path=str(local_path),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
)
|
||||
|
||||
# Fall back to streaming for cloud storage
|
||||
image_content = storage.get_admin_image(document_id, page_number)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(image_content),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
@@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Get image dimensions for normalization
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
||||
if not image_path.exists():
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {request.page_number} not available",
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
@@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter:
|
||||
|
||||
if request.bbox is not None:
|
||||
# Get image dimensions
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {annotation.page_number} not available",
|
||||
)
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
|
||||
@@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
BoundingBox,
|
||||
DocumentCategoriesResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
DocumentListResponse,
|
||||
DocumentStatus,
|
||||
DocumentStatsResponse,
|
||||
DocumentUpdateRequest,
|
||||
DocumentUploadResponse,
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
@@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
@@ -60,8 +61,9 @@ def _convert_pdf_to_images(
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
@@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
str | None,
|
||||
Query(description="Optional group key for document organization", max_length=255),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
|
||||
] = "invoice",
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate group_key length
|
||||
@@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database
|
||||
# Update file path in database (using storage path for reference)
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
doc.file_path = str(file_path)
|
||||
# Store the storage path (relative path within storage)
|
||||
doc.file_path = storage_path
|
||||
session.add(doc)
|
||||
|
||||
# Convert PDF to images for annotation
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
@@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
category=category,
|
||||
group_key=group_key,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
@@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
str | None,
|
||||
Query(description="Filter by batch ID"),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by document category"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
@@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
has_annotations=has_annotations,
|
||||
auto_label_status=auto_label_status,
|
||||
batch_id=batch_id,
|
||||
category=category,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
@@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
category=doc.category if hasattr(doc, 'category') else "invoice",
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
@@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||
category=document.category if hasattr(document, 'category') else "invoice",
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
@@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete file
|
||||
file_path = Path(document.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
# Delete file using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Delete images
|
||||
images_dir = ADMIN_IMAGES_DIR / document_id
|
||||
if images_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(images_dir)
|
||||
# Delete the raw PDF
|
||||
filename = Path(document.file_path).name
|
||||
if filename:
|
||||
try:
|
||||
storage._storage.delete(document.file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete PDF file: {e}")
|
||||
|
||||
# Delete admin images
|
||||
try:
|
||||
storage.delete_admin_images(document_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete admin images: {e}")
|
||||
|
||||
# Delete from database
|
||||
db.delete_document(document_id)
|
||||
@@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
response_model=DocumentCategoriesResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get available categories",
|
||||
description="Get list of all available document categories.",
|
||||
)
|
||||
async def get_categories(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentCategoriesResponse:
|
||||
"""Get all available document categories."""
|
||||
categories = db.get_document_categories()
|
||||
return DocumentCategoriesResponse(
|
||||
categories=categories,
|
||||
total=len(categories),
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/category",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document category",
|
||||
description="Update the category for a document.",
|
||||
)
|
||||
async def update_document_category(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: DocumentUpdateRequest,
|
||||
) -> dict:
|
||||
"""Update document category."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update category if provided
|
||||
if request.category is not None:
|
||||
db.update_document_category(document_id, request.category)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"category": request.category,
|
||||
"message": "Document category updated",
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
@@ -17,6 +17,7 @@ from inference.web.schemas.admin import (
|
||||
TrainingStatus,
|
||||
TrainingTaskResponse,
|
||||
)
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
@@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
db: AdminDBDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Validate minimum document count for proper train/val/test split
|
||||
@@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
|
||||
# Get storage paths from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
datasets_dir = storage.get_datasets_base_path()
|
||||
admin_images_dir = storage.get_admin_images_base_path()
|
||||
|
||||
if datasets_dir is None or admin_images_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
@@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
admin_images_dir=Path("data/admin_images"),
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
status=dataset.status,
|
||||
training_status=dataset.training_status,
|
||||
active_training_task_id=(
|
||||
str(dataset.active_training_task_id)
|
||||
if dataset.active_training_task_id
|
||||
else None
|
||||
),
|
||||
train_ratio=dataset.train_ratio,
|
||||
val_ratio=dataset.val_ratio,
|
||||
seed=dataset.seed,
|
||||
|
||||
@@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images and exports directory
|
||||
storage = get_storage_helper()
|
||||
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
@@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
# Get exports directory from StorageHelper
|
||||
exports_base = storage.get_exports_base_path()
|
||||
if exports_base is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
@@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
@@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
|
||||
@@ -22,6 +22,7 @@ from inference.web.schemas.inference import (
|
||||
InferenceResult,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
@@ -90,8 +91,17 @@ def create_inference_router(
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Save uploaded file
|
||||
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
||||
# Get storage helper and uploads directory
|
||||
storage = get_storage_helper()
|
||||
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
|
||||
if uploads_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
# Save uploaded file to temporary location for processing
|
||||
upload_path = uploads_dir / f"{doc_id}{file_ext}"
|
||||
try:
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
@@ -149,12 +159,13 @@ def create_inference_router(
|
||||
# Cleanup uploaded file
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
@router.get("/results/{filename}", response_model=None)
|
||||
async def get_result_image(filename: str) -> FileResponse:
|
||||
"""Get visualization result image."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
storage = get_storage_helper()
|
||||
file_path = storage.get_result_local_path(filename)
|
||||
|
||||
if not file_path.exists():
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
@@ -169,15 +180,15 @@ def create_inference_router(
|
||||
@router.delete("/results/{filename}")
|
||||
async def delete_result(filename: str) -> dict:
|
||||
"""Delete a result file."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
storage = get_storage_helper()
|
||||
|
||||
if not file_path.exists():
|
||||
if not storage.result_exists(filename):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
file_path.unlink()
|
||||
storage.delete_result(filename)
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
return router
|
||||
|
||||
@@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.labeling import PreLabelResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
@@ -23,19 +24,14 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
||||
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
@@ -43,8 +39,9 @@ def _convert_pdf_to_images(
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
@@ -70,9 +67,6 @@ def create_labeling_router(
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||
|
||||
# Ensure upload directory exists
|
||||
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@router.post(
|
||||
"/pre-label",
|
||||
response_model=PreLabelResponse,
|
||||
@@ -165,10 +159,11 @@ def create_labeling_router(
|
||||
csv_field_values=expected_values,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(
|
||||
@@ -176,15 +171,14 @@ def create_labeling_router(
|
||||
detail="Failed to save file",
|
||||
)
|
||||
|
||||
# Update file path in database
|
||||
db.update_document_file_path(document_id, str(file_path))
|
||||
# Update file path in database (using storage path)
|
||||
db.update_document_file_path(document_id, storage_path)
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Public API imports
|
||||
from inference.web.api.v1.public import (
|
||||
@@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount static files for results
|
||||
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(config.storage.result_dir)),
|
||||
name="results",
|
||||
)
|
||||
# Mount static files for results using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir:
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(results_dir)),
|
||||
name="results",
|
||||
)
|
||||
else:
|
||||
logger.warning("Could not mount static results directory: local storage not available")
|
||||
|
||||
# Include public API routes
|
||||
inference_router = create_inference_router(inference_service, config.storage)
|
||||
|
||||
@@ -4,16 +4,49 @@ Web Application Configuration
|
||||
Centralized configuration for the web application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_storage_backend(
|
||||
config_path: Path | str | None = None,
|
||||
) -> "StorageBackend":
|
||||
"""Get storage backend for file operations.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to storage configuration file.
|
||||
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
from shared.storage import get_storage_backend as _get_storage_backend
|
||||
|
||||
# Check for config file path
|
||||
if config_path is None:
|
||||
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
|
||||
if config_path_str:
|
||||
config_path = Path(config_path_str)
|
||||
|
||||
return _get_storage_backend(config_path=config_path)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""YOLO model configuration."""
|
||||
"""YOLO model configuration.
|
||||
|
||||
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
|
||||
- Models need to be accessible by inference service on any platform
|
||||
- Models may be version-controlled or deployed separately
|
||||
- Models are part of the application, not user data
|
||||
"""
|
||||
|
||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||
confidence_threshold: float = 0.5
|
||||
@@ -33,24 +66,39 @@ class ServerConfig:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""File storage configuration.
|
||||
class FileConfig:
|
||||
"""File handling configuration.
|
||||
|
||||
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
||||
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
||||
and avoids storing duplicate files.
|
||||
This config holds file handling settings. For file operations,
|
||||
use the storage backend with PREFIXES from shared.storage.prefixes.
|
||||
|
||||
Example:
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
|
||||
storage = get_storage_backend()
|
||||
path = PREFIXES.document_path(document_id)
|
||||
storage.upload_bytes(content, path)
|
||||
|
||||
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
|
||||
They are kept for backward compatibility with existing code and tests.
|
||||
New code should use the storage backend with PREFIXES instead.
|
||||
"""
|
||||
|
||||
upload_dir: Path = Path("uploads")
|
||||
result_dir: Path = Path("results")
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
||||
admin_images_dir: Path = Path("data/admin_images")
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
dpi: int = DEFAULT_DPI
|
||||
presigned_url_expiry_seconds: int = 3600
|
||||
|
||||
# Deprecated path fields - kept for backward compatibility
|
||||
# New code should use storage backend with PREFIXES instead
|
||||
# All paths are now under data/ to match WSL storage layout
|
||||
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
|
||||
result_dir: Path = field(default_factory=lambda: Path("data/results"))
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
|
||||
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||
@@ -61,9 +109,17 @@ class StorageConfig:
|
||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
StorageConfig = FileConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncConfig:
|
||||
"""Async processing configuration."""
|
||||
"""Async processing configuration.
|
||||
|
||||
Note: For file paths, use the storage backend with PREFIXES.
|
||||
Example: PREFIXES.upload_path(filename, "async")
|
||||
"""
|
||||
|
||||
# Queue settings
|
||||
queue_max_size: int = 100
|
||||
@@ -77,14 +133,17 @@ class AsyncConfig:
|
||||
|
||||
# Storage
|
||||
result_retention_days: int = 7
|
||||
temp_upload_dir: Path = Path("uploads/async")
|
||||
max_file_size_mb: int = 50
|
||||
|
||||
# Deprecated: kept for backward compatibility
|
||||
# Path under data/ to match WSL storage layout
|
||||
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
|
||||
|
||||
# Cleanup
|
||||
cleanup_interval_hours: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -95,19 +154,41 @@ class AppConfig:
|
||||
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
file: FileConfig = field(default_factory=FileConfig)
|
||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||
storage_backend: "StorageBackend | None" = None
|
||||
|
||||
@property
|
||||
def storage(self) -> FileConfig:
|
||||
"""Backward compatibility alias for file config."""
|
||||
return self.file
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
"""Create config from dictionary."""
|
||||
file_config = config_dict.get("file", config_dict.get("storage", {}))
|
||||
return cls(
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
file=FileConfig(**file_config),
|
||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||
)
|
||||
|
||||
|
||||
def create_app_config(
|
||||
storage_config_path: Path | str | None = None,
|
||||
) -> AppConfig:
|
||||
"""Create application configuration with storage backend.
|
||||
|
||||
Args:
|
||||
storage_config_path: Optional path to storage configuration file.
|
||||
|
||||
Returns:
|
||||
Configured AppConfig instance with storage backend initialized.
|
||||
"""
|
||||
storage_backend = get_storage_backend(config_path=storage_config_path)
|
||||
return AppConfig(storage_backend=storage_backend)
|
||||
|
||||
|
||||
# Default configuration instance
|
||||
default_config = AppConfig()
|
||||
|
||||
@@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,13 @@ class AutoLabelScheduler:
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
if output_dir is None:
|
||||
storage = get_storage_helper()
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,6 +108,14 @@ class TrainingScheduler:
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||
|
||||
# Update dataset training status to running
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
@@ -192,6 +201,15 @@ class TrainingScheduler:
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Update dataset training status to completed and main status to trained
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
update_main_status=True, # Set main status to 'trained'
|
||||
)
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
@@ -203,6 +221,13 @@ class TrainingScheduler:
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
# Update dataset training status to failed
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
dataset_id,
|
||||
training_status="failed",
|
||||
active_training_task_id=None,
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_model_version_from_training(
|
||||
@@ -268,9 +293,10 @@ class TrainingScheduler:
|
||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
||||
f"Model version {version} created (mAP: {mAP_display})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -283,8 +309,11 @@ class TrainingScheduler:
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
@@ -293,8 +322,12 @@ class TrainingScheduler:
|
||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/training") / task_id
|
||||
# Create export directory using StorageHelper
|
||||
training_base = storage.get_training_data_path()
|
||||
if training_base is None:
|
||||
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
|
||||
return None
|
||||
export_dir = training_base / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
@@ -323,14 +356,16 @@ class TrainingScheduler:
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label
|
||||
@@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())}
|
||||
self._db.add_training_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
|
||||
# because models need to be accessible by inference service on any platform
|
||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||
config = SharedTrainingConfig(
|
||||
model_path=model_name,
|
||||
|
||||
@@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||
category: str | None = Field(None, description="Filter documents by category (optional)")
|
||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||
seed: int = Field(42, description="Random seed for split")
|
||||
@@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel):
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
|
||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
@@ -44,6 +45,7 @@ class DocumentItem(BaseModel):
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
@@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel):
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
@@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel):
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
|
||||
|
||||
class DocumentUpdateRequest(BaseModel):
|
||||
"""Request for updating document metadata."""
|
||||
|
||||
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
|
||||
|
||||
class DocumentCategoriesResponse(BaseModel):
|
||||
"""Response for available document categories."""
|
||||
|
||||
categories: list[str] = Field(..., description="List of available categories")
|
||||
total: int = Field(..., ge=0, description="Total number of categories")
|
||||
|
||||
@@ -5,6 +5,7 @@ Manages async request lifecycle and background processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
@@ -17,6 +18,7 @@ from typing import TYPE_CHECKING
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
@@ -189,9 +191,7 @@ class AsyncProcessingService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> Path:
|
||||
"""Save uploaded file to temp storage."""
|
||||
import re
|
||||
|
||||
"""Save uploaded file to temp storage using StorageHelper."""
|
||||
# Extract extension from filename
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
@@ -203,9 +203,11 @@ class AsyncProcessingService:
|
||||
if ext not in self.ALLOWED_EXTENSIONS:
|
||||
ext = ".pdf"
|
||||
|
||||
# Create async upload directory
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Get upload directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
|
||||
# Build file path - request_id is a UUID so it's safe
|
||||
file_path = upload_dir / f"{request_id}{ext}"
|
||||
@@ -355,8 +357,9 @@ class AsyncProcessingService:
|
||||
|
||||
def _cleanup_orphan_files(self) -> int:
|
||||
"""Clean up upload files that don't have matching requests."""
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
if not upload_dir.exists():
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None or not upload_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
|
||||
@@ -13,7 +13,7 @@ from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@ from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
from inference.data.admin_models import AdminDocument
|
||||
from shared.data.db import DocumentDB
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,8 +123,12 @@ def process_document_autolabel(
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mark as processing
|
||||
@@ -152,10 +157,12 @@ def process_document_autolabel(
|
||||
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
||||
|
||||
# Build task data
|
||||
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
|
||||
# Use raw_pdfs base path for pdf_path
|
||||
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
||||
storage_config = StorageConfig()
|
||||
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
task_data = {
|
||||
"row_dict": row_dict,
|
||||
@@ -246,8 +253,8 @@ def _save_annotations_to_db(
|
||||
Returns:
|
||||
Number of annotations saved
|
||||
"""
|
||||
from PIL import Image
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS
|
||||
from shared.fields import FIELD_CLASS_IDS
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Mapping from CSV field names to internal field names
|
||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||
@@ -266,6 +273,9 @@ def _save_annotations_to_db(
|
||||
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
||||
scale = dpi / 72.0
|
||||
|
||||
# Get storage helper for image dimensions
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Cache for image dimensions per page
|
||||
image_dimensions: dict[int, tuple[int, int]] = {}
|
||||
|
||||
@@ -274,18 +284,11 @@ def _save_annotations_to_db(
|
||||
if page_no in image_dimensions:
|
||||
return image_dimensions[page_no]
|
||||
|
||||
# Try to load from admin_images
|
||||
admin_images_dir = Path("data/admin_images") / document_id
|
||||
image_path = admin_images_dir / f"page_{page_no}.png"
|
||||
|
||||
if image_path.exists():
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
dims = img.size # (width, height)
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
|
||||
# Get dimensions from storage helper
|
||||
dims = storage.get_admin_image_dimensions(document_id, page_no)
|
||||
if dims:
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
|
||||
return None
|
||||
|
||||
@@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db(
|
||||
from datetime import datetime
|
||||
|
||||
document_id = str(document.document_id)
|
||||
storage_config = StorageConfig()
|
||||
|
||||
# Build pdf_path using admin_upload_dir (same as auto-label)
|
||||
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
# Build pdf_path using raw_pdfs base path (same as auto-label)
|
||||
storage = get_storage_helper()
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
return {
|
||||
"success": False,
|
||||
"document_id": document_id,
|
||||
"error": "Storage not configured for local access",
|
||||
}
|
||||
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
# Build report dict compatible with DocumentDB.save_document()
|
||||
field_results = []
|
||||
|
||||
217
packages/inference/inference/web/services/document_service.py
Normal file
217
packages/inference/inference/web/services/document_service.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Document Service for storage-backed file operations.
|
||||
|
||||
Provides a unified interface for document upload, download, and serving
|
||||
using the storage abstraction layer.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentResult:
|
||||
"""Result of document operation."""
|
||||
|
||||
id: str
|
||||
file_path: str
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document file operations using storage backend.
|
||||
|
||||
Provides upload, download, and URL generation for documents and images.
|
||||
"""
|
||||
|
||||
# Storage path prefixes
|
||||
DOCUMENTS_PREFIX = "documents"
|
||||
IMAGES_PREFIX = "images"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_backend: "StorageBackend",
|
||||
admin_db: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize document service.
|
||||
|
||||
Args:
|
||||
storage_backend: Storage backend for file operations.
|
||||
admin_db: Optional AdminDB instance for database operations.
|
||||
"""
|
||||
self._storage = storage_backend
|
||||
self._admin_db = admin_db
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
dataset_id: str | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> DocumentResult:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename.
|
||||
dataset_id: Optional dataset ID for organization.
|
||||
document_id: Optional document ID (generated if not provided).
|
||||
|
||||
Returns:
|
||||
DocumentResult with ID and storage path.
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
# Extract extension from filename
|
||||
ext = ""
|
||||
if "." in filename:
|
||||
ext = "." + filename.rsplit(".", 1)[-1].lower()
|
||||
|
||||
# Build logical path
|
||||
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
|
||||
|
||||
# Upload via storage backend
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
|
||||
return DocumentResult(
|
||||
id=document_id,
|
||||
file_path=remote_path,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
def download_document(self, remote_path: str) -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a document.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for document access.
|
||||
"""
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, remote_path: str) -> bool:
|
||||
"""Check if a document exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
return self._storage.exists(remote_path)
|
||||
|
||||
def delete_document_files(self, remote_path: str) -> bool:
|
||||
"""Delete a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
return self._storage.delete(remote_path)
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Logical path where image was stored.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
return remote_path
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for image access.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
image_paths = self._storage.list_files(prefix)
|
||||
|
||||
deleted_count = 0
|
||||
for path in image_paths:
|
||||
if self._storage.delete(path):
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
@@ -303,12 +305,19 @@ class InferenceService:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
@@ -320,19 +329,26 @@ class InferenceService:
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
|
||||
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
Storage helpers for web services.
|
||||
|
||||
Provides convenience functions for common storage operations,
|
||||
wrapping the storage backend with proper path handling using prefixes.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_default_storage() -> "StorageBackend":
|
||||
"""Get the default storage backend.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
return get_storage_backend()
|
||||
|
||||
|
||||
class StorageHelper:
|
||||
"""Helper class for storage operations with prefixes.
|
||||
|
||||
Provides high-level operations for document storage, including
|
||||
upload, download, and URL generation with proper path prefixes.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: "StorageBackend | None" = None) -> None:
|
||||
"""Initialize storage helper.
|
||||
|
||||
Args:
|
||||
storage: Storage backend to use. If None, creates default.
|
||||
"""
|
||||
self._storage = storage or get_default_storage()
|
||||
|
||||
@property
|
||||
def storage(self) -> "StorageBackend":
|
||||
"""Get the underlying storage backend."""
|
||||
return self._storage
|
||||
|
||||
# Document operations
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
document_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename (used for extension).
|
||||
document_id: Optional document ID. Generated if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple of (document_id, storage_path).
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
ext = Path(filename).suffix.lower() or ".pdf"
|
||||
path = PREFIXES.document_path(document_id, ext)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
|
||||
return document_id, path
|
||||
|
||||
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
document_id: str,
|
||||
extension: str = ".pdf",
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Check if a document exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Image operations
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
# Upload staging operations
|
||||
|
||||
def save_upload(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
subfolder: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file to upload staging area.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
|
||||
"""Get a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Delete a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Result operations
|
||||
|
||||
def save_result(self, content: bytes, filename: str) -> str:
|
||||
"""Save a result file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_result(self, filename: str) -> bytes:
|
||||
"""Get a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
|
||||
"""Get presigned URL for a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def result_exists(self, filename: str) -> bool:
|
||||
"""Check if a result file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_result(self, filename: str) -> bool:
|
||||
"""Delete a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Export operations
|
||||
|
||||
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
|
||||
"""Save an export file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
export_id: Export identifier.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_export_url(
|
||||
self,
|
||||
export_id: str,
|
||||
filename: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an export file.
|
||||
|
||||
Args:
|
||||
export_id: Export identifier.
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
# Admin image operations
|
||||
|
||||
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
|
||||
"""Get the storage path for an admin image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Storage path like "admin_images/doc123/page_1.png"
|
||||
"""
|
||||
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
|
||||
|
||||
def save_admin_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save an admin page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_admin_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
|
||||
"""Check if an admin page image exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
True if image exists.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def list_admin_images(self, document_id: str) -> list[str]:
|
||||
"""List all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
def delete_admin_images(self, document_id: str) -> int:
|
||||
"""Delete all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def get_admin_image_local_path(
|
||||
self, document_id: str, page_num: int
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an admin image.
|
||||
|
||||
This method is useful for serving files via FileResponse.
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
# Cloud storage - cannot get local path
|
||||
return None
|
||||
|
||||
remote_path = self.get_admin_image_path(document_id, page_num)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(remote_path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_image_dimensions(
|
||||
self, document_id: str, page_num: int
|
||||
) -> tuple[int, int] | None:
|
||||
"""Get the dimensions (width, height) of an admin image.
|
||||
|
||||
This method is useful for normalizing bounding box coordinates.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if image exists, None otherwise.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
# Try local path first for efficiency
|
||||
local_path = self.get_admin_image_local_path(document_id, page_num)
|
||||
if local_path is not None:
|
||||
with Image.open(local_path) as img:
|
||||
return img.size
|
||||
|
||||
# Fall back to downloading for cloud storage
|
||||
if not self.admin_image_exists(document_id, page_num):
|
||||
return None
|
||||
|
||||
try:
|
||||
import io
|
||||
image_bytes = self.get_admin_image(document_id, page_num)
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return img.size
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Raw PDF operations (legacy compatibility)
|
||||
|
||||
def save_raw_pdf(self, content: bytes, filename: str) -> str:
|
||||
"""Save a raw PDF for auto-labeling pipeline.
|
||||
|
||||
Args:
|
||||
content: PDF content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_raw_pdf(self, filename: str) -> bytes:
|
||||
"""Get a raw PDF from storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
PDF content as bytes.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def raw_pdf_exists(self, filename: str) -> bool:
|
||||
"""Check if a raw PDF exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.exists(path)
|
||||
|
||||
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a raw PDF.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdf_path(self, filename: str) -> str:
|
||||
"""Get the storage path for a raw PDF (not the local filesystem path).
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "raw_pdfs/filename.pdf"
|
||||
"""
|
||||
return f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
|
||||
# Result local path operations
|
||||
|
||||
def get_result_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a result file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.result_path(filename)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_results_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for results (local storage only).
|
||||
|
||||
Used for mounting static file directories.
|
||||
|
||||
Returns:
|
||||
Path to results directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Upload local path operations
|
||||
|
||||
def get_upload_local_path(
|
||||
self, filename: str, subfolder: str | None = None
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an upload file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
|
||||
"""Get the base directory path for uploads (local storage only).
|
||||
|
||||
Args:
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Path to uploads directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
if subfolder:
|
||||
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
|
||||
else:
|
||||
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Check if an upload file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.exists(path)
|
||||
|
||||
# Dataset operations
|
||||
|
||||
def get_datasets_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for datasets (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to datasets directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_images_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for admin images (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to admin_images directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdfs_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for raw PDFs (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to raw_pdfs directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_autolabel_output_path(self) -> Path | None:
|
||||
"""Get the directory path for autolabel output (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to autolabel_output directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use a subfolder under results for autolabel output
|
||||
base_path = self._storage._get_full_path("autolabel_output")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_training_data_path(self) -> Path | None:
|
||||
"""Get the directory path for training data exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to training directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path("training")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_exports_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to exports directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Default instance for convenience
|
||||
_default_helper: StorageHelper | None = None
|
||||
|
||||
|
||||
def get_storage_helper() -> StorageHelper:
|
||||
"""Get the default storage helper instance.
|
||||
|
||||
Creates the helper on first call with default storage backend.
|
||||
|
||||
Returns:
|
||||
Default StorageHelper instance.
|
||||
"""
|
||||
global _default_helper
|
||||
if _default_helper is None:
|
||||
_default_helper = StorageHelper()
|
||||
return _default_helper
|
||||
Reference in New Issue
Block a user