re-structure

This commit is contained in:
Yaojia Wang
2026-02-01 22:55:31 +01:00
parent 400b12a967
commit b602d0a340
176 changed files with 856 additions and 853 deletions

View File

@@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install inference package
COPY packages/inference /app/packages/inference
RUN pip install --no-cache-dir -e /app/packages/inference
# Copy frontend (if needed)
COPY frontend /app/frontend
WORKDIR /app/packages/inference
EXPOSE 8000
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]

View File

View File

@@ -0,0 +1,105 @@
"""Trigger training jobs on Azure Container Instances."""
import logging
import os
logger = logging.getLogger(__name__)
# Azure SDK is optional; only needed if using ACI trigger
try:
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
EnvironmentVariable,
GpuResource,
ResourceRequests,
ResourceRequirements,
)
_AZURE_SDK_AVAILABLE = True
except ImportError:
_AZURE_SDK_AVAILABLE = False
def start_training_container(task_id: str) -> str | None:
"""
Start an Azure Container Instance for a training task.
Returns the container group name if successful, None otherwise.
Requires environment variables:
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
"""
if not _AZURE_SDK_AVAILABLE:
logger.warning(
"Azure SDK not installed. Install azure-mgmt-containerinstance "
"and azure-identity to use ACI trigger."
)
return None
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
image = os.environ.get(
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
)
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
location = os.environ.get("AZURE_LOCATION", "eastus")
if not subscription_id:
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
return None
credential = DefaultAzureCredential()
client = ContainerInstanceManagementClient(credential, subscription_id)
container_name = f"training-{task_id[:8]}"
env_vars = [
EnvironmentVariable(name="TASK_ID", value=task_id),
]
# Pass DB connection securely
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
val = os.environ.get(var, "")
if val:
env_vars.append(EnvironmentVariable(name=var, value=val))
db_password = os.environ.get("DB_PASSWORD", "")
if db_password:
env_vars.append(
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
)
container = Container(
name=container_name,
image=image,
resources=ResourceRequirements(
requests=ResourceRequests(
cpu=4,
memory_in_gb=16,
gpu=GpuResource(count=1, sku=gpu_sku),
)
),
environment_variables=env_vars,
command=[
"python",
"run_training.py",
"--task-id",
task_id,
],
)
group = ContainerGroup(
location=location,
containers=[container],
os_type="Linux",
restart_policy="Never",
)
logger.info("Creating ACI container group: %s", container_name)
client.container_groups.begin_create_or_update(
resource_group, container_name, group
)
return container_name

View File

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Inference CLI
Runs inference on new PDFs to extract invoice data.
"""
import argparse
import json
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
def main():
parser = argparse.ArgumentParser(
description='Extract invoice data from PDFs using trained model'
)
parser.add_argument(
'--model', '-m',
required=True,
help='Path to trained YOLO model (.pt file)'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Input PDF file or directory'
)
parser.add_argument(
'--output', '-o',
help='Output JSON file (default: stdout)'
)
parser.add_argument(
'--confidence',
type=float,
default=0.5,
help='Detection confidence threshold (default: 0.5)'
)
parser.add_argument(
'--dpi',
type=int,
default=DEFAULT_DPI,
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
)
parser.add_argument(
'--no-fallback',
action='store_true',
help='Disable fallback OCR'
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
)
parser.add_argument(
'--gpu',
action='store_true',
help='Use GPU'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
args = parser.parse_args()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
sys.exit(1)
# Get input files
input_path = Path(args.input)
if input_path.is_file():
pdf_files = [input_path]
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
from backend.pipeline import InferencePipeline
# Initialize pipeline
pipeline = InferencePipeline(
model_path=model_path,
confidence_threshold=args.confidence,
ocr_lang=args.lang,
use_gpu=args.gpu,
dpi=args.dpi,
enable_fallback=not args.no_fallback
)
# Process files
results = []
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
if result.fallback_used:
print(f" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
# Output results
if len(results) == 1:
output = results[0]
else:
output = results
json_output = json.dumps(output, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
else:
print(json_output)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,159 @@
"""
Web Server CLI
Command-line interface for starting the web server.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
from shared.config import DEFAULT_DPI
def setup_logging(debug: bool = False) -> None:
"""Configure logging."""
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Start the Invoice Field Extraction web server",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to listen on",
)
parser.add_argument(
"--model",
"-m",
type=Path,
default=Path("runs/train/invoice_fields/weights/best.pt"),
help="Path to YOLO model weights",
)
parser.add_argument(
"--confidence",
type=float,
default=0.5,
help="Detection confidence threshold",
)
parser.add_argument(
"--dpi",
type=int,
default=DEFAULT_DPI,
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="Disable GPU acceleration",
)
parser.add_argument(
"--reload",
action="store_true",
help="Enable auto-reload for development",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode",
)
return parser.parse_args()
def main() -> None:
"""Main entry point."""
args = parse_args()
setup_logging(debug=args.debug)
logger = logging.getLogger(__name__)
# Validate model path
if not args.model.exists():
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info("=" * 60)
logger.info("Invoice Field Extraction Web Server")
logger.info("=" * 60)
logger.info(f"Model: {args.model}")
logger.info(f"Confidence threshold: {args.confidence}")
logger.info(f"GPU enabled: {not args.no_gpu}")
logger.info(f"Server: http://{args.host}:{args.port}")
logger.info("=" * 60)
# Create config
from backend.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
config = AppConfig(
model=ModelConfig(
model_path=args.model,
confidence_threshold=args.confidence,
use_gpu=not args.no_gpu,
dpi=args.dpi,
),
server=ServerConfig(
host=args.host,
port=args.port,
debug=args.debug,
reload=args.reload,
workers=args.workers,
),
file=FileConfig(),
)
# Create and run app
import uvicorn
from backend.web.app import create_app
app = create_app(config)
uvicorn.run(
app,
host=config.server.host,
port=config.server.port,
reload=config.server.reload,
workers=config.server.workers if not config.server.reload else 1,
log_level="debug" if config.server.debug else "info",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,437 @@
"""
Admin API SQLModel Database Models
Defines the database schema for admin document management, annotations, and training tasks.
Includes batch upload support, training document links, and annotation history.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
# Import field mappings from single source of truth
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
# =============================================================================
# Core Models
# =============================================================================
class AdminToken(SQLModel, table=True):
"""Admin authentication token."""
__tablename__ = "admin_tokens"
token: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
expires_at: datetime | None = Field(default=None)
class AdminDocument(SQLModel, table=True):
"""Document uploaded for labeling/annotation."""
__tablename__ = "admin_documents"
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
file_path: str = Field(max_length=512) # Path to stored file
page_count: int = Field(default=1)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, auto_labeling, labeled, exported
auto_label_status: str | None = Field(default=None, max_length=20)
# Auto-label status: running, completed, failed
auto_label_error: str | None = Field(default=None)
# v2: Upload source tracking
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization
category: str = Field(default="invoice", max_length=100, index=True)
# Document category for training different models (e.g., invoice, letter, receipt)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
# When auto-label was queued
annotation_lock_until: datetime | None = Field(default=None)
# Lock for manual annotation while auto-label runs
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class AdminAnnotation(SQLModel, table=True):
"""Annotation for a document (bounding box + label)."""
__tablename__ = "admin_annotations"
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
page_number: int = Field(default=1) # 1-indexed
class_id: int # 0-9 for invoice fields
class_name: str = Field(max_length=50) # e.g., "invoice_number"
# Bounding box (normalized 0-1 coordinates)
x_center: float
y_center: float
width: float
height: float
# Original pixel coordinates (for display)
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
# OCR extracted text (if available)
text_value: str | None = Field(default=None)
confidence: float | None = Field(default=None)
# Source: manual, auto, imported
source: str = Field(default="manual", max_length=20, index=True)
# v2: Verification fields
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
# v2: Override tracking
override_source: str | None = Field(default=None, max_length=20)
# If this annotation overrides another: 'auto' or 'imported'
original_annotation_id: UUID | None = Field(default=None)
# Reference to the annotation this overrides
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingTask(SQLModel, table=True):
"""Training/fine-tuning task."""
__tablename__ = "training_tasks"
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, scheduled, running, completed, failed, cancelled
task_type: str = Field(default="train", max_length=20)
# Task type: train, finetune
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training configuration
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Schedule settings
scheduled_at: datetime | None = Field(default=None)
cron_expression: str | None = Field(default=None, max_length=50)
is_recurring: bool = Field(default=False)
# Execution details
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
error_message: str | None = Field(default=None)
# Result metrics
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
model_path: str | None = Field(default=None, max_length=512)
# v2: Document count and extracted metrics
document_count: int = Field(default=0)
# Count of documents used in training
metrics_mAP: float | None = Field(default=None, index=True)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
# Extracted metrics for easy querying
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingLog(SQLModel, table=True):
"""Training log entry."""
__tablename__ = "training_logs"
log_id: int | None = Field(default=None, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
level: str = Field(max_length=20) # INFO, WARNING, ERROR
message: str
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# =============================================================================
# Batch Upload Models (v2)
# =============================================================================
class BatchUpload(SQLModel, table=True):
"""Batch upload of multiple documents via ZIP file."""
__tablename__ = "batch_uploads"
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255) # ZIP filename
file_size: int
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
status: str = Field(default="processing", max_length=20, index=True)
# Status: processing, completed, partial, failed
total_files: int = Field(default=0)
processed_files: int = Field(default=0)
# Number of files processed so far
successful_files: int = Field(default=0)
failed_files: int = Field(default=0)
csv_filename: str | None = Field(default=None, max_length=255)
# CSV file used for auto-labeling
csv_row_count: int | None = Field(default=None)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: datetime | None = Field(default=None)
class BatchUploadFile(SQLModel, table=True):
"""Individual file within a batch upload."""
__tablename__ = "batch_upload_files"
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
filename: str = Field(max_length=255) # PDF filename within ZIP
document_id: UUID | None = Field(default=None)
# Link to created AdminDocument (if successful)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, processing, completed, failed, skipped
error_message: str | None = Field(default=None)
annotation_count: int = Field(default=0)
# Number of annotations created for this file
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# CSV row data for this file (if available)
created_at: datetime = Field(default_factory=datetime.utcnow)
processed_at: datetime | None = Field(default=None)
# =============================================================================
# Training Document Link (v2)
# =============================================================================
class TrainingDataset(SQLModel, table=True):
"""Training dataset containing selected documents with train/val/test splits."""
__tablename__ = "training_datasets"
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, trained, archived, failed
training_status: str | None = Field(default=None, max_length=20, index=True)
# Training status: pending, scheduled, running, completed, failed, cancelled
active_training_task_id: UUID | None = Field(default=None, index=True)
train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1)
seed: int = Field(default=42)
total_documents: int = Field(default=0)
total_images: int = Field(default=0)
total_annotations: int = Field(default=0)
dataset_path: str | None = Field(default=None, max_length=512)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class DatasetDocument(SQLModel, table=True):
"""Junction table linking datasets to documents with split assignment."""
__tablename__ = "dataset_documents"
id: UUID = Field(default_factory=uuid4, primary_key=True)
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
split: str = Field(max_length=10) # train, val, test
page_count: int = Field(default=0)
annotation_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingDocumentLink(SQLModel, table=True):
"""Junction table linking training tasks to documents."""
__tablename__ = "training_document_links"
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of annotations at training time (includes count, verified count, etc.)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Model Version Management
# =============================================================================
class ModelVersion(SQLModel, table=True):
"""Model version for inference deployment."""
__tablename__ = "model_versions"
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
version: str = Field(max_length=50, index=True)
# Semantic version e.g., "1.0.0", "2.1.0"
name: str = Field(max_length=255)
description: str | None = Field(default=None)
model_path: str = Field(max_length=512)
# Path to the model weights file
status: str = Field(default="inactive", max_length=20, index=True)
# Status: active, inactive, archived
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training metrics
metrics_mAP: float | None = Field(default=None)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
document_count: int = Field(default=0)
# Number of documents used in training
# Training configuration snapshot
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of epochs, batch_size, etc.
# File info
file_size: int | None = Field(default=None)
# Model file size in bytes
# Timestamps
trained_at: datetime | None = Field(default=None)
# When training completed
activated_at: datetime | None = Field(default=None)
# When this version was last activated
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================
class AnnotationHistory(SQLModel, table=True):
"""History of annotation changes (for override tracking)."""
__tablename__ = "annotation_history"
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
# Change action: created, updated, deleted, override
action: str = Field(max_length=20, index=True)
# Previous value (for updates/deletes)
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# New value (for creates/updates)
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Change metadata
changed_by: str | None = Field(default=None, max_length=255)
# User/token who made the change
change_reason: str | None = Field(default=None)
# Optional reason for change
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
# This ensures consistency with the trained YOLO model
# Read-only models for API responses
class AdminDocumentRead(SQLModel):
"""Admin document response model."""
document_id: UUID
filename: str
file_size: int
content_type: str
page_count: int
status: str
auto_label_status: str | None
auto_label_error: str | None
category: str = "invoice"
created_at: datetime
updated_at: datetime
class AdminAnnotationRead(SQLModel):
"""Admin annotation response model."""
annotation_id: UUID
document_id: UUID
page_number: int
class_id: int
class_name: str
x_center: float
y_center: float
width: float
height: float
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
text_value: str | None
confidence: float | None
source: str
created_at: datetime
class TrainingTaskRead(SQLModel):
"""Training task response model."""
task_id: UUID
name: str
description: str | None
status: str
task_type: str
config: dict[str, Any] | None
scheduled_at: datetime | None
is_recurring: bool
started_at: datetime | None
completed_at: datetime | None
error_message: str | None
result_metrics: dict[str, Any] | None
model_path: str | None
dataset_id: UUID | None
created_at: datetime
class TrainingDatasetRead(SQLModel):
"""Training dataset response model."""
dataset_id: UUID
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
created_at: datetime
updated_at: datetime
class DatasetDocumentRead(SQLModel):
"""Dataset document response model."""
id: UUID
dataset_id: UUID
document_id: UUID
split: str
page_count: int
annotation_count: int

View File

@@ -0,0 +1,374 @@
"""
Async Request Database Operations
Database interface for async invoice processing requests using SQLModel.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy import func, text
from sqlmodel import Session, select
from backend.data.database import get_session_context, create_db_and_tables, close_engine
from backend.data.models import ApiKey, AsyncRequest, RateLimitEvent
logger = logging.getLogger(__name__)
# Legacy dataclasses for backward compatibility
from dataclasses import dataclass
@dataclass(frozen=True)
class ApiKeyConfig:
"""API key configuration and limits (legacy compatibility)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestDB:
"""Database interface for async processing requests using SQLModel."""
def __init__(self, connection_string: str | None = None) -> None:
# connection_string is kept for backward compatibility but ignored
# SQLModel uses the global engine from database.py
self._initialized = False
def connect(self):
"""Legacy method - returns self for compatibility."""
return self
def close(self) -> None:
"""Close database connections."""
close_engine()
def __enter__(self) -> "AsyncRequestDB":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass # Sessions are managed per-operation
def create_tables(self) -> None:
"""Create async processing tables if they don't exist."""
create_db_and_tables()
self._initialized = True
# ==========================================================================
# API Key Operations
# ==========================================================================
def is_valid_api_key(self, api_key: str) -> bool:
"""Check if API key exists and is active."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
return result is not None and result.is_active is True
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
"""Get API key configuration and limits."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
if result is None:
return None
return ApiKeyConfig(
api_key=result.api_key,
name=result.name,
is_active=result.is_active,
requests_per_minute=result.requests_per_minute,
max_concurrent_jobs=result.max_concurrent_jobs,
max_file_size_mb=result.max_file_size_mb,
)
def create_api_key(
self,
api_key: str,
name: str,
requests_per_minute: int = 10,
max_concurrent_jobs: int = 3,
max_file_size_mb: int = 50,
) -> None:
"""Create a new API key."""
with get_session_context() as session:
existing = session.get(ApiKey, api_key)
if existing:
existing.name = name
existing.requests_per_minute = requests_per_minute
existing.max_concurrent_jobs = max_concurrent_jobs
existing.max_file_size_mb = max_file_size_mb
session.add(existing)
else:
new_key = ApiKey(
api_key=api_key,
name=name,
requests_per_minute=requests_per_minute,
max_concurrent_jobs=max_concurrent_jobs,
max_file_size_mb=max_file_size_mb,
)
session.add(new_key)
def update_api_key_usage(self, api_key: str) -> None:
"""Update API key last used timestamp and increment total requests."""
with get_session_context() as session:
key = session.get(ApiKey, api_key)
if key:
key.last_used_at = datetime.utcnow()
key.total_requests += 1
session.add(key)
# ==========================================================================
# Async Request Operations
# ==========================================================================
def create_request(
self,
api_key: str,
filename: str,
file_size: int,
content_type: str,
expires_at: datetime,
request_id: str | None = None,
) -> str:
"""Create a new async request."""
with get_session_context() as session:
request = AsyncRequest(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
)
if request_id:
request.request_id = UUID(request_id)
session.add(request)
session.flush() # To get the generated ID
return str(request.request_id)
def get_request(self, request_id: str) -> AsyncRequest | None:
"""Get a single async request by ID."""
with get_session_context() as session:
result = session.get(AsyncRequest, UUID(request_id))
if result:
# Detach from session for use outside context
session.expunge(result)
return result
def get_request_by_api_key(
self,
request_id: str,
api_key: str,
) -> AsyncRequest | None:
"""Get a request only if it belongs to the given API key."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.request_id == UUID(request_id),
AsyncRequest.api_key == api_key,
)
result = session.exec(statement).first()
if result:
session.expunge(result)
return result
def update_status(
self,
request_id: str,
status: str,
error_message: str | None = None,
increment_retry: bool = False,
) -> None:
"""Update request status."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = status
if status == "processing":
request.started_at = datetime.utcnow()
if error_message is not None:
request.error_message = error_message
if increment_retry:
request.retry_count += 1
session.add(request)
def complete_request(
self,
request_id: str,
document_id: str,
result: dict[str, Any],
processing_time_ms: float,
visualization_path: str | None = None,
) -> None:
"""Mark request as completed with result."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = "completed"
request.document_id = document_id
request.result = result
request.processing_time_ms = processing_time_ms
request.visualization_path = visualization_path
request.completed_at = datetime.utcnow()
session.add(request)
def get_requests_by_api_key(
self,
api_key: str,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AsyncRequest], int]:
"""Get paginated requests for an API key."""
with get_session_context() as session:
# Count query
count_stmt = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
count_stmt = count_stmt.where(AsyncRequest.status == status)
total = session.exec(count_stmt).one()
# Fetch query
statement = select(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
statement = statement.where(AsyncRequest.status == status)
statement = statement.order_by(AsyncRequest.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
# Detach results from session
for r in results:
session.expunge(r)
return list(results), total
def count_active_jobs(self, api_key: str) -> int:
"""Count active (pending + processing) jobs for an API key."""
with get_session_context() as session:
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key,
AsyncRequest.status.in_(["pending", "processing"]),
)
return session.exec(statement).one()
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
"""Get pending requests ordered by creation time."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.status == "pending"
).order_by(AsyncRequest.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_queue_position(self, request_id: str) -> int | None:
"""Get position of a request in the pending queue."""
with get_session_context() as session:
# Get the request's created_at
request = session.get(AsyncRequest, UUID(request_id))
if not request:
return None
# Count pending requests created before this one
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.status == "pending",
AsyncRequest.created_at < request.created_at,
)
count = session.exec(statement).one()
return count + 1 # 1-based position
# ==========================================================================
# Rate Limit Operations
# ==========================================================================
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
"""Record a rate limit event."""
with get_session_context() as session:
event = RateLimitEvent(
api_key=api_key,
event_type=event_type,
)
session.add(event)
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
"""Count requests in the last N seconds."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
statement = select(func.count()).select_from(RateLimitEvent).where(
RateLimitEvent.api_key == api_key,
RateLimitEvent.event_type == "request",
RateLimitEvent.created_at > cutoff,
)
return session.exec(statement).one()
# ==========================================================================
# Cleanup Operations
# ==========================================================================
def delete_expired_requests(self) -> int:
"""Delete requests that have expired. Returns count of deleted rows."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
expired = session.exec(statement).all()
count = len(expired)
for request in expired:
session.delete(request)
logger.info(f"Deleted {count} expired async requests")
return count
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
"""Delete rate limit events older than N hours."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(hours=hours)
statement = select(RateLimitEvent).where(
RateLimitEvent.created_at < cutoff
)
old_events = session.exec(statement).all()
count = len(old_events)
for event in old_events:
session.delete(event)
return count
def reset_stale_processing_requests(
self,
stale_minutes: int = 10,
max_retries: int = 3,
) -> int:
"""
Reset requests stuck in 'processing' status.
Requests that have been processing for more than stale_minutes
are considered stale. They are either reset to 'pending' (if under
max_retries) or set to 'failed'.
"""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
reset_count = 0
# Find stale processing requests
statement = select(AsyncRequest).where(
AsyncRequest.status == "processing",
AsyncRequest.started_at < cutoff,
)
stale_requests = session.exec(statement).all()
for request in stale_requests:
if request.retry_count < max_retries:
request.status = "pending"
request.started_at = None
else:
request.status = "failed"
request.error_message = "Processing timeout after max retries"
session.add(request)
reset_count += 1
if reset_count > 0:
logger.warning(f"Reset {reset_count} stale processing requests")
return reset_count

View File

@@ -0,0 +1,318 @@
"""
Database Engine and Session Management
Provides SQLModel database engine and session handling.
"""
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine
import sys
from shared.config import get_db_connection_string
logger = logging.getLogger(__name__)
# Global engine instance
_engine = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
connection_string = get_db_connection_string()
# Convert psycopg2 format to SQLAlchemy format
if connection_string.startswith("postgresql://"):
# Already in correct format
pass
elif "host=" in connection_string:
# Convert DSN format to URL format
parts = dict(item.split("=") for item in connection_string.split())
connection_string = (
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
f"/{parts.get('dbname', 'docmaster')}"
)
_engine = create_engine(
connection_string,
echo=False, # Set to True for SQL debugging
pool_pre_ping=True, # Verify connections before use
pool_size=5,
max_overflow=10,
)
return _engine
def run_migrations() -> None:
"""Run database migrations for new columns."""
engine = get_engine()
migrations = [
# Migration 004: Training datasets tables and dataset_id on training_tasks
(
"training_datasets_tables",
"""
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
""",
),
(
"dataset_documents_table",
"""
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
""",
),
(
"training_tasks_dataset_id",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
""",
),
# Migration 005: Add group_key to admin_documents
(
"admin_documents_group_key",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
""",
),
# Migration 006: Model versions table
(
"model_versions_table",
"""
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
task_id UUID REFERENCES training_tasks(task_id),
dataset_id UUID REFERENCES training_datasets(dataset_id),
metrics_mAP FLOAT,
metrics_precision FLOAT,
metrics_recall FLOAT,
document_count INTEGER NOT NULL DEFAULT 0,
training_config JSONB,
file_size BIGINT,
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""",
),
# Migration 009: Add category to admin_documents
(
"admin_documents_category",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
""",
),
# Migration 010: Add training_status and active_training_task_id to training_datasets
(
"training_datasets_training_status",
"""
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
""",
),
# Migration 010b: Update existing datasets with completed training to 'trained' status
(
"training_datasets_update_trained_status",
"""
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
);
""",
),
# Migration 007: Add extra columns to training_tasks
(
"training_tasks_name",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
""",
),
(
"training_tasks_description",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
""",
),
(
"training_tasks_admin_token",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
""",
),
(
"training_tasks_task_type",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
""",
),
(
"training_tasks_recurring",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
""",
),
(
"training_tasks_metrics",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
""",
),
(
"training_tasks_updated_at",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
""",
),
# Migration 008: Fix model_versions foreign key constraints
(
"model_versions_fk_fix",
"""
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_dataset_id_fkey
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_task_id_fkey
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
""",
),
# Migration 006b: Ensure only one active model at a time
(
"model_versions_single_active",
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
ON model_versions(is_active) WHERE is_active = TRUE;
""",
),
]
with engine.connect() as conn:
for name, sql in migrations:
try:
conn.execute(text(sql))
conn.commit()
logger.info(f"Migration '{name}' applied successfully")
except Exception as e:
# Log but don't fail - column may already exist
logger.debug(f"Migration '{name}' skipped or failed: {e}")
def create_db_and_tables() -> None:
"""Create all database tables."""
from backend.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
from backend.data.admin_models import ( # noqa: F401
AdminToken,
AdminDocument,
AdminAnnotation,
TrainingTask,
TrainingLog,
)
engine = get_engine()
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
# Run migrations for new columns
run_migrations()
def get_session() -> Session:
"""Get a new database session."""
engine = get_engine()
return Session(engine)
@contextmanager
def get_session_context() -> Generator[Session, None, None]:
"""Context manager for database sessions with auto-commit/rollback."""
session = get_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def close_engine() -> None:
"""Close the database engine and release connections."""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None
logger.info("Database engine closed")
def execute_raw_sql(sql: str) -> None:
"""Execute raw SQL (for migrations)."""
engine = get_engine()
with engine.connect() as conn:
conn.execute(text(sql))
conn.commit()

View File

@@ -0,0 +1,95 @@
"""
SQLModel Database Models
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
class ApiKey(SQLModel, table=True):
"""API key configuration and limits."""
__tablename__ = "api_keys"
api_key: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
requests_per_minute: int = Field(default=10)
max_concurrent_jobs: int = Field(default=3)
max_file_size_mb: int = Field(default=50)
total_requests: int = Field(default=0)
total_processed: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
class AsyncRequest(SQLModel, table=True):
"""Async request record."""
__tablename__ = "async_requests"
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
status: str = Field(default="pending", max_length=20, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
document_id: str | None = Field(default=None, max_length=100)
error_message: str | None = Field(default=None)
retry_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
expires_at: datetime = Field(index=True)
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
processing_time_ms: float | None = Field(default=None)
visualization_path: str | None = Field(default=None, max_length=255)
class RateLimitEvent(SQLModel, table=True):
"""Rate limit event record."""
__tablename__ = "rate_limit_events"
id: int | None = Field(default=None, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
event_type: str = Field(max_length=50)
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Read-only models for responses (without table=True)
class ApiKeyRead(SQLModel):
"""API key response model (read-only)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestRead(SQLModel):
"""Async request response model (read-only)."""
request_id: UUID
api_key: str
status: str
filename: str
file_size: int
content_type: str
document_id: str | None
error_message: str | None
retry_count: int
created_at: datetime
started_at: datetime | None
completed_at: datetime | None
expires_at: datetime
result: dict[str, Any] | None
processing_time_ms: float | None
visualization_path: str | None

View File

@@ -0,0 +1,26 @@
"""
Repository Pattern Implementation
Provides domain-specific repository classes to replace the monolithic AdminDB.
Each repository handles a single domain following Single Responsibility Principle.
"""
from backend.data.repositories.base import BaseRepository
from backend.data.repositories.token_repository import TokenRepository
from backend.data.repositories.document_repository import DocumentRepository
from backend.data.repositories.annotation_repository import AnnotationRepository
from backend.data.repositories.training_task_repository import TrainingTaskRepository
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
__all__ = [
"BaseRepository",
"TokenRepository",
"DocumentRepository",
"AnnotationRepository",
"TrainingTaskRepository",
"DatasetRepository",
"ModelVersionRepository",
"BatchUploadRepository",
]

View File

@@ -0,0 +1,357 @@
"""
Annotation Repository
Handles annotation operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import AdminAnnotation, AnnotationHistory
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class AnnotationRepository(BaseRepository[AdminAnnotation]):
"""Repository for annotation management.
Handles:
- Annotation CRUD operations
- Batch annotation creation
- Annotation verification
- Annotation override tracking
"""
def create(
self,
document_id: str,
page_number: int,
class_id: int,
class_name: str,
x_center: float,
y_center: float,
width: float,
height: float,
bbox_x: int,
bbox_y: int,
bbox_width: int,
bbox_height: int,
text_value: str | None = None,
confidence: float | None = None,
source: str = "manual",
) -> str:
"""Create a new annotation.
Returns:
Annotation ID as string
"""
with get_session_context() as session:
annotation = AdminAnnotation(
document_id=UUID(document_id),
page_number=page_number,
class_id=class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=bbox_x,
bbox_y=bbox_y,
bbox_width=bbox_width,
bbox_height=bbox_height,
text_value=text_value,
confidence=confidence,
source=source,
)
session.add(annotation)
session.flush()
return str(annotation.annotation_id)
def create_batch(
self,
annotations: list[dict[str, Any]],
) -> list[str]:
"""Create multiple annotations in a batch.
Args:
annotations: List of annotation data dicts
Returns:
List of annotation IDs
"""
with get_session_context() as session:
ids = []
for ann_data in annotations:
annotation = AdminAnnotation(
document_id=UUID(ann_data["document_id"]),
page_number=ann_data.get("page_number", 1),
class_id=ann_data["class_id"],
class_name=ann_data["class_name"],
x_center=ann_data["x_center"],
y_center=ann_data["y_center"],
width=ann_data["width"],
height=ann_data["height"],
bbox_x=ann_data["bbox_x"],
bbox_y=ann_data["bbox_y"],
bbox_width=ann_data["bbox_width"],
bbox_height=ann_data["bbox_height"],
text_value=ann_data.get("text_value"),
confidence=ann_data.get("confidence"),
source=ann_data.get("source", "auto"),
)
session.add(annotation)
session.flush()
ids.append(str(annotation.annotation_id))
return ids
def get(self, annotation_id: str) -> AdminAnnotation | None:
"""Get an annotation by ID."""
with get_session_context() as session:
result = session.get(AdminAnnotation, UUID(annotation_id))
if result:
session.expunge(result)
return result
def get_for_document(
self,
document_id: str,
page_number: int | None = None,
) -> list[AdminAnnotation]:
"""Get all annotations for a document."""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if page_number is not None:
statement = statement.where(AdminAnnotation.page_number == page_number)
statement = statement.order_by(AdminAnnotation.class_id)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def update(
self,
annotation_id: str,
x_center: float | None = None,
y_center: float | None = None,
width: float | None = None,
height: float | None = None,
bbox_x: int | None = None,
bbox_y: int | None = None,
bbox_width: int | None = None,
bbox_height: int | None = None,
text_value: str | None = None,
class_id: int | None = None,
class_name: str | None = None,
) -> bool:
"""Update an annotation.
Returns:
True if updated, False if not found
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
if x_center is not None:
annotation.x_center = x_center
if y_center is not None:
annotation.y_center = y_center
if width is not None:
annotation.width = width
if height is not None:
annotation.height = height
if bbox_x is not None:
annotation.bbox_x = bbox_x
if bbox_y is not None:
annotation.bbox_y = bbox_y
if bbox_width is not None:
annotation.bbox_width = bbox_width
if bbox_height is not None:
annotation.bbox_height = bbox_height
if text_value is not None:
annotation.text_value = text_value
if class_id is not None:
annotation.class_id = class_id
if class_name is not None:
annotation.class_name = class_name
annotation.updated_at = datetime.utcnow()
session.add(annotation)
return True
return False
def delete(self, annotation_id: str) -> bool:
"""Delete an annotation."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
session.delete(annotation)
session.commit()
return True
return False
def delete_for_document(
self,
document_id: str,
source: str | None = None,
) -> int:
"""Delete all annotations for a document.
Returns:
Count of deleted annotations
"""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if source:
statement = statement.where(AdminAnnotation.source == source)
annotations = session.exec(statement).all()
count = len(annotations)
for ann in annotations:
session.delete(ann)
session.commit()
return count
def verify(
self,
annotation_id: str,
admin_token: str,
) -> AdminAnnotation | None:
"""Mark an annotation as verified."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
annotation.is_verified = True
annotation.verified_at = datetime.utcnow()
annotation.verified_by = admin_token
annotation.updated_at = datetime.utcnow()
session.add(annotation)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def override(
self,
annotation_id: str,
admin_token: str,
change_reason: str | None = None,
**updates: Any,
) -> AdminAnnotation | None:
"""Override an auto-generated annotation.
Creates a history record and updates the annotation.
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
previous_value = {
"class_id": annotation.class_id,
"class_name": annotation.class_name,
"bbox": {
"x": annotation.bbox_x,
"y": annotation.bbox_y,
"width": annotation.bbox_width,
"height": annotation.bbox_height,
},
"normalized": {
"x_center": annotation.x_center,
"y_center": annotation.y_center,
"width": annotation.width,
"height": annotation.height,
},
"text_value": annotation.text_value,
"confidence": annotation.confidence,
"source": annotation.source,
}
for key, value in updates.items():
if hasattr(annotation, key):
setattr(annotation, key, value)
if annotation.source == "auto":
annotation.override_source = "auto"
annotation.source = "manual"
annotation.updated_at = datetime.utcnow()
session.add(annotation)
history = AnnotationHistory(
annotation_id=UUID(annotation_id),
document_id=annotation.document_id,
action="override",
previous_value=previous_value,
new_value=updates,
changed_by=admin_token,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def create_history(
self,
annotation_id: UUID,
document_id: UUID,
action: str,
previous_value: dict[str, Any] | None = None,
new_value: dict[str, Any] | None = None,
changed_by: str | None = None,
change_reason: str | None = None,
) -> AnnotationHistory:
"""Create an annotation history record."""
with get_session_context() as session:
history = AnnotationHistory(
annotation_id=annotation_id,
document_id=document_id,
action=action,
previous_value=previous_value,
new_value=new_value,
changed_by=changed_by,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(history)
session.expunge(history)
return history
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
"""Get history for a specific annotation."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.annotation_id == annotation_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
"""Get all annotation history for a document."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.document_id == document_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)

View File

@@ -0,0 +1,75 @@
"""
Base Repository
Provides common functionality for all repositories.
"""
import logging
from abc import ABC
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Generator, TypeVar, Generic
from uuid import UUID
from sqlmodel import Session
from backend.data.database import get_session_context
logger = logging.getLogger(__name__)
T = TypeVar("T")
class BaseRepository(ABC, Generic[T]):
"""Base class for all repositories.
Provides:
- Session management via context manager
- Logging infrastructure
- Common query patterns
- Utility methods for datetime and UUID handling
"""
@contextmanager
def _session(self) -> Generator[Session, None, None]:
"""Get a database session with auto-commit/rollback."""
with get_session_context() as session:
yield session
def _expunge(self, session: Session, entity: T) -> T:
"""Detach entity from session for safe return."""
session.expunge(entity)
return entity
def _expunge_all(self, session: Session, entities: list[T]) -> list[T]:
"""Detach multiple entities from session."""
for entity in entities:
session.expunge(entity)
return entities
@staticmethod
def _now() -> datetime:
"""Get current UTC time as timezone-aware datetime.
Use this instead of datetime.utcnow() which is deprecated in Python 3.12+.
"""
return datetime.now(timezone.utc)
@staticmethod
def _validate_uuid(value: str, field_name: str = "id") -> UUID:
"""Validate and convert string to UUID.
Args:
value: String to convert to UUID
field_name: Name of field for error message
Returns:
Validated UUID
Raises:
ValueError: If value is not a valid UUID
"""
try:
return UUID(value)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid {field_name}: {value}") from e

View File

@@ -0,0 +1,136 @@
"""
Batch Upload Repository
Handles batch upload operations following Single Responsibility Principle.
"""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import BatchUpload, BatchUploadFile
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class BatchUploadRepository(BaseRepository[BatchUpload]):
"""Repository for batch upload management.
Handles:
- Batch upload CRUD operations
- Batch file tracking
- Progress monitoring
"""
def create(
self,
admin_token: str,
filename: str,
file_size: int,
upload_source: str = "ui",
) -> BatchUpload:
"""Create a new batch upload record."""
with get_session_context() as session:
batch = BatchUpload(
admin_token=admin_token,
filename=filename,
file_size=file_size,
upload_source=upload_source,
)
session.add(batch)
session.commit()
session.refresh(batch)
session.expunge(batch)
return batch
def get(self, batch_id: UUID) -> BatchUpload | None:
"""Get batch upload by ID."""
with get_session_context() as session:
result = session.get(BatchUpload, batch_id)
if result:
session.expunge(result)
return result
def update(
self,
batch_id: UUID,
**kwargs: Any,
) -> None:
"""Update batch upload fields."""
with get_session_context() as session:
batch = session.get(BatchUpload, batch_id)
if batch:
for key, value in kwargs.items():
if hasattr(batch, key):
setattr(batch, key, value)
session.add(batch)
def create_file(
self,
batch_id: UUID,
filename: str,
**kwargs: Any,
) -> BatchUploadFile:
"""Create a batch upload file record."""
with get_session_context() as session:
file_record = BatchUploadFile(
batch_id=batch_id,
filename=filename,
**kwargs,
)
session.add(file_record)
session.commit()
session.refresh(file_record)
session.expunge(file_record)
return file_record
def update_file(
self,
file_id: UUID,
**kwargs: Any,
) -> None:
"""Update batch upload file fields."""
with get_session_context() as session:
file_record = session.get(BatchUploadFile, file_id)
if file_record:
for key, value in kwargs.items():
if hasattr(file_record, key):
setattr(file_record, key, value)
session.add(file_record)
def get_files(self, batch_id: UUID) -> list[BatchUploadFile]:
"""Get all files for a batch upload."""
with get_session_context() as session:
statement = select(BatchUploadFile).where(
BatchUploadFile.batch_id == batch_id
).order_by(BatchUploadFile.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_paginated(
self,
admin_token: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[BatchUpload], int]:
"""Get paginated batch uploads."""
with get_session_context() as session:
count_stmt = select(func.count()).select_from(BatchUpload)
total = session.exec(count_stmt).one()
statement = select(BatchUpload).order_by(
BatchUpload.created_at.desc()
).offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total

View File

@@ -0,0 +1,216 @@
"""
Dataset Repository
Handles training dataset operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class DatasetRepository(BaseRepository[TrainingDataset]):
"""Repository for training dataset management.
Handles:
- Dataset CRUD operations
- Dataset status management
- Dataset document linking
- Training status tracking
"""
def create(
self,
name: str,
description: str | None = None,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> TrainingDataset:
"""Create a new training dataset."""
with get_session_context() as session:
dataset = TrainingDataset(
name=name,
description=description,
train_ratio=train_ratio,
val_ratio=val_ratio,
seed=seed,
)
session.add(dataset)
session.commit()
session.refresh(dataset)
session.expunge(dataset)
return dataset
def get(self, dataset_id: str | UUID) -> TrainingDataset | None:
"""Get a dataset by ID."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if dataset:
session.expunge(dataset)
return dataset
def get_paginated(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[TrainingDataset], int]:
"""List datasets with optional status filter."""
with get_session_context() as session:
query = select(TrainingDataset)
count_query = select(func.count()).select_from(TrainingDataset)
if status:
query = query.where(TrainingDataset.status == status)
count_query = count_query.where(TrainingDataset.status == status)
total = session.exec(count_query).one()
datasets = session.exec(
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
).all()
for d in datasets:
session.expunge(d)
return list(datasets), total
def get_active_training_tasks(
self, dataset_ids: list[str]
) -> dict[str, dict[str, str]]:
"""Get active training tasks for datasets.
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
"""
if not dataset_ids:
return {}
valid_uuids = []
for d in dataset_ids:
try:
valid_uuids.append(UUID(d))
except ValueError:
logger.warning("Invalid UUID in get_active_training_tasks: %s", d)
continue
if not valid_uuids:
return {}
with get_session_context() as session:
statement = select(TrainingTask).where(
TrainingTask.dataset_id.in_(valid_uuids),
TrainingTask.status.in_(["pending", "scheduled", "running"]),
)
results = session.exec(statement).all()
return {
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
for t in results
}
def update_status(
self,
dataset_id: str | UUID,
status: str,
error_message: str | None = None,
total_documents: int | None = None,
total_images: int | None = None,
total_annotations: int | None = None,
dataset_path: str | None = None,
) -> None:
"""Update dataset status and optional totals."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.status = status
dataset.updated_at = datetime.utcnow()
if error_message is not None:
dataset.error_message = error_message
if total_documents is not None:
dataset.total_documents = total_documents
if total_images is not None:
dataset.total_images = total_images
if total_annotations is not None:
dataset.total_annotations = total_annotations
if dataset_path is not None:
dataset.dataset_path = dataset_path
session.add(dataset)
session.commit()
def update_training_status(
self,
dataset_id: str | UUID,
training_status: str | None,
active_training_task_id: str | UUID | None = None,
update_main_status: bool = False,
) -> None:
"""Update dataset training status."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.training_status = training_status
dataset.active_training_task_id = (
UUID(str(active_training_task_id)) if active_training_task_id else None
)
dataset.updated_at = datetime.utcnow()
if update_main_status and training_status == "completed":
dataset.status = "trained"
session.add(dataset)
session.commit()
def add_documents(
self,
dataset_id: str | UUID,
documents: list[dict[str, Any]],
) -> None:
"""Batch insert documents into a dataset.
Each dict: {document_id, split, page_count, annotation_count}
"""
with get_session_context() as session:
for doc in documents:
dd = DatasetDocument(
dataset_id=UUID(str(dataset_id)),
document_id=UUID(str(doc["document_id"])),
split=doc["split"],
page_count=doc.get("page_count", 0),
annotation_count=doc.get("annotation_count", 0),
)
session.add(dd)
session.commit()
def get_documents(self, dataset_id: str | UUID) -> list[DatasetDocument]:
"""Get all documents in a dataset."""
with get_session_context() as session:
results = session.exec(
select(DatasetDocument)
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
).all()
for r in results:
session.expunge(r)
return list(results)
def delete(self, dataset_id: str | UUID) -> bool:
"""Delete a dataset and its document links."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return False
# Delete associated document links first
doc_links = session.exec(
select(DatasetDocument).where(
DatasetDocument.dataset_id == UUID(str(dataset_id))
)
).all()
for link in doc_links:
session.delete(link)
session.delete(dataset)
session.commit()
return True

View File

@@ -0,0 +1,453 @@
"""
Document Repository
Handles document operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import AdminDocument, AdminAnnotation
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class DocumentRepository(BaseRepository[AdminDocument]):
"""Repository for document management.
Handles:
- Document CRUD operations
- Document status management
- Document filtering and pagination
- Document category management
"""
def create(
self,
filename: str,
file_size: int,
content_type: str,
file_path: str,
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
category: str = "invoice",
admin_token: str | None = None,
) -> str:
"""Create a new document record.
Args:
filename: Original filename
file_size: File size in bytes
content_type: MIME type
file_path: Storage path
page_count: Number of pages
upload_source: Upload source (ui/api)
csv_field_values: CSV field values for reference
group_key: User-defined grouping key
category: Document category
admin_token: Deprecated, kept for compatibility
Returns:
Document ID as string
"""
with get_session_context() as session:
document = AdminDocument(
filename=filename,
file_size=file_size,
content_type=content_type,
file_path=file_path,
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
category=category,
)
session.add(document)
session.flush()
return str(document.document_id)
def get(self, document_id: str) -> AdminDocument | None:
"""Get a document by ID.
Args:
document_id: Document UUID as string
Returns:
AdminDocument if found, None otherwise
"""
with get_session_context() as session:
result = session.get(AdminDocument, UUID(document_id))
if result:
session.expunge(result)
return result
def get_by_token(
self,
document_id: str,
admin_token: str | None = None,
) -> AdminDocument | None:
"""Get a document by ID. Token parameter is deprecated."""
return self.get(document_id)
def get_paginated(
self,
admin_token: str | None = None,
status: str | None = None,
upload_source: str | None = None,
has_annotations: bool | None = None,
auto_label_status: str | None = None,
batch_id: str | None = None,
category: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AdminDocument], int]:
"""Get paginated documents with optional filters.
Args:
admin_token: Deprecated, kept for compatibility
status: Filter by status
upload_source: Filter by upload source
has_annotations: Filter by annotation presence
auto_label_status: Filter by auto-label status
batch_id: Filter by batch ID
category: Filter by category
limit: Page size
offset: Pagination offset
Returns:
Tuple of (documents, total_count)
"""
with get_session_context() as session:
where_clauses = []
if status:
where_clauses.append(AdminDocument.status == status)
if upload_source:
where_clauses.append(AdminDocument.upload_source == upload_source)
if auto_label_status:
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
if batch_id:
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
if category:
where_clauses.append(AdminDocument.category == category)
count_stmt = select(func.count()).select_from(AdminDocument)
if where_clauses:
count_stmt = count_stmt.where(*where_clauses)
if has_annotations is not None:
if has_annotations:
count_stmt = (
count_stmt
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.group_by(AdminDocument.document_id)
)
else:
count_stmt = (
count_stmt
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.where(AdminAnnotation.annotation_id.is_(None))
)
total = session.exec(count_stmt).one()
statement = select(AdminDocument)
if where_clauses:
statement = statement.where(*where_clauses)
if has_annotations is not None:
if has_annotations:
statement = (
statement
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.group_by(AdminDocument.document_id)
)
else:
statement = (
statement
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.where(AdminAnnotation.annotation_id.is_(None))
)
statement = statement.order_by(AdminDocument.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def update_status(
self,
document_id: str,
status: str,
auto_label_status: str | None = None,
auto_label_error: str | None = None,
) -> None:
"""Update document status.
Args:
document_id: Document UUID as string
status: New status
auto_label_status: Auto-label status
auto_label_error: Auto-label error message
"""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.status = status
document.updated_at = datetime.now(timezone.utc)
if auto_label_status is not None:
document.auto_label_status = auto_label_status
if auto_label_error is not None:
document.auto_label_error = auto_label_error
session.add(document)
def update_file_path(self, document_id: str, file_path: str) -> None:
"""Update document file path."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.file_path = file_path
document.updated_at = datetime.now(timezone.utc)
session.add(document)
def update_group_key(self, document_id: str, group_key: str | None) -> bool:
"""Update document group key."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.group_key = group_key
document.updated_at = datetime.now(timezone.utc)
session.add(document)
return True
return False
def update_category(self, document_id: str, category: str) -> AdminDocument | None:
"""Update document category."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.category = category
document.updated_at = datetime.now(timezone.utc)
session.add(document)
session.commit()
session.refresh(document)
return document
return None
def delete(self, document_id: str) -> bool:
"""Delete a document and its annotations.
Args:
document_id: Document UUID as string
Returns:
True if deleted, False if not found
"""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
ann_stmt = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
annotations = session.exec(ann_stmt).all()
for ann in annotations:
session.delete(ann)
session.delete(document)
session.commit()
return True
return False
def get_categories(self) -> list[str]:
"""Get list of unique document categories."""
with get_session_context() as session:
statement = (
select(AdminDocument.category)
.distinct()
.order_by(AdminDocument.category)
)
categories = session.exec(statement).all()
return [c for c in categories if c is not None]
def get_labeled_for_export(
self,
admin_token: str | None = None,
) -> list[AdminDocument]:
"""Get all labeled documents ready for export."""
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == "labeled"
)
if admin_token:
statement = statement.where(AdminDocument.admin_token == admin_token)
statement = statement.order_by(AdminDocument.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def count_by_status(
self,
admin_token: str | None = None,
) -> dict[str, int]:
"""Count documents by status."""
with get_session_context() as session:
statement = select(
AdminDocument.status,
func.count(AdminDocument.document_id),
).group_by(AdminDocument.status)
results = session.exec(statement).all()
return {status: count for status, count in results}
def get_by_ids(self, document_ids: list[str]) -> list[AdminDocument]:
"""Get documents by list of IDs."""
with get_session_context() as session:
uuids = [UUID(str(did)) for did in document_ids]
results = session.exec(
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
).all()
for r in results:
session.expunge(r)
return list(results)
def get_for_training(
self,
admin_token: str | None = None,
status: str = "labeled",
has_annotations: bool = True,
min_annotation_count: int | None = None,
exclude_used_in_training: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[AdminDocument], int]:
"""Get documents suitable for training with filtering."""
from backend.data.admin_models import TrainingDocumentLink
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == status,
)
if has_annotations or min_annotation_count:
annotation_subq = (
select(func.count(AdminAnnotation.annotation_id))
.where(AdminAnnotation.document_id == AdminDocument.document_id)
.correlate(AdminDocument)
.scalar_subquery()
)
if has_annotations:
statement = statement.where(annotation_subq > 0)
if min_annotation_count:
statement = statement.where(annotation_subq >= min_annotation_count)
if exclude_used_in_training:
from sqlalchemy import exists
training_subq = exists(
select(1)
.select_from(TrainingDocumentLink)
.where(TrainingDocumentLink.document_id == AdminDocument.document_id)
)
statement = statement.where(~training_subq)
count_statement = select(func.count()).select_from(statement.subquery())
total = session.exec(count_statement).one()
statement = statement.order_by(AdminDocument.created_at.desc())
statement = statement.limit(limit).offset(offset)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def acquire_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
duration_seconds: int = 300,
) -> AdminDocument | None:
"""Acquire annotation lock for a document."""
from datetime import timedelta
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
now = datetime.now(timezone.utc)
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if lock_until and lock_until > now:
return None
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc
def release_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
force: bool = False,
) -> AdminDocument | None:
"""Release annotation lock for a document."""
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
doc.annotation_lock_until = None
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc
def extend_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
additional_seconds: int = 300,
) -> AdminDocument | None:
"""Extend an existing annotation lock."""
from datetime import timedelta
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
now = datetime.now(timezone.utc)
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if not lock_until or lock_until <= now:
return None
doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds)
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc

View File

@@ -0,0 +1,200 @@
"""
Model Version Repository
Handles model version operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import ModelVersion
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class ModelVersionRepository(BaseRepository[ModelVersion]):
"""Repository for model version management.
Handles:
- Model version CRUD operations
- Model activation/deactivation
- Active model resolution
"""
def create(
self,
version: str,
name: str,
model_path: str,
description: str | None = None,
task_id: str | UUID | None = None,
dataset_id: str | UUID | None = None,
metrics_mAP: float | None = None,
metrics_precision: float | None = None,
metrics_recall: float | None = None,
document_count: int = 0,
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
model = ModelVersion(
version=version,
name=name,
model_path=model_path,
description=description,
task_id=UUID(str(task_id)) if task_id else None,
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
)
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def get(self, version_id: str | UUID) -> ModelVersion | None:
"""Get a model version by ID."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if model:
session.expunge(model)
return model
def get_paginated(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[ModelVersion], int]:
"""List model versions with optional status filter."""
with get_session_context() as session:
query = select(ModelVersion)
count_query = select(func.count()).select_from(ModelVersion)
if status:
query = query.where(ModelVersion.status == status)
count_query = count_query.where(ModelVersion.status == status)
total = session.exec(count_query).one()
models = session.exec(
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
).all()
for m in models:
session.expunge(m)
return list(models), total
def get_active(self) -> ModelVersion | None:
"""Get the currently active model version for inference."""
with get_session_context() as session:
result = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).first()
if result:
session.expunge(result)
return result
def activate(self, version_id: str | UUID) -> ModelVersion | None:
"""Activate a model version for inference (deactivates all others)."""
with get_session_context() as session:
all_versions = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).all()
for v in all_versions:
v.is_active = False
v.status = "inactive"
v.updated_at = datetime.utcnow()
session.add(v)
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = True
model.status = "active"
model.activated_at = datetime.utcnow()
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def deactivate(self, version_id: str | UUID) -> ModelVersion | None:
"""Deactivate a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = False
model.status = "inactive"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def update(
self,
version_id: str | UUID,
name: str | None = None,
description: str | None = None,
status: str | None = None,
) -> ModelVersion | None:
"""Update model version metadata."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if name is not None:
model.name = name
if description is not None:
model.description = description
if status is not None:
model.status = status
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def archive(self, version_id: str | UUID) -> ModelVersion | None:
"""Archive a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if model.is_active:
return None
model.status = "archived"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def delete(self, version_id: str | UUID) -> bool:
"""Delete a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return False
if model.is_active:
return False
session.delete(model)
session.commit()
return True

View File

@@ -0,0 +1,117 @@
"""
Token Repository
Handles admin token operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from backend.data.admin_models import AdminToken
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class TokenRepository(BaseRepository[AdminToken]):
"""Repository for admin token management.
Handles:
- Token validation (active status, expiration)
- Token CRUD operations
- Usage tracking
"""
def is_valid(self, token: str) -> bool:
"""Check if admin token exists and is active.
Args:
token: The token string to validate
Returns:
True if token exists, is active, and not expired
"""
with self._session() as session:
result = session.get(AdminToken, token)
if result is None:
return False
if not result.is_active:
return False
if result.expires_at and result.expires_at < self._now():
return False
return True
def get(self, token: str) -> AdminToken | None:
"""Get admin token details.
Args:
token: The token string
Returns:
AdminToken if found, None otherwise
"""
with self._session() as session:
result = session.get(AdminToken, token)
if result:
session.expunge(result)
return result
def create(
self,
token: str,
name: str,
expires_at: datetime | None = None,
) -> None:
"""Create or update an admin token.
If token exists, updates name, expires_at, and reactivates it.
Otherwise creates a new token.
Args:
token: The token string
name: Display name for the token
expires_at: Optional expiration datetime
"""
with self._session() as session:
existing = session.get(AdminToken, token)
if existing:
existing.name = name
existing.expires_at = expires_at
existing.is_active = True
session.add(existing)
else:
new_token = AdminToken(
token=token,
name=name,
expires_at=expires_at,
)
session.add(new_token)
def update_usage(self, token: str) -> None:
"""Update admin token last used timestamp.
Args:
token: The token string
"""
with self._session() as session:
admin_token = session.get(AdminToken, token)
if admin_token:
admin_token.last_used_at = self._now()
session.add(admin_token)
def deactivate(self, token: str) -> bool:
"""Deactivate an admin token.
Args:
token: The token string
Returns:
True if token was deactivated, False if not found
"""
with self._session() as session:
admin_token = session.get(AdminToken, token)
if admin_token:
admin_token.is_active = False
session.add(admin_token)
return True
return False

View File

@@ -0,0 +1,249 @@
"""
Training Task Repository
Handles training task operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class TrainingTaskRepository(BaseRepository[TrainingTask]):
"""Repository for training task management.
Handles:
- Training task CRUD operations
- Task status management
- Training logs
- Training document links
"""
def create(
self,
admin_token: str,
name: str,
task_type: str = "train",
description: str | None = None,
config: dict[str, Any] | None = None,
scheduled_at: datetime | None = None,
cron_expression: str | None = None,
is_recurring: bool = False,
dataset_id: str | None = None,
) -> str:
"""Create a new training task.
Returns:
Task ID as string
"""
with get_session_context() as session:
task = TrainingTask(
admin_token=admin_token,
name=name,
task_type=task_type,
description=description,
config=config,
scheduled_at=scheduled_at,
cron_expression=cron_expression,
is_recurring=is_recurring,
status="scheduled" if scheduled_at else "pending",
dataset_id=dataset_id,
)
session.add(task)
session.flush()
return str(task.task_id)
def get(self, task_id: str) -> TrainingTask | None:
"""Get a training task by ID."""
with get_session_context() as session:
result = session.get(TrainingTask, UUID(task_id))
if result:
session.expunge(result)
return result
def get_by_token(
self,
task_id: str,
admin_token: str | None = None,
) -> TrainingTask | None:
"""Get a training task by ID. Token parameter is deprecated."""
return self.get(task_id)
def get_paginated(
self,
admin_token: str | None = None,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[TrainingTask], int]:
"""Get paginated training tasks."""
with get_session_context() as session:
count_stmt = select(func.count()).select_from(TrainingTask)
if status:
count_stmt = count_stmt.where(TrainingTask.status == status)
total = session.exec(count_stmt).one()
statement = select(TrainingTask)
if status:
statement = statement.where(TrainingTask.status == status)
statement = statement.order_by(TrainingTask.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def get_pending(self) -> list[TrainingTask]:
"""Get pending training tasks ready to run."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(TrainingTask).where(
TrainingTask.status.in_(["pending", "scheduled"]),
(TrainingTask.scheduled_at == None) | (TrainingTask.scheduled_at <= now),
).order_by(TrainingTask.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_running(self) -> TrainingTask | None:
"""Get currently running training task.
Returns:
Running task or None if no task is running
"""
with get_session_context() as session:
result = session.exec(
select(TrainingTask)
.where(TrainingTask.status == "running")
.order_by(TrainingTask.started_at.desc())
).first()
if result:
session.expunge(result)
return result
def update_status(
self,
task_id: str,
status: str,
error_message: str | None = None,
result_metrics: dict[str, Any] | None = None,
model_path: str | None = None,
) -> None:
"""Update training task status."""
with get_session_context() as session:
task = session.get(TrainingTask, UUID(task_id))
if task:
task.status = status
task.updated_at = datetime.utcnow()
if status == "running":
task.started_at = datetime.utcnow()
elif status in ("completed", "failed"):
task.completed_at = datetime.utcnow()
if error_message is not None:
task.error_message = error_message
if result_metrics is not None:
task.result_metrics = result_metrics
if model_path is not None:
task.model_path = model_path
session.add(task)
def cancel(self, task_id: str) -> bool:
"""Cancel a training task."""
with get_session_context() as session:
task = session.get(TrainingTask, UUID(task_id))
if task and task.status in ("pending", "scheduled"):
task.status = "cancelled"
task.updated_at = datetime.utcnow()
session.add(task)
return True
return False
def add_log(
self,
task_id: str,
level: str,
message: str,
details: dict[str, Any] | None = None,
) -> None:
"""Add a training log entry."""
with get_session_context() as session:
log = TrainingLog(
task_id=UUID(task_id),
level=level,
message=message,
details=details,
)
session.add(log)
def get_logs(
self,
task_id: str,
limit: int = 100,
offset: int = 0,
) -> list[TrainingLog]:
"""Get training logs for a task."""
with get_session_context() as session:
statement = select(TrainingLog).where(
TrainingLog.task_id == UUID(task_id)
).order_by(TrainingLog.created_at.desc()).offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def create_document_link(
self,
task_id: UUID,
document_id: UUID,
annotation_snapshot: dict[str, Any] | None = None,
) -> TrainingDocumentLink:
"""Create a training document link."""
with get_session_context() as session:
link = TrainingDocumentLink(
task_id=task_id,
document_id=document_id,
annotation_snapshot=annotation_snapshot,
)
session.add(link)
session.commit()
session.refresh(link)
session.expunge(link)
return link
def get_document_links(self, task_id: UUID) -> list[TrainingDocumentLink]:
"""Get all document links for a training task."""
with get_session_context() as session:
statement = select(TrainingDocumentLink).where(
TrainingDocumentLink.task_id == task_id
).order_by(TrainingDocumentLink.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_document_training_tasks(self, document_id: UUID) -> list[TrainingDocumentLink]:
"""Get all training tasks that used this document."""
with get_session_context() as session:
statement = select(TrainingDocumentLink).where(
TrainingDocumentLink.document_id == document_id
).order_by(TrainingDocumentLink.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)

View File

@@ -0,0 +1,5 @@
from .pipeline import InferencePipeline, InferenceResult
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']

View File

@@ -0,0 +1,101 @@
"""
Inference Configuration Constants
Centralized configuration values for the inference pipeline.
Extracted from hardcoded values across multiple modules for easier maintenance.
"""
# ============================================================================
# Detection & Model Configuration
# ============================================================================
# YOLO Detection
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
# ============================================================================
# Image Processing Configuration
# ============================================================================
# DPI (Dots Per Inch) for PDF rendering
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
# ============================================================================
# Customer Number Parser Configuration
# ============================================================================
# Pattern confidence scores (higher = more confident)
CUSTOMER_NUMBER_CONFIDENCE = {
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
'compact': 0.75, # Compact format (e.g., "JTY5763")
'generic_base': 0.5, # Base score for generic alphanumeric pattern
}
# Bonus scores for generic pattern matching
CUSTOMER_NUMBER_BONUS = {
'has_dash': 0.2, # Bonus if contains dash
'typical_format': 0.25, # Bonus for format XXX NNN-X
'medium_length': 0.1, # Bonus for length 6-12 characters
}
# Customer number length constraints
CUSTOMER_NUMBER_LENGTH = {
'min': 6, # Minimum length for medium length bonus
'max': 12, # Maximum length for medium length bonus
}
# ============================================================================
# Field Extraction Confidence Scores
# ============================================================================
# Confidence multipliers and base scores
FIELD_CONFIDENCE = {
'pdf_text': 1.0, # PDF text extraction (always accurate)
'payment_line_high': 0.95, # Payment line parsed successfully
'regex_fallback': 0.5, # Regex-based fallback extraction
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
}
# ============================================================================
# Payment Line Validation
# ============================================================================
# Account number length thresholds for type detection
ACCOUNT_TYPE_THRESHOLD = {
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
}
# ============================================================================
# OCR Configuration
# ============================================================================
# Minimum OCR reference number length
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
# ============================================================================
# Pattern Matching
# ============================================================================
# Swedish postal code pattern (to exclude from customer numbers)
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
# ============================================================================
# Usage Notes
# ============================================================================
"""
These constants can be overridden at runtime by passing parameters to
constructors or methods. The values here serve as sensible defaults
based on Swedish invoice processing requirements.
Example:
from backend.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
detector = YOLODetector(
model_path="model.pt",
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
)
"""

View File

@@ -0,0 +1,390 @@
"""
Swedish Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
Uses Strategy Pattern with multiple matching patterns.
Common Swedish customer number formats:
- JTY 576-3
- EMM 256-6
- DWQ 211-X
- FFL 019N
"""
import re
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List
from shared.exceptions import CustomerNumberParseError
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
"""The normalized customer number"""
pattern_name: str
"""Name of the pattern that matched"""
confidence: float
"""Confidence score (0.0 to 1.0)"""
raw_text: str
"""Original text that was matched"""
position: int = 0
"""Position in text where match was found"""
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""
Try to match pattern in text.
Args:
text: Text to search for customer number
Returns:
CustomerNumberMatch if found, None otherwise
"""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""
Format matched groups to standard format.
Args:
match: Regex match object
Returns:
Formatted customer number string
"""
pass
class DashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123-X (with dash)
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with dash format."""
match = self.PATTERN.search(text)
if not match:
return None
# Check if it's not a postal code
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE10643, etc.
return bool(
text.upper().startswith('SE ') and
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
)
class NoDashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123X (no dash)
Examples: Dwq 211X, FFL 019N
Converts to: DWQ 211-X, FFL 019-N
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number without dash."""
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format (add dash)."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CompactFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC123X (compact, no spaces)
Examples: JTY5763, FFL019N
"""
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match compact customer number format."""
upper_text = text.upper()
match = self.PATTERN.search(upper_text)
if not match:
return None
# Filter out SE postal codes
if match.group(1) == 'SE':
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="CompactFormat",
confidence=0.75,
raw_text=match.group(0),
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to ABC123X or ABC123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
if suffix:
return f"{prefix} {number}-{suffix}"
else:
return f"{prefix}{number}"
class GenericAlphanumericPattern(CustomerNumberPattern):
"""
Generic pattern: Letters + numbers + optional dash/letter
Examples: EMM 256-6, ABC 123, FFL 019
"""
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match generic alphanumeric pattern."""
upper_text = text.upper()
all_matches = []
for match in self.PATTERN.finditer(upper_text):
matched_text = match.group(1)
# Filter out pure numbers
if re.match(r'^\d+$', matched_text):
continue
# Filter out Swedish postal codes
if re.match(r'^SE[\s\-]*\d', matched_text):
continue
# Filter out single letter + digit + space + digit (V4 2)
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
continue
# Calculate confidence based on characteristics
confidence = self._calculate_confidence(matched_text)
all_matches.append((confidence, matched_text, match.start()))
if all_matches:
# Return highest confidence match
best = max(all_matches, key=lambda x: x[0])
return CustomerNumberMatch(
value=best[1].strip(),
pattern_name="GenericAlphanumeric",
confidence=best[0],
raw_text=best[1],
position=best[2]
)
return None
def format(self, match: re.Match) -> str:
"""Return matched text as-is (already uppercase)."""
return match.group(1).strip()
def _calculate_confidence(self, text: str) -> float:
"""Calculate confidence score based on text characteristics."""
# Require letters AND digits
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
has_digits = bool(re.search(r'\d', text))
if not (has_letters and has_digits):
return 0.0 # Not a valid customer number
score = 0.5 # Base score
# Bonus for containing dash
if '-' in text:
score += 0.2
# Bonus for typical format XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
score += 0.25
# Bonus for medium length
if 6 <= len(text) <= 12:
score += 0.1
return min(score, 1.0)
class LabeledPattern(CustomerNumberPattern):
"""
Pattern: Explicit label + customer number
Examples:
- "Kundnummer: JTY 576-3"
- "Customer No: EMM 256-6"
"""
PATTERN = re.compile(
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
re.IGNORECASE
)
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with explicit label."""
match = self.PATTERN.search(text)
if not match:
return None
extracted = match.group(1).strip()
# Remove trailing punctuation
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
if extracted and len(extracted) >= 2:
return CustomerNumberMatch(
value=extracted.upper(),
pattern_name="Labeled",
confidence=0.98, # Very high confidence when labeled
raw_text=match.group(0),
position=match.start()
)
return None
def format(self, match: re.Match) -> str:
"""Return matched customer number."""
extracted = match.group(1).strip()
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
"""Initialize parser with patterns ordered by specificity."""
self.patterns: List[CustomerNumberPattern] = [
LabeledPattern(), # Highest priority - explicit label
DashFormatPattern(), # Standard format with dash
NoDashFormatPattern(), # Standard format without dash
CompactFormatPattern(), # Compact format
GenericAlphanumericPattern(), # Fallback generic pattern
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Args:
text: Text to search for customer number
Returns:
Tuple of (customer_number, is_valid, error_message)
"""
if not text or not text.strip():
return None, False, "Empty text"
text = text.strip()
# Try each pattern
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# No matches
if not all_matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
self.logger.debug(
f"Customer number matched: {best_match.value} "
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
)
return best_match.value, True, None
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
Args:
text: Text to search
Returns:
List of CustomerNumberMatch sorted by confidence (descending)
"""
if not text or not text.strip():
return []
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# Sort by confidence (highest first), then by position (later first)
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)

View File

@@ -0,0 +1,630 @@
"""
Field Extractor Module
Extracts and validates field values from detected regions.
This module is used during inference to extract values from OCR text.
It uses shared utilities from shared.utils for text cleaning and validation.
Enhanced features:
- Multi-source fusion with confidence weighting
- Smart amount parsing with multiple strategies
- Enhanced date format unification
- OCR error correction integration
Refactored to use modular normalizers for each field type.
"""
from dataclasses import dataclass, field
from collections import defaultdict
import re
import numpy as np
from PIL import Image
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import Detection
# Import shared utilities for text cleaning and validation
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
# Import new unified parsers
from .payment_line_parser import PaymentLineParser
from .customer_number_parser import CustomerNumberParser
# Import normalizers
from .normalizers import (
BaseNormalizer,
NormalizationResult,
create_normalizer_registry,
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
@dataclass
class ExtractedField:
"""Represents an extracted field value."""
field_name: str
raw_text: str
normalized_value: str | None
confidence: float
detection_confidence: float
ocr_confidence: float
bbox: tuple[float, float, float, float]
page_no: int
is_valid: bool = True
validation_error: str | None = None
# Multi-source fusion fields
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
ocr_corrections_applied: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
result = {
'field_name': self.field_name,
'value': self.normalized_value,
'raw_text': self.raw_text,
'confidence': self.confidence,
'bbox': list(self.bbox),
'page_no': self.page_no,
'is_valid': self.is_valid,
'validation_error': self.validation_error
}
if self.alternative_values:
result['alternatives'] = self.alternative_values
if self.extraction_method != 'single':
result['extraction_method'] = self.extraction_method
return result
class FieldExtractor:
"""Extracts field values from detected regions using OCR or PDF text."""
def __init__(
self,
ocr_lang: str = 'en',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300,
use_enhanced_parsing: bool = False
):
"""
Initialize field extractor.
Args:
ocr_lang: Language for OCR
use_gpu: Whether to use GPU for OCR
bbox_padding: Padding to add around bboxes (as fraction)
dpi: DPI used for rendering (for coordinate conversion)
use_enhanced_parsing: Whether to use enhanced normalizers
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self.bbox_padding = bbox_padding
self.dpi = dpi
self._ocr_engine = None # Lazy init
self.use_enhanced_parsing = use_enhanced_parsing
# Initialize new unified parsers
self.payment_line_parser = PaymentLineParser()
self.customer_number_parser = CustomerNumberParser()
# Initialize normalizer registry
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
@property
def ocr_engine(self):
"""Lazy-load OCR engine only when needed."""
if self._ocr_engine is None:
from shared.ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang)
return self._ocr_engine
def extract_from_detection_with_pdf(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int
) -> ExtractedField:
"""
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
Args:
detection: Detection object with bbox in pixel coordinates
pdf_tokens: List of Token objects from PDF text extraction
image_width: Width of rendered image in pixels
image_height: Height of rendered image in pixels
Returns:
ExtractedField object
"""
# Convert detection bbox from pixels to PDF points
scale = 72 / self.dpi # points per pixel
x0_pdf = detection.bbox[0] * scale
y0_pdf = detection.bbox[1] * scale
x1_pdf = detection.bbox[2] * scale
y1_pdf = detection.bbox[3] * scale
# Add padding in points
pad = 3 # Small padding in points
# Find tokens that overlap with detection bbox
matching_tokens = []
for token in pdf_tokens:
if token.page_no != detection.page_no:
continue
tx0, ty0, tx1, ty1 = token.bbox
# Check overlap
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
# Calculate overlap ratio to prioritize better matches
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
if overlap_x > 0 and overlap_y > 0:
token_area = (tx1 - tx0) * (ty1 - ty0)
overlap_area = overlap_x * overlap_y
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
detection_confidence=detection.confidence,
ocr_confidence=1.0, # PDF text is always accurate
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def extract_from_detection(
self,
detection: Detection,
image: np.ndarray | Image.Image
) -> ExtractedField:
"""
Extract field value from a detection region using OCR.
Args:
detection: Detection object
image: Full page image
Returns:
ExtractedField object
"""
if isinstance(image, Image.Image):
image = np.array(image)
# Get padded bbox
h, w = image.shape[:2]
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
# Crop region
x0, y0, x1, y1 = [int(v) for v in bbox]
region = image[y0:y1, x0:x1]
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
# Combined confidence
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=confidence,
detection_confidence=detection.confidence,
ocr_confidence=ocr_confidence,
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def _normalize_and_validate(
self,
field_name: str,
raw_text: str
) -> tuple[str | None, bool, str | None]:
"""
Normalize and validate extracted text for a field.
Uses modular normalizers for each field type.
Falls back to legacy methods for payment_line and customer_number.
Returns:
(normalized_value, is_valid, validation_error)
"""
text = raw_text.strip()
if not text:
return None, False, "Empty text"
# Special handling for payment_line and customer_number (use unified parsers)
if field_name == 'payment_line':
return self._normalize_payment_line(text)
if field_name == 'customer_number':
return self._normalize_customer_number(text)
# Use normalizer registry for other fields
normalizer = self._normalizers.get(field_name)
if normalizer:
result = normalizer.normalize(text)
return result.to_tuple()
# Fallback for unknown fields
return text, True, None
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize payment line region text using unified PaymentLineParser.
Extracts the machine-readable payment line format from OCR text.
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Examples:
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
Returns normalized format preserving ALL components including Amount.
This allows downstream cross-validation to extract fields properly.
"""
# Use unified payment line parser
return self.payment_line_parser.format_for_field_extractor(
self.payment_line_parser.parse(text)
)
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize customer number text using unified CustomerNumberParser.
Supports various Swedish customer number formats:
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
"""
return self.customer_number_parser.parse(text)
def extract_all_fields(
self,
detections: list[Detection],
image: np.ndarray | Image.Image
) -> list[ExtractedField]:
"""
Extract fields from all detections.
Args:
detections: List of detections
image: Full page image
Returns:
List of ExtractedField objects
"""
fields = []
for detection in detections:
field = self.extract_from_detection(detection, image)
fields.append(field)
return fields
@staticmethod
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
"""
Infer OCR field from InvoiceNumber if not detected.
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
Args:
fields: Dict of field_name -> normalized_value
Returns:
Updated fields dict with inferred OCR if applicable
"""
# If OCR already exists, no need to infer
if fields.get('OCR'):
return fields
# If InvoiceNumber exists and is numeric, use it as OCR
invoice_number = fields.get('InvoiceNumber')
if invoice_number:
# Check if it's mostly digits (valid OCR reference)
digits_only = re.sub(r'\D', '', invoice_number)
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
fields['OCR'] = invoice_number
return fields
# =========================================================================
# Multi-Source Fusion with Confidence Weighting
# =========================================================================
def fuse_multiple_detections(
self,
extracted_fields: list[ExtractedField]
) -> list[ExtractedField]:
"""
Fuse multiple detections of the same field using confidence-weighted voting.
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
this method selects the best value or combines them intelligently.
Strategies:
1. For numeric fields (Amount, OCR): prefer values that pass validation
2. For date fields: prefer values in expected range
3. For giro numbers: prefer values with valid Luhn checksum
4. General: weighted vote by confidence scores
Args:
extracted_fields: List of all extracted fields (may have duplicates)
Returns:
List with duplicates resolved to single best value per field
"""
# Group fields by name
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
for field in extracted_fields:
fields_by_name[field.field_name].append(field)
fused_fields = []
for field_name, candidates in fields_by_name.items():
if len(candidates) == 1:
# No fusion needed
fused_fields.append(candidates[0])
else:
# Multiple candidates - fuse them
fused = self._fuse_field_candidates(field_name, candidates)
fused_fields.append(fused)
return fused_fields
def _fuse_field_candidates(
self,
field_name: str,
candidates: list[ExtractedField]
) -> ExtractedField:
"""
Fuse multiple candidates for a single field.
Returns the best candidate with alternatives recorded.
"""
# Sort by confidence (descending)
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
# Collect all unique values with their max confidence
value_scores: dict[str, tuple[float, ExtractedField]] = {}
for c in sorted_candidates:
if c.normalized_value:
if c.normalized_value not in value_scores:
value_scores[c.normalized_value] = (c.confidence, c)
else:
# Keep the higher confidence one
if c.confidence > value_scores[c.normalized_value][0]:
value_scores[c.normalized_value] = (c.confidence, c)
if not value_scores:
# No valid values, return the highest confidence candidate
return sorted_candidates[0]
# Field-specific fusion strategy
best_value, best_field = self._select_best_value(field_name, value_scores)
# Record alternatives
alternatives = [
(v, score) for v, (score, _) in value_scores.items()
if v != best_value
]
# Create fused result
result = ExtractedField(
field_name=field_name,
raw_text=best_field.raw_text,
normalized_value=best_value,
confidence=value_scores[best_value][0],
detection_confidence=best_field.detection_confidence,
ocr_confidence=best_field.ocr_confidence,
bbox=best_field.bbox,
page_no=best_field.page_no,
is_valid=best_field.is_valid,
validation_error=best_field.validation_error,
alternative_values=alternatives,
extraction_method='fused' if len(value_scores) > 1 else 'single'
)
return result
def _select_best_value(
self,
field_name: str,
value_scores: dict[str, tuple[float, ExtractedField]]
) -> tuple[str, ExtractedField]:
"""
Select the best value for a field using field-specific logic.
Returns (best_value, best_field)
"""
items = list(value_scores.items())
# Field-specific selection
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
# Prefer values with valid Luhn checksum
for value, (score, field) in items:
digits = re.sub(r'\D', '', value)
if FieldValidators.luhn_checksum(digits):
return value, field
elif field_name == 'Amount':
# Prefer larger amounts (usually the total, not subtotals)
amounts = []
for value, (score, field) in items:
try:
amt = float(value.replace(',', '.'))
amounts.append((amt, value, field))
except ValueError:
continue
if amounts:
# Return the largest amount
amounts.sort(reverse=True)
return amounts[0][1], amounts[0][2]
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Prefer dates in reasonable range
from datetime import datetime
for value, (score, field) in items:
try:
dt = datetime.strptime(value, '%Y-%m-%d')
# Prefer recent dates (within last 2 years and next 1 year)
now = datetime.now()
if now.year - 2 <= dt.year <= now.year + 1:
return value, field
except ValueError:
continue
# Default: return highest confidence value
best = max(items, key=lambda x: x[1][0])
return best[0], best[1][1]
# =========================================================================
# Apply OCR Corrections to Raw Text
# =========================================================================
def apply_ocr_corrections(
self,
field_name: str,
raw_text: str
) -> tuple[str, list[str]]:
"""
Apply OCR corrections to raw text based on field type.
Returns (corrected_text, list_of_corrections_applied)
"""
corrections_applied = []
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
# Aggressive correction for numeric fields
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name == 'Amount':
# Conservative correction for amounts (preserve decimal separators)
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Conservative correction for dates
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
# No correction for other fields
return raw_text, []
# =========================================================================
# Extraction with All Enhancements
# =========================================================================
def extract_with_enhancements(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int,
use_enhanced_parsing: bool = True
) -> ExtractedField:
"""
Extract field value with all enhancements enabled.
Combines:
1. OCR error correction
2. Enhanced amount/date parsing
3. Multi-strategy extraction
Args:
detection: Detection object
pdf_tokens: PDF text tokens
image_width: Image width in pixels
image_height: Image height in pixels
use_enhanced_parsing: Whether to use enhanced parsing methods
Returns:
ExtractedField with enhancements applied
"""
# First, extract using standard method
base_result = self.extract_from_detection_with_pdf(
detection, pdf_tokens, image_width, image_height
)
if not use_enhanced_parsing:
return base_result
# Apply OCR corrections
corrected_text, corrections = self.apply_ocr_corrections(
base_result.field_name, base_result.raw_text
)
# Re-normalize with enhanced methods if corrections were applied
if corrections or base_result.normalized_value is None:
# Use enhanced normalizers for Amount and Date fields
if base_result.field_name == 'Amount':
enhanced_normalizer = EnhancedAmountNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
enhanced_normalizer = EnhancedDateNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
else:
# Re-run standard normalization with corrected text
normalized, is_valid, error = self._normalize_and_validate(
base_result.field_name, corrected_text
)
# Update result if we got a better value
if normalized and (not base_result.normalized_value or is_valid):
base_result.normalized_value = normalized
base_result.is_valid = is_valid
base_result.validation_error = error
base_result.ocr_corrections_applied = corrections
if corrections:
base_result.extraction_method = 'corrected'
return base_result

View File

@@ -0,0 +1,60 @@
"""
Normalizers Package
Provides field-specific normalizers for invoice data extraction.
Each normalizer handles a specific field type's normalization and validation.
"""
from .base import BaseNormalizer, NormalizationResult
from .invoice_number import InvoiceNumberNormalizer
from .ocr_number import OcrNumberNormalizer
from .bankgiro import BankgiroNormalizer
from .plusgiro import PlusgiroNormalizer
from .amount import AmountNormalizer, EnhancedAmountNormalizer
from .date import DateNormalizer, EnhancedDateNormalizer
from .supplier_org_number import SupplierOrgNumberNormalizer
__all__ = [
# Base
"BaseNormalizer",
"NormalizationResult",
# Normalizers
"InvoiceNumberNormalizer",
"OcrNumberNormalizer",
"BankgiroNormalizer",
"PlusgiroNormalizer",
"AmountNormalizer",
"EnhancedAmountNormalizer",
"DateNormalizer",
"EnhancedDateNormalizer",
"SupplierOrgNumberNormalizer",
]
# Registry of all normalizers by field name
def create_normalizer_registry(
use_enhanced: bool = False,
) -> dict[str, BaseNormalizer]:
"""
Create a registry mapping field names to normalizer instances.
Args:
use_enhanced: Whether to use enhanced normalizers for amount/date
Returns:
Dictionary mapping field names to normalizer instances
"""
amount_normalizer = EnhancedAmountNormalizer() if use_enhanced else AmountNormalizer()
date_normalizer = EnhancedDateNormalizer() if use_enhanced else DateNormalizer()
return {
"InvoiceNumber": InvoiceNumberNormalizer(),
"OCR": OcrNumberNormalizer(),
"Bankgiro": BankgiroNormalizer(),
"Plusgiro": PlusgiroNormalizer(),
"Amount": amount_normalizer,
"InvoiceDate": date_normalizer,
"InvoiceDueDate": date_normalizer,
# Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping)
"supplier_organisation_number": SupplierOrgNumberNormalizer(),
}

View File

@@ -0,0 +1,185 @@
"""
Amount Normalizer
Handles normalization and validation of monetary amounts.
"""
import re
from shared.utils.text_cleaner import TextCleaner
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
from .base import BaseNormalizer, NormalizationResult
class AmountNormalizer(BaseNormalizer):
"""
Normalizes monetary amounts from Swedish invoices.
Handles various Swedish amount formats:
- With decimal: 1 234,56 kr
- With SEK suffix: 1234.56 SEK
- Multiple amounts (returns the last one, usually the total)
"""
@property
def field_name(self) -> str:
return "Amount"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Split by newlines and process line by line to get the last valid amount
lines = text.split("\n")
# Collect all valid amounts from all lines
all_amounts: list[float] = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
for line in lines:
line = line.strip()
if not line:
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
continue
# Return the last amount found (usually the total)
if all_amounts:
return NormalizationResult.success(f"{all_amounts[-1]:.2f}")
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
# Try to find any decimal number
simple_pattern = r"(\d+[,\.]\d{2})"
matches = re.findall(simple_pattern, text)
if matches:
amount_str = matches[-1].replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
# Last resort: try to find integer amount (no decimals)
# Look for patterns like "Amount: 11699" or standalone numbers
int_pattern = r"(?:amount|belopp|summa|total)[:\s]*(\d+)"
match = re.search(int_pattern, text, re.IGNORECASE)
if match:
try:
amount = float(match.group(1))
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
# Very last resort: find any standalone number >= 3 digits
standalone_pattern = r"\b(\d{3,})\b"
matches = re.findall(standalone_pattern, text)
if matches:
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
return NormalizationResult.failure(f"Cannot parse amount: {text}")
class EnhancedAmountNormalizer(AmountNormalizer):
"""
Enhanced amount parsing with multiple strategies.
Strategies:
1. Pattern matching for Swedish formats
2. Context-aware extraction (look for keywords like "Total", "Summa")
3. OCR error correction for common digit errors
4. Multi-amount handling (prefer last/largest as total)
"""
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
labeled_patterns = [
# Swedish patterns
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
(
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
0.8,
), # Lower priority for VAT
# Generic pattern
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
]
candidates: list[tuple[float, float, int]] = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
if candidates:
# Sort by priority (desc), then by position (later is usually total)
candidates.sort(key=lambda x: (-x[1], -x[2]))
best_amount = candidates[0][0]
return NormalizationResult.success(f"{best_amount:.2f}")
# Strategy 3: Parse with shared validator
cleaned = TextCleaner.normalize_amount_text(corrected_text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
# Strategy 4: Try to extract any decimal number as fallback
decimal_pattern = r"(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})"
matches = re.findall(decimal_pattern, corrected_text)
if matches:
# Clean and parse each match
amounts: list[float] = []
for m in matches:
cleaned_m = m.replace(" ", "").replace(".", "").replace(",", ".")
# Handle Swedish format: "1 234,56" -> "1234.56"
if "," in m and "." not in m:
cleaned_m = m.replace(" ", "").replace(",", ".")
try:
amt = float(cleaned_m)
if 0 < amt < 10_000_000:
amounts.append(amt)
except ValueError:
continue
if amounts:
# Return the last/largest amount (usually the total)
return NormalizationResult.success(f"{max(amounts):.2f}")
return NormalizationResult.failure(f"Cannot parse amount: {text[:50]}")

View File

@@ -0,0 +1,87 @@
"""
Bankgiro Normalizer
Handles normalization and validation of Swedish Bankgiro numbers.
"""
import re
from shared.utils.validators import FieldValidators
from .base import BaseNormalizer, NormalizationResult
class BankgiroNormalizer(BaseNormalizer):
"""
Normalizes Swedish Bankgiro numbers.
Bankgiro rules:
- 7 or 8 digits only
- Last digit is Luhn (Mod10) check digit
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
Display pattern: ^\\d{3,4}-\\d{4}$
Normalized pattern: ^\\d{7,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
"""
@property
def field_name(self) -> str:
return "Bankgiro"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
# This distinguishes BG from PG which uses X-X format (digits-single digit)
bg_matches = re.findall(r"(\d{3,4})-(\d{4})", text)
if bg_matches:
# Try each match and find one with valid Luhn
for match in bg_matches:
digits = match[0] + match[1]
if len(digits) in (7, 8) and FieldValidators.luhn_checksum(digits):
# Valid BG found
formatted = self._format_bankgiro(digits)
return NormalizationResult.success(formatted)
# No valid Luhn, use first match
digits = bg_matches[0][0] + bg_matches[0][1]
if len(digits) in (7, 8):
formatted = self._format_bankgiro(digits)
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# Fallback: try to find 7-8 consecutive digits
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
# to avoid misinterpreting PG as BG
pg_format_present = re.search(r"(?<![0-9])\d{1,7}-\d(?!\d)", text)
if pg_format_present:
return NormalizationResult.failure("No valid Bankgiro found in text")
digit_match = re.search(r"\b(\d{7,8})\b", text)
if digit_match:
digits = digit_match.group(1)
if len(digits) in (7, 8):
formatted = self._format_bankgiro(digits)
if FieldValidators.luhn_checksum(digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
return NormalizationResult.failure("No valid Bankgiro found in text")
@staticmethod
def _format_bankgiro(digits: str) -> str:
"""Format Bankgiro number with dash."""
if len(digits) == 8:
return f"{digits[:4]}-{digits[4:]}"
else:
return f"{digits[:3]}-{digits[3:]}"

View File

@@ -0,0 +1,71 @@
"""
Base Normalizer Interface
Defines the contract for all field normalizers.
Each normalizer handles a specific field type's normalization and validation.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass(frozen=True)
class NormalizationResult:
"""Result of a normalization operation."""
value: str | None
is_valid: bool
error: str | None = None
@classmethod
def success(cls, value: str) -> "NormalizationResult":
"""Create a successful result."""
return cls(value=value, is_valid=True, error=None)
@classmethod
def success_with_warning(cls, value: str, warning: str) -> "NormalizationResult":
"""Create a successful result with a warning."""
return cls(value=value, is_valid=True, error=warning)
@classmethod
def failure(cls, error: str) -> "NormalizationResult":
"""Create a failed result."""
return cls(value=None, is_valid=False, error=error)
def to_tuple(self) -> tuple[str | None, bool, str | None]:
"""Convert to legacy tuple format for backward compatibility."""
return (self.value, self.is_valid, self.error)
class BaseNormalizer(ABC):
"""
Abstract base class for field normalizers.
Each normalizer is responsible for:
1. Cleaning and normalizing raw text
2. Validating the normalized value
3. Returning a standardized result
"""
@property
@abstractmethod
def field_name(self) -> str:
"""The field name this normalizer handles."""
pass
@abstractmethod
def normalize(self, text: str) -> NormalizationResult:
"""
Normalize and validate the input text.
Args:
text: Raw text to normalize
Returns:
NormalizationResult with normalized value or error
"""
pass
def __call__(self, text: str) -> NormalizationResult:
"""Allow using the normalizer as a callable."""
return self.normalize(text)

View File

@@ -0,0 +1,200 @@
"""
Date Normalizer
Handles normalization and validation of invoice dates.
"""
import re
from datetime import datetime
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
from .base import BaseNormalizer, NormalizationResult
class DateNormalizer(BaseNormalizer):
"""
Normalizes dates from Swedish invoices.
Handles various date formats:
- 2025-08-29 (ISO format)
- 2025.08.29 (dot separator)
- 29/08/2025 (European format)
- 29.08.2025 (European with dots)
- 20250829 (compact format)
Output format: YYYY-MM-DD (ISO 8601)
"""
# Date patterns with their parsing logic
PATTERNS = [
# ISO format: 2025-08-29
(
r"(\d{4})-(\d{1,2})-(\d{1,2})",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
# Dot format: 2025.08.29 (common in Swedish)
(
r"(\d{4})\.(\d{1,2})\.(\d{1,2})",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
# European slash format: 29/08/2025
(
r"(\d{1,2})/(\d{1,2})/(\d{4})",
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
),
# European dot format: 29.08.2025
(
r"(\d{1,2})\.(\d{1,2})\.(\d{4})",
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
),
# Compact format: 20250829
(
r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
]
@property
def field_name(self) -> str:
return "Date"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# First, try using shared validator
iso_date = FieldValidators.format_date_iso(text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return NormalizationResult.success(iso_date)
# Fallback: try original patterns for edge cases
for pattern, extractor in self.PATTERNS:
match = re.search(pattern, text)
if match:
try:
year, month, day = extractor(match)
# Validate date
parsed_date = datetime(year, month, day)
# Sanity check: year should be reasonable (2000-2100)
if 2000 <= parsed_date.year <= 2100:
return NormalizationResult.success(
parsed_date.strftime("%Y-%m-%d")
)
except ValueError:
continue
return NormalizationResult.failure(f"Cannot parse date: {text}")
class EnhancedDateNormalizer(DateNormalizer):
"""
Enhanced date parsing with comprehensive format support.
Additional support for:
- Swedish text: "29 december 2024", "29 dec 2024"
- OCR error correction: 2O24-12-29 -> 2024-12-29
"""
# Swedish month names
SWEDISH_MONTHS = {
"januari": 1,
"jan": 1,
"februari": 2,
"feb": 2,
"mars": 3,
"mar": 3,
"april": 4,
"apr": 4,
"maj": 5,
"juni": 6,
"jun": 6,
"juli": 7,
"jul": 7,
"augusti": 8,
"aug": 8,
"september": 9,
"sep": 9,
"sept": 9,
"oktober": 10,
"okt": 10,
"november": 11,
"nov": 11,
"december": 12,
"dec": 12,
}
# Extended patterns
EXTENDED_PATTERNS = [
# ISO format: 2025-08-29, 2025/08/29
("ymd", r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"),
# Dot format: 2025.08.29
("ymd", r"(\d{4})\.(\d{1,2})\.(\d{1,2})"),
# European slash: 29/08/2025
("dmy", r"(\d{1,2})/(\d{1,2})/(\d{4})"),
# European dot: 29.08.2025
("dmy", r"(\d{1,2})\.(\d{1,2})\.(\d{4})"),
# European dash: 29-08-2025
("dmy", r"(\d{1,2})-(\d{1,2})-(\d{4})"),
# Compact: 20250829
("ymd_compact", r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)"),
]
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Apply OCR corrections
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Try shared validator first
iso_date = FieldValidators.format_date_iso(corrected_text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return NormalizationResult.success(iso_date)
# Try Swedish text date pattern: "29 december 2024" or "29 dec 2024"
swedish_pattern = r"(\d{1,2})\s+([a-z\u00e5\u00e4\u00f6]+)\s+(\d{4})"
match = re.search(swedish_pattern, corrected_text.lower())
if match:
day = int(match.group(1))
month_name = match.group(2)
year = int(match.group(3))
if month_name in self.SWEDISH_MONTHS:
month = self.SWEDISH_MONTHS[month_name]
try:
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
except ValueError:
pass
# Extended patterns
for fmt, pattern in self.EXTENDED_PATTERNS:
match = re.search(pattern, corrected_text)
if match:
try:
if fmt == "ymd":
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
elif fmt == "dmy":
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
elif fmt == "ymd_compact":
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
else:
continue
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
except ValueError:
continue
return NormalizationResult.failure(f"Cannot parse date: {text[:50]}")

View File

@@ -0,0 +1,84 @@
"""
Invoice Number Normalizer
Handles normalization and validation of invoice numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class InvoiceNumberNormalizer(BaseNormalizer):
"""
Normalizes invoice numbers from Swedish invoices.
Invoice numbers can be:
- Pure digits: 12345678
- Alphanumeric: A3861, INV-2024-001, F12345
- With separators: 2024/001, 2024-001
Strategy:
1. Look for common invoice number patterns
2. Prefer shorter, more specific matches over long digit sequences
"""
@property
def field_name(self) -> str:
return "InvoiceNumber"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
# Examples: A3861, F12345, INV001
alpha_patterns = [
r"\b([A-Z]{1,3}\d{3,10})\b", # A3861, INV12345
r"\b(\d{3,10}[A-Z]{1,3})\b", # 12345A
r"\b([A-Z]{2,5}[-/]?\d{3,10})\b", # INV-12345, FAK12345
]
for pattern in alpha_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return NormalizationResult.success(match.group(1).upper())
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
year_pattern = r"\b(20\d{2}[-/]\d{3,8})\b"
match = re.search(year_pattern, text)
if match:
return NormalizationResult.success(match.group(1))
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
# This avoids capturing long OCR numbers
digit_sequences = re.findall(r"\b(\d{3,10})\b", text)
if digit_sequences:
# Prefer shorter sequences (more likely to be invoice number)
# Also filter out sequences that look like dates (8 digits starting with 20)
valid_sequences = []
for seq in digit_sequences:
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith("20"):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return NormalizationResult.success(min(valid_sequences, key=len))
# Fallback: extract all digits if nothing else works
digits = re.sub(r"\D", "", text)
if len(digits) >= 3:
# Limit to first 15 digits to avoid very long sequences
return NormalizationResult.success_with_warning(
digits[:15], "Fallback extraction"
)
return NormalizationResult.failure(
f"Cannot extract invoice number from: {text[:50]}"
)

View File

@@ -0,0 +1,37 @@
"""
OCR Number Normalizer
Handles normalization and validation of OCR reference numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class OcrNumberNormalizer(BaseNormalizer):
"""
Normalizes OCR (Optical Character Recognition) reference numbers.
OCR numbers in Swedish payment systems:
- Minimum 5 digits
- Used for automated payment matching
"""
@property
def field_name(self) -> str:
return "OCR"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
digits = re.sub(r"\D", "", text)
if len(digits) < 5:
return NormalizationResult.failure(
f"Too few digits for OCR: {len(digits)}"
)
return NormalizationResult.success(digits)

View File

@@ -0,0 +1,90 @@
"""
Plusgiro Normalizer
Handles normalization and validation of Swedish Plusgiro numbers.
"""
import re
from shared.utils.validators import FieldValidators
from .base import BaseNormalizer, NormalizationResult
class PlusgiroNormalizer(BaseNormalizer):
"""
Normalizes Swedish Plusgiro numbers.
Plusgiro rules:
- 2 to 8 digits
- Last digit is Luhn (Mod10) check digit
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
Display pattern: ^\\d{1,7}-\\d$
Normalized pattern: ^\\d{2,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
"""
@property
def field_name(self) -> str:
return "Plusgiro"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
# This is distinct from BG format which has 4 digits after the dash
# Pattern allows spaces within the number like "486 98 63-6"
# (?<![0-9]) ensures we don't start from within another number (like BG)
pg_matches = re.findall(r"(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)", text)
if pg_matches:
# Try each match and find one with valid Luhn
for match in pg_matches:
# Remove spaces from the first part
digits = re.sub(r"\s", "", match[0]) + match[1]
if 2 <= len(digits) <= 8 and FieldValidators.luhn_checksum(digits):
# Valid PG found
formatted = f"{digits[:-1]}-{digits[-1]}"
return NormalizationResult.success(formatted)
# No valid Luhn, use first match with most digits
best_match = max(pg_matches, key=lambda m: len(re.sub(r"\s", "", m[0])))
digits = re.sub(r"\s", "", best_match[0]) + best_match[1]
if 2 <= len(digits) <= 8:
formatted = f"{digits[:-1]}-{digits[-1]}"
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# If no PG format found, extract all digits and format as PG
# This handles cases where the number might be in BG format or raw digits
all_digits = re.sub(r"\D", "", text)
# Try to find a valid 2-8 digit sequence
if 2 <= len(all_digits) <= 8:
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
if FieldValidators.luhn_checksum(all_digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# Try to find any 2-8 digit sequence in text
digit_match = re.search(r"\b(\d{2,8})\b", text)
if digit_match:
digits = digit_match.group(1)
formatted = f"{digits[:-1]}-{digits[-1]}"
if FieldValidators.luhn_checksum(digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
return NormalizationResult.failure("No valid Plusgiro found in text")

View File

@@ -0,0 +1,60 @@
"""
Supplier Organization Number Normalizer
Handles normalization and validation of Swedish organization numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class SupplierOrgNumberNormalizer(BaseNormalizer):
"""
Normalizes Swedish supplier organization numbers.
Extracts organization number in format: NNNNNN-NNNN (10 digits)
Also handles VAT numbers: SE + 10 digits + 01
Examples:
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
'Momsreg.nr SE556123456701' -> '556123-4567'
"""
@property
def field_name(self) -> str:
return "supplier_org_number"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Pattern 1: Standard org number format: NNNNNN-NNNN
org_pattern = r"\b(\d{6})-?(\d{4})\b"
match = re.search(org_pattern, text)
if match:
org_num = f"{match.group(1)}-{match.group(2)}"
return NormalizationResult.success(org_num)
# Pattern 2: VAT number format: SE + 10 digits + 01
vat_pattern = r"SE\s*(\d{10})01"
match = re.search(vat_pattern, text, re.IGNORECASE)
if match:
digits = match.group(1)
org_num = f"{digits[:6]}-{digits[6:]}"
return NormalizationResult.success(org_num)
# Pattern 3: Just 10 consecutive digits
digits_pattern = r"\b(\d{10})\b"
match = re.search(digits_pattern, text)
if match:
digits = match.group(1)
# Validate: first digit should be 1-9 for Swedish org numbers
if digits[0] in "123456789":
org_num = f"{digits[:6]}-{digits[6:]}"
return NormalizationResult.success(org_num)
return NormalizationResult.failure(
f"Cannot extract org number from: {text[:100]}"
)

View File

@@ -0,0 +1,261 @@
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Unifies payment line parsing logic that was previously duplicated across multiple modules.
Standard Swedish payment line format:
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example:
# 94228110015950070 # 15658 00 8 > 48666036#14#
This parser handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
"""
import re
import logging
from dataclasses import dataclass
from typing import Optional
from shared.exceptions import PaymentLineParseError
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
"""OCR reference number (payment reference)"""
amount: Optional[str] = None
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
account_number: Optional[str] = None
"""Bankgiro or Plusgiro account number"""
record_type: Optional[str] = None
"""Record type digit (usually '5' or '8' or '9')"""
check_digits: Optional[str] = None
"""Check digits for account validation"""
raw_text: str = ""
"""Original raw text that was parsed"""
is_valid: bool = True
"""Whether parsing was successful"""
error: Optional[str] = None
"""Error message if parsing failed"""
parse_method: str = "unknown"
"""Which parsing pattern was used (for debugging)"""
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
NO_AMOUNT_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Alternative pattern: look for OCR > ACCOUNT# pattern
ALT_PATTERN = re.compile(
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Account only pattern: > ACCOUNT#CHECK#
ACCOUNT_ONLY_PATTERN = re.compile(
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
"""Initialize parser with logger."""
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
Args:
text: Raw payment line text from OCR
Returns:
PaymentLineData with parsed fields or error information
"""
if not text or not text.strip():
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="Empty payment line text",
parse_method="none"
)
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try pattern without amount
match = self.NO_AMOUNT_PATTERN.search(text)
if match:
return self._parse_no_amount_match(match, text)
# Try alternative pattern
match = self.ALT_PATTERN.search(text)
if match:
return self._parse_alt_match(match, text)
# Try account only pattern
match = self.ACCOUNT_ONLY_PATTERN.search(text)
if match:
return self._parse_account_only_match(match, text)
# No match - return error
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="No valid payment line format found",
parse_method="none"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="full"
)
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse pattern match without amount."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="no_amount"
)
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse alternative pattern match."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="alternative"
)
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse account-only pattern match."""
account = self._clean_digits(match.group(1))
check_digits = match.group(2)
return PaymentLineData(
ocr_number="",
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error="Partial payment line (account only)",
parse_method="account_only"
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string (OCR error correction)."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string in standard Swedish payment line format
"""
if not data.is_valid:
return data.raw_text
# Full format with amount
if data.amount and data.record_type:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
# Format without amount
if data.ocr_number and data.account_number:
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
# Account only
if data.account_number:
return f"> {data.account_number}#{data.check_digits}#"
# Fallback
return data.raw_text
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
"""
Format parsed data for FieldExtractor compatibility.
Returns:
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
"""
if not data.is_valid:
return None, False, data.error
formatted = self.format_machine_readable(data)
return formatted, True, data.error

View File

@@ -0,0 +1,499 @@
"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import time
import re
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
@dataclass
class CrossValidationResult:
"""Result of cross-validation between payment_line and other fields."""
is_valid: bool = False
ocr_match: bool | None = None # None if not comparable
amount_match: bool | None = None
bankgiro_match: bool | None = None
plusgiro_match: bool | None = None
payment_line_ocr: str | None = None
payment_line_amount: str | None = None
payment_line_account: str | None = None
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
details: list[str] = field(default_factory=list)
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
result = {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'supplier_org_number': self.fields.get('supplier_org_number'),
'customer_number': self.fields.get('customer_number'),
'payment_line': self.fields.get('payment_line'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
# Add bboxes if present
if self.bboxes:
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
# Add cross-validation results if present
if self.cross_validation:
result['cross_validation'] = {
'is_valid': self.cross_validation.is_valid,
'ocr_match': self.cross_validation.ocr_match,
'amount_match': self.cross_validation.amount_match,
'bankgiro_match': self.cross_validation.bankgiro_match,
'plusgiro_match': self.cross_validation.plusgiro_match,
'payment_line_ocr': self.cross_validation.payment_line_ocr,
'payment_line_amount': self.cross_validation.payment_line_amount,
'payment_line_account': self.cross_validation.payment_line_account,
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
return result
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
try:
all_detections = []
all_extracted = []
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
parsed = self.payment_line_parser.parse(payment_line)
if not parsed.is_valid:
return None, None, None
return parsed.ocr_number, parsed.amount, parsed.account_number
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
Payment line values take PRIORITY over individually detected fields.
Swedish payment line (Betalningsrad) contains:
- OCR reference number
- Amount (kronor and öre)
- Bankgiro or Plusgiro account number
This method:
1. Parses payment_line to extract OCR, Amount, Account
2. Compares with separately detected fields for validation
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
"""
payment_line = result.fields.get('payment_line')
if not payment_line:
return
cv = CrossValidationResult()
cv.details = []
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values
# OCR: payment_line takes priority
detected_ocr = result.fields.get('OCR')
if cv.payment_line_ocr:
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
if detected_ocr:
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
if cv.ocr_match:
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
else:
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
else:
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
# OVERRIDE: use payment_line OCR
result.fields['OCR'] = cv.payment_line_ocr
result.confidence['OCR'] = 0.95 # High confidence for payment_line
# Amount: payment_line takes priority
detected_amount = result.fields.get('Amount')
if cv.payment_line_amount:
if detected_amount:
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
det_amount = self._normalize_amount_for_compare(str(detected_amount))
cv.amount_match = pl_amount == det_amount
if cv.amount_match:
cv.details.append(f"Amount match: {cv.payment_line_amount}")
else:
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
else:
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
# OVERRIDE: use payment_line Amount
result.fields['Amount'] = cv.payment_line_amount
result.confidence['Amount'] = 0.95
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_bankgiro = result.fields.get('Bankgiro')
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_bankgiro:
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
cv.bankgiro_match = pl_bg_digits == det_bg_digits
if cv.bankgiro_match:
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
else:
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_plusgiro = result.fields.get('Plusgiro')
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_plusgiro:
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
cv.plusgiro_match = pl_pg_digits == det_pg_digits
if cv.plusgiro_match:
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
else:
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Determine overall validity
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
# has both accounts, the other one cannot be matched - this is expected and OK.
# Only count the account type that payment_line actually has.
matches = [cv.ocr_match, cv.amount_match]
# Only include account match if payment_line has that account type
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
matches.append(cv.bankgiro_match)
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
matches.append(cv.plusgiro_match)
valid_matches = [m for m in matches if m is not None]
if valid_matches:
match_count = sum(1 for m in valid_matches if m)
cv.is_valid = match_count >= min(2, len(valid_matches))
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
else:
# No comparison possible
cv.is_valid = True
cv.details.append("No comparison available from payment_line")
result.cross_validation = cv
def _normalize_amount_for_compare(self, amount: str) -> float | None:
"""Normalize amount string to float for comparison."""
try:
# Remove spaces, convert comma to dot
cleaned = amount.replace(' ', '').replace(',', '.')
# Handle Swedish format with space as thousands separator
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
return round(float(cleaned), 2)
except (ValueError, AttributeError):
return None
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from shared.pdf.renderer import render_pdf_to_images
from shared.ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result

View File

@@ -0,0 +1,188 @@
"""
YOLO Detection Module
Runs YOLO model inference for field detection.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
# Import field mappings from single source of truth
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
@dataclass
class Detection:
"""Represents a single YOLO detection."""
class_id: int
class_name: str
confidence: float
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
def get_padded_bbox(
self,
padding: float = 0.1,
image_width: float | None = None,
image_height: float | None = None
) -> tuple[float, float, float, float]:
"""Get bbox with padding for OCR extraction."""
pad_x = self.width * padding
pad_y = self.height * padding
x0 = self.x0 - pad_x
y0 = self.y0 - pad_y
x1 = self.x1 + pad_x
y1 = self.y1 + pad_y
if image_width:
x0 = max(0, x0)
x1 = min(image_width, x1)
if image_height:
y0 = max(0, y0)
y1 = min(image_height, y1)
return (x0, y0, x1, y1)
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
# This ensures consistency with the trained YOLO model
class YOLODetector:
"""YOLO model wrapper for field detection."""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
iou_threshold: float = 0.45,
device: str = 'auto'
):
"""
Initialize YOLO detector.
Args:
model_path: Path to trained YOLO model (.pt file)
confidence_threshold: Minimum confidence for detections
iou_threshold: IOU threshold for NMS
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
"""
from ultralytics import YOLO
self.model = YOLO(model_path)
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.device = device
def detect(
self,
image: str | Path | np.ndarray,
page_no: int = 0
) -> list[Detection]:
"""
Run detection on an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
Returns:
List of Detection objects
"""
results = self.model.predict(
source=image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
device=self.device,
verbose=False
)
detections = []
for result in results:
boxes = result.boxes
if boxes is None:
continue
for i in range(len(boxes)):
class_id = int(boxes.cls[i])
confidence = float(boxes.conf[i])
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
detections.append(Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=tuple(bbox),
page_no=page_no
))
return detections
def detect_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> dict[int, list[Detection]]:
"""
Run detection on all pages of a PDF.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
Returns:
Dict mapping page number to list of detections
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
results = {}
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
detections = self.detect(image_array, page_no=page_no)
results[page_no] = detections
return results
def get_field_name(self, class_name: str) -> str:
"""Convert class name to field name."""
return CLASS_TO_FIELD.get(class_name, class_name)

View File

@@ -0,0 +1,7 @@
"""
Cross-validation module for verifying field extraction using LLM.
"""
from .llm_validator import LLMValidator
__all__ = ['LLMValidator']

View File

@@ -0,0 +1,748 @@
"""
LLM-based cross-validation for invoice field extraction.
Uses a vision LLM to extract fields from invoice PDFs and compare with
the autolabel results to identify potential errors.
"""
import json
import base64
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI
@dataclass
class LLMExtractionResult:
"""Result of LLM field extraction."""
document_id: str
invoice_number: Optional[str] = None
invoice_date: Optional[str] = None
invoice_due_date: Optional[str] = None
ocr_number: Optional[str] = None
bankgiro: Optional[str] = None
plusgiro: Optional[str] = None
amount: Optional[str] = None
supplier_organisation_number: Optional[str] = None
raw_response: Optional[str] = None
model_used: Optional[str] = None
processing_time_ms: Optional[float] = None
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class LLMValidator:
"""
Cross-validates invoice field extraction using LLM.
Queries documents with failed field matches from the database,
sends the PDF images to an LLM for extraction, and stores
the results for comparison.
"""
# Fields to extract (excluding customer_number as requested)
FIELDS_TO_EXTRACT = [
'InvoiceNumber',
'InvoiceDate',
'InvoiceDueDate',
'OCR',
'Bankgiro',
'Plusgiro',
'Amount',
'supplier_organisation_number',
]
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
{
"invoice_number": "the invoice number/fakturanummer",
"invoice_date": "the invoice date in YYYY-MM-DD format",
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
"ocr_number": "the OCR payment reference number",
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
"plusgiro": "the plusgiro number",
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
}
Rules:
- If a field is not found or not visible, use null
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
- For amounts, extract just the numeric value without currency symbols
- The OCR number is typically a long number used for payment reference
- Look for "Att betala" or "Summa att betala" for the amount
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
Return ONLY the JSON object, no other text."""
def __init__(self, connection_string: str = None, pdf_dir: str = None):
"""
Initialize the validator.
Args:
connection_string: PostgreSQL connection string
pdf_dir: Directory containing PDF files
"""
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS
self.connection_string = connection_string or get_db_connection_string()
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def create_validation_table(self):
"""Create the llm_validation table if it doesn't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS llm_validations (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
-- Extracted fields
invoice_number TEXT,
invoice_date TEXT,
invoice_due_date TEXT,
ocr_number TEXT,
bankgiro TEXT,
plusgiro TEXT,
amount TEXT,
supplier_organisation_number TEXT,
-- Metadata
raw_response TEXT,
model_used TEXT,
processing_time_ms REAL,
error TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
-- Comparison results (populated later)
comparison_results JSONB,
UNIQUE(document_id)
);
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
ON llm_validations(document_id);
""")
conn.commit()
def get_documents_with_failed_matches(
self,
exclude_customer_number: bool = True,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Get documents that have at least one failed field match.
Args:
exclude_customer_number: If True, ignore customer_number failures
limit: Maximum number of documents to return
Returns:
List of document info with failed fields
"""
conn = self.connect()
with conn.cursor() as cursor:
# Find documents with failed matches (excluding customer_number if requested)
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND fr.field_name != 'customer_number'"
query = f"""
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
d.supplier_name, d.split
FROM documents d
JOIN field_results fr ON d.document_id = fr.document_id
WHERE fr.matched = false
AND fr.field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
AND d.document_id NOT IN (
SELECT document_id FROM llm_validations WHERE error IS NULL
)
ORDER BY d.document_id
"""
if limit:
query += f" LIMIT {limit}"
cursor.execute(query)
results = []
for row in cursor.fetchall():
doc_id = row[0]
# Get failed fields for this document
exclude_clause_inner = ""
if exclude_customer_number:
exclude_clause_inner = "AND field_name != 'customer_number'"
cursor.execute(f"""
SELECT field_name, csv_value, score
FROM field_results
WHERE document_id = %s
AND matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause_inner}
""", (doc_id,))
failed_fields = [
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
for r in cursor.fetchall()
]
results.append({
'document_id': doc_id,
'pdf_path': row[1],
'pdf_type': row[2],
'supplier_name': row[3],
'split': row[4],
'failed_fields': failed_fields,
})
return results
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
"""Get statistics about failed matches."""
conn = self.connect()
with conn.cursor() as cursor:
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND field_name != 'customer_number'"
# Count by field
cursor.execute(f"""
SELECT field_name, COUNT(*) as cnt
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
GROUP BY field_name
ORDER BY cnt DESC
""")
by_field = {row[0]: row[1] for row in cursor.fetchall()}
# Count documents with failures
cursor.execute(f"""
SELECT COUNT(DISTINCT document_id)
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
""")
doc_count = cursor.fetchone()[0]
# Already validated count
cursor.execute("""
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
""")
validated_count = cursor.fetchone()[0]
return {
'documents_with_failures': doc_count,
'already_validated': validated_count,
'remaining': doc_count - validated_count,
'failures_by_field': by_field,
}
def render_pdf_to_image(
self,
pdf_path: Path,
page_no: int = 0,
dpi: int = DEFAULT_DPI,
max_size_mb: float = 18.0
) -> bytes:
"""
Render a PDF page to PNG image bytes.
Args:
pdf_path: Path to PDF file
page_no: Page number to render (0-indexed)
dpi: Resolution for rendering
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
Returns:
PNG image bytes
"""
import fitz # PyMuPDF
from io import BytesIO
from PIL import Image
doc = fitz.open(pdf_path)
page = doc[page_no]
# Try different DPI values until we get a small enough image
dpi_values = [dpi, 120, 100, 72, 50]
for current_dpi in dpi_values:
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
pix = page.get_pixmap(matrix=mat)
png_bytes = pix.tobytes("png")
size_mb = len(png_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return png_bytes
# If still too large, use JPEG compression
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
pix = page.get_pixmap(matrix=mat)
# Convert to PIL Image and compress as JPEG
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Try different JPEG quality levels
for quality in [85, 70, 50, 30]:
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=quality)
jpeg_bytes = buffer.getvalue()
size_mb = len(jpeg_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return jpeg_bytes
doc.close()
# Return whatever we have, let the API handle the error
return jpeg_bytes
def extract_with_openai(
self,
image_bytes: bytes,
model: str = "gpt-4o"
) -> LLMExtractionResult:
"""
Extract fields using OpenAI's vision API (supports Azure OpenAI).
Args:
image_bytes: PNG image bytes
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
Returns:
Extraction result
"""
import openai
import time
start_time = time.time()
# Encode image to base64 and detect format
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
if image_bytes[:4] == b'\x89PNG':
media_type = "image/png"
else:
media_type = "image/jpeg"
# Check for Azure OpenAI configuration
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
if azure_endpoint and azure_api_key:
# Use Azure OpenAI
client = openai.AzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=azure_api_key,
api_version="2024-02-15-preview"
)
model = azure_deployment # Use deployment name for Azure
else:
# Use standard OpenAI
client = openai.OpenAI()
try:
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self.EXTRACTION_PROMPT},
{
"type": "image_url",
"image_url": {
"url": f"data:{media_type};base64,{image_b64}",
"detail": "high"
}
}
]
}
],
max_tokens=1000,
temperature=0,
)
raw_response = response.choices[0].message.content
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
# Try to extract JSON from response (may have markdown code blocks)
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="", # Will be set by caller
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def extract_with_anthropic(
self,
image_bytes: bytes,
model: str = "claude-sonnet-4-20250514"
) -> LLMExtractionResult:
"""
Extract fields using Anthropic's Claude API.
Args:
image_bytes: PNG image bytes
model: Model to use
Returns:
Extraction result
"""
import anthropic
import time
start_time = time.time()
# Encode image to base64
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
client = anthropic.Anthropic()
try:
response = client.messages.create(
model=model,
max_tokens=1000,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_b64,
}
},
{
"type": "text",
"text": self.EXTRACTION_PROMPT
}
]
}
],
)
raw_response = response.content[0].text
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="",
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def save_validation_result(self, result: LLMExtractionResult):
"""Save extraction result to database."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
INSERT INTO llm_validations (
document_id, invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number, raw_response, model_used,
processing_time_ms, error
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id) DO UPDATE SET
invoice_number = EXCLUDED.invoice_number,
invoice_date = EXCLUDED.invoice_date,
invoice_due_date = EXCLUDED.invoice_due_date,
ocr_number = EXCLUDED.ocr_number,
bankgiro = EXCLUDED.bankgiro,
plusgiro = EXCLUDED.plusgiro,
amount = EXCLUDED.amount,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
raw_response = EXCLUDED.raw_response,
model_used = EXCLUDED.model_used,
processing_time_ms = EXCLUDED.processing_time_ms,
error = EXCLUDED.error,
created_at = NOW()
""", (
result.document_id,
result.invoice_number,
result.invoice_date,
result.invoice_due_date,
result.ocr_number,
result.bankgiro,
result.plusgiro,
result.amount,
result.supplier_organisation_number,
result.raw_response,
result.model_used,
result.processing_time_ms,
result.error,
))
conn.commit()
def validate_document(
self,
doc_id: str,
provider: str = "openai",
model: str = None
) -> LLMExtractionResult:
"""
Validate a single document using LLM.
Args:
doc_id: Document ID
provider: LLM provider ("openai" or "anthropic")
model: Model to use (defaults based on provider)
Returns:
Extraction result
"""
# Get PDF path
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
if not pdf_path.exists():
return LLMExtractionResult(
document_id=doc_id,
error=f"PDF not found: {pdf_path}"
)
# Render first page
try:
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
except Exception as e:
return LLMExtractionResult(
document_id=doc_id,
error=f"Failed to render PDF: {str(e)}"
)
# Extract with LLM
if provider == "openai":
model = model or "gpt-4o"
result = self.extract_with_openai(image_bytes, model)
elif provider == "anthropic":
model = model or "claude-sonnet-4-20250514"
result = self.extract_with_anthropic(image_bytes, model)
else:
return LLMExtractionResult(
document_id=doc_id,
error=f"Unknown provider: {provider}"
)
result.document_id = doc_id
# Save to database
self.save_validation_result(result)
return result
def validate_batch(
self,
limit: int = 10,
provider: str = "openai",
model: str = None,
verbose: bool = True
) -> List[LLMExtractionResult]:
"""
Validate a batch of documents with failed matches.
Args:
limit: Maximum number of documents to validate
provider: LLM provider
model: Model to use
verbose: Print progress
Returns:
List of extraction results
"""
# Get documents to validate
docs = self.get_documents_with_failed_matches(limit=limit)
if verbose:
print(f"Found {len(docs)} documents with failed matches to validate")
results = []
for i, doc in enumerate(docs):
doc_id = doc['document_id']
if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
result = self.validate_document(doc_id, provider, model)
results.append(result)
if verbose:
if result.error:
print(f" ERROR: {result.error}")
else:
print(f" OK ({result.processing_time_ms:.0f}ms)")
return results
def compare_results(self, doc_id: str) -> Dict[str, Any]:
"""
Compare LLM extraction with autolabel results.
Args:
doc_id: Document ID
Returns:
Comparison results
"""
conn = self.connect()
with conn.cursor() as cursor:
# Get autolabel results
cursor.execute("""
SELECT field_name, csv_value, matched, matched_text
FROM field_results
WHERE document_id = %s
""", (doc_id,))
autolabel = {}
for row in cursor.fetchall():
autolabel[row[0]] = {
'csv_value': row[1],
'matched': row[2],
'matched_text': row[3],
}
# Get LLM results
cursor.execute("""
SELECT invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number
FROM llm_validations
WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return {'error': 'No LLM validation found'}
llm = {
'InvoiceNumber': row[0],
'InvoiceDate': row[1],
'InvoiceDueDate': row[2],
'OCR': row[3],
'Bankgiro': row[4],
'Plusgiro': row[5],
'Amount': row[6],
'supplier_organisation_number': row[7],
}
# Compare
comparison = {}
for field in self.FIELDS_TO_EXTRACT:
auto = autolabel.get(field, {})
llm_value = llm.get(field)
comparison[field] = {
'csv_value': auto.get('csv_value'),
'autolabel_matched': auto.get('matched'),
'autolabel_text': auto.get('matched_text'),
'llm_value': llm_value,
'agreement': self._values_match(auto.get('csv_value'), llm_value),
}
return comparison
def _values_match(self, csv_value: str, llm_value: str) -> bool:
"""Check if CSV value matches LLM extracted value."""
if csv_value is None or llm_value is None:
return csv_value == llm_value
# Normalize for comparison
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
return csv_norm == llm_norm

View File

@@ -0,0 +1,9 @@
"""
Web Application Module
Provides REST API and web interface for invoice field extraction.
"""
from .app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,8 @@
"""
Backward compatibility shim for admin_routes.py
DEPRECATED: Import from backend.web.api.v1.admin.documents instead.
"""
from backend.web.api.v1.admin.documents import *
__all__ = ["create_admin_router"]

View File

@@ -0,0 +1,21 @@
"""
Admin API v1
Document management, annotations, and training endpoints.
"""
from backend.web.api.v1.admin.annotations import create_annotation_router
from backend.web.api.v1.admin.augmentation import create_augmentation_router
from backend.web.api.v1.admin.auth import create_auth_router
from backend.web.api.v1.admin.documents import create_documents_router
from backend.web.api.v1.admin.locks import create_locks_router
from backend.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_augmentation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
"create_training_router",
]

View File

@@ -0,0 +1,706 @@
"""
Admin Annotation API Routes
FastAPI endpoints for annotation management.
"""
import io
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
from backend.data.repositories import DocumentRepository, AnnotationRepository
from backend.web.core.auth import AdminTokenDep
from backend.web.services.autolabel import get_auto_label_service
from backend.web.services.storage_helpers import get_storage_helper
from backend.web.schemas.admin import (
AnnotationCreate,
AnnotationItem,
AnnotationListResponse,
AnnotationOverrideRequest,
AnnotationOverrideResponse,
AnnotationResponse,
AnnotationSource,
AnnotationUpdate,
AnnotationVerifyRequest,
AnnotationVerifyResponse,
AutoLabelRequest,
AutoLabelResponse,
BoundingBox,
)
from backend.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Global repository instances
_doc_repo: DocumentRepository | None = None
_ann_repo: AnnotationRepository | None = None
def get_doc_repository() -> DocumentRepository:
"""Get the DocumentRepository instance."""
global _doc_repo
if _doc_repo is None:
_doc_repo = DocumentRepository()
return _doc_repo
def get_ann_repository() -> AnnotationRepository:
"""Get the AnnotationRepository instance."""
global _ann_repo
if _ann_repo is None:
_ann_repo = AnnotationRepository()
return _ann_repo
# Type aliases for dependency injection
DocRepoDep = Annotated[DocumentRepository, Depends(get_doc_repository)]
AnnRepoDep = Annotated[AnnotationRepository, Depends(get_ann_repository)]
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_annotation_router() -> APIRouter:
"""Create annotation API router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
# =========================================================================
# Image Endpoints
# =========================================================================
@router.get(
"/{document_id}/images/{page_number}",
response_model=None,
responses={
200: {"content": {"image/png": {}}, "description": "Page image"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Get page image",
description="Get the image for a specific page.",
)
async def get_page_image(
document_id: str,
page_number: int,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
) -> FileResponse | StreamingResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Validate page number
if page_number < 1 or page_number > document.page_count:
raise HTTPException(
status_code=404,
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
)
# Get storage helper
storage = get_storage_helper()
# Check if image exists
if not storage.admin_image_exists(document_id, page_number):
raise HTTPException(
status_code=404,
detail=f"Image for page {page_number} not found",
)
# Try to get local path for efficient file serving
local_path = storage.get_admin_image_local_path(document_id, page_number)
if local_path is not None:
return FileResponse(
path=str(local_path),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
)
# Fall back to streaming for cloud storage
image_content = storage.get_admin_image(document_id, page_number)
return StreamingResponse(
io.BytesIO(image_content),
media_type="image/png",
headers={
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
},
)
# =========================================================================
# Annotation Endpoints
# =========================================================================
@router.get(
"/{document_id}/annotations",
response_model=AnnotationListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="List annotations",
description="Get all annotations for a document.",
)
async def list_annotations(
document_id: str,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
page_number: Annotated[
int | None,
Query(ge=1, description="Filter by page number"),
] = None,
) -> AnnotationListResponse:
"""List annotations for a document."""
_validate_uuid(document_id, "document_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Get annotations
raw_annotations = ann_repo.get_for_document(document_id, page_number)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
return AnnotationListResponse(
document_id=document_id,
page_count=document.page_count,
total_annotations=len(annotations),
annotations=annotations,
)
@router.post(
"/{document_id}/annotations",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Create annotation",
description="Create a new annotation for a document.",
)
async def create_annotation(
document_id: str,
request: AnnotationCreate,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationResponse:
"""Create a new annotation."""
_validate_uuid(document_id, "document_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Validate page number
if request.page_number > document.page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
)
# Get image dimensions for normalization
storage = get_storage_helper()
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
if dimensions is None:
raise HTTPException(
status_code=400,
detail=f"Image for page {request.page_number} not available",
)
image_width, image_height = dimensions
# Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
width = request.bbox.width / image_width
height = request.bbox.height / image_height
# Get class name
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
# Create annotation
annotation_id = ann_repo.create(
document_id=document_id,
page_number=request.page_number,
class_id=request.class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=request.bbox.x,
bbox_y=request.bbox.y,
bbox_width=request.bbox.width,
bbox_height=request.bbox.height,
text_value=request.text_value,
source="manual",
)
# Keep status as pending - user must click "Mark Complete" to finalize
# This allows user to add multiple annotations before saving to PostgreSQL
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation created successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Update annotation",
description="Update an existing annotation.",
)
async def update_annotation(
document_id: str,
annotation_id: str,
request: AnnotationUpdate,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationResponse:
"""Update an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Get existing annotation
annotation = ann_repo.get(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Prepare update data
update_kwargs = {}
if request.class_id is not None:
update_kwargs["class_id"] = request.class_id
update_kwargs["class_name"] = FIELD_CLASSES.get(
request.class_id, f"class_{request.class_id}"
)
if request.text_value is not None:
update_kwargs["text_value"] = request.text_value
if request.bbox is not None:
# Get image dimensions
storage = get_storage_helper()
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
if dimensions is None:
raise HTTPException(
status_code=400,
detail=f"Image for page {annotation.page_number} not available",
)
image_width, image_height = dimensions
# Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
update_kwargs["width"] = request.bbox.width / image_width
update_kwargs["height"] = request.bbox.height / image_height
update_kwargs["bbox_x"] = request.bbox.x
update_kwargs["bbox_y"] = request.bbox.y
update_kwargs["bbox_width"] = request.bbox.width
update_kwargs["bbox_height"] = request.bbox.height
# Update annotation
if update_kwargs:
success = ann_repo.update(annotation_id, **update_kwargs)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to update annotation",
)
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation updated successfully",
)
@router.delete(
"/{document_id}/annotations/{annotation_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Delete annotation",
description="Delete an annotation.",
)
async def delete_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> dict:
"""Delete an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Get existing annotation
annotation = ann_repo.get(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Delete annotation
ann_repo.delete(annotation_id)
return {
"status": "deleted",
"annotation_id": annotation_id,
"message": "Annotation deleted successfully",
}
# =========================================================================
# Auto-Labeling Endpoints
# =========================================================================
@router.post(
"/{document_id}/auto-label",
response_model=AutoLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Trigger auto-labeling",
description="Trigger auto-labeling for a document using field values.",
)
async def trigger_auto_label(
document_id: str,
request: AutoLabelRequest,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AutoLabelResponse:
"""Trigger auto-labeling for a document."""
_validate_uuid(document_id, "document_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Validate field values
if not request.field_values:
raise HTTPException(
status_code=400,
detail="At least one field value is required",
)
# Get the actual file path from storage
# document.file_path is a relative storage path like "raw_pdfs/uuid.pdf"
storage = get_storage_helper()
filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path
file_path = storage.get_raw_pdf_local_path(filename)
if file_path is None:
raise HTTPException(
status_code=500,
detail=f"Cannot find PDF file: {document.file_path}",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=str(file_path),
field_values=request.field_values,
doc_repo=doc_repo,
ann_repo=ann_repo,
replace_existing=request.replace_existing,
)
if result["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
)
return AutoLabelResponse(
document_id=document_id,
status=result["status"],
annotations_created=result["annotations_created"],
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
)
@router.delete(
"/{document_id}/annotations",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete all annotations",
description="Delete all annotations for a document (optionally filter by source).",
)
async def delete_all_annotations(
document_id: str,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
source: Annotated[
str | None,
Query(description="Filter by source (manual, auto, imported)"),
] = None,
) -> dict:
"""Delete all annotations for a document."""
_validate_uuid(document_id, "document_id")
# Validate source
if source and source not in ("manual", "auto", "imported"):
raise HTTPException(
status_code=400,
detail=f"Invalid source: {source}",
)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Delete annotations
deleted_count = ann_repo.delete_for_document(document_id, source)
# Update document status if all annotations deleted
remaining = ann_repo.get_for_document(document_id)
if not remaining:
doc_repo.update_status(document_id, "pending")
return {
"status": "deleted",
"document_id": document_id,
"deleted_count": deleted_count,
"message": f"Deleted {deleted_count} annotations",
}
# =========================================================================
# Phase 5: Annotation Enhancement
# =========================================================================
@router.post(
"/{document_id}/annotations/{annotation_id}/verify",
response_model=AnnotationVerifyResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Verify annotation",
description="Mark an annotation as verified by a human reviewer.",
)
async def verify_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
) -> AnnotationVerifyResponse:
"""Verify an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Verify the annotation
annotation = ann_repo.verify(annotation_id, admin_token)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
return AnnotationVerifyResponse(
annotation_id=annotation_id,
is_verified=annotation.is_verified,
verified_at=annotation.verified_at,
verified_by=annotation.verified_by,
message="Annotation verified successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}/override",
response_model=AnnotationOverrideResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Override annotation",
description="Override an auto-generated annotation with manual corrections.",
)
async def override_annotation(
document_id: str,
annotation_id: str,
request: AnnotationOverrideRequest,
admin_token: AdminTokenDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationOverrideResponse:
"""Override an auto-generated annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found",
)
# Build updates dict from request
updates = {}
if request.text_value is not None:
updates["text_value"] = request.text_value
if request.class_id is not None:
updates["class_id"] = request.class_id
# Update class_name if class_id changed
if request.class_id in FIELD_CLASSES:
updates["class_name"] = FIELD_CLASSES[request.class_id]
if request.class_name is not None:
updates["class_name"] = request.class_name
if request.bbox:
# Update bbox fields
if "x" in request.bbox:
updates["bbox_x"] = request.bbox["x"]
if "y" in request.bbox:
updates["bbox_y"] = request.bbox["y"]
if "width" in request.bbox:
updates["bbox_width"] = request.bbox["width"]
if "height" in request.bbox:
updates["bbox_height"] = request.bbox["height"]
if not updates:
raise HTTPException(
status_code=400,
detail="No updates provided. Specify at least one field to update.",
)
# Override the annotation
annotation = ann_repo.override(
annotation_id=annotation_id,
admin_token=admin_token,
change_reason=request.reason,
**updates,
)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Get history to return history_id
history_records = ann_repo.get_history(UUID(annotation_id))
latest_history = history_records[0] if history_records else None
return AnnotationOverrideResponse(
annotation_id=annotation_id,
source=annotation.source,
override_source=annotation.override_source,
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
message="Annotation overridden successfully",
history_id=str(latest_history.history_id) if latest_history else "",
)
return router

View File

@@ -0,0 +1,15 @@
"""Augmentation API module."""
from fastapi import APIRouter
from .routes import register_augmentation_routes
def create_augmentation_router() -> APIRouter:
"""Create and configure the augmentation router."""
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
register_augmentation_routes(router)
return router
__all__ = ["create_augmentation_router"]

View File

@@ -0,0 +1,160 @@
"""Augmentation API routes."""
from fastapi import APIRouter, Query
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep, DatasetRepoDep
from backend.web.schemas.admin.augmentation import (
AugmentationBatchRequest,
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewRequest,
AugmentationPreviewResponse,
AugmentationTypeInfo,
AugmentationTypesResponse,
AugmentedDatasetListResponse,
PresetInfo,
PresetsResponse,
)
def register_augmentation_routes(router: APIRouter) -> None:
"""Register augmentation endpoints on the router."""
@router.get(
"/types",
response_model=AugmentationTypesResponse,
summary="List available augmentation types",
)
async def list_augmentation_types(
admin_token: AdminTokenDep,
) -> AugmentationTypesResponse:
"""
List all available augmentation types with descriptions and parameters.
"""
from shared.augmentation.pipeline import (
AUGMENTATION_REGISTRY,
AugmentationPipeline,
)
types = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
# Create instance with empty params to get preview params
aug = aug_class({})
types.append(
AugmentationTypeInfo(
name=name,
description=(aug_class.__doc__ or "").strip(),
affects_geometry=aug_class.affects_geometry,
stage=AugmentationPipeline.STAGE_MAPPING[name],
default_params=aug.get_preview_params(),
)
)
return AugmentationTypesResponse(augmentation_types=types)
@router.get(
"/presets",
response_model=PresetsResponse,
summary="Get augmentation presets",
)
async def get_presets(
admin_token: AdminTokenDep,
) -> PresetsResponse:
"""Get predefined augmentation presets for common use cases."""
from shared.augmentation.presets import list_presets
presets = [PresetInfo(**p) for p in list_presets()]
return PresetsResponse(presets=presets)
@router.post(
"/preview/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview augmentation on document image",
)
async def preview_augmentation(
document_id: str,
request: AugmentationPreviewRequest,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Returns URLs to original and augmented preview images.
"""
from backend.web.services.augmentation_service import AugmentationService
service = AugmentationService(doc_repo=docs)
return await service.preview_single(
document_id=document_id,
page=page,
augmentation_type=request.augmentation_type,
params=request.params,
)
@router.post(
"/preview-config/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview full augmentation config on document",
)
async def preview_config(
document_id: str,
config: AugmentationConfigSchema,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""Preview complete augmentation pipeline on a document page."""
from backend.web.services.augmentation_service import AugmentationService
service = AugmentationService(doc_repo=docs)
return await service.preview_config(
document_id=document_id,
page=page,
config=config,
)
@router.post(
"/batch",
response_model=AugmentationBatchResponse,
summary="Create augmented dataset (offline preprocessing)",
)
async def create_augmented_dataset(
request: AugmentationBatchRequest,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
datasets: DatasetRepoDep,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
This runs as a background task. The augmented images are stored
alongside the original dataset for training.
"""
from backend.web.services.augmentation_service import AugmentationService
service = AugmentationService(doc_repo=docs, dataset_repo=datasets)
return await service.create_augmented_dataset(
source_dataset_id=request.dataset_id,
config=request.config,
output_name=request.output_name,
multiplier=request.multiplier,
)
@router.get(
"/datasets",
response_model=AugmentedDatasetListResponse,
summary="List augmented datasets",
)
async def list_augmented_datasets(
admin_token: AdminTokenDep,
datasets: DatasetRepoDep,
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
offset: int = Query(default=0, ge=0, description="Offset"),
) -> AugmentedDatasetListResponse:
"""List all augmented datasets."""
from backend.web.services.augmentation_service import AugmentationService
service = AugmentationService(dataset_repo=datasets)
return await service.list_augmented_datasets(limit=limit, offset=offset)

View File

@@ -0,0 +1,82 @@
"""
Admin Auth Routes
FastAPI endpoints for admin token management.
"""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter
from backend.web.core.auth import AdminTokenDep, TokenRepoDep
from backend.web.schemas.admin import (
AdminTokenCreate,
AdminTokenResponse,
)
from backend.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def create_auth_router() -> APIRouter:
"""Create admin auth router."""
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
@router.post(
"/token",
response_model=AdminTokenResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
},
summary="Create admin token",
description="Create a new admin authentication token.",
)
async def create_token(
request: AdminTokenCreate,
tokens: TokenRepoDep,
) -> AdminTokenResponse:
"""Create a new admin token."""
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration (use timezone-aware datetime)
expires_at = None
if request.expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days)
# Create token in database
tokens.create(
token=token,
name=request.name,
expires_at=expires_at,
)
return AdminTokenResponse(
token=token,
name=request.name,
expires_at=expires_at,
message="Admin token created successfully",
)
@router.delete(
"/token",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Revoke admin token",
description="Revoke the current admin token.",
)
async def revoke_token(
admin_token: AdminTokenDep,
tokens: TokenRepoDep,
) -> dict:
"""Revoke the current admin token."""
tokens.deactivate(admin_token)
return {
"status": "revoked",
"message": "Admin token has been revoked",
}
return router

View File

@@ -0,0 +1,135 @@
"""
Dashboard API Routes
FastAPI endpoints for dashboard statistics and activity.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from backend.web.core.auth import (
AdminTokenDep,
get_model_version_repository,
get_training_task_repository,
ModelVersionRepoDep,
TrainingTaskRepoDep,
)
from backend.web.schemas.admin import (
DashboardStatsResponse,
DashboardActiveModelResponse,
ActiveModelInfo,
RunningTrainingInfo,
RecentActivityResponse,
ActivityItem,
)
from backend.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
)
logger = logging.getLogger(__name__)
def create_dashboard_router() -> APIRouter:
"""Create dashboard API router."""
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
@router.get(
"/stats",
response_model=DashboardStatsResponse,
summary="Get dashboard statistics",
description="Returns document counts and annotation completeness metrics.",
)
async def get_dashboard_stats(
admin_token: AdminTokenDep,
) -> DashboardStatsResponse:
"""Get dashboard statistics."""
service = DashboardStatsService()
stats = service.get_stats()
return DashboardStatsResponse(
total_documents=stats["total_documents"],
annotation_complete=stats["annotation_complete"],
annotation_incomplete=stats["annotation_incomplete"],
pending=stats["pending"],
completeness_rate=stats["completeness_rate"],
)
@router.get(
"/active-model",
response_model=DashboardActiveModelResponse,
summary="Get active model info",
description="Returns current active model and running training status.",
)
async def get_active_model(
admin_token: AdminTokenDep,
model_repo: ModelVersionRepoDep,
task_repo: TrainingTaskRepoDep,
) -> DashboardActiveModelResponse:
"""Get active model and training status."""
# Get active model
active_model = model_repo.get_active()
model_info = None
if active_model:
model_info = ActiveModelInfo(
version_id=str(active_model.version_id),
version=active_model.version,
name=active_model.name,
metrics_mAP=active_model.metrics_mAP,
metrics_precision=active_model.metrics_precision,
metrics_recall=active_model.metrics_recall,
document_count=active_model.document_count,
activated_at=active_model.activated_at,
)
# Get running training task
running_task = task_repo.get_running()
training_info = None
if running_task:
training_info = RunningTrainingInfo(
task_id=str(running_task.task_id),
name=running_task.name,
status=running_task.status,
started_at=running_task.started_at,
progress=running_task.progress or 0,
)
return DashboardActiveModelResponse(
model=model_info,
running_training=training_info,
)
@router.get(
"/activity",
response_model=RecentActivityResponse,
summary="Get recent activity",
description="Returns recent system activities sorted by timestamp.",
)
async def get_recent_activity(
admin_token: AdminTokenDep,
limit: Annotated[
int,
Query(ge=1, le=50, description="Maximum number of activities"),
] = 10,
) -> RecentActivityResponse:
"""Get recent system activity."""
service = DashboardActivityService()
activities = service.get_recent_activities(limit=limit)
return RecentActivityResponse(
activities=[
ActivityItem(
type=act["type"],
description=act["description"],
timestamp=act["timestamp"],
metadata=act["metadata"],
)
for act in activities
]
)
return router

View File

@@ -0,0 +1,699 @@
"""
Admin Document Routes
FastAPI endpoints for admin document management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from backend.web.config import DEFAULT_DPI, StorageConfig
from backend.web.core.auth import (
AdminTokenDep,
DocumentRepoDep,
AnnotationRepoDep,
TrainingTaskRepoDep,
)
from backend.web.services.storage_helpers import get_storage_helper
from backend.web.schemas.admin import (
AnnotationItem,
AnnotationSource,
AutoLabelStatus,
BoundingBox,
DocumentCategoriesResponse,
DocumentDetailResponse,
DocumentItem,
DocumentListResponse,
DocumentStatus,
DocumentStatsResponse,
DocumentUpdateRequest,
DocumentUploadResponse,
ModelMetrics,
TrainingHistoryItem,
)
from backend.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, dpi: int
) -> None:
"""Convert PDF pages to images for annotation using StorageHelper."""
import fitz
storage = get_storage_helper()
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
# Render at configured DPI for consistency with training
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
# Save to storage using StorageHelper
image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close()
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"""Create admin documents router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
# Directories are created by StorageConfig.__post_init__
allowed_extensions = storage_config.allowed_extensions
@router.post(
"",
response_model=DocumentUploadResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Upload document",
description="Upload a PDF or image document for labeling.",
)
async def upload_document(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
file: UploadFile = File(..., description="PDF or image file"),
auto_label: Annotated[
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
group_key: Annotated[
str | None,
Query(description="Optional group key for document organization", max_length=255),
] = None,
category: Annotated[
str,
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
] = "invoice",
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Get page count (for PDF)
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record (token only used for auth, not stored)
document_id = docs.create(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
group_key=group_key,
category=category,
)
# Save file to storage using StorageHelper
storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try:
storage_path = storage.save_raw_pdf(content, filename)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database (using storage path for reference)
from backend.data.database import get_session_context
from backend.data.admin_models import AdminDocument
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if doc:
# Store the storage path (relative path within storage)
doc.file_path = storage_path
session.add(doc)
# Convert PDF to images for annotation
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling if requested
auto_label_started = False
if auto_label:
# Auto-labeling will be triggered by a background task
docs.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
auto_label_started = True
return DocumentUploadResponse(
document_id=document_id,
filename=file.filename,
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
category=category,
group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@router.get(
"",
response_model=DocumentListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List documents",
description="List all documents for the current admin.",
)
async def list_documents(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
upload_source: Annotated[
str | None,
Query(description="Filter by upload source (ui or api)"),
] = None,
has_annotations: Annotated[
bool | None,
Query(description="Filter by annotation presence"),
] = None,
auto_label_status: Annotated[
str | None,
Query(description="Filter by auto-label status"),
] = None,
batch_id: Annotated[
str | None,
Query(description="Filter by batch ID"),
] = None,
category: Annotated[
str | None,
Query(description="Filter by document category"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> DocumentListResponse:
"""List documents."""
# Validate status
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Validate upload_source
if upload_source and upload_source not in ("ui", "api"):
raise HTTPException(
status_code=400,
detail=f"Invalid upload_source: {upload_source}",
)
# Validate auto_label_status
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid auto_label_status: {auto_label_status}",
)
documents, total = docs.get_paginated(
admin_token=admin_token,
status=status,
upload_source=upload_source,
has_annotations=has_annotations,
auto_label_status=auto_label_status,
batch_id=batch_id,
category=category,
limit=limit,
offset=offset,
)
# Get annotation counts and build items
items = []
for doc in documents:
doc_annotations = annotations.get_for_document(str(doc.document_id))
# Determine if document can be annotated (not locked)
can_annotate = True
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
from datetime import datetime, timezone
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
items.append(
DocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
file_size=doc.file_size,
page_count=doc.page_count,
status=DocumentStatus(doc.status),
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
annotation_count=len(doc_annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
category=doc.category if hasattr(doc, 'category') else "invoice",
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
)
)
return DocumentListResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/stats",
response_model=DocumentStatsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get document statistics",
description="Get document count by status.",
)
async def get_document_stats(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
) -> DocumentStatsResponse:
"""Get document statistics."""
counts = docs.count_by_status(admin_token)
return DocumentStatsResponse(
total=sum(counts.values()),
pending=counts.get("pending", 0),
auto_labeling=counts.get("auto_labeling", 0),
labeled=counts.get("labeled", 0),
exported=counts.get("exported", 0),
)
@router.get(
"/categories",
response_model=DocumentCategoriesResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get available categories",
description="Get list of all available document categories.",
)
async def get_categories(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
) -> DocumentCategoriesResponse:
"""Get all available document categories."""
categories = docs.get_categories()
return DocumentCategoriesResponse(
categories=categories,
total=len(categories),
)
@router.get(
"/{document_id}",
response_model=DocumentDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Get document detail",
description="Get document details with annotations.",
)
async def get_document(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
tasks: TrainingTaskRepoDep,
) -> DocumentDetailResponse:
"""Get document details."""
_validate_uuid(document_id, "document_id")
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = annotations.get_for_document(document_id)
annotation_items = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
# Generate image URLs
image_urls = []
for page in range(1, document.page_count + 1):
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
# Determine if document can be annotated (not locked)
can_annotate = True
annotation_lock_until = None
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
from datetime import datetime, timezone
annotation_lock_until = document.annotation_lock_until
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
# Get CSV field values if available
csv_field_values = None
if hasattr(document, 'csv_field_values') and document.csv_field_values:
csv_field_values = document.csv_field_values
# Get training history (Phase 5)
training_history = []
training_links = tasks.get_document_training_tasks(document.document_id)
for link in training_links:
# Get task details
task = tasks.get(str(link.task_id))
if task:
# Build metrics
metrics = None
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
training_history.append(
TrainingHistoryItem(
task_id=str(link.task_id),
name=task.name,
trained_at=link.created_at,
model_metrics=metrics,
)
)
return DocumentDetailResponse(
document_id=str(document.document_id),
filename=document.filename,
file_size=document.file_size,
content_type=document.content_type,
page_count=document.page_count,
status=DocumentStatus(document.status),
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None,
category=document.category if hasattr(document, 'category') else "invoice",
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
annotations=annotation_items,
image_urls=image_urls,
training_history=training_history,
created_at=document.created_at,
updated_at=document.updated_at,
)
@router.delete(
"/{document_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete document",
description="Delete a document and its annotations.",
)
async def delete_document(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
) -> dict:
"""Delete a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete file using StorageHelper
storage = get_storage_helper()
# Delete the raw PDF
filename = Path(document.file_path).name
if filename:
try:
storage._storage.delete(document.file_path)
except Exception as e:
logger.warning(f"Failed to delete PDF file: {e}")
# Delete admin images
try:
storage.delete_admin_images(document_id)
except Exception as e:
logger.warning(f"Failed to delete admin images: {e}")
# Delete from database
docs.delete(document_id)
return {
"status": "deleted",
"document_id": document_id,
"message": "Document deleted successfully",
}
@router.patch(
"/{document_id}/status",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document status",
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
)
async def update_document_status(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
status: Annotated[
str,
Query(description="New status"),
],
) -> dict:
"""Update document status.
When status is set to 'labeled', the annotations are automatically
saved to PostgreSQL documents/field_results tables for consistency
with CLI auto-label workflow.
"""
_validate_uuid(document_id, "document_id")
# Validate status
if status not in ("pending", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Verify ownership
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# If marking as labeled, save annotations to PostgreSQL DocumentDB
db_save_result = None
if status == "labeled":
from backend.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document
doc_annotations = annotations.get_for_document(document_id)
if doc_annotations:
db_save_result = save_manual_annotations_to_document_db(
document=document,
annotations=doc_annotations,
)
docs.update_status(document_id, status)
response = {
"status": "updated",
"document_id": document_id,
"new_status": status,
"message": "Document status updated",
}
# Include PostgreSQL save result if applicable
if db_save_result:
response["document_db_saved"] = db_save_result.get("success", False)
response["fields_saved"] = db_save_result.get("fields_saved", 0)
return response
@router.patch(
"/{document_id}/group-key",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document group key",
description="Update the group key for a document.",
)
async def update_document_group_key(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
group_key: Annotated[
str | None,
Query(description="New group key (null to clear)"),
] = None,
) -> dict:
"""Update document group key."""
_validate_uuid(document_id, "document_id")
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Verify document exists
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update group key
docs.update_group_key(document_id, group_key)
return {
"status": "updated",
"document_id": document_id,
"group_key": group_key,
"message": "Document group key updated",
}
@router.patch(
"/{document_id}/category",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document category",
description="Update the category for a document.",
)
async def update_document_category(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
request: DocumentUpdateRequest,
) -> dict:
"""Update document category."""
_validate_uuid(document_id, "document_id")
# Verify document exists
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update category if provided
if request.category is not None:
docs.update_category(document_id, request.category)
return {
"status": "updated",
"document_id": document_id,
"category": request.category,
"message": "Document category updated",
}
return router

View File

@@ -0,0 +1,181 @@
"""
Admin Document Lock Routes
FastAPI endpoints for annotation lock management.
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep
from backend.web.schemas.admin import (
AnnotationLockRequest,
AnnotationLockResponse,
)
from backend.web.schemas.common import ErrorResponse
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_locks_router() -> APIRouter:
"""Create annotation locks router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
@router.post(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Document already locked"},
},
summary="Acquire annotation lock",
description="Acquire a lock on a document to prevent concurrent annotation edits.",
)
async def acquire_lock(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Acquire annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to acquire lock
updated_doc = docs.acquire_annotation_lock(
document_id=document_id,
admin_token=admin_token,
duration_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Document is already locked. Please try again later.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock acquired for {request.duration_seconds} seconds",
)
@router.delete(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Release annotation lock",
description="Release the annotation lock on a document.",
)
async def release_lock(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
force: Annotated[
bool,
Query(description="Force release (admin override)"),
] = False,
) -> AnnotationLockResponse:
"""Release annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Release lock
updated_doc = docs.release_annotation_lock(
document_id=document_id,
admin_token=admin_token,
force=force,
)
if updated_doc is None:
raise HTTPException(
status_code=404,
detail="Failed to release lock",
)
return AnnotationLockResponse(
document_id=document_id,
locked=False,
lock_expires_at=None,
message="Lock released successfully",
)
@router.patch(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
},
summary="Extend annotation lock",
description="Extend an existing annotation lock.",
)
async def extend_lock(
document_id: str,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Extend annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to extend lock
updated_doc = docs.extend_annotation_lock(
document_id=document_id,
admin_token=admin_token,
additional_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock extended by {request.duration_seconds} seconds",
)
return router

View File

@@ -0,0 +1,30 @@
"""
Admin Training API Routes
FastAPI endpoints for training task management and scheduling.
"""
from fastapi import APIRouter
from ._utils import _validate_uuid
from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
def create_training_router() -> APIRouter:
"""Create training API router."""
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
register_task_routes(router)
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
return router
__all__ = ["create_training_router", "_validate_uuid"]

View File

@@ -0,0 +1,16 @@
"""Shared utilities for training routes."""
from uuid import UUID
from fastapi import HTTPException
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)

View File

@@ -0,0 +1,291 @@
"""Training Dataset Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import (
AdminTokenDep,
DatasetRepoDep,
DocumentRepoDep,
AnnotationRepoDep,
ModelVersionRepoDep,
TrainingTaskRepoDep,
)
from backend.web.schemas.admin import (
DatasetCreateRequest,
DatasetDetailResponse,
DatasetDocumentItem,
DatasetListItem,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingStatus,
TrainingTaskResponse,
)
from backend.web.services.storage_helpers import get_storage_helper
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_dataset_routes(router: APIRouter) -> None:
"""Register dataset endpoints on the router."""
@router.post(
"/datasets",
response_model=DatasetResponse,
summary="Create training dataset",
description="Create a dataset from selected documents with train/val/test splits.",
)
async def create_dataset(
request: DatasetCreateRequest,
admin_token: AdminTokenDep,
datasets: DatasetRepoDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from backend.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split
if len(request.document_ids) < 10:
raise HTTPException(
status_code=400,
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
)
dataset = datasets.create(
name=request.name,
description=request.description,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
)
# Get storage paths from StorageHelper
storage = get_storage_helper()
datasets_dir = storage.get_datasets_base_path()
admin_images_dir = storage.get_admin_images_base_path()
if datasets_dir is None or admin_images_dir is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
builder = DatasetBuilder(
datasets_repo=datasets,
documents_repo=docs,
annotations_repo=annotations,
base_dir=datasets_dir,
)
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=request.document_ids,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
admin_images_dir=admin_images_dir,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return DatasetResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
status="ready",
message="Dataset created successfully",
)
@router.get(
"/datasets",
response_model=DatasetListResponse,
summary="List datasets",
)
async def list_datasets(
admin_token: AdminTokenDep,
datasets_repo: DatasetRepoDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> DatasetListResponse:
"""List training datasets."""
datasets_list, total = datasets_repo.get_paginated(status=status, limit=limit, offset=offset)
# Get active training tasks for each dataset (graceful degradation on error)
dataset_ids = [str(d.dataset_id) for d in datasets_list]
try:
active_tasks = datasets_repo.get_active_training_tasks(dataset_ids)
except Exception:
logger.exception("Failed to fetch active training tasks")
active_tasks = {}
return DatasetListResponse(
total=total,
limit=limit,
offset=offset,
datasets=[
DatasetListItem(
dataset_id=str(d.dataset_id),
name=d.name,
description=d.description,
status=d.status,
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
created_at=d.created_at,
)
for d in datasets_list
],
)
@router.get(
"/datasets/{dataset_id}",
response_model=DatasetDetailResponse,
summary="Get dataset detail",
)
async def get_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
datasets_repo: DatasetRepoDep,
) -> DatasetDetailResponse:
"""Get dataset details with document list."""
_validate_uuid(dataset_id, "dataset_id")
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
docs = datasets_repo.get_documents(dataset_id)
return DatasetDetailResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
description=dataset.description,
status=dataset.status,
training_status=dataset.training_status,
active_training_task_id=(
str(dataset.active_training_task_id)
if dataset.active_training_task_id
else None
),
train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio,
seed=dataset.seed,
total_documents=dataset.total_documents,
total_images=dataset.total_images,
total_annotations=dataset.total_annotations,
dataset_path=dataset.dataset_path,
error_message=dataset.error_message,
documents=[
DatasetDocumentItem(
document_id=str(d.document_id),
split=d.split,
page_count=d.page_count,
annotation_count=d.annotation_count,
)
for d in docs
],
created_at=dataset.created_at,
updated_at=dataset.updated_at,
)
@router.delete(
"/datasets/{dataset_id}",
summary="Delete dataset",
)
async def delete_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
datasets_repo: DatasetRepoDep,
) -> dict:
"""Delete a dataset and its files."""
import shutil
from pathlib import Path
_validate_uuid(dataset_id, "dataset_id")
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.dataset_path:
dataset_dir = Path(dataset.dataset_path)
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
datasets_repo.delete(dataset_id)
return {"message": "Dataset deleted"}
@router.post(
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
description="Create a training task. Set base_model_version_id in config for incremental training.",
)
async def train_from_dataset(
dataset_id: str,
request: DatasetTrainRequest,
admin_token: AdminTokenDep,
datasets_repo: DatasetRepoDep,
models: ModelVersionRepoDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset.
For incremental training, set config.base_model_version_id to a model version UUID.
The training will use that model as the starting point instead of a pretrained model.
"""
_validate_uuid(dataset_id, "dataset_id")
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.status != "ready":
raise HTTPException(
status_code=400,
detail=f"Dataset is not ready (status: {dataset.status})",
)
config_dict = request.config.model_dump()
# Resolve base_model_version_id to actual model path for incremental training
base_model_version_id = config_dict.get("base_model_version_id")
if base_model_version_id:
_validate_uuid(base_model_version_id, "base_model_version_id")
base_model = models.get(base_model_version_id)
if not base_model:
raise HTTPException(
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
base_model.version,
base_model.model_path,
)
task_id = tasks.create(
admin_token=admin_token,
name=request.name,
task_type="finetune" if base_model_version_id else "train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
message = (
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
if base_model_version_id
else "Training task created from dataset"
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message=message,
)

View File

@@ -0,0 +1,218 @@
"""Training Documents and Models Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import (
AdminTokenDep,
DocumentRepoDep,
AnnotationRepoDep,
TrainingTaskRepoDep,
)
from backend.web.schemas.admin import (
ModelMetrics,
TrainingDocumentItem,
TrainingDocumentsResponse,
TrainingModelItem,
TrainingModelsResponse,
TrainingStatus,
)
from backend.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_document_routes(router: APIRouter) -> None:
"""Register training document and model endpoints on the router."""
@router.get(
"/documents",
response_model=TrainingDocumentsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get documents for training",
description="Get labeled documents available for training with filtering options.",
)
async def get_training_documents(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
tasks: TrainingTaskRepoDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
] = True,
min_annotation_count: Annotated[
int | None,
Query(ge=1, description="Minimum annotation count"),
] = None,
exclude_used_in_training: Annotated[
bool,
Query(description="Exclude documents already used in training"),
] = False,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
documents, total = docs.get_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
min_annotation_count=min_annotation_count,
exclude_used_in_training=exclude_used_in_training,
limit=limit,
offset=offset,
)
items = []
for doc in documents:
doc_annotations = annotations.get_for_document(str(doc.document_id))
sources = {"manual": 0, "auto": 0}
for ann in doc_annotations:
if ann.source in sources:
sources[ann.source] += 1
training_links = tasks.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(doc_annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
)
)
return TrainingDocumentsResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/models/{task_id}/download",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Model not found"},
},
summary="Download trained model",
description="Download trained model weights file.",
)
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
tasks: TrainingTaskRepoDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
from pathlib import Path
_validate_uuid(task_id, "task_id")
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if not task.model_path:
raise HTTPException(
status_code=404,
detail="Model file not available for this task",
)
model_path = Path(task.model_path)
if not model_path.exists():
raise HTTPException(
status_code=404,
detail="Model file not found on disk",
)
return FileResponse(
path=str(model_path),
media_type="application/octet-stream",
filename=f"{task.name}_model.pt",
)
@router.get(
"/completed-tasks",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get completed training tasks",
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
)
async def get_completed_training_tasks(
admin_token: AdminTokenDep,
tasks_repo: TrainingTaskRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
task_list, total = tasks_repo.get_paginated(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
offset=offset,
)
items = []
for task in task_list:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
download_url = None
if task.model_path and task.status == "completed":
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
items.append(
TrainingModelItem(
task_id=str(task.task_id),
name=task.name,
status=TrainingStatus(task.status),
document_count=task.document_count,
created_at=task.created_at,
completed_at=task.completed_at,
metrics=metrics,
model_path=task.model_path,
download_url=download_url,
)
)
return TrainingModelsResponse(
total=total,
limit=limit,
offset=offset,
models=items,
)

View File

@@ -0,0 +1,134 @@
"""Training Export Endpoints."""
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep, AnnotationRepoDep
from backend.web.schemas.admin import (
ExportRequest,
ExportResponse,
)
from backend.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def register_export_routes(router: APIRouter) -> None:
"""Register export endpoints on the router."""
@router.post(
"/export",
response_model=ExportResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Export annotations",
description="Export annotations in YOLO format for training.",
)
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
) -> ExportResponse:
"""Export annotations for training."""
from backend.web.services.storage_helpers import get_storage_helper
# Get storage helper for reading images and exports directory
storage = get_storage_helper()
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
status_code=400,
detail=f"Unsupported export format: {request.format}",
)
documents = docs.get_labeled_for_export(admin_token)
if not documents:
raise HTTPException(
status_code=400,
detail="No labeled documents available for export",
)
# Get exports directory from StorageHelper
exports_base = storage.get_exports_base_path()
if exports_base is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
total_docs = len(documents)
train_count = int(total_docs * request.split_ratio)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
doc_annotations = annotations.get_for_document(str(doc.document_id))
if not doc_annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in doc_annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue
# Get image from storage
doc_id = str(doc.document_id)
if not storage.admin_image_exists(doc_id, page_num):
continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
from shared.fields import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
(export_dir / "data.yaml").write_text(yaml_content)
return ExportResponse(
status="completed",
export_path=str(export_dir),
total_images=total_images,
total_annotations=total_annotations,
train_count=len(train_docs),
val_count=len(val_docs),
message=f"Exported {total_images} images with {total_annotations} annotations",
)

View File

@@ -0,0 +1,333 @@
"""Model Version Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, Request
from backend.web.core.auth import AdminTokenDep, ModelVersionRepoDep
from backend.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
ModelVersionItem,
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_model_routes(router: APIRouter) -> None:
"""Register model version endpoints on the router."""
@router.post(
"/models",
response_model=ModelVersionResponse,
summary="Create model version",
description="Register a new model version for deployment.",
)
async def create_model_version(
request: ModelVersionCreateRequest,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Create a new model version."""
if request.task_id:
_validate_uuid(request.task_id, "task_id")
if request.dataset_id:
_validate_uuid(request.dataset_id, "dataset_id")
model = models.create(
version=request.version,
name=request.name,
model_path=request.model_path,
description=request.description,
task_id=request.task_id,
dataset_id=request.dataset_id,
metrics_mAP=request.metrics_mAP,
metrics_precision=request.metrics_precision,
metrics_recall=request.metrics_recall,
document_count=request.document_count,
training_config=request.training_config,
file_size=request.file_size,
trained_at=request.trained_at,
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version created successfully",
)
@router.get(
"/models",
response_model=ModelVersionListResponse,
summary="List model versions",
)
async def list_model_versions(
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> ModelVersionListResponse:
"""List model versions with optional status filter."""
model_list, total = models.get_paginated(status=status, limit=limit, offset=offset)
return ModelVersionListResponse(
total=total,
limit=limit,
offset=offset,
models=[
ModelVersionItem(
version_id=str(m.version_id),
version=m.version,
name=m.name,
status=m.status,
is_active=m.is_active,
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
activated_at=m.activated_at,
created_at=m.created_at,
)
for m in model_list
],
)
@router.get(
"/models/active",
response_model=ActiveModelResponse,
summary="Get active model",
description="Get the currently active model for inference.",
)
async def get_active_model(
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ActiveModelResponse:
"""Get the currently active model version."""
model = models.get_active()
if not model:
return ActiveModelResponse(has_active_model=False, model=None)
return ActiveModelResponse(
has_active_model=True,
model=ModelVersionItem(
version_id=str(model.version_id),
version=model.version,
name=model.name,
status=model.status,
is_active=model.is_active,
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
),
)
@router.get(
"/models/{version_id}",
response_model=ModelVersionDetailResponse,
summary="Get model version detail",
)
async def get_model_version(
version_id: str,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionDetailResponse:
"""Get detailed model version information."""
_validate_uuid(version_id, "version_id")
model = models.get(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionDetailResponse(
version_id=str(model.version_id),
version=model.version,
name=model.name,
description=model.description,
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
metrics_precision=model.metrics_precision,
metrics_recall=model.metrics_recall,
document_count=model.document_count,
training_config=model.training_config,
file_size=model.file_size,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
updated_at=model.updated_at,
)
@router.patch(
"/models/{version_id}",
response_model=ModelVersionResponse,
summary="Update model version",
)
async def update_model_version(
version_id: str,
request: ModelVersionUpdateRequest,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Update model version metadata."""
_validate_uuid(version_id, "version_id")
model = models.update(
version_id=version_id,
name=request.name,
description=request.description,
status=request.status,
)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version updated successfully",
)
@router.post(
"/models/{version_id}/activate",
response_model=ModelVersionResponse,
summary="Activate model version",
description="Activate a model version for inference (deactivates all others).",
)
async def activate_model_version(
version_id: str,
request: Request,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
model = models.activate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
# Trigger model reload in inference service
inference_service = getattr(request.app.state, "inference_service", None)
model_reloaded = False
if inference_service:
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info(f"Inference model reloaded to version {model.version}")
except Exception as e:
logger.warning(f"Failed to reload inference model: {e}")
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message=message,
)
@router.post(
"/models/{version_id}/deactivate",
response_model=ModelVersionResponse,
summary="Deactivate model version",
)
async def deactivate_model_version(
version_id: str,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Deactivate a model version."""
_validate_uuid(version_id, "version_id")
model = models.deactivate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version deactivated",
)
@router.post(
"/models/{version_id}/archive",
response_model=ModelVersionResponse,
summary="Archive model version",
)
async def archive_model_version(
version_id: str,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Archive a model version."""
_validate_uuid(version_id, "version_id")
model = models.archive(version_id)
if not model:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot archive active model",
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version archived",
)
@router.delete(
"/models/{version_id}",
summary="Delete model version",
)
async def delete_model_version(
version_id: str,
admin_token: AdminTokenDep,
models: ModelVersionRepoDep,
) -> dict:
"""Delete a model version."""
_validate_uuid(version_id, "version_id")
success = models.delete(version_id)
if not success:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot delete active model",
)
return {"message": "Model version deleted"}
@router.post(
"/models/reload",
summary="Reload inference model",
description="Reload the inference model from the currently active model version.",
)
async def reload_inference_model(
request: Request,
admin_token: AdminTokenDep,
) -> dict:
"""Reload the inference model from active version."""
inference_service = getattr(request.app.state, "inference_service", None)
if not inference_service:
raise HTTPException(
status_code=500,
detail="Inference service not available",
)
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info("Inference model manually reloaded")
return {"message": "Model reloaded successfully", "reloaded": True}
else:
return {"message": "Model already up to date", "reloaded": False}
except Exception as e:
logger.error(f"Failed to reload model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to reload model: {e}",
)

View File

@@ -0,0 +1,263 @@
"""Training Task Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import AdminTokenDep, TrainingTaskRepoDep
from backend.web.schemas.admin import (
TrainingLogItem,
TrainingLogsResponse,
TrainingStatus,
TrainingTaskCreate,
TrainingTaskDetailResponse,
TrainingTaskItem,
TrainingTaskListResponse,
TrainingTaskResponse,
TrainingType,
)
from backend.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_task_routes(router: APIRouter) -> None:
"""Register training task endpoints on the router."""
@router.post(
"/tasks",
response_model=TrainingTaskResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Create training task",
description="Create a new training task.",
)
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
config_dict = request.config.model_dump() if request.config else {}
task_id = tasks.create(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
description=request.description,
config=config_dict,
scheduled_at=request.scheduled_at,
cron_expression=request.cron_expression,
is_recurring=bool(request.cron_expression),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
message="Training task created successfully",
)
@router.get(
"/tasks",
response_model=TrainingTaskListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List training tasks",
description="List all training tasks.",
)
async def list_training_tasks(
admin_token: AdminTokenDep,
tasks_repo: TrainingTaskRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingTaskListResponse:
"""List training tasks."""
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
if status and status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
task_list, total = tasks_repo.get_paginated(
admin_token=admin_token,
status=status,
limit=limit,
offset=offset,
)
items = [
TrainingTaskItem(
task_id=str(task.task_id),
name=task.name,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
scheduled_at=task.scheduled_at,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in task_list
]
return TrainingTaskListResponse(
total=total,
limit=limit,
offset=offset,
tasks=items,
)
@router.get(
"/tasks/{task_id}",
response_model=TrainingTaskDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training task detail",
description="Get training task details.",
)
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
return TrainingTaskDetailResponse(
task_id=str(task.task_id),
name=task.name,
description=task.description,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
config=task.config,
scheduled_at=task.scheduled_at,
cron_expression=task.cron_expression,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
error_message=task.error_message,
result_metrics=task.result_metrics,
model_path=task.model_path,
created_at=task.created_at,
)
@router.post(
"/tasks/{task_id}/cancel",
response_model=TrainingTaskResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
},
summary="Cancel training task",
description="Cancel a pending or scheduled training task.",
)
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if task.status not in ("pending", "scheduled"):
raise HTTPException(
status_code=409,
detail=f"Cannot cancel task with status: {task.status}",
)
success = tasks.cancel(task_id)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to cancel training task",
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.CANCELLED,
message="Training task cancelled successfully",
)
@router.get(
"/tasks/{task_id}/logs",
response_model=TrainingLogsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training logs",
description="Get training task logs.",
)
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
tasks: TrainingTaskRepoDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingLogsResponse:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
logs = tasks.get_logs(task_id, limit, offset)
items = [
TrainingLogItem(
level=log.level,
message=log.message,
details=log.details,
created_at=log.created_at,
)
for log in logs
]
return TrainingLogsResponse(
task_id=task_id,
logs=items,
)

View File

@@ -0,0 +1,248 @@
"""
Batch Upload API Routes
Endpoints for batch uploading documents via ZIP files with CSV metadata.
"""
import io
import logging
import zipfile
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse
from backend.data.repositories import BatchUploadRepository
from backend.web.core.auth import validate_admin_token
from backend.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from backend.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__)
# Global repository instance
_batch_repo: BatchUploadRepository | None = None
def get_batch_repository() -> BatchUploadRepository:
"""Get the BatchUploadRepository instance."""
global _batch_repo
if _batch_repo is None:
_batch_repo = BatchUploadRepository()
return _batch_repo
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
@router.post("/upload")
async def upload_batch(
file: UploadFile = File(...),
upload_source: str = Form(default="ui"),
async_mode: bool = Form(default=True),
auto_label: bool = Form(default=True),
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
) -> dict:
"""Upload a batch of documents via ZIP file.
The ZIP file can contain:
- Multiple PDF files
- Optional CSV file with field values for auto-labeling
CSV format:
- Required column: DocumentId (matches PDF filename without extension)
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
Args:
file: ZIP file upload
upload_source: Upload source (ui or api)
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch upload result with batch_id and status
"""
if not file.filename.lower().endswith('.zip'):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Check compressed size
if file.size and file.size > MAX_COMPRESSED_SIZE:
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
raise HTTPException(
status_code=400,
detail=f"File size exceeds {max_mb:.0f}MB limit"
)
try:
# Read file content
zip_content = await file.read()
# Additional security validation before processing
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
# Quick validation of ZIP structure
test_zip.testzip()
except zipfile.BadZipFile:
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
if async_mode:
# Async mode: Queue task and return immediately
from uuid import uuid4
batch_id = uuid4()
# Create batch task for background processing
task = BatchTask(
batch_id=batch_id,
admin_token=admin_token,
zip_content=zip_content,
zip_filename=file.filename,
upload_source=upload_source,
auto_label=auto_label,
created_at=datetime.utcnow(),
)
# Submit to queue
queue = get_batch_queue()
if not queue.submit(task):
raise HTTPException(
status_code=503,
detail="Processing queue is full. Please try again later."
)
logger.info(
f"Batch upload queued: batch_id={batch_id}, "
f"filename={file.filename}, async_mode=True"
)
# Return 202 Accepted with batch_id and status URL
return JSONResponse(
status_code=202,
content={
"status": "accepted",
"batch_id": str(batch_id),
"message": "Batch upload queued for processing",
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
"queue_depth": queue.get_queue_depth(),
}
)
else:
# Sync mode: Process immediately and return results
service = BatchUploadService(batch_repo)
result = service.process_zip_upload(
admin_token=admin_token,
zip_filename=file.filename,
zip_content=zip_content,
upload_source=upload_source,
)
logger.info(
f"Batch upload completed: batch_id={result.get('batch_id')}, "
f"status={result.get('status')}, files={result.get('successful_files')}"
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing batch upload: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail="Failed to process batch upload. Please contact support."
)
@router.get("/status/{batch_id}")
async def get_batch_status(
batch_id: str,
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
) -> dict:
"""Get batch upload status and file processing details.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
batch_repo: Batch upload repository
Returns:
Batch status with file processing details
"""
# Validate UUID format
try:
batch_uuid = UUID(batch_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid batch ID format")
# Check batch exists and verify ownership
batch = batch_repo.get(batch_uuid)
if not batch:
raise HTTPException(status_code=404, detail="Batch not found")
# CRITICAL: Verify ownership
if batch.admin_token != admin_token:
raise HTTPException(
status_code=403,
detail="You do not have access to this batch"
)
# Now safe to return details
service = BatchUploadService(batch_repo)
result = service.get_batch_status(batch_id)
return result
@router.get("/list")
async def list_batch_uploads(
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
limit: int = 50,
offset: int = 0,
) -> dict:
"""List batch uploads for the current admin token.
Args:
admin_token: Admin authentication token
batch_repo: Batch upload repository
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of batch uploads
"""
# Validate pagination parameters
if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
# Get batch uploads filtered by admin token
batches, total = batch_repo.get_paginated(
admin_token=admin_token,
limit=limit,
offset=offset,
)
return {
"batches": [
{
"batch_id": str(b.batch_id),
"filename": b.filename,
"status": b.status,
"total_files": b.total_files,
"successful_files": b.successful_files,
"failed_files": b.failed_files,
"created_at": b.created_at.isoformat() if b.created_at else None,
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
}
for b in batches
],
"total": total,
"limit": limit,
"offset": offset,
}

View File

@@ -0,0 +1,16 @@
"""
Public API v1
Customer-facing endpoints for inference, async processing, and labeling.
"""
from backend.web.api.v1.public.inference import create_inference_router
from backend.web.api.v1.public.async_api import create_async_router, set_async_service
from backend.web.api.v1.public.labeling import create_labeling_router
__all__ = [
"create_inference_router",
"create_async_router",
"set_async_service",
"create_labeling_router",
]

View File

@@ -0,0 +1,372 @@
"""
Async API Routes
FastAPI endpoints for async invoice processing.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from backend.web.dependencies import (
ApiKeyDep,
AsyncDBDep,
PollRateLimitDep,
SubmitRateLimitDep,
)
from backend.web.schemas.inference import (
AsyncRequestItem,
AsyncRequestsListResponse,
AsyncResultResponse,
AsyncStatus,
AsyncStatusResponse,
AsyncSubmitResponse,
DetectionResult,
InferenceResult,
)
from backend.web.schemas.common import ErrorResponse
def _validate_request_id(request_id: str) -> None:
"""Validate that request_id is a valid UUID format."""
try:
UUID(request_id)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid request ID format. Must be a valid UUID.",
)
logger = logging.getLogger(__name__)
# Global reference to async processing service (set during app startup)
_async_service = None
def set_async_service(service) -> None:
"""Set the async processing service instance."""
global _async_service
_async_service = service
def get_async_service():
"""Get the async processing service instance."""
if _async_service is None:
raise RuntimeError("AsyncProcessingService not initialized")
return _async_service
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
"""Create async API router."""
router = APIRouter(prefix="/async", tags=["Async Processing"])
@router.post(
"/submit",
response_model=AsyncSubmitResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid API key"},
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
503: {"model": ErrorResponse, "description": "Queue full"},
},
summary="Submit PDF for async processing",
description="Submit a PDF or image file for asynchronous processing. "
"Returns a request_id that can be used to poll for results.",
)
async def submit_document(
api_key: SubmitRateLimitDep,
file: UploadFile = File(..., description="PDF or image file to process"),
) -> AsyncSubmitResponse:
"""Submit a document for async processing."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate file extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Check file size (get from config via service)
service = get_async_service()
max_size = service._async_config.max_file_size_mb * 1024 * 1024
if len(content) > max_size:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: "
f"{service._async_config.max_file_size_mb}MB",
)
# Submit request
result = service.submit_request(
api_key=api_key,
file_content=content,
filename=file.filename,
content_type=file.content_type or "application/octet-stream",
)
if not result.success:
if "queue" in (result.error or "").lower():
raise HTTPException(status_code=503, detail=result.error)
raise HTTPException(status_code=500, detail=result.error)
return AsyncSubmitResponse(
status="accepted",
message="Request submitted for processing",
request_id=result.request_id,
estimated_wait_seconds=result.estimated_wait_seconds,
poll_url=f"/api/v1/async/status/{result.request_id}",
)
@router.get(
"/status/{request_id}",
response_model=AsyncStatusResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get request status",
description="Get the current processing status of an async request.",
)
async def get_status(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncStatusResponse:
"""Get the status of an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Get queue position for pending requests
position = None
if request.status == "pending":
position = db.get_queue_position(request_id)
# Build result URL for completed requests
result_url = None
if request.status == "completed":
result_url = f"/api/v1/async/result/{request_id}"
return AsyncStatusResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
filename=request.filename,
created_at=request.created_at,
started_at=request.started_at,
completed_at=request.completed_at,
position_in_queue=position,
error_message=request.error_message,
result_url=result_url,
)
@router.get(
"/result/{request_id}",
response_model=AsyncResultResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Request not completed"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get extraction results",
description="Get the extraction results for a completed async request.",
)
async def get_result(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncResultResponse:
"""Get the results of a completed async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Check if completed or failed
if request.status not in ("completed", "failed"):
raise HTTPException(
status_code=409,
detail=f"Request not yet completed. Current status: {request.status}",
)
# Build inference result from stored data
inference_result = None
if request.result:
# Convert detections to DetectionResult objects
detections = []
for d in request.result.get("detections", []):
detections.append(DetectionResult(
field=d.get("field", ""),
confidence=d.get("confidence", 0.0),
bbox=d.get("bbox", [0, 0, 0, 0]),
))
inference_result = InferenceResult(
document_id=request.result.get("document_id", str(request.request_id)[:8]),
success=request.result.get("success", False),
document_type=request.result.get("document_type", "invoice"),
fields=request.result.get("fields", {}),
confidence=request.result.get("confidence", {}),
detections=detections,
processing_time_ms=request.processing_time_ms or 0.0,
errors=request.result.get("errors", []),
)
# Build visualization URL
viz_url = None
if request.visualization_path:
viz_url = f"/api/v1/results/{request.visualization_path}"
return AsyncResultResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
processing_time_ms=request.processing_time_ms or 0.0,
result=inference_result,
visualization_url=viz_url,
)
@router.get(
"/requests",
response_model=AsyncRequestsListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
},
summary="List requests",
description="List all async requests for the authenticated API key.",
)
async def list_requests(
api_key: ApiKeyDep,
db: AsyncDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (pending, processing, completed, failed)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Maximum number of results"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Pagination offset"),
] = 0,
) -> AsyncRequestsListResponse:
"""List all requests for the authenticated API key."""
# Validate status filter
if status and status not in ("pending", "processing", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid status filter: {status}. "
"Must be one of: pending, processing, completed, failed",
)
# Get requests from database
requests, total = db.get_requests_by_api_key(
api_key=api_key,
status=status,
limit=limit,
offset=offset,
)
# Convert to response items
items = [
AsyncRequestItem(
request_id=str(r.request_id),
status=AsyncStatus(r.status),
filename=r.filename,
file_size=r.file_size,
created_at=r.created_at,
completed_at=r.completed_at,
)
for r in requests
]
return AsyncRequestsListResponse(
total=total,
limit=limit,
offset=offset,
requests=items,
)
@router.delete(
"/requests/{request_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
},
summary="Cancel/delete request",
description="Cancel a pending request or delete a completed/failed request.",
)
async def delete_request(
request_id: str,
api_key: ApiKeyDep,
db: AsyncDBDep,
) -> dict:
"""Delete or cancel an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Cannot delete processing requests
if request.status == "processing":
raise HTTPException(
status_code=409,
detail="Cannot delete a request that is currently processing",
)
# Delete from database (will cascade delete related records)
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM async_requests WHERE request_id = %s",
(request_id,),
)
conn.commit()
return {
"status": "deleted",
"request_id": request_id,
"message": "Request deleted successfully",
}
return router

View File

@@ -0,0 +1,194 @@
"""
Inference API Routes
FastAPI route definitions for the inference API.
"""
from __future__ import annotations
import logging
import shutil
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from backend.web.schemas.inference import (
DetectionResult,
HealthResponse,
InferenceResponse,
InferenceResult,
)
from backend.web.schemas.common import ErrorResponse
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from backend.web.services import InferenceService
from backend.web.config import StorageConfig
logger = logging.getLogger(__name__)
def create_inference_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with inference endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["inference"])
@router.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Check service health status."""
return HealthResponse(
status="healthy",
model_loaded=inference_service.is_initialized,
gpu_available=inference_service.gpu_available,
version="1.0.0",
)
@router.post(
"/infer",
response_model=InferenceResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
)
async def infer_document(
file: UploadFile = File(..., description="PDF or image file to process"),
) -> InferenceResponse:
"""
Process a document and extract invoice fields.
Accepts PDF or image files (PNG, JPG, JPEG).
Returns extracted field values with confidence scores.
"""
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Generate document ID
doc_id = str(uuid.uuid4())[:8]
# Get storage helper and uploads directory
storage = get_storage_helper()
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
if uploads_dir is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Storage not configured for local access",
)
# Save uploaded file to temporary location for processing
upload_path = uploads_dir / f"{doc_id}{file_ext}"
try:
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save uploaded file",
)
try:
# Process based on file type
if file_ext == ".pdf":
service_result = inference_service.process_pdf(
upload_path, document_id=doc_id
)
else:
service_result = inference_service.process_image(
upload_path, document_id=doc_id
)
# Build response
viz_url = None
if service_result.visualization_path:
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
document_type=service_result.document_type,
fields=service_result.fields,
confidence=service_result.confidence,
detections=[
DetectionResult(**d) for d in service_result.detections
],
processing_time_ms=service_result.processing_time_ms,
visualization_url=viz_url,
errors=service_result.errors,
)
return InferenceResponse(
status="success" if service_result.success else "partial",
message=f"Processed document {doc_id}",
result=inference_result,
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
finally:
# Cleanup uploaded file
upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}", response_model=None)
async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image."""
storage = get_storage_helper()
file_path = storage.get_result_local_path(filename)
if file_path is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
return FileResponse(
path=file_path,
media_type="image/png",
filename=filename,
)
@router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict:
"""Delete a result file."""
storage = get_storage_helper()
if not storage.result_exists(filename):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
storage.delete_result(filename)
return {"status": "deleted", "filename": filename}
return router

View File

@@ -0,0 +1,197 @@
"""
Labeling API Routes
FastAPI endpoints for pre-labeling documents with expected field values.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from backend.data.repositories import DocumentRepository
from backend.web.schemas.labeling import PreLabelResponse
from backend.web.schemas.common import ErrorResponse
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from backend.web.services import InferenceService
from backend.web.config import StorageConfig
logger = logging.getLogger(__name__)
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, dpi: int
) -> None:
"""Convert PDF pages to images for annotation using StorageHelper."""
import fitz
storage = get_storage_helper()
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
# Save to storage using StorageHelper
image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close()
def get_doc_repository() -> DocumentRepository:
"""Get document repository instance."""
return DocumentRepository()
def create_labeling_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with labeling endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["labeling"])
@router.post(
"/pre-label",
response_model=PreLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
summary="Pre-label document with expected values",
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
)
async def pre_label(
file: UploadFile = File(..., description="PDF or image file to process"),
field_values: str = Form(
...,
description="JSON object with expected field values. "
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
),
doc_repo: DocumentRepository = Depends(get_doc_repository),
) -> PreLabelResponse:
"""
Upload a document with expected field values for pre-labeling.
Returns document_id which can be used to retrieve results later.
Example field_values JSON:
```json
{
"InvoiceNumber": "12345",
"Amount": "1500.00",
"Bankgiro": "123-4567",
"OCR": "1234567890"
}
```
"""
# Parse field_values JSON
try:
expected_values = json.loads(field_values)
if not isinstance(expected_values, dict):
raise ValueError("field_values must be a JSON object")
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON in field_values: {e}",
)
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read file",
)
# Get page count for PDF
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record with field_values
document_id = doc_repo.create(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
upload_source="api",
csv_field_values=expected_values,
)
# Save file to storage using StorageHelper
storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try:
storage_path = storage.save_raw_pdf(content, filename)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save file",
)
# Update file path in database (using storage path)
doc_repo.update_file_path(document_id, storage_path)
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="pending",
)
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
return PreLabelResponse(document_id=document_id)
return router

View File

@@ -0,0 +1,953 @@
"""
FastAPI Application Factory
Creates and configures the FastAPI application.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from backend.web.services import InferenceService
from backend.web.services.storage_helpers import get_storage_helper
# Public API imports
from backend.web.api.v1.public import (
create_inference_router,
create_async_router,
set_async_service,
create_labeling_router,
)
# Async processing imports
from backend.data.async_request_db import AsyncRequestDB
from backend.web.workers.async_queue import AsyncTaskQueue
from backend.web.services.async_processing import AsyncProcessingService
from backend.web.dependencies import init_dependencies
from backend.web.core.rate_limiter import RateLimiter
# Admin API imports
from backend.web.api.v1.admin import (
create_annotation_router,
create_augmentation_router,
create_auth_router,
create_documents_router,
create_locks_router,
create_training_router,
)
from backend.web.api.v1.admin.dashboard import create_dashboard_router
from backend.web.core.scheduler import start_scheduler, stop_scheduler
from backend.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
# Batch upload imports
from backend.web.api.v1.batch.routes import router as batch_upload_router
from backend.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from backend.web.services.batch_upload import BatchUploadService
from backend.data.repositories import ModelVersionRepository
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
def create_app(config: AppConfig | None = None) -> FastAPI:
"""
Create and configure FastAPI application.
Args:
config: Application configuration. Uses default if not provided.
Returns:
Configured FastAPI application
"""
config = config or default_config
# Create model path resolver that reads from database
def get_active_model_path():
"""Resolve active model path from database."""
try:
model_repo = ModelVersionRepository()
active_model = model_repo.get_active()
if active_model and active_model.model_path:
return active_model.model_path
except Exception as e:
logger.warning(f"Failed to get active model from database: {e}")
return None
# Create inference service with database model resolver
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
model_path_resolver=get_active_model_path,
)
# Create async processing components
async_db = AsyncRequestDB()
rate_limiter = RateLimiter(async_db)
task_queue = AsyncTaskQueue(
max_size=config.async_processing.queue_max_size,
worker_count=config.async_processing.worker_count,
)
async_service = AsyncProcessingService(
inference_service=inference_service,
db=async_db,
queue=task_queue,
rate_limiter=rate_limiter,
async_config=config.async_processing,
storage_config=config.storage,
)
# Initialize dependencies for FastAPI
init_dependencies(async_db, rate_limiter)
set_async_service(async_service)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize async request database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.)
try:
from backend.data.database import create_db_and_tables
create_db_and_tables()
logger.info("Admin database tables ready")
except Exception as e:
logger.error(f"Failed to initialize admin database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
logger.info("Inference service ready")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
# Continue anyway - service will retry on first request
# Start async processing service
try:
async_service.start()
logger.info("Async processing service started")
except Exception as e:
logger.error(f"Failed to start async processing: {e}")
# Start batch upload queue
try:
batch_service = BatchUploadService()
init_batch_queue(batch_service)
logger.info("Batch upload queue started")
except Exception as e:
logger.error(f"Failed to start batch upload queue: {e}")
# Start training scheduler
try:
start_scheduler()
logger.info("Training scheduler started")
except Exception as e:
logger.error(f"Failed to start training scheduler: {e}")
# Start auto-label scheduler
try:
start_autolabel_scheduler()
logger.info("AutoLabel scheduler started")
except Exception as e:
logger.error(f"Failed to start autolabel scheduler: {e}")
yield
logger.info("Shutting down Invoice Inference API...")
# Stop auto-label scheduler
try:
stop_autolabel_scheduler()
logger.info("AutoLabel scheduler stopped")
except Exception as e:
logger.error(f"Error stopping autolabel scheduler: {e}")
# Stop training scheduler
try:
stop_scheduler()
logger.info("Training scheduler stopped")
except Exception as e:
logger.error(f"Error stopping training scheduler: {e}")
# Stop batch upload queue
try:
shutdown_batch_queue()
logger.info("Batch upload queue stopped")
except Exception as e:
logger.error(f"Error stopping batch upload queue: {e}")
# Stop async processing service
try:
async_service.stop(timeout=30.0)
logger.info("Async processing service stopped")
except Exception as e:
logger.error(f"Error stopping async service: {e}")
# Close database connection
try:
async_db.close()
logger.info("Database connection closed")
except Exception as e:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
# Store inference service for access by routes (e.g., model reload)
# This will be set after app creation
app = FastAPI(
title="Invoice Field Extraction API",
description="""
REST API for extracting fields from Swedish invoices.
## Features
- YOLO-based field detection
- OCR text extraction
- Field normalization and validation
- Visualization of detections
## Supported Fields
- InvoiceNumber
- InvoiceDate
- InvoiceDueDate
- OCR (reference number)
- Bankgiro
- Plusgiro
- Amount
- supplier_org_number (Swedish organization number)
- customer_number
- payment_line (machine-readable payment code)
""",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files for results using StorageHelper
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir:
app.mount(
"/static/results",
StaticFiles(directory=str(results_dir)),
name="results",
)
else:
logger.warning("Could not mount static results directory: local storage not available")
# Include public API routes
inference_router = create_inference_router(inference_service, config.storage)
app.include_router(inference_router)
async_router = create_async_router(config.storage.allowed_extensions)
app.include_router(async_router, prefix="/api/v1")
labeling_router = create_labeling_router(inference_service, config.storage)
app.include_router(labeling_router)
# Include admin API routes
auth_router = create_auth_router()
app.include_router(auth_router, prefix="/api/v1")
documents_router = create_documents_router(config.storage)
app.include_router(documents_router, prefix="/api/v1")
locks_router = create_locks_router()
app.include_router(locks_router, prefix="/api/v1")
annotation_router = create_annotation_router()
app.include_router(annotation_router, prefix="/api/v1")
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include dashboard routes
dashboard_router = create_dashboard_router()
app.include_router(dashboard_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)
# Store inference service in app state for access by routes
app.state.inference_service = inference_service
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
"""Serve the web UI."""
return get_html_ui()
return app
def get_html_ui() -> str:
"""Generate HTML UI for the web application."""
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Invoice Field Extraction</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
color: white;
margin-bottom: 30px;
}
header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
}
header p {
opacity: 0.9;
font-size: 1.1rem;
}
.main-content {
display: flex;
flex-direction: column;
gap: 20px;
}
.card {
background: white;
border-radius: 16px;
padding: 24px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
font-size: 1.3rem;
display: flex;
align-items: center;
gap: 10px;
}
.upload-card {
display: flex;
align-items: center;
gap: 20px;
flex-wrap: wrap;
}
.upload-card h2 {
margin-bottom: 0;
white-space: nowrap;
}
.upload-area {
border: 2px dashed #ddd;
border-radius: 10px;
padding: 15px 25px;
text-align: center;
cursor: pointer;
transition: all 0.3s;
background: #fafafa;
flex: 1;
min-width: 200px;
}
.upload-area:hover, .upload-area.dragover {
border-color: #667eea;
background: #f0f4ff;
}
.upload-area.has-file {
border-color: #10b981;
background: #ecfdf5;
}
.upload-icon {
font-size: 24px;
display: inline;
margin-right: 8px;
}
.upload-area p {
color: #666;
margin: 0;
display: inline;
}
.upload-area small {
color: #999;
display: block;
margin-top: 5px;
}
#file-input {
display: none;
}
.file-name {
margin-top: 15px;
padding: 10px 15px;
background: #e0f2fe;
border-radius: 8px;
color: #0369a1;
font-weight: 500;
}
.btn {
display: inline-block;
padding: 12px 24px;
border: none;
border-radius: 10px;
font-size: 0.9rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn-primary:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.loading {
display: none;
align-items: center;
gap: 10px;
}
.loading.active {
display: flex;
}
.spinner {
width: 24px;
height: 24px;
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.results {
display: none;
}
.results.active {
display: block;
}
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
padding-bottom: 15px;
border-bottom: 2px solid #eee;
}
.result-status {
padding: 6px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
}
.result-status.success {
background: #dcfce7;
color: #166534;
}
.result-status.partial {
background: #fef3c7;
color: #92400e;
}
.result-status.error {
background: #fee2e2;
color: #991b1b;
}
.fields-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 12px;
}
.field-item {
padding: 12px;
background: #f8fafc;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.field-item label {
display: block;
font-size: 0.75rem;
color: #64748b;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.field-item .value {
font-size: 1.1rem;
font-weight: 600;
color: #1e293b;
}
.field-item .confidence {
font-size: 0.75rem;
color: #10b981;
margin-top: 2px;
}
.visualization {
margin-top: 20px;
}
.visualization img {
width: 100%;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
}
.processing-time {
text-align: center;
color: #64748b;
font-size: 0.9rem;
margin-top: 15px;
}
.cross-validation {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 10px;
padding: 15px;
margin-top: 20px;
}
.cross-validation h3 {
margin: 0 0 10px 0;
color: #334155;
font-size: 1rem;
}
.cv-status {
font-weight: 600;
padding: 8px 12px;
border-radius: 6px;
margin-bottom: 10px;
display: inline-block;
}
.cv-status.valid {
background: #dcfce7;
color: #166534;
}
.cv-status.invalid {
background: #fef3c7;
color: #92400e;
}
.cv-details {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 10px;
}
.cv-item {
background: white;
border: 1px solid #e2e8f0;
border-radius: 6px;
padding: 6px 12px;
font-size: 0.85rem;
display: flex;
align-items: center;
gap: 6px;
}
.cv-item.match {
border-color: #86efac;
background: #f0fdf4;
}
.cv-item.mismatch {
border-color: #fca5a5;
background: #fef2f2;
}
.cv-icon {
font-weight: bold;
}
.cv-item.match .cv-icon {
color: #16a34a;
}
.cv-item.mismatch .cv-icon {
color: #dc2626;
}
.cv-summary {
margin-top: 10px;
font-size: 0.8rem;
color: #64748b;
}
.error-message {
background: #fee2e2;
color: #991b1b;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
footer {
text-align: center;
color: white;
opacity: 0.8;
margin-top: 30px;
font-size: 0.9rem;
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>📄 Invoice Field Extraction</h1>
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
</header>
<div class="main-content">
<!-- Upload Section - Compact -->
<div class="card upload-card">
<h2>📤 Upload</h2>
<div class="upload-area" id="upload-area">
<span class="upload-icon">📁</span>
<p>Drag & drop or <strong>click to browse</strong></p>
<small>PDF, PNG, JPG (max 50MB)</small>
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
</div>
<div class="file-name" id="file-name" style="display: none;"></div>
<button class="btn btn-primary" id="submit-btn" disabled>
🚀 Extract
</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p>Processing...</p>
</div>
</div>
<!-- Results Section - Full Width -->
<div class="card">
<h2>📊 Extraction Results</h2>
<div id="placeholder" style="text-align: center; padding: 30px; color: #999;">
<div style="font-size: 48px; margin-bottom: 10px;">🔍</div>
<p>Upload a document to see extraction results</p>
</div>
<div class="results" id="results">
<div class="result-header">
<span>Document: <strong id="doc-id"></strong></span>
<span class="result-status" id="result-status"></span>
</div>
<div class="fields-grid" id="fields-grid"></div>
<div class="processing-time" id="processing-time"></div>
<div class="cross-validation" id="cross-validation" style="display: none;"></div>
<div class="error-message" id="error-message" style="display: none;"></div>
<div class="visualization" id="visualization" style="display: none;">
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
<img id="viz-image" src="" alt="Detection visualization">
</div>
</div>
</div>
</div>
<footer>
<p>Powered by ColaCoder</p>
</footer>
</div>
<script>
const uploadArea = document.getElementById('upload-area');
const fileInput = document.getElementById('file-input');
const fileName = document.getElementById('file-name');
const submitBtn = document.getElementById('submit-btn');
const loading = document.getElementById('loading');
const placeholder = document.getElementById('placeholder');
const results = document.getElementById('results');
let selectedFile = null;
// Drag and drop handlers
uploadArea.addEventListener('click', () => fileInput.click());
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFile(files[0]);
}
});
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFile(e.target.files[0]);
}
});
function handleFile(file) {
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
const ext = '.' + file.name.split('.').pop().toLowerCase();
if (!validTypes.includes(ext)) {
alert('Please upload a PDF, PNG, or JPG file.');
return;
}
selectedFile = file;
fileName.textContent = `📎 ${file.name}`;
fileName.style.display = 'block';
uploadArea.classList.add('has-file');
submitBtn.disabled = false;
}
submitBtn.addEventListener('click', async () => {
if (!selectedFile) return;
// Show loading
submitBtn.disabled = true;
loading.classList.add('active');
placeholder.style.display = 'none';
results.classList.remove('active');
try {
const formData = new FormData();
formData.append('file', selectedFile);
const response = await fetch('/api/v1/infer', {
method: 'POST',
body: formData,
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Processing failed');
}
displayResults(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('error-message').textContent = error.message;
document.getElementById('error-message').style.display = 'block';
results.classList.add('active');
} finally {
loading.classList.remove('active');
submitBtn.disabled = false;
}
});
function displayResults(data) {
const result = data.result;
// Document ID
document.getElementById('doc-id').textContent = result.document_id;
// Status
const statusEl = document.getElementById('result-status');
statusEl.textContent = result.success ? 'Success' : 'Partial';
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
// Fields
const fieldsGrid = document.getElementById('fields-grid');
fieldsGrid.innerHTML = '';
const fieldOrder = [
'InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR',
'Amount', 'Bankgiro', 'Plusgiro',
'supplier_org_number', 'customer_number', 'payment_line'
];
fieldOrder.forEach(field => {
const value = result.fields[field];
const confidence = result.confidence[field];
if (value !== null && value !== undefined) {
const fieldDiv = document.createElement('div');
fieldDiv.className = 'field-item';
fieldDiv.innerHTML = `
<label>${formatFieldName(field)}</label>
<div class="value">${value}</div>
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
`;
fieldsGrid.appendChild(fieldDiv);
}
});
// Processing time
document.getElementById('processing-time').textContent =
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
// Cross-validation results
const cvDiv = document.getElementById('cross-validation');
if (result.cross_validation) {
const cv = result.cross_validation;
let cvHtml = '<h3>🔍 Cross-Validation (Payment Line)</h3>';
cvHtml += `<div class="cv-status ${cv.is_valid ? 'valid' : 'invalid'}">`;
cvHtml += cv.is_valid ? '✅ Valid' : '⚠️ Mismatch Detected';
cvHtml += '</div>';
cvHtml += '<div class="cv-details">';
if (cv.payment_line_ocr) {
const matchIcon = cv.ocr_match === true ? '' : (cv.ocr_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.ocr_match === true ? 'match' : (cv.ocr_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> OCR: ${cv.payment_line_ocr}</div>`;
}
if (cv.payment_line_amount) {
const matchIcon = cv.amount_match === true ? '' : (cv.amount_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.amount_match === true ? 'match' : (cv.amount_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> Amount: ${cv.payment_line_amount}</div>`;
}
if (cv.payment_line_account) {
const accountType = cv.payment_line_account_type === 'bankgiro' ? 'Bankgiro' : 'Plusgiro';
const matchField = cv.payment_line_account_type === 'bankgiro' ? cv.bankgiro_match : cv.plusgiro_match;
const matchIcon = matchField === true ? '' : (matchField === false ? '' : '');
cvHtml += `<div class="cv-item ${matchField === true ? 'match' : (matchField === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> ${accountType}: ${cv.payment_line_account}</div>`;
}
cvHtml += '</div>';
if (cv.details && cv.details.length > 0) {
cvHtml += '<div class="cv-summary">' + cv.details[cv.details.length - 1] + '</div>';
}
cvDiv.innerHTML = cvHtml;
cvDiv.style.display = 'block';
} else {
cvDiv.style.display = 'none';
}
// Visualization
if (result.visualization_url) {
const vizDiv = document.getElementById('visualization');
const vizImg = document.getElementById('viz-image');
vizImg.src = result.visualization_url;
vizDiv.style.display = 'block';
}
// Errors
if (result.errors && result.errors.length > 0) {
document.getElementById('error-message').textContent = result.errors.join(', ');
document.getElementById('error-message').style.display = 'block';
} else {
document.getElementById('error-message').style.display = 'none';
}
results.classList.add('active');
}
function formatFieldName(name) {
const nameMap = {
'InvoiceNumber': 'Invoice Number',
'InvoiceDate': 'Invoice Date',
'InvoiceDueDate': 'Due Date',
'OCR': 'OCR Reference',
'Amount': 'Amount',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'supplier_org_number': 'Supplier Org Number',
'customer_number': 'Customer Number',
'payment_line': 'Payment Line'
};
return nameMap[name] || name.replace(/([A-Z])/g, ' $1').replace(/_/g, ' ').trim();
}
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,194 @@
"""
Web Application Configuration
Centralized configuration for the web application.
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from shared.config import DEFAULT_DPI
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_storage_backend(
config_path: Path | str | None = None,
) -> "StorageBackend":
"""Get storage backend for file operations.
Args:
config_path: Optional path to storage configuration file.
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
Returns:
Configured StorageBackend instance.
"""
from shared.storage import get_storage_backend as _get_storage_backend
# Check for config file path
if config_path is None:
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
if config_path_str:
config_path = Path(config_path_str)
return _get_storage_backend(config_path=config_path)
@dataclass(frozen=True)
class ModelConfig:
"""YOLO model configuration.
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
- Models need to be accessible by inference service on any platform
- Models may be version-controlled or deployed separately
- Models are part of the application, not user data
"""
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5
use_gpu: bool = True
dpi: int = DEFAULT_DPI
@dataclass(frozen=True)
class ServerConfig:
"""Server configuration."""
host: str = "0.0.0.0"
port: int = 8000
debug: bool = False
reload: bool = False
workers: int = 1
@dataclass(frozen=True)
class FileConfig:
"""File handling configuration.
This config holds file handling settings. For file operations,
use the storage backend with PREFIXES from shared.storage.prefixes.
Example:
from shared.storage import PREFIXES, get_storage_backend
storage = get_storage_backend()
path = PREFIXES.document_path(document_id)
storage.upload_bytes(content, path)
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
They are kept for backward compatibility with existing code and tests.
New code should use the storage backend with PREFIXES instead.
"""
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI
presigned_url_expiry_seconds: int = 3600
# Deprecated path fields - kept for backward compatibility
# New code should use storage backend with PREFIXES instead
# All paths are now under data/ to match WSL storage layout
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
result_dir: Path = field(default_factory=lambda: Path("data/results"))
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
def __post_init__(self) -> None:
"""Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
self.upload_dir.mkdir(parents=True, exist_ok=True)
self.result_dir.mkdir(parents=True, exist_ok=True)
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
# Backward compatibility alias
StorageConfig = FileConfig
@dataclass(frozen=True)
class AsyncConfig:
"""Async processing configuration.
Note: For file paths, use the storage backend with PREFIXES.
Example: PREFIXES.upload_path(filename, "async")
"""
# Queue settings
queue_max_size: int = 100
worker_count: int = 1
task_timeout_seconds: int = 300
# Rate limiting defaults
default_requests_per_minute: int = 10
default_max_concurrent_jobs: int = 3
default_min_poll_interval_ms: int = 1000
# Storage
result_retention_days: int = 7
max_file_size_mb: int = 50
# Deprecated: kept for backward compatibility
# Path under data/ to match WSL storage layout
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
# Cleanup
cleanup_interval_hours: int = 1
def __post_init__(self) -> None:
"""Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
file: FileConfig = field(default_factory=FileConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
storage_backend: "StorageBackend | None" = None
@property
def storage(self) -> FileConfig:
"""Backward compatibility alias for file config."""
return self.file
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary."""
file_config = config_dict.get("file", config_dict.get("storage", {}))
return cls(
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
file=FileConfig(**file_config),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
)
def create_app_config(
storage_config_path: Path | str | None = None,
) -> AppConfig:
"""Create application configuration with storage backend.
Args:
storage_config_path: Optional path to storage configuration file.
Returns:
Configured AppConfig instance with storage backend initialized.
"""
storage_backend = get_storage_backend(config_path=storage_config_path)
return AppConfig(storage_backend=storage_backend)
# Default configuration instance
default_config = AppConfig()

View File

@@ -0,0 +1,61 @@
"""
Core Components
Reusable core functionality: authentication, rate limiting, scheduling.
"""
from backend.web.core.auth import (
validate_admin_token,
get_token_repository,
get_document_repository,
get_annotation_repository,
get_dataset_repository,
get_training_task_repository,
get_model_version_repository,
get_batch_upload_repository,
AdminTokenDep,
TokenRepoDep,
DocumentRepoDep,
AnnotationRepoDep,
DatasetRepoDep,
TrainingTaskRepoDep,
ModelVersionRepoDep,
BatchUploadRepoDep,
)
from backend.web.core.rate_limiter import RateLimiter
from backend.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from backend.web.core.autolabel_scheduler import (
start_autolabel_scheduler,
stop_autolabel_scheduler,
get_autolabel_scheduler,
)
from backend.web.core.task_interface import TaskRunner, TaskStatus, TaskManager
__all__ = [
"validate_admin_token",
"get_token_repository",
"get_document_repository",
"get_annotation_repository",
"get_dataset_repository",
"get_training_task_repository",
"get_model_version_repository",
"get_batch_upload_repository",
"AdminTokenDep",
"TokenRepoDep",
"DocumentRepoDep",
"AnnotationRepoDep",
"DatasetRepoDep",
"TrainingTaskRepoDep",
"ModelVersionRepoDep",
"BatchUploadRepoDep",
"RateLimiter",
"start_scheduler",
"stop_scheduler",
"get_training_scheduler",
"start_autolabel_scheduler",
"stop_autolabel_scheduler",
"get_autolabel_scheduler",
"TaskRunner",
"TaskStatus",
"TaskManager",
]

View File

@@ -0,0 +1,115 @@
"""
Admin Authentication
FastAPI dependencies for admin token authentication and repository access.
"""
from functools import lru_cache
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from backend.data.repositories import (
TokenRepository,
DocumentRepository,
AnnotationRepository,
DatasetRepository,
TrainingTaskRepository,
ModelVersionRepository,
BatchUploadRepository,
)
@lru_cache(maxsize=1)
def get_token_repository() -> TokenRepository:
"""Get the TokenRepository instance (thread-safe singleton)."""
return TokenRepository()
def reset_token_repository() -> None:
"""Reset the TokenRepository instance (for testing)."""
get_token_repository.cache_clear()
async def validate_admin_token(
x_admin_token: Annotated[str | None, Header()] = None,
token_repo: TokenRepository = Depends(get_token_repository),
) -> str:
"""Validate admin token from header."""
if not x_admin_token:
raise HTTPException(
status_code=401,
detail="Admin token required. Provide X-Admin-Token header.",
)
if not token_repo.is_valid(x_admin_token):
raise HTTPException(
status_code=401,
detail="Invalid or expired admin token.",
)
# Update last used timestamp
token_repo.update_usage(x_admin_token)
return x_admin_token
# Type alias for dependency injection
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
TokenRepoDep = Annotated[TokenRepository, Depends(get_token_repository)]
@lru_cache(maxsize=1)
def get_document_repository() -> DocumentRepository:
"""Get the DocumentRepository instance (thread-safe singleton)."""
return DocumentRepository()
@lru_cache(maxsize=1)
def get_annotation_repository() -> AnnotationRepository:
"""Get the AnnotationRepository instance (thread-safe singleton)."""
return AnnotationRepository()
@lru_cache(maxsize=1)
def get_dataset_repository() -> DatasetRepository:
"""Get the DatasetRepository instance (thread-safe singleton)."""
return DatasetRepository()
@lru_cache(maxsize=1)
def get_training_task_repository() -> TrainingTaskRepository:
"""Get the TrainingTaskRepository instance (thread-safe singleton)."""
return TrainingTaskRepository()
@lru_cache(maxsize=1)
def get_model_version_repository() -> ModelVersionRepository:
"""Get the ModelVersionRepository instance (thread-safe singleton)."""
return ModelVersionRepository()
@lru_cache(maxsize=1)
def get_batch_upload_repository() -> BatchUploadRepository:
"""Get the BatchUploadRepository instance (thread-safe singleton)."""
return BatchUploadRepository()
def reset_all_repositories() -> None:
"""Reset all repository instances (for testing)."""
get_token_repository.cache_clear()
get_document_repository.cache_clear()
get_annotation_repository.cache_clear()
get_dataset_repository.cache_clear()
get_training_task_repository.cache_clear()
get_model_version_repository.cache_clear()
get_batch_upload_repository.cache_clear()
# Repository dependency type aliases
DocumentRepoDep = Annotated[DocumentRepository, Depends(get_document_repository)]
AnnotationRepoDep = Annotated[AnnotationRepository, Depends(get_annotation_repository)]
DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]

View File

@@ -0,0 +1,202 @@
"""
Auto-Label Scheduler
Background scheduler for processing documents pending auto-labeling.
"""
import logging
import threading
from pathlib import Path
from backend.data.repositories import DocumentRepository, AnnotationRepository
from backend.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
)
from backend.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
class AutoLabelScheduler(TaskRunner):
"""Scheduler for auto-labeling tasks."""
def __init__(
self,
check_interval_seconds: int = 10,
batch_size: int = 5,
output_dir: Path | None = None,
):
"""
Initialize auto-label scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
batch_size: Number of documents to process per batch
output_dir: Output directory for temporary files
"""
self._check_interval = check_interval_seconds
self._batch_size = batch_size
# Get output directory from StorageHelper
if output_dir is None:
storage = get_storage_helper()
output_dir = storage.get_autolabel_output_path()
self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._lock = threading.Lock()
self._doc_repo = DocumentRepository()
self._ann_repo = AnnotationRepository()
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "autolabel_scheduler"
@property
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
def get_status(self) -> TaskStatus:
"""Get current status of the scheduler."""
try:
pending_docs = get_pending_autolabel_documents(limit=1000)
pending_count = len(pending_docs)
except Exception:
pending_count = 0
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=pending_count,
processing_count=1 if self._running else 0,
)
def start(self) -> None:
"""Start the scheduler."""
with self._lock:
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self, timeout: float | None = None) -> None:
"""Stop the scheduler.
Args:
timeout: Maximum time to wait for graceful shutdown.
If None, uses default of 5 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
thread_to_join = self._thread
effective_timeout = timeout if timeout is not None else 5.0
if thread_to_join:
thread_to_join.join(timeout=effective_timeout)
with self._lock:
self._thread = None
logger.info("AutoLabel scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._process_pending_documents()
except Exception as e:
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _process_pending_documents(self) -> None:
"""Check and process pending auto-label documents."""
try:
documents = get_pending_autolabel_documents(limit=self._batch_size)
if not documents:
return
logger.info(f"Processing {len(documents)} pending autolabel documents")
for doc in documents:
if self._stop_event.is_set():
break
try:
result = process_document_autolabel(
document=doc,
output_dir=self._output_dir,
doc_repo=self._doc_repo,
ann_repo=self._ann_repo,
)
if result.get("success"):
logger.info(
f"AutoLabel completed for document {doc.document_id}"
)
else:
logger.warning(
f"AutoLabel failed for document {doc.document_id}: "
f"{result.get('error', 'Unknown error')}"
)
except Exception as e:
logger.error(
f"Error processing document {doc.document_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
# Global scheduler instance
_autolabel_scheduler: AutoLabelScheduler | None = None
_autolabel_lock = threading.Lock()
def get_autolabel_scheduler() -> AutoLabelScheduler:
"""Get the auto-label scheduler instance.
Uses double-checked locking pattern for thread safety.
"""
global _autolabel_scheduler
if _autolabel_scheduler is None:
with _autolabel_lock:
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
return _autolabel_scheduler
def start_autolabel_scheduler() -> None:
"""Start the global auto-label scheduler."""
scheduler = get_autolabel_scheduler()
scheduler.start()
def stop_autolabel_scheduler() -> None:
"""Stop the global auto-label scheduler."""
global _autolabel_scheduler
if _autolabel_scheduler:
_autolabel_scheduler.stop()
_autolabel_scheduler = None

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,571 @@
"""
Admin Training Scheduler
Background scheduler for training tasks using APScheduler.
"""
import logging
import threading
from datetime import datetime
from pathlib import Path
from typing import Any
from backend.data.repositories import (
TrainingTaskRepository,
DatasetRepository,
ModelVersionRepository,
DocumentRepository,
AnnotationRepository,
)
from backend.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
class TrainingScheduler(TaskRunner):
"""Scheduler for training tasks."""
def __init__(
self,
check_interval_seconds: int = 60,
):
"""
Initialize training scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
"""
self._check_interval = check_interval_seconds
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._lock = threading.Lock()
# Repositories
self._training_tasks = TrainingTaskRepository()
self._datasets = DatasetRepository()
self._model_versions = ModelVersionRepository()
self._documents = DocumentRepository()
self._annotations = AnnotationRepository()
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "training_scheduler"
@property
def is_running(self) -> bool:
"""Check if the scheduler is currently active."""
return self._running
def get_status(self) -> TaskStatus:
"""Get current status of the scheduler."""
try:
pending_tasks = self._training_tasks.get_pending()
pending_count = len(pending_tasks)
except Exception:
pending_count = 0
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=pending_count,
processing_count=1 if self._running else 0,
)
def start(self) -> None:
"""Start the scheduler."""
with self._lock:
if self._running:
logger.warning("Training scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
def stop(self, timeout: float | None = None) -> None:
"""Stop the scheduler.
Args:
timeout: Maximum time to wait for graceful shutdown.
If None, uses default of 5 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
thread_to_join = self._thread
effective_timeout = timeout if timeout is not None else 5.0
if thread_to_join:
thread_to_join.join(timeout=effective_timeout)
with self._lock:
self._thread = None
logger.info("Training scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._check_pending_tasks()
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _check_pending_tasks(self) -> None:
"""Check and execute pending training tasks."""
try:
tasks = self._training_tasks.get_pending()
for task in tasks:
task_id = str(task.task_id)
# Check if scheduled time has passed
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
continue
logger.info(f"Starting training task: {task_id}")
try:
dataset_id = getattr(task, "dataset_id", None)
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._training_tasks.update_status(
task_id=task_id,
status="failed",
error_message=str(e),
)
except Exception as e:
logger.error(f"Error checking pending tasks: {e}")
def _execute_task(
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
) -> None:
"""Execute a training task."""
# Update status to running
self._training_tasks.update_status(task_id, "running")
self._training_tasks.add_log(task_id, "INFO", "Training task started")
# Update dataset training status to running
if dataset_id:
self._datasets.update_training_status(
dataset_id,
training_status="running",
active_training_task_id=task_id,
)
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
learning_rate = config.get("learning_rate", 0.01)
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Get augmentation config if present
augmentation_config = config.get("augmentation")
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
self._training_tasks.add_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
)
else:
# Train from pretrained model
effective_model = model_name
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._datasets.get(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
dataset_path = Path(dataset.dataset_path)
self._training_tasks.add_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
)
else:
export_result = self._export_training_data(task_id)
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
dataset_path = Path(data_yaml).parent
self._training_tasks.add_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Apply augmentation if config is provided
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
aug_result = self._apply_augmentation(
task_id, dataset_path, augmentation_config, augmentation_multiplier
)
if aug_result:
self._training_tasks.add_log(
task_id, "INFO",
f"Augmentation complete: {aug_result['augmented_images']} new images "
f"(total: {aug_result['total_images']})",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=effective_model, # Use base model or pretrained model
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project_name=project_name,
)
# Update task with results
self._training_tasks.update_status(
task_id=task_id,
status="completed",
result_metrics=result.get("metrics"),
model_path=result.get("model_path"),
)
self._training_tasks.add_log(task_id, "INFO", "Training completed successfully")
# Update dataset training status to completed and main status to trained
if dataset_id:
self._datasets.update_training_status(
dataset_id,
training_status="completed",
active_training_task_id=None,
update_main_status=True, # Set main status to 'trained'
)
# Auto-create model version for the completed training
self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
# Update dataset training status to failed
if dataset_id:
self._datasets.update_training_status(
dataset_id,
training_status="failed",
active_training_task_id=None,
)
raise
def _create_model_version_from_training(
self,
task_id: str,
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
# Get task info for name
task = self._training_tasks.get(task_id)
task_name = task.name if task else f"Task {task_id[:8]}"
# Generate version number based on existing versions
existing_versions = self._model_versions.get_paginated(limit=1, offset=0)
version_count = existing_versions[1] if existing_versions else 0
version = f"v{version_count + 1}.0"
# Extract metrics from result
metrics = result.get("metrics", {})
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
metrics_precision = metrics.get("precision")
metrics_recall = metrics.get("recall")
# Get file size if possible
file_size = None
model_file = Path(model_path)
if model_file.exists():
file_size = model_file.stat().st_size
# Get document count from dataset if available
document_count = 0
if dataset_id:
dataset = self._datasets.get(dataset_id)
if dataset:
document_count = dataset.total_documents
# Create model version
model_version = self._model_versions.create(
version=version,
name=task_name,
model_path=str(model_path),
description=f"Auto-created from training task {task_id[:8]}",
task_id=task_id,
dataset_id=dataset_id,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
)
logger.info(
f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}"
)
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
self._training_tasks.add_log(
task_id, "INFO",
f"Model version {version} created (mAP: {mAP_display})",
)
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
from shared.fields import FIELD_CLASSES
from backend.web.services.storage_helpers import get_storage_helper
# Get storage helper for reading images
storage = get_storage_helper()
# Get all labeled documents
documents = self._documents.get_labeled_for_export()
if not documents:
self._training_tasks.add_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory using StorageHelper
training_base = storage.get_training_data_path()
if training_base is None:
self._training_tasks.add_log(task_id, "ERROR", "Storage not configured for local access")
return None
export_dir = training_base / task_id
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
# 80/20 train/val split
total_docs = len(documents)
train_count = int(total_docs * 0.8)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = self._annotations.get_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
# Get image from storage
doc_id = str(doc.document_id)
if not storage.admin_image_exists(doc_id, page_num):
continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1
# Write YOLO label
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
# Create data.yaml
yaml_path = export_dir / "data.yaml"
yaml_content = f"""path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
yaml_path.write_text(yaml_content)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _run_yolo_training(
self,
task_id: str,
model_name: str,
data_yaml: str,
epochs: int,
batch_size: int,
image_size: int,
learning_rate: float,
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
# Create log callback that writes to DB
def log_callback(level: str, message: str) -> None:
self._training_tasks.add_log(task_id, level, message)
# Create shared training config
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
# because models need to be accessible by inference service on any platform
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig(
model_path=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
)
# Run training using shared trainer
trainer = YOLOTrainer(config=config, log_callback=log_callback)
result = trainer.train()
if not result.success:
raise ValueError(result.error or "Training failed")
return {
"model_path": result.model_path,
"metrics": result.metrics,
}
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
"""Check if any augmentations are enabled in the config."""
augmentation_fields = [
"perspective_warp", "wrinkle", "edge_damage", "stain",
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
]
for field in augmentation_fields:
if field in aug_config:
field_config = aug_config[field]
if isinstance(field_config, dict) and field_config.get("enabled", False):
return True
return False
def _apply_augmentation(
self,
task_id: str,
dataset_path: Path,
aug_config: dict[str, Any],
multiplier: int,
) -> dict[str, int] | None:
"""Apply augmentation to dataset before training."""
try:
from shared.augmentation import DatasetAugmenter
self._training_tasks.add_log(
task_id, "INFO",
f"Applying augmentation with multiplier={multiplier}",
)
augmenter = DatasetAugmenter(aug_config)
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
return result
except Exception as e:
logger.error(f"Augmentation failed for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Augmentation failed: {e}. Continuing with original dataset.",
)
return None
# Global scheduler instance
_scheduler: TrainingScheduler | None = None
_scheduler_lock = threading.Lock()
def get_training_scheduler() -> TrainingScheduler:
"""Get the training scheduler instance.
Uses double-checked locking pattern for thread safety.
"""
global _scheduler
if _scheduler is None:
with _scheduler_lock:
if _scheduler is None:
_scheduler = TrainingScheduler()
return _scheduler
def start_scheduler() -> None:
"""Start the global training scheduler."""
scheduler = get_training_scheduler()
scheduler.start()
def stop_scheduler() -> None:
"""Stop the global training scheduler."""
global _scheduler
if _scheduler:
_scheduler.stop()
_scheduler = None

View File

@@ -0,0 +1,161 @@
"""Unified task management interface.
Provides abstract base class for all task runners (schedulers and queues)
and a TaskManager facade for unified lifecycle management.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass(frozen=True)
class TaskStatus:
"""Status of a task runner.
Attributes:
name: Unique identifier for the runner.
is_running: Whether the runner is currently active.
pending_count: Number of tasks waiting to be processed.
processing_count: Number of tasks currently being processed.
error: Optional error message if runner is in error state.
"""
name: str
is_running: bool
pending_count: int
processing_count: int
error: str | None = None
class TaskRunner(ABC):
"""Abstract base class for all task runners.
All schedulers and task queues should implement this interface
to enable unified lifecycle management and monitoring.
Note:
Implementations may have different `start()` signatures based on
their initialization needs (e.g., handler functions, services).
Use the implementation-specific start methods for initialization,
and use TaskManager for unified status monitoring.
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this runner."""
pass
@abstractmethod
def start(self, *args, **kwargs) -> None:
"""Start the task runner.
Should be idempotent - calling start on an already running
runner should have no effect.
Note:
Implementations may require additional parameters (handlers,
services). See implementation-specific documentation.
"""
pass
@abstractmethod
def stop(self, timeout: float | None = None) -> None:
"""Stop the task runner gracefully.
Args:
timeout: Maximum time to wait for graceful shutdown in seconds.
If None, use implementation default.
"""
pass
@property
@abstractmethod
def is_running(self) -> bool:
"""Check if the runner is currently active."""
pass
@abstractmethod
def get_status(self) -> TaskStatus:
"""Get current status of the runner.
Returns:
TaskStatus with current state information.
"""
pass
class TaskManager:
"""Unified manager for all task runners.
Provides centralized lifecycle management and monitoring
for all registered task runners.
"""
def __init__(self) -> None:
"""Initialize the task manager."""
self._runners: dict[str, TaskRunner] = {}
def register(self, runner: TaskRunner) -> None:
"""Register a task runner.
Args:
runner: TaskRunner instance to register.
"""
self._runners[runner.name] = runner
def get_runner(self, name: str) -> TaskRunner | None:
"""Get a specific runner by name.
Args:
name: Name of the runner to retrieve.
Returns:
TaskRunner if found, None otherwise.
"""
return self._runners.get(name)
@property
def runner_names(self) -> list[str]:
"""Get names of all registered runners.
Returns:
List of runner names.
"""
return list(self._runners.keys())
def start_all(self) -> None:
"""Start all registered runners that support no-argument start.
Note:
Runners requiring initialization parameters (like AsyncTaskQueue
or BatchTaskQueue) should be started individually before
registering with TaskManager.
"""
for runner in self._runners.values():
try:
runner.start()
except TypeError:
# Runner requires arguments - skip (should be started individually)
pass
def stop_all(self, timeout: float = 30.0) -> None:
"""Stop all registered runners gracefully.
Args:
timeout: Total timeout to distribute across all runners.
"""
if not self._runners:
return
per_runner_timeout = timeout / len(self._runners)
for runner in self._runners.values():
runner.stop(timeout=per_runner_timeout)
def get_all_status(self) -> dict[str, TaskStatus]:
"""Get status of all registered runners.
Returns:
Dict mapping runner names to their status.
"""
return {name: runner.get_status() for name, runner in self._runners.items()}

View File

@@ -0,0 +1,133 @@
"""
FastAPI Dependencies
Dependency injection for the async API endpoints.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException, Request
from backend.data.async_request_db import AsyncRequestDB
from backend.web.rate_limiter import RateLimiter
logger = logging.getLogger(__name__)
# Global instances (initialized in app startup)
_async_db: AsyncRequestDB | None = None
_rate_limiter: RateLimiter | None = None
def init_dependencies(db: AsyncRequestDB, rate_limiter: RateLimiter) -> None:
"""Initialize global dependency instances."""
global _async_db, _rate_limiter
_async_db = db
_rate_limiter = rate_limiter
def get_async_db() -> AsyncRequestDB:
"""Get async request database instance."""
if _async_db is None:
raise RuntimeError("AsyncRequestDB not initialized")
return _async_db
def get_rate_limiter() -> RateLimiter:
"""Get rate limiter instance."""
if _rate_limiter is None:
raise RuntimeError("RateLimiter not initialized")
return _rate_limiter
async def verify_api_key(
x_api_key: Annotated[str | None, Header()] = None,
) -> str:
"""
Verify API key exists and is active.
Raises:
HTTPException: 401 if API key is missing or invalid
"""
if not x_api_key:
raise HTTPException(
status_code=401,
detail="X-API-Key header is required",
headers={"WWW-Authenticate": "API-Key"},
)
db = get_async_db()
if not db.is_valid_api_key(x_api_key):
raise HTTPException(
status_code=401,
detail="Invalid or inactive API key",
headers={"WWW-Authenticate": "API-Key"},
)
# Update usage tracking
try:
db.update_api_key_usage(x_api_key)
except Exception as e:
logger.warning(f"Failed to update API key usage: {e}")
return x_api_key
async def check_submit_rate_limit(
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check rate limit before processing submit request.
Raises:
HTTPException: 429 if rate limit exceeded
"""
rate_limiter = get_rate_limiter()
status = rate_limiter.check_submit_limit(api_key)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Rate limit exceeded",
headers=headers,
)
return api_key
async def check_poll_rate_limit(
request: Request,
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check poll rate limit to prevent abuse.
Raises:
HTTPException: 429 if polling too frequently
"""
# Extract request_id from path parameters
request_id = request.path_params.get("request_id")
if not request_id:
return api_key # No request_id, skip poll limit check
rate_limiter = get_rate_limiter()
status = rate_limiter.check_poll_limit(api_key, request_id)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Polling too frequently",
headers=headers,
)
return api_key
# Type aliases for cleaner route signatures
ApiKeyDep = Annotated[str, Depends(verify_api_key)]
SubmitRateLimitDep = Annotated[str, Depends(check_submit_rate_limit)]
PollRateLimitDep = Annotated[str, Depends(check_poll_rate_limit)]
AsyncDBDep = Annotated[AsyncRequestDB, Depends(get_async_db)]
RateLimiterDep = Annotated[RateLimiter, Depends(get_rate_limiter)]

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,11 @@
"""
API Schemas
Pydantic models for request/response validation.
"""
# Import everything from sub-modules for backward compatibility
from backend.web.schemas.common import * # noqa: F401, F403
from backend.web.schemas.admin import * # noqa: F401, F403
from backend.web.schemas.inference import * # noqa: F401, F403
from backend.web.schemas.labeling import * # noqa: F401, F403

View File

@@ -0,0 +1,19 @@
"""
Admin API Request/Response Schemas
Pydantic models for admin API validation and serialization.
"""
from .enums import * # noqa: F401, F403
from .auth import * # noqa: F401, F403
from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .dashboard import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse
DocumentDetailResponse.model_rebuild()

View File

@@ -0,0 +1,152 @@
"""Admin Annotation Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .enums import AnnotationSource
class BoundingBox(BaseModel):
"""Bounding box coordinates."""
x: int = Field(..., ge=0, description="X coordinate (pixels)")
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
width: int = Field(..., ge=1, description="Width (pixels)")
height: int = Field(..., ge=1, description="Height (pixels)")
class AnnotationCreate(BaseModel):
"""Request to create an annotation."""
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
text_value: str | None = Field(None, description="Text value (optional)")
class AnnotationUpdate(BaseModel):
"""Request to update an annotation."""
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
bbox: BoundingBox | None = Field(None, description="New bounding box")
text_value: str | None = Field(None, description="New text value")
class AnnotationItem(BaseModel):
"""Single annotation item."""
annotation_id: str = Field(..., description="Annotation UUID")
page_number: int = Field(..., ge=1, description="Page number")
class_id: int = Field(..., ge=0, le=9, description="Class ID")
class_name: str = Field(..., description="Class name")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
normalized_bbox: dict[str, float] = Field(
..., description="Normalized bbox (x_center, y_center, width, height)"
)
text_value: str | None = Field(None, description="Text value")
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
source: AnnotationSource = Field(..., description="Annotation source")
created_at: datetime = Field(..., description="Creation timestamp")
class AnnotationResponse(BaseModel):
"""Response for annotation operation."""
annotation_id: str = Field(..., description="Annotation UUID")
message: str = Field(..., description="Status message")
class AnnotationListResponse(BaseModel):
"""Response for annotation list."""
document_id: str = Field(..., description="Document UUID")
page_count: int = Field(..., ge=1, description="Total pages")
total_annotations: int = Field(..., ge=0, description="Total annotations")
annotations: list[AnnotationItem] = Field(
default_factory=list, description="Annotation list"
)
class AnnotationLockRequest(BaseModel):
"""Request to acquire annotation lock."""
duration_seconds: int = Field(
default=300,
ge=60,
le=3600,
description="Lock duration in seconds (60-3600)",
)
class AnnotationLockResponse(BaseModel):
"""Response for annotation lock operation."""
document_id: str = Field(..., description="Document UUID")
locked: bool = Field(..., description="Whether lock was acquired/released")
lock_expires_at: datetime | None = Field(
None, description="Lock expiration time"
)
message: str = Field(..., description="Status message")
class AutoLabelRequest(BaseModel):
"""Request to trigger auto-labeling."""
field_values: dict[str, str] = Field(
...,
description="Field values to match (e.g., {'invoice_number': '12345'})",
)
replace_existing: bool = Field(
default=False, description="Replace existing auto annotations"
)
class AutoLabelResponse(BaseModel):
"""Response for auto-labeling."""
document_id: str = Field(..., description="Document UUID")
status: str = Field(..., description="Auto-labeling status")
annotations_created: int = Field(
default=0, ge=0, description="Number of annotations created"
)
message: str = Field(..., description="Status message")
class AnnotationVerifyRequest(BaseModel):
"""Request to verify an annotation."""
pass # No body needed, just POST to verify
class AnnotationVerifyResponse(BaseModel):
"""Response for annotation verification."""
annotation_id: str = Field(..., description="Annotation UUID")
is_verified: bool = Field(..., description="Verification status")
verified_at: datetime = Field(..., description="Verification timestamp")
verified_by: str = Field(..., description="Admin token who verified")
message: str = Field(..., description="Status message")
class AnnotationOverrideRequest(BaseModel):
"""Request to override an annotation."""
bbox: dict[str, int] | None = Field(
None, description="Updated bounding box {x, y, width, height}"
)
text_value: str | None = Field(None, description="Updated text value")
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
class_name: str | None = Field(None, description="Updated class name")
reason: str | None = Field(None, description="Reason for override")
class AnnotationOverrideResponse(BaseModel):
"""Response for annotation override."""
annotation_id: str = Field(..., description="Annotation UUID")
source: str = Field(..., description="New source (manual)")
override_source: str | None = Field(None, description="Original source (auto)")
original_annotation_id: str | None = Field(None, description="Original annotation ID")
message: str = Field(..., description="Status message")
history_id: str = Field(..., description="History record UUID")

View File

@@ -0,0 +1,187 @@
"""Admin Augmentation Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AugmentationParamsSchema(BaseModel):
"""Single augmentation parameters."""
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
probability: float = Field(
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
)
params: dict[str, Any] = Field(
default_factory=dict, description="Type-specific parameters"
)
class AugmentationConfigSchema(BaseModel):
"""Complete augmentation configuration."""
# Geometric transforms
perspective_warp: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Degradation effects
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
edge_damage: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Lighting effects
lighting_variation: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Blur effects
gaussian_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
motion_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Noise effects
gaussian_noise: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
salt_pepper: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Texture effects
paper_texture: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
scanner_artifacts: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Global settings
preserve_bboxes: bool = Field(
default=True, description="Whether to adjust bboxes for geometric transforms"
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
class AugmentationTypeInfo(BaseModel):
"""Information about an augmentation type."""
name: str = Field(..., description="Augmentation name")
description: str = Field(..., description="Augmentation description")
affects_geometry: bool = Field(
..., description="Whether this augmentation affects bbox coordinates"
)
stage: str = Field(..., description="Processing stage")
default_params: dict[str, Any] = Field(
default_factory=dict, description="Default parameters"
)
class AugmentationTypesResponse(BaseModel):
"""Response for listing augmentation types."""
augmentation_types: list[AugmentationTypeInfo] = Field(
..., description="Available augmentation types"
)
class PresetInfo(BaseModel):
"""Information about a preset."""
name: str = Field(..., description="Preset name")
description: str = Field(..., description="Preset description")
class PresetsResponse(BaseModel):
"""Response for listing presets."""
presets: list[PresetInfo] = Field(..., description="Available presets")
class AugmentationPreviewRequest(BaseModel):
"""Request to preview augmentation on an image."""
augmentation_type: str = Field(..., description="Type of augmentation to preview")
params: dict[str, Any] = Field(
default_factory=dict, description="Override parameters"
)
class AugmentationPreviewResponse(BaseModel):
"""Response with preview image data."""
preview_url: str = Field(..., description="URL to preview image")
original_url: str = Field(..., description="URL to original image")
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
class AugmentationBatchRequest(BaseModel):
"""Request to augment a dataset offline."""
dataset_id: str = Field(..., description="Source dataset UUID")
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
output_name: str = Field(
..., min_length=1, max_length=255, description="Output dataset name"
)
multiplier: int = Field(
default=2, ge=1, le=10, description="Augmented copies per image"
)
class AugmentationBatchResponse(BaseModel):
"""Response for batch augmentation."""
task_id: str = Field(..., description="Background task UUID")
status: str = Field(..., description="Task status")
message: str = Field(..., description="Status message")
estimated_images: int = Field(..., description="Estimated total images")
class AugmentedDatasetItem(BaseModel):
"""Single augmented dataset in list."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
created_at: datetime = Field(..., description="Creation timestamp")
class AugmentedDatasetListResponse(BaseModel):
"""Response for listing augmented datasets."""
total: int = Field(..., ge=0, description="Total datasets")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
datasets: list[AugmentedDatasetItem] = Field(
default_factory=list, description="Dataset list"
)
class AugmentedDatasetDetailResponse(BaseModel):
"""Detailed augmented dataset response."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
config: AugmentationConfigSchema | None = Field(
None, description="Augmentation config used"
)
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
dataset_path: str | None = Field(None, description="Dataset path on disk")
error_message: str | None = Field(None, description="Error message if failed")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")

View File

@@ -0,0 +1,23 @@
"""Admin Auth Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class AdminTokenCreate(BaseModel):
"""Request to create an admin token."""
name: str = Field(..., min_length=1, max_length=255, description="Token name")
expires_in_days: int | None = Field(
None, ge=1, le=365, description="Token expiration in days (optional)"
)
class AdminTokenResponse(BaseModel):
"""Response with created admin token."""
token: str = Field(..., description="Admin token")
name: str = Field(..., description="Token name")
expires_at: datetime | None = Field(None, description="Token expiration time")
message: str = Field(..., description="Status message")

View File

@@ -0,0 +1,92 @@
"""
Dashboard API Schemas
Pydantic models for dashboard statistics and activity endpoints.
"""
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, Field
# Activity type literals for type safety
ActivityType = Literal[
"document_uploaded",
"annotation_modified",
"training_completed",
"training_failed",
"model_activated",
]
class DashboardStatsResponse(BaseModel):
"""Response for dashboard statistics."""
total_documents: int = Field(..., description="Total number of documents")
annotation_complete: int = Field(
..., description="Documents with complete annotations"
)
annotation_incomplete: int = Field(
..., description="Documents with incomplete annotations"
)
pending: int = Field(..., description="Documents pending processing")
completeness_rate: float = Field(
..., description="Annotation completeness percentage"
)
class ActiveModelInfo(BaseModel):
"""Active model information."""
version_id: str = Field(..., description="Model version UUID")
version: str = Field(..., description="Model version string")
name: str = Field(..., description="Model name")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision score")
metrics_recall: float | None = Field(None, description="Recall score")
document_count: int = Field(0, description="Number of training documents")
activated_at: datetime | None = Field(None, description="Activation timestamp")
class RunningTrainingInfo(BaseModel):
"""Running training task information."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
status: str = Field(..., description="Training status")
started_at: datetime | None = Field(None, description="Start timestamp")
progress: int = Field(0, description="Training progress percentage")
class DashboardActiveModelResponse(BaseModel):
"""Response for dashboard active model endpoint."""
model: ActiveModelInfo | None = Field(
None, description="Active model info, null if none"
)
running_training: RunningTrainingInfo | None = Field(
None, description="Running training task, null if none"
)
class ActivityItem(BaseModel):
"""Single activity item."""
type: ActivityType = Field(
...,
description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated",
)
description: str = Field(..., description="Human-readable description")
timestamp: datetime = Field(..., description="Activity timestamp")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class RecentActivityResponse(BaseModel):
"""Response for recent activity endpoint."""
activities: list[ActivityItem] = Field(
default_factory=list, description="List of recent activities"
)

View File

@@ -0,0 +1,90 @@
"""Admin Dataset Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .training import TrainingConfig
class DatasetCreateRequest(BaseModel):
"""Request to create a training dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
category: str | None = Field(None, description="Filter documents by category (optional)")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split")
class DatasetDocumentItem(BaseModel):
"""Document within a dataset."""
document_id: str
split: str
page_count: int
annotation_count: int
class DatasetResponse(BaseModel):
"""Response after creating a dataset."""
dataset_id: str
name: str
status: str
message: str
class DatasetDetailResponse(BaseModel):
"""Detailed dataset info with documents."""
dataset_id: str
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
documents: list[DatasetDocumentItem]
created_at: datetime
updated_at: datetime
class DatasetListItem(BaseModel):
"""Dataset in list view."""
dataset_id: str
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
total_documents: int
total_images: int
total_annotations: int
created_at: datetime
class DatasetListResponse(BaseModel):
"""Paginated dataset list."""
total: int
limit: int
offset: int
datasets: list[DatasetListItem]
class DatasetTrainRequest(BaseModel):
"""Request to start training from a dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
config: TrainingConfig = Field(..., description="Training configuration")

View File

@@ -0,0 +1,123 @@
"""Admin Document Schemas."""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from .enums import AutoLabelStatus, DocumentStatus
if TYPE_CHECKING:
from .annotations import AnnotationItem
from .training import TrainingHistoryItem
class DocumentUploadResponse(BaseModel):
"""Response for document upload."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
message: str = Field(..., description="Status message")
class DocumentItem(BaseModel):
"""Single document in list."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentListResponse(BaseModel):
"""Response for document list."""
total: int = Field(..., ge=0, description="Total documents")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
documents: list[DocumentItem] = Field(
default_factory=list, description="Document list"
)
class DocumentDetailResponse(BaseModel):
"""Response for document detail."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
content_type: str = Field(..., description="MIME type")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
annotation_lock_until: datetime | None = Field(
None, description="Lock expiration time if document is locked"
)
annotations: list["AnnotationItem"] = Field(
default_factory=list, description="Document annotations"
)
image_urls: list[str] = Field(
default_factory=list, description="URLs to page images"
)
training_history: list["TrainingHistoryItem"] = Field(
default_factory=list, description="Training tasks that used this document"
)
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentStatsResponse(BaseModel):
"""Document statistics response."""
total: int = Field(..., ge=0, description="Total documents")
pending: int = Field(default=0, ge=0, description="Pending documents")
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents")
class DocumentUpdateRequest(BaseModel):
"""Request for updating document metadata."""
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key")
class DocumentCategoriesResponse(BaseModel):
"""Response for available document categories."""
categories: list[str] = Field(..., description="List of available categories")
total: int = Field(..., ge=0, description="Total number of categories")

View File

@@ -0,0 +1,46 @@
"""Admin API Enums."""
from enum import Enum
class DocumentStatus(str, Enum):
"""Document status enum."""
PENDING = "pending"
AUTO_LABELING = "auto_labeling"
LABELED = "labeled"
EXPORTED = "exported"
class AutoLabelStatus(str, Enum):
"""Auto-labeling status enum."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class TrainingStatus(str, Enum):
"""Training task status enum."""
PENDING = "pending"
SCHEDULED = "scheduled"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingType(str, Enum):
"""Training task type enum."""
TRAIN = "train"
FINETUNE = "finetune"
class AnnotationSource(str, Enum):
"""Annotation source enum."""
MANUAL = "manual"
AUTO = "auto"
IMPORTED = "imported"

View File

@@ -0,0 +1,95 @@
"""Admin Model Version Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class ModelVersionCreateRequest(BaseModel):
"""Request to create a model version."""
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
name: str = Field(..., min_length=1, max_length=255, description="Model name")
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
description: str | None = Field(None, description="Optional description")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
document_count: int = Field(0, ge=0, description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
class ModelVersionUpdateRequest(BaseModel):
"""Request to update a model version."""
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
description: str | None = Field(None, description="Description")
status: str | None = Field(None, description="Status (inactive, archived)")
class ModelVersionItem(BaseModel):
"""Model version in list view."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
class ModelVersionListResponse(BaseModel):
"""Paginated model version list."""
total: int = Field(..., ge=0, description="Total model versions")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
class ModelVersionDetailResponse(BaseModel):
"""Detailed model version info."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
description: str | None = Field(None, description="Description")
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision")
metrics_recall: float | None = Field(None, description="Recall")
document_count: int = Field(..., description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class ModelVersionResponse(BaseModel):
"""Response for model version operation."""
version_id: str = Field(..., description="Version UUID")
status: str = Field(..., description="Model status")
message: str = Field(..., description="Status message")
class ActiveModelResponse(BaseModel):
"""Response for active model query."""
has_active_model: bool = Field(..., description="Whether an active model exists")
model: ModelVersionItem | None = Field(None, description="Active model if exists")

View File

@@ -0,0 +1,219 @@
"""Admin Training Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from .augmentation import AugmentationConfigSchema
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
)
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
project_name: str = Field(
default="invoice_fields", description="Training project name"
)
# Data augmentation settings
augmentation: AugmentationConfigSchema | None = Field(
default=None,
description="Augmentation configuration. If provided, augments dataset before training.",
)
augmentation_multiplier: int = Field(
default=2,
ge=1,
le=10,
description="Number of augmented copies per original image",
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""
name: str = Field(..., min_length=1, max_length=255, description="Task name")
description: str | None = Field(None, max_length=1000, description="Description")
task_type: TrainingType = Field(
default=TrainingType.TRAIN, description="Task type"
)
config: TrainingConfig = Field(
default_factory=TrainingConfig, description="Training configuration"
)
scheduled_at: datetime | None = Field(
None, description="Scheduled execution time"
)
cron_expression: str | None = Field(
None, max_length=50, description="Cron expression for recurring tasks"
)
class TrainingTaskItem(BaseModel):
"""Single training task in list."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskListResponse(BaseModel):
"""Response for training task list."""
total: int = Field(..., ge=0, description="Total tasks")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
class TrainingTaskDetailResponse(BaseModel):
"""Response for training task detail."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
description: str | None = Field(None, description="Description")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
config: dict[str, Any] | None = Field(None, description="Training configuration")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
cron_expression: str | None = Field(None, description="Cron expression")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
error_message: str | None = Field(None, description="Error message")
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
model_path: str | None = Field(None, description="Trained model path")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskResponse(BaseModel):
"""Response for training task operation."""
task_id: str = Field(..., description="Task UUID")
status: TrainingStatus = Field(..., description="Task status")
message: str = Field(..., description="Status message")
class TrainingLogItem(BaseModel):
"""Single training log entry."""
level: str = Field(..., description="Log level")
message: str = Field(..., description="Log message")
details: dict[str, Any] | None = Field(None, description="Additional details")
created_at: datetime = Field(..., description="Timestamp")
class TrainingLogsResponse(BaseModel):
"""Response for training logs."""
task_id: str = Field(..., description="Task UUID")
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
class ExportRequest(BaseModel):
"""Request to export annotations."""
format: str = Field(
default="yolo", description="Export format (yolo, coco, voc)"
)
include_images: bool = Field(
default=True, description="Include images in export"
)
split_ratio: float = Field(
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
)
class ExportResponse(BaseModel):
"""Response for export operation."""
status: str = Field(..., description="Export status")
export_path: str = Field(..., description="Path to exported dataset")
total_images: int = Field(..., ge=0, description="Total images exported")
total_annotations: int = Field(..., ge=0, description="Total annotations")
train_count: int = Field(..., ge=0, description="Training set count")
val_count: int = Field(..., ge=0, description="Validation set count")
message: str = Field(..., description="Status message")
class TrainingDocumentItem(BaseModel):
"""Document item for training page."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Filename")
annotation_count: int = Field(..., ge=0, description="Total annotations")
annotation_sources: dict[str, int] = Field(
..., description="Annotation counts by source (manual, auto)"
)
used_in_training: list[str] = Field(
default_factory=list, description="List of training task IDs that used this document"
)
last_modified: datetime = Field(..., description="Last modification time")
class TrainingDocumentsResponse(BaseModel):
"""Response for GET /admin/training/documents."""
total: int = Field(..., ge=0, description="Total document count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
documents: list[TrainingDocumentItem] = Field(
default_factory=list, description="Documents available for training"
)
class ModelMetrics(BaseModel):
"""Training model metrics."""
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
class TrainingModelItem(BaseModel):
"""Trained model item for model list."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Model name")
status: TrainingStatus = Field(..., description="Training status")
document_count: int = Field(..., ge=0, description="Documents used in training")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")
metrics: ModelMetrics = Field(..., description="Model metrics")
model_path: str | None = Field(None, description="Path to model weights")
download_url: str | None = Field(None, description="Download URL for model")
class TrainingModelsResponse(BaseModel):
"""Response for GET /admin/training/models."""
total: int = Field(..., ge=0, description="Total model count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
models: list[TrainingModelItem] = Field(
default_factory=list, description="Trained models"
)
class TrainingHistoryItem(BaseModel):
"""Training history for a document."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
trained_at: datetime = Field(..., description="Training timestamp")
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")

View File

@@ -0,0 +1,15 @@
"""
Common Schemas
Shared Pydantic models used across multiple API modules.
"""
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")

View File

@@ -0,0 +1,196 @@
"""
API Request/Response Schemas
Pydantic models for API validation and serialization.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
# =============================================================================
# Enums
# =============================================================================
class AsyncStatus(str, Enum):
"""Async request status enum."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
# =============================================================================
# Sync API Schemas (existing)
# =============================================================================
class DetectionResult(BaseModel):
"""Single detection result."""
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
class ExtractedField(BaseModel):
"""Extracted and normalized field value."""
field_name: str = Field(..., description="Field name")
value: str | None = Field(None, description="Extracted value")
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
is_valid: bool = Field(True, description="Whether the value passed validation")
class InferenceResult(BaseModel):
"""Complete inference result for a document."""
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
document_type: str = Field(
default="invoice", description="Document type: 'invoice' or 'letter'"
)
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)
confidence: dict[str, float] = Field(
default_factory=dict, description="Confidence scores per field"
)
detections: list[DetectionResult] = Field(
default_factory=list, description="Raw YOLO detections"
)
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
errors: list[str] = Field(default_factory=list, description="Error messages")
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
status: str = Field(..., description="Response status: success or error")
message: str = Field(..., description="Response message")
result: InferenceResult | None = Field(None, description="Inference result")
class BatchInferenceResponse(BaseModel):
"""API response for batch inference endpoint."""
status: str = Field(..., description="Response status")
message: str = Field(..., description="Response message")
total: int = Field(..., description="Total documents processed")
successful: int = Field(..., description="Number of successful extractions")
results: list[InferenceResult] = Field(
default_factory=list, description="Individual results"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
gpu_available: bool = Field(..., description="Whether GPU is available")
version: str = Field(..., description="API version")
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")
# =============================================================================
# Async API Schemas
# =============================================================================
class AsyncSubmitResponse(BaseModel):
"""Response for async submit endpoint."""
status: str = Field(default="accepted", description="Response status")
message: str = Field(..., description="Response message")
request_id: str = Field(..., description="Unique request identifier (UUID)")
estimated_wait_seconds: int = Field(
..., ge=0, description="Estimated wait time in seconds"
)
poll_url: str = Field(..., description="URL to poll for status updates")
class AsyncStatusResponse(BaseModel):
"""Response for async status endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
created_at: datetime = Field(..., description="Request creation timestamp")
started_at: datetime | None = Field(
None, description="Processing start timestamp"
)
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
position_in_queue: int | None = Field(
None, description="Position in queue (for pending status)"
)
error_message: str | None = Field(
None, description="Error message (for failed status)"
)
result_url: str | None = Field(
None, description="URL to fetch results (for completed status)"
)
class AsyncResultResponse(BaseModel):
"""Response for async result endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Processing status")
processing_time_ms: float = Field(
..., ge=0, description="Total processing time in milliseconds"
)
result: InferenceResult | None = Field(
None, description="Extraction result (when completed)"
)
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
class AsyncRequestItem(BaseModel):
"""Single item in async requests list."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
created_at: datetime = Field(..., description="Request creation timestamp")
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
class AsyncRequestsListResponse(BaseModel):
"""Response for async requests list endpoint."""
total: int = Field(..., ge=0, description="Total number of requests")
limit: int = Field(..., ge=1, description="Maximum items per page")
offset: int = Field(..., ge=0, description="Current offset")
requests: list[AsyncRequestItem] = Field(
default_factory=list, description="List of requests"
)
class RateLimitInfo(BaseModel):
"""Rate limit information (included in headers)."""
limit: int = Field(..., description="Maximum requests per minute")
remaining: int = Field(..., description="Remaining requests in current window")
reset_at: datetime = Field(..., description="Time when limit resets")

View File

@@ -0,0 +1,13 @@
"""
Labeling API Schemas
Pydantic models for pre-labeling and label validation endpoints.
"""
from pydantic import BaseModel, Field
class PreLabelResponse(BaseModel):
"""API response for pre-label endpoint."""
document_id: str = Field(..., description="Document identifier for retrieving results")

View File

@@ -0,0 +1,18 @@
"""
Business Logic Services
Service layer for processing requests and orchestrating data operations.
"""
from backend.web.services.autolabel import AutoLabelService, get_auto_label_service
from backend.web.services.inference import InferenceService
from backend.web.services.async_processing import AsyncProcessingService
from backend.web.services.batch_upload import BatchUploadService
__all__ = [
"AutoLabelService",
"get_auto_label_service",
"InferenceService",
"AsyncProcessingService",
"BatchUploadService",
]

View File

@@ -0,0 +1,386 @@
"""
Async Processing Service
Manages async request lifecycle and background processing.
"""
import logging
import re
import shutil
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from threading import Event, Thread
from typing import TYPE_CHECKING
from backend.data.async_request_db import AsyncRequestDB
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from backend.web.core.rate_limiter import RateLimiter
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from backend.web.config import AsyncConfig, StorageConfig
from backend.web.services.inference import InferenceService
logger = logging.getLogger(__name__)
@dataclass
class AsyncSubmitResult:
"""Result from async submit operation."""
success: bool
request_id: str | None = None
estimated_wait_seconds: int = 0
error: str | None = None
class AsyncProcessingService:
"""
Manages async request lifecycle and processing.
Coordinates between:
- HTTP endpoints (submit/status/result)
- Background task queue
- Database storage
- Rate limiting
"""
def __init__(
self,
inference_service: "InferenceService",
db: AsyncRequestDB,
queue: AsyncTaskQueue,
rate_limiter: RateLimiter,
async_config: "AsyncConfig",
storage_config: "StorageConfig",
) -> None:
self._inference = inference_service
self._db = db
self._queue = queue
self._rate_limiter = rate_limiter
self._async_config = async_config
self._storage_config = storage_config
# Cleanup thread
self._cleanup_stop_event = Event()
self._cleanup_thread: Thread | None = None
def start(self) -> None:
"""Start the async processing service."""
# Start the task queue with our handler
self._queue.start(self._process_task)
# Start cleanup thread
self._cleanup_stop_event.clear()
self._cleanup_thread = Thread(
target=self._cleanup_loop,
name="async-cleanup",
daemon=True,
)
self._cleanup_thread.start()
logger.info("AsyncProcessingService started")
def stop(self, timeout: float = 30.0) -> None:
"""Stop the async processing service."""
# Stop cleanup thread
self._cleanup_stop_event.set()
if self._cleanup_thread and self._cleanup_thread.is_alive():
self._cleanup_thread.join(timeout=5.0)
# Stop task queue
self._queue.stop(timeout=timeout)
logger.info("AsyncProcessingService stopped")
def submit_request(
self,
api_key: str,
file_content: bytes,
filename: str,
content_type: str,
) -> AsyncSubmitResult:
"""
Submit a new async processing request.
Args:
api_key: API key for the request
file_content: File content as bytes
filename: Original filename
content_type: File content type
Returns:
AsyncSubmitResult with request_id and status
"""
# Generate request ID
request_id = str(uuid.uuid4())
# Save file to temp storage
file_path = self._save_upload(request_id, filename, file_content)
file_size = len(file_content)
try:
# Calculate expiration
expires_at = datetime.utcnow() + timedelta(
days=self._async_config.result_retention_days
)
# Create database record
self._db.create_request(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
request_id=request_id,
)
# Record rate limit event
self._rate_limiter.record_request(api_key)
# Create and queue task
task = AsyncTask(
request_id=request_id,
api_key=api_key,
file_path=file_path,
filename=filename,
created_at=datetime.utcnow(),
)
if not self._queue.submit(task):
# Queue is full
self._db.update_status(
request_id,
"failed",
error_message="Processing queue is full",
)
# Cleanup file
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
request_id=request_id,
error="Processing queue is full. Please try again later.",
)
# Estimate wait time
estimated_wait = self._estimate_wait()
return AsyncSubmitResult(
success=True,
request_id=request_id,
estimated_wait_seconds=estimated_wait,
)
except Exception as e:
logger.error(f"Failed to submit request: {e}", exc_info=True)
# Cleanup file on error
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
# Return generic error message to avoid leaking implementation details
error="Failed to process request. Please try again later.",
)
# Allowed file extensions whitelist
ALLOWED_EXTENSIONS = frozenset({".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"})
def _save_upload(
self,
request_id: str,
filename: str,
content: bytes,
) -> Path:
"""Save uploaded file to temp storage using StorageHelper."""
# Extract extension from filename
ext = Path(filename).suffix.lower()
# Validate extension: must be alphanumeric only (e.g., .pdf, .png)
if not ext or not re.match(r'^\.[a-z0-9]+$', ext):
ext = ".pdf"
# Validate against whitelist
if ext not in self.ALLOWED_EXTENSIONS:
ext = ".pdf"
# Get upload directory from StorageHelper
storage = get_storage_helper()
upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None:
raise ValueError("Storage not configured for local access")
# Build file path - request_id is a UUID so it's safe
file_path = upload_dir / f"{request_id}{ext}"
# Defense in depth: ensure path is within upload_dir
if not file_path.resolve().is_relative_to(upload_dir.resolve()):
raise ValueError("Invalid file path detected")
file_path.write_bytes(content)
return file_path
def _process_task(self, task: AsyncTask) -> None:
"""
Process a single task (called by worker thread).
This method is called by the AsyncTaskQueue worker threads.
"""
start_time = time.time()
try:
# Update status to processing
self._db.update_status(task.request_id, "processing")
# Ensure file exists
if not task.file_path.exists():
raise FileNotFoundError(f"Upload file not found: {task.file_path}")
# Run inference based on file type
file_ext = task.file_path.suffix.lower()
if file_ext == ".pdf":
result = self._inference.process_pdf(
task.file_path,
document_id=task.request_id[:8],
)
else:
result = self._inference.process_image(
task.file_path,
document_id=task.request_id[:8],
)
# Calculate processing time
processing_time_ms = (time.time() - start_time) * 1000
# Prepare result for storage
result_data = {
"document_id": result.document_id,
"success": result.success,
"document_type": result.document_type,
"fields": result.fields,
"confidence": result.confidence,
"detections": result.detections,
"errors": result.errors,
}
# Get visualization path as string
viz_path = None
if result.visualization_path:
viz_path = str(result.visualization_path.name)
# Store result in database
self._db.complete_request(
request_id=task.request_id,
document_id=result.document_id,
result=result_data,
processing_time_ms=processing_time_ms,
visualization_path=viz_path,
)
logger.info(
f"Task {task.request_id} completed successfully "
f"in {processing_time_ms:.0f}ms"
)
except Exception as e:
logger.error(
f"Task {task.request_id} failed: {e}",
exc_info=True,
)
self._db.update_status(
task.request_id,
"failed",
error_message=str(e),
increment_retry=True,
)
finally:
# Cleanup uploaded file
if task.file_path.exists():
task.file_path.unlink(missing_ok=True)
def _estimate_wait(self) -> int:
"""Estimate wait time based on queue depth."""
queue_depth = self._queue.get_queue_depth()
processing_count = self._queue.get_processing_count()
total_pending = queue_depth + processing_count
# Estimate ~5 seconds per document
avg_processing_time = 5
return total_pending * avg_processing_time
def _cleanup_loop(self) -> None:
"""Background cleanup loop."""
logger.info("Cleanup thread started")
cleanup_interval = self._async_config.cleanup_interval_hours * 3600
while not self._cleanup_stop_event.wait(timeout=cleanup_interval):
try:
self._run_cleanup()
except Exception as e:
logger.error(f"Cleanup failed: {e}", exc_info=True)
logger.info("Cleanup thread stopped")
def _run_cleanup(self) -> None:
"""Run cleanup operations."""
logger.info("Running cleanup...")
# Delete expired requests
deleted_requests = self._db.delete_expired_requests()
# Reset stale processing requests
reset_count = self._db.reset_stale_processing_requests(
stale_minutes=self._async_config.task_timeout_seconds // 60,
max_retries=3,
)
# Cleanup old rate limit events
deleted_events = self._db.cleanup_old_rate_limit_events(hours=1)
# Cleanup old poll timestamps
cleaned_polls = self._rate_limiter.cleanup_poll_timestamps()
# Cleanup rate limiter request windows
self._rate_limiter.cleanup_request_windows()
# Cleanup orphaned upload files
orphan_count = self._cleanup_orphan_files()
logger.info(
f"Cleanup complete: {deleted_requests} expired requests, "
f"{reset_count} stale requests reset, "
f"{deleted_events} rate limit events, "
f"{cleaned_polls} poll timestamps, "
f"{orphan_count} orphan files"
)
def _cleanup_orphan_files(self) -> int:
"""Clean up upload files that don't have matching requests."""
storage = get_storage_helper()
upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None or not upload_dir.exists():
return 0
count = 0
# Files older than 1 hour without matching request are considered orphans
cutoff = time.time() - 3600
for file_path in upload_dir.iterdir():
if not file_path.is_file():
continue
# Check if file is old enough
if file_path.stat().st_mtime > cutoff:
continue
# Extract request_id from filename
request_id = file_path.stem
# Check if request exists in database
request = self._db.get_request(request_id)
if request is None:
file_path.unlink(missing_ok=True)
count += 1
return count

View File

@@ -0,0 +1,322 @@
"""Augmentation service for handling augmentation operations."""
import base64
import io
import re
import uuid
from pathlib import Path
from typing import Any
import numpy as np
from fastapi import HTTPException
from PIL import Image
from backend.data.repositories import DocumentRepository, DatasetRepository
from backend.web.schemas.admin.augmentation import (
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
)
# Constants
PREVIEW_MAX_SIZE = 800
PREVIEW_SEED = 42
UUID_PATTERN = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
re.IGNORECASE,
)
class AugmentationService:
"""Service for augmentation operations."""
def __init__(
self,
doc_repo: DocumentRepository | None = None,
dataset_repo: DatasetRepository | None = None,
) -> None:
"""Initialize service with repository connections."""
self.doc_repo = doc_repo or DocumentRepository()
self.dataset_repo = dataset_repo or DatasetRepository()
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
"""
Validate UUID format to prevent path traversal.
Args:
value: Value to validate.
field_name: Field name for error message.
Raises:
HTTPException: If value is not a valid UUID.
"""
if not UUID_PATTERN.match(value):
raise HTTPException(
status_code=400,
detail=f"Invalid {field_name} format: {value}",
)
async def preview_single(
self,
document_id: str,
page: int,
augmentation_type: str,
params: dict[str, Any],
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
augmentation_type: Name of augmentation to apply.
params: Override parameters.
Returns:
Preview response with image URLs.
Raises:
HTTPException: If document not found or augmentation invalid.
"""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
# Validate augmentation type
if augmentation_type not in AUGMENTATION_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown augmentation type: {augmentation_type}. "
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
)
# Get document and load image
image = await self._load_document_page(document_id, page)
# Create config with only this augmentation enabled
config_kwargs = {
augmentation_type: AugmentationParams(
enabled=True,
probability=1.0, # Always apply for preview
params=params,
),
"seed": PREVIEW_SEED, # Deterministic preview
}
config = AugmentationConfig(**config_kwargs)
pipeline = AugmentationPipeline(config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=params,
)
async def preview_config(
self,
document_id: str,
page: int,
config: AugmentationConfigSchema,
) -> AugmentationPreviewResponse:
"""
Preview full augmentation config on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
config: Full augmentation configuration.
Returns:
Preview response with image URLs.
"""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
# Load image
image = await self._load_document_page(document_id, page)
# Convert Pydantic model to internal config
config_dict = config.model_dump()
internal_config = AugmentationConfig.from_dict(config_dict)
pipeline = AugmentationPipeline(internal_config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=config_dict,
)
async def create_augmented_dataset(
self,
source_dataset_id: str,
config: AugmentationConfigSchema,
output_name: str,
multiplier: int,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
Args:
source_dataset_id: Source dataset UUID.
config: Augmentation configuration.
output_name: Name for the new dataset.
multiplier: Number of augmented copies per image.
Returns:
Batch response with task ID.
Raises:
HTTPException: If source dataset not found.
"""
# Validate source dataset exists
try:
source_dataset = self.dataset_repo.get(source_dataset_id)
if source_dataset is None:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
)
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
) from e
# Create task ID for background processing
task_id = str(uuid.uuid4())
# Estimate total images
estimated_images = (
source_dataset.total_images * multiplier
if hasattr(source_dataset, "total_images")
else 0
)
# TODO: Queue background task for actual augmentation
# For now, return pending status
return AugmentationBatchResponse(
task_id=task_id,
status="pending",
message=f"Augmentation task queued for dataset '{output_name}'",
estimated_images=estimated_images,
)
async def list_augmented_datasets(
self,
limit: int = 20,
offset: int = 0,
) -> AugmentedDatasetListResponse:
"""
List augmented datasets.
Args:
limit: Maximum number of datasets to return.
offset: Number of datasets to skip.
Returns:
List response with datasets.
"""
# TODO: Implement actual database query for augmented datasets
# For now, return empty list
return AugmentedDatasetListResponse(
total=0,
limit=limit,
offset=offset,
datasets=[],
)
async def _load_document_page(
self,
document_id: str,
page: int,
) -> np.ndarray:
"""
Load a document page as numpy array.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
Returns:
Image as numpy array (H, W, C) with dtype uint8.
Raises:
HTTPException: If document or page not found.
"""
# Validate document_id format to prevent path traversal
self._validate_uuid(document_id, "document_id")
# Get document from database
try:
document = self.doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
) from e
# Get image path for page
if hasattr(document, "images_dir"):
images_dir = Path(document.images_dir)
else:
# Fallback to constructed path
from backend.web.core.config import get_settings
settings = get_settings()
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
# Find image for page
page_idx = page - 1 # Convert to 0-indexed
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
if page_idx >= len(image_files):
raise HTTPException(
status_code=404,
detail=f"Page {page} not found for document {document_id}",
)
# Load image
image_path = image_files[page_idx]
pil_image = Image.open(image_path).convert("RGB")
return np.array(pil_image)
def _image_to_data_url(self, image: np.ndarray) -> str:
"""Convert numpy image to base64 data URL."""
pil_image = Image.fromarray(image)
# Resize for preview if too large
max_size = PREVIEW_MAX_SIZE
if max(pil_image.size) > max_size:
ratio = max_size / max(pil_image.size)
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"

View File

@@ -0,0 +1,343 @@
"""
Admin Auto-Labeling Service
Uses FieldMatcher to automatically create annotations from field values.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from backend.data.repositories import DocumentRepository, AnnotationRepository
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken
logger = logging.getLogger(__name__)
class AutoLabelService:
"""Service for automatic document labeling using field matching."""
def __init__(self, ocr_engine: OCREngine | None = None):
"""
Initialize auto-label service.
Args:
ocr_engine: OCR engine instance (creates one if not provided)
"""
self._ocr_engine = ocr_engine
self._field_matcher = FieldMatcher()
@property
def ocr_engine(self) -> OCREngine:
"""Lazy initialization of OCR engine."""
if self._ocr_engine is None:
self._ocr_engine = OCREngine(lang="en")
return self._ocr_engine
def auto_label_document(
self,
document_id: str,
file_path: str,
field_values: dict[str, str],
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
replace_existing: bool = False,
skip_lock_check: bool = False,
) -> dict[str, Any]:
"""
Auto-label a document using field matching.
Args:
document_id: Document UUID
file_path: Path to document file
field_values: Dict of field_name -> value to match
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
replace_existing: Whether to replace existing auto annotations
skip_lock_check: Skip annotation lock check (for batch processing)
Returns:
Dict with status and annotation count
"""
# Initialize repositories if not provided
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
try:
# Get document info first
document = doc_repo.get(document_id)
if document is None:
raise ValueError(f"Document not found: {document_id}")
# Check annotation lock unless explicitly skipped
if not skip_lock_check:
from datetime import datetime, timezone
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
if document.annotation_lock_until > datetime.now(timezone.utc):
raise ValueError(
f"Document is locked for annotation until {document.annotation_lock_until}. "
"Auto-labeling skipped."
)
# Update status to running
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
# Delete existing auto annotations if requested
if replace_existing:
deleted = ann_repo.delete_for_document(
document_id=document_id,
source="auto",
)
logger.info(f"Deleted {deleted} existing auto annotations")
# Process document
path = Path(file_path)
annotations_created = 0
if path.suffix.lower() == ".pdf":
# Process PDF (all pages)
annotations_created = self._process_pdf(
document_id, path, field_values, ann_repo
)
else:
# Process single image
annotations_created = self._process_image(
document_id, path, field_values, ann_repo, page_number=1
)
# Update document status
status = "labeled" if annotations_created > 0 else "pending"
doc_repo.update_status(
document_id=document_id,
status=status,
auto_label_status="completed",
)
return {
"status": "completed",
"annotations_created": annotations_created,
}
except Exception as e:
logger.error(f"Auto-labeling failed for {document_id}: {e}")
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"status": "failed",
"error": str(e),
"annotations_created": 0,
}
def _process_pdf(
self,
document_id: str,
pdf_path: Path,
field_values: dict[str, str],
ann_repo: AnnotationRepository,
) -> int:
"""Process PDF document and create annotations."""
from shared.pdf.renderer import render_pdf_to_images
import io
total_annotations = 0
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=page_no,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_no + 1, # 1-indexed
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
ann_repo.create_batch(annotations)
total_annotations += len(annotations)
return total_annotations
def _process_image(
self,
document_id: str,
image_path: Path,
field_values: dict[str, str],
ann_repo: AnnotationRepository,
page_number: int = 1,
) -> int:
"""Process single image and create annotations."""
# Load image
image = Image.open(image_path)
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=0,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_number,
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
ann_repo.create_batch(annotations)
return len(annotations)
def _find_annotations(
self,
document_id: str,
tokens: list[OCRToken],
field_values: dict[str, str],
page_number: int,
image_width: int,
image_height: int,
) -> list[dict[str, Any]]:
"""Find annotations for field values using token matching."""
from shared.normalize import normalize_field
annotations = []
for field_name, value in field_values.items():
if not value or not value.strip():
continue
# Map field name to class ID
class_id = self._get_class_id(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {field_name}")
continue
class_name = FIELD_CLASSES[class_id]
# Normalize value
try:
normalized_values = normalize_field(field_name, value)
except Exception as e:
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
normalized_values = [value]
# Find matches
matches = self._field_matcher.find_matches(
tokens=tokens,
field_name=field_name,
normalized_values=normalized_values,
page_no=page_number - 1, # 0-indexed for matcher
)
# Take best match
if matches:
best_match = matches[0]
bbox = best_match.bbox # (x0, y0, x1, y1)
# Calculate normalized coordinates (YOLO format)
x_center = (bbox[0] + bbox[2]) / 2 / image_width
y_center = (bbox[1] + bbox[3]) / 2 / image_height
width = (bbox[2] - bbox[0]) / image_width
height = (bbox[3] - bbox[1]) / image_height
# Pixel coordinates
bbox_x = int(bbox[0])
bbox_y = int(bbox[1])
bbox_width = int(bbox[2] - bbox[0])
bbox_height = int(bbox[3] - bbox[1])
annotations.append({
"document_id": document_id,
"page_number": page_number,
"class_id": class_id,
"class_name": class_name,
"x_center": x_center,
"y_center": y_center,
"width": width,
"height": height,
"bbox_x": bbox_x,
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_text,
"confidence": best_match.score,
"source": "auto",
})
return annotations
def _get_class_id(self, field_name: str) -> int | None:
"""Map field name to class ID."""
# Direct match
if field_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[field_name]
# Handle alternative names
name_mapping = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"PaymentLine": "payment_line",
"customer_number": "customer_number",
}
mapped_name = name_mapping.get(field_name)
if mapped_name and mapped_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[mapped_name]
return None
# Global service instance
_auto_label_service: AutoLabelService | None = None
def get_auto_label_service() -> AutoLabelService:
"""Get the auto-label service instance."""
global _auto_label_service
if _auto_label_service is None:
_auto_label_service = AutoLabelService()
return _auto_label_service
def reset_auto_label_service() -> None:
"""Reset the auto-label service (for testing)."""
global _auto_label_service
_auto_label_service = None

View File

@@ -0,0 +1,548 @@
"""
Batch Upload Service
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
"""
import csv
import io
import logging
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from backend.data.repositories import BatchUploadRepository
from shared.fields import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)
# Security limits
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
MAX_FILES_IN_ZIP = 1000
class CSVRowData(BaseModel):
"""Validated CSV row data with security checks."""
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
invoice_number: str | None = Field(None, max_length=255)
invoice_date: str | None = Field(None, max_length=50)
invoice_due_date: str | None = Field(None, max_length=50)
amount: str | None = Field(None, max_length=100)
ocr: str | None = Field(None, max_length=100)
bankgiro: str | None = Field(None, max_length=50)
plusgiro: str | None = Field(None, max_length=50)
customer_number: str | None = Field(None, max_length=255)
supplier_organisation_number: str | None = Field(None, max_length=50)
@field_validator('*', mode='before')
@classmethod
def strip_whitespace(cls, v):
"""Strip whitespace from all string fields."""
if isinstance(v, str):
return v.strip()
return v
@field_validator('*', mode='before')
@classmethod
def reject_suspicious_patterns(cls, v):
"""Reject values with suspicious characters."""
if isinstance(v, str):
# Reject SQL/shell metacharacters and newlines
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
if any(char in v for char in dangerous_chars):
raise ValueError(f"Suspicious characters detected in value")
return v
class BatchUploadService:
"""Service for handling batch uploads of documents via ZIP files."""
def __init__(self, batch_repo: BatchUploadRepository | None = None):
"""Initialize the batch upload service.
Args:
batch_repo: Batch upload repository (created if None)
"""
self.batch_repo = batch_repo or BatchUploadRepository()
def _safe_extract_filename(self, zip_path: str) -> str:
"""Safely extract filename from ZIP path, preventing path traversal.
Args:
zip_path: Path from ZIP file entry
Returns:
Safe filename
Raises:
ValueError: If path contains traversal attempts or is invalid
"""
# Reject absolute paths
if zip_path.startswith('/') or zip_path.startswith('\\'):
raise ValueError(f"Absolute path rejected: {zip_path}")
# Reject path traversal attempts
if '..' in zip_path:
raise ValueError(f"Path traversal rejected: {zip_path}")
# Reject Windows drive letters
if len(zip_path) >= 2 and zip_path[1] == ':':
raise ValueError(f"Windows path rejected: {zip_path}")
# Get only the basename
safe_name = Path(zip_path).name
if not safe_name or safe_name in ['.', '..']:
raise ValueError(f"Invalid filename: {zip_path}")
# Validate filename doesn't contain suspicious characters
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
raise ValueError(f"Invalid characters in filename: {safe_name}")
return safe_name
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
"""Validate ZIP file against Zip bomb and other attacks.
Args:
zip_file: Opened ZIP file
Raises:
ValueError: If ZIP file is unsafe
"""
total_uncompressed = 0
file_count = 0
for zip_info in zip_file.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILES_IN_ZIP:
raise ValueError(
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
)
# Check individual file size
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
raise ValueError(
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
)
# Accumulate uncompressed size
total_uncompressed += zip_info.file_size
# Check total uncompressed size (Zip bomb protection)
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
raise ValueError(
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
)
# Validate filename safety
try:
self._safe_extract_filename(zip_info.filename)
except ValueError as e:
logger.warning(f"Rejecting malicious ZIP entry: {e}")
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
def process_zip_upload(
self,
admin_token: str,
zip_filename: str,
zip_content: bytes,
upload_source: str = "ui",
) -> dict[str, Any]:
"""Process a ZIP file containing PDFs and optional CSV.
Args:
admin_token: Admin authentication token
zip_filename: Name of the ZIP file
zip_content: ZIP file content as bytes
upload_source: Upload source (ui or api)
Returns:
Dictionary with batch upload results
"""
batch = self.batch_repo.create(
admin_token=admin_token,
filename=zip_filename,
file_size=len(zip_content),
upload_source=upload_source,
)
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
# Validate ZIP safety first
self._validate_zip_safety(zip_file)
result = self._process_zip_contents(
batch_id=batch.batch_id,
admin_token=admin_token,
zip_file=zip_file,
)
# Update batch upload status
self.batch_repo.update(
batch_id=batch.batch_id,
status=result["status"],
total_files=result["total_files"],
processed_files=result["processed_files"],
successful_files=result["successful_files"],
failed_files=result["failed_files"],
csv_filename=result.get("csv_filename"),
csv_row_count=result.get("csv_row_count"),
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
**result,
}
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Invalid ZIP file format",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Invalid ZIP file format",
}
except ValueError as e:
# Security validation errors
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="ZIP file validation failed",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "ZIP file validation failed",
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Processing error",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Failed to process batch upload",
}
def _process_zip_contents(
self,
batch_id: UUID,
admin_token: str,
zip_file: zipfile.ZipFile,
) -> dict[str, Any]:
"""Process contents of ZIP file.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
zip_file: Opened ZIP file
Returns:
Processing results dictionary
"""
# Extract file lists
pdf_files = []
csv_file = None
csv_data = {}
for file_info in zip_file.filelist:
if file_info.is_dir():
continue
try:
# Use safe filename extraction
filename = self._safe_extract_filename(file_info.filename)
except ValueError as e:
logger.warning(f"Skipping invalid file: {e}")
continue
if filename.lower().endswith('.pdf'):
pdf_files.append(file_info)
elif filename.lower().endswith('.csv'):
if csv_file is None:
csv_file = file_info
# Parse CSV
csv_data = self._parse_csv_file(zip_file, file_info)
else:
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
if not pdf_files:
return {
"status": "failed",
"total_files": 0,
"processed_files": 0,
"successful_files": 0,
"failed_files": 0,
"error": "No PDF files found in ZIP",
}
# Process each PDF file
total_files = len(pdf_files)
successful_files = 0
failed_files = 0
for pdf_info in pdf_files:
file_record = None
try:
# Use safe filename extraction
filename = self._safe_extract_filename(pdf_info.filename)
# Create batch upload file record
file_record = self.batch_repo.create_file(
batch_id=batch_id,
filename=filename,
status="processing",
)
# Get CSV data for this file if available
document_id_base = Path(filename).stem
csv_row_data = csv_data.get(document_id_base)
# Extract PDF content
pdf_content = zip_file.read(pdf_info.filename)
# TODO: Save PDF file and create document
# For now, just mark as completed
self.batch_repo.update_file(
file_id=file_record.file_id,
status="completed",
csv_row_data=csv_row_data,
processed_at=datetime.utcnow(),
)
successful_files += 1
except ValueError as e:
# Path validation error
logger.warning(f"Skipping invalid file: {e}")
if file_record:
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Invalid filename",
processed_at=datetime.utcnow(),
)
failed_files += 1
except Exception as e:
logger.error(f"Error processing PDF: {e}", exc_info=True)
if file_record:
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Processing error",
processed_at=datetime.utcnow(),
)
failed_files += 1
# Determine overall status
if failed_files == 0:
status = "completed"
elif successful_files == 0:
status = "failed"
else:
status = "partial"
result = {
"status": status,
"total_files": total_files,
"processed_files": total_files,
"successful_files": successful_files,
"failed_files": failed_files,
}
if csv_file:
result["csv_filename"] = Path(csv_file.filename).name
result["csv_row_count"] = len(csv_data)
return result
def _parse_csv_file(
self,
zip_file: zipfile.ZipFile,
csv_file_info: zipfile.ZipInfo,
) -> dict[str, dict[str, Any]]:
"""Parse CSV file and extract field values with validation.
Args:
zip_file: Opened ZIP file
csv_file_info: CSV file info
Returns:
Dictionary mapping DocumentId to validated field values
"""
# Try multiple encodings
csv_bytes = zip_file.read(csv_file_info.filename)
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
csv_content = None
for encoding in encodings:
try:
csv_content = csv_bytes.decode(encoding)
logger.info(f"CSV decoded with {encoding}")
break
except UnicodeDecodeError:
continue
if csv_content is None:
logger.error("Failed to decode CSV with any encoding")
raise ValueError("Unable to decode CSV file")
csv_reader = csv.DictReader(io.StringIO(csv_content))
result = {}
# Case-insensitive column mapping
field_name_map = {
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
'Amount': ['Amount', 'amount'],
'OCR': ['OCR', 'ocr'],
'Bankgiro': ['Bankgiro', 'bankgiro'],
'Plusgiro': ['Plusgiro', 'plusgiro'],
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
}
for row_num, row in enumerate(csv_reader, start=2):
try:
# Create case-insensitive lookup
row_lower = {k.lower(): v for k, v in row.items()}
# Find DocumentId with case-insensitive matching
document_id = None
for variant in field_name_map['DocumentId']:
if variant.lower() in row_lower:
document_id = row_lower[variant.lower()]
break
if not document_id:
logger.warning(f"Row {row_num}: No DocumentId found")
continue
# Validate using Pydantic model
csv_row_dict = {'document_id': document_id}
# Map CSV field names to model attribute names
csv_to_model_attr = {
'InvoiceNumber': 'invoice_number',
'InvoiceDate': 'invoice_date',
'InvoiceDueDate': 'invoice_due_date',
'Amount': 'amount',
'OCR': 'ocr',
'Bankgiro': 'bankgiro',
'Plusgiro': 'plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for csv_field in field_name_map.keys():
if csv_field == 'DocumentId':
continue
model_attr = csv_to_model_attr.get(csv_field)
if not model_attr:
continue
for variant in field_name_map[csv_field]:
if variant.lower() in row_lower and row_lower[variant.lower()]:
csv_row_dict[model_attr] = row_lower[variant.lower()]
break
# Validate
validated_row = CSVRowData(**csv_row_dict)
# Extract only the fields we care about (map back to CSV field names)
field_values = {}
model_attr_to_csv = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'amount': 'Amount',
'ocr': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for model_attr, csv_field in model_attr_to_csv.items():
value = getattr(validated_row, model_attr, None)
if value and csv_field in CSV_TO_CLASS_MAPPING:
field_values[csv_field] = value
if field_values:
result[document_id] = field_values
except Exception as e:
logger.warning(f"Row {row_num}: Validation error - {e}")
continue
return result
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
"""Get batch upload status.
Args:
batch_id: Batch upload ID
Returns:
Batch status dictionary
"""
batch = self.batch_repo.get(UUID(batch_id))
if not batch:
return {
"error": "Batch upload not found",
}
files = self.batch_repo.get_files(batch.batch_id)
return {
"batch_id": str(batch.batch_id),
"filename": batch.filename,
"status": batch.status,
"total_files": batch.total_files,
"processed_files": batch.processed_files,
"successful_files": batch.successful_files,
"failed_files": batch.failed_files,
"csv_filename": batch.csv_filename,
"csv_row_count": batch.csv_row_count,
"error_message": batch.error_message,
"created_at": batch.created_at.isoformat() if batch.created_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"files": [
{
"filename": f.filename,
"status": f.status,
"error_message": f.error_message,
"annotation_count": f.annotation_count,
}
for f in files
],
}

View File

@@ -0,0 +1,276 @@
"""
Dashboard Service
Business logic for dashboard statistics and activity aggregation.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import func, exists, and_, or_
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import (
AdminDocument,
AdminAnnotation,
AnnotationHistory,
TrainingTask,
ModelVersion,
)
logger = logging.getLogger(__name__)
# Field class IDs for completeness calculation
# Identifiers: invoice_number (0) or ocr_number (3)
IDENTIFIER_CLASS_IDS = {0, 3}
# Payment accounts: bankgiro (4) or plusgiro (5)
PAYMENT_CLASS_IDS = {4, 5}
def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool:
"""Check if a document's annotations are complete.
A document is complete if it has:
- At least one identifier field (invoice_number OR ocr_number)
- At least one payment field (bankgiro OR plusgiro)
Args:
annotations: List of annotation dicts with class_id
Returns:
True if document has required fields
"""
class_ids = {ann.get("class_id") for ann in annotations}
has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS)
has_payment = bool(class_ids & PAYMENT_CLASS_IDS)
return has_identifier and has_payment
class DashboardStatsService:
"""Service for computing dashboard statistics."""
def get_stats(self) -> dict[str, Any]:
"""Get dashboard statistics.
Returns:
Dict with total_documents, annotation_complete, annotation_incomplete,
pending, and completeness_rate
"""
with get_session_context() as session:
# Total documents
total = session.exec(
select(func.count()).select_from(AdminDocument)
).one()
# Pending documents (status in ['pending', 'auto_labeling'])
pending = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status.in_(["pending", "auto_labeling"]))
).one()
# Complete annotations: labeled + has identifier + has payment
complete = self._count_complete(session)
# Incomplete: labeled but not complete
labeled_count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status == "labeled")
).one()
incomplete = labeled_count - complete
# Calculate completeness rate
total_assessed = complete + incomplete
completeness_rate = (
round(complete / total_assessed * 100, 2)
if total_assessed > 0
else 0.0
)
return {
"total_documents": total,
"annotation_complete": complete,
"annotation_incomplete": incomplete,
"pending": pending,
"completeness_rate": completeness_rate,
}
def _count_complete(self, session) -> int:
"""Count documents with complete annotations.
A document is complete if it:
1. Has status = 'labeled'
2. Has at least one identifier annotation (class_id 0 or 3)
3. Has at least one payment annotation (class_id 4 or 5)
"""
# Subquery for documents with identifier
has_identifier = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS),
)
)
)
# Subquery for documents with payment
has_payment = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS),
)
)
)
count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(
and_(
AdminDocument.status == "labeled",
has_identifier,
has_payment,
)
)
).one()
return count
class DashboardActivityService:
"""Service for aggregating recent activities."""
def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]:
"""Get recent system activities.
Aggregates from:
- Document uploads
- Annotation modifications
- Training completions/failures
- Model activations
Args:
limit: Maximum number of activities to return
Returns:
List of activity dicts sorted by timestamp DESC
"""
activities = []
with get_session_context() as session:
# Document uploads (recent 10)
uploads = session.exec(
select(AdminDocument)
.order_by(AdminDocument.created_at.desc())
.limit(limit)
).all()
for doc in uploads:
activities.append({
"type": "document_uploaded",
"description": f"Uploaded {doc.filename}",
"timestamp": doc.created_at,
"metadata": {
"document_id": str(doc.document_id),
"filename": doc.filename,
},
})
# Annotation modifications (from history)
modifications = session.exec(
select(AnnotationHistory)
.where(AnnotationHistory.action == "override")
.order_by(AnnotationHistory.created_at.desc())
.limit(limit)
).all()
for mod in modifications:
# Get document filename
doc = session.get(AdminDocument, mod.document_id)
filename = doc.filename if doc else "Unknown"
field_name = ""
if mod.new_value and isinstance(mod.new_value, dict):
field_name = mod.new_value.get("class_name", "")
activities.append({
"type": "annotation_modified",
"description": f"Modified {filename} {field_name}".strip(),
"timestamp": mod.created_at,
"metadata": {
"annotation_id": str(mod.annotation_id),
"document_id": str(mod.document_id),
"field_name": field_name,
},
})
# Training completions and failures
training_tasks = session.exec(
select(TrainingTask)
.where(TrainingTask.status.in_(["completed", "failed"]))
.order_by(TrainingTask.updated_at.desc())
.limit(limit)
).all()
for task in training_tasks:
if task.updated_at is None:
continue
if task.status == "completed":
# Use metrics_mAP field directly
mAP = task.metrics_mAP or 0.0
activities.append({
"type": "training_completed",
"description": f"Training complete: {task.name}, mAP {mAP:.1%}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"mAP": mAP,
},
})
else:
activities.append({
"type": "training_failed",
"description": f"Training failed: {task.name}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"error": task.error_message or "",
},
})
# Model activations
model_versions = session.exec(
select(ModelVersion)
.where(ModelVersion.activated_at.is_not(None))
.order_by(ModelVersion.activated_at.desc())
.limit(limit)
).all()
for model in model_versions:
if model.activated_at is None:
continue
activities.append({
"type": "model_activated",
"description": f"Activated model {model.version}",
"timestamp": model.activated_at,
"metadata": {
"version_id": str(model.version_id),
"version": model.version,
},
})
# Sort all activities by timestamp DESC and return top N
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return activities[:limit]

View File

@@ -0,0 +1,265 @@
"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from shared.fields import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(
self,
datasets_repo,
documents_repo,
annotations_repo,
base_dir: Path,
):
self._datasets_repo = datasets_repo
self._documents_repo = documents_repo
self._annotations_repo = annotations_repo
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._documents_repo.get_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Group documents by group_key and assign splits
doc_list = list(documents)
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._datasets_repo.add_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _assign_splits_by_group(
self,
documents: list,
train_ratio: float,
val_ratio: float,
seed: int,
) -> dict[str, str]:
"""Assign splits based on group_key.
Logic:
- Documents with same group_key stay together in the same split
- Groups with only 1 document go directly to train
- Groups with 2+ documents participate in shuffle & split
Args:
documents: List of AdminDocument objects
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
seed: Random seed for reproducibility
Returns:
Dict mapping document_id (str) -> split ("train"/"val"/"test")
"""
# Group documents by group_key
# None/empty group_key treated as unique (each doc is its own group)
groups: dict[str | None, list] = {}
for doc in documents:
key = doc.group_key if doc.group_key else None
if key is None:
# Treat each ungrouped doc as its own unique group
# Use document_id as pseudo-key
key = f"__ungrouped_{doc.document_id}"
groups.setdefault(key, []).append(doc)
# Separate single-doc groups from multi-doc groups
single_doc_groups: list[tuple[str | None, list]] = []
multi_doc_groups: list[tuple[str | None, list]] = []
for key, docs in groups.items():
if len(docs) == 1:
single_doc_groups.append((key, docs))
else:
multi_doc_groups.append((key, docs))
# Initialize result mapping
doc_splits: dict[str, str] = {}
# Combine all groups for splitting
all_groups = single_doc_groups + multi_doc_groups
# Shuffle all groups and assign splits
if all_groups:
rng = random.Random(seed)
rng.shuffle(all_groups)
n_groups = len(all_groups)
n_train = max(1, round(n_groups * train_ratio))
# Ensure at least 1 in val if we have more than 1 group
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
for i, (_key, docs) in enumerate(all_groups):
if i < n_train:
split = "train"
elif i < n_train + n_val:
split = "val"
else:
split = "test"
for doc in docs:
doc_splits[str(doc.document_id)] = split
logger.info(
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
len(all_groups),
sum(1 for s in doc_splits.values() if s == "train"),
sum(1 for s in doc_splits.values() if s == "val"),
)
return doc_splits
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))

View File

@@ -0,0 +1,550 @@
"""
Database-based Auto-labeling Service
Processes documents with field values stored in the database (csv_field_values).
Used by the pre-label API to create annotations from expected values.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI
from shared.fields import CSV_TO_CLASS_MAPPING
from backend.data.admin_models import AdminDocument
from backend.data.repositories import DocumentRepository, AnnotationRepository
from shared.data.db import DocumentDB
from backend.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
# Initialize DocumentDB for saving match reports
_document_db: DocumentDB | None = None
def get_document_db() -> DocumentDB:
"""Get or create DocumentDB instance with connection and tables initialized.
Follows the same pattern as CLI autolabel (src/cli/autolabel.py lines 370-373).
"""
global _document_db
if _document_db is None:
_document_db = DocumentDB()
_document_db.connect()
_document_db.create_tables() # Ensure tables exist
logger.info("Connected to PostgreSQL DocumentDB for match reports")
return _document_db
def convert_csv_field_values_to_row_dict(
document: AdminDocument,
) -> dict[str, Any]:
"""
Convert AdminDocument.csv_field_values to row_dict format for autolabel.
Args:
document: AdminDocument with csv_field_values
Returns:
Dictionary in row_dict format compatible with autolabel_tasks
"""
csv_values = document.csv_field_values or {}
# Build row_dict with DocumentId
row_dict = {
"DocumentId": str(document.document_id),
}
# Map csv_field_values to row_dict format
# csv_field_values uses keys like: InvoiceNumber, InvoiceDate, Amount, OCR, Bankgiro, etc.
# row_dict uses same keys
for key, value in csv_values.items():
if value is not None and value != "":
row_dict[key] = str(value)
return row_dict
def get_pending_autolabel_documents(
limit: int = 10,
) -> list[AdminDocument]:
"""
Get documents pending auto-labeling.
Args:
limit: Maximum number of documents to return
Returns:
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
"""
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import AdminDocument
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == "auto_labeling",
AdminDocument.auto_label_status == "pending",
).order_by(AdminDocument.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def process_document_autolabel(
document: AdminDocument,
output_dir: Path | None = None,
dpi: int = DEFAULT_DPI,
min_confidence: float = 0.5,
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
) -> dict[str, Any]:
"""
Process a single document for auto-labeling using csv_field_values.
Args:
document: AdminDocument with csv_field_values and file_path
output_dir: Output directory for temp files
dpi: Rendering DPI
min_confidence: Minimum match confidence
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
Returns:
Result dictionary with success status and annotations
"""
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
from shared.pdf import PDFDocument
# Initialize repositories if not provided
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
document_id = str(document.document_id)
file_path = Path(document.file_path)
# Get output directory from StorageHelper
storage = get_storage_helper()
if output_dir is None:
output_dir = storage.get_autolabel_output_path()
if output_dir is None:
output_dir = Path("data/autolabel_output")
output_dir.mkdir(parents=True, exist_ok=True)
# Mark as processing
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
try:
# Check if file exists
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Convert csv_field_values to row_dict
row_dict = convert_csv_field_values_to_row_dict(document)
if len(row_dict) <= 1: # Only has DocumentId
raise ValueError("No field values to match")
# Determine PDF type (text or scanned)
is_scanned = False
with PDFDocument(file_path) as pdf_doc:
# Check if first page has extractable text
tokens = list(pdf_doc.extract_text_tokens(0))
is_scanned = len(tokens) < 10 # Threshold for "no text"
# Build task data
# Use raw_pdfs base path for pdf_path
# This ensures consistency with CLI autolabel for reprocess_failed.py
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
if raw_pdfs_dir is None:
raise ValueError("Storage not configured for local access")
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
task_data = {
"row_dict": row_dict,
"pdf_path": str(pdf_path_for_report),
"output_dir": str(output_dir),
"dpi": dpi,
"min_confidence": min_confidence,
}
# Process based on PDF type
if is_scanned:
result = process_scanned_pdf(task_data)
else:
result = process_text_pdf(task_data)
# Save report to DocumentDB (same as CLI autolabel)
if result.get("report"):
try:
doc_db = get_document_db()
doc_db.save_document(result["report"])
logger.info(f"Saved match report to DocumentDB for {document_id}")
except Exception as e:
logger.warning(f"Failed to save report to DocumentDB: {e}")
# Save annotations to database
if result.get("success") and result.get("report"):
_save_annotations_to_db(
ann_repo=ann_repo,
document_id=document_id,
report=result["report"],
page_annotations=result.get("pages", []),
dpi=dpi,
)
# Mark as completed
doc_repo.update_status(
document_id=document_id,
status="labeled",
auto_label_status="completed",
)
else:
# Mark as failed
errors = result.get("report", {}).get("errors", ["Unknown error"])
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error="; ".join(errors) if errors else "No annotations generated",
)
return result
except Exception as e:
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
# Mark as failed
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"doc_id": document_id,
"success": False,
"error": str(e),
}
def _save_annotations_to_db(
ann_repo: AnnotationRepository,
document_id: str,
report: dict[str, Any],
page_annotations: list[dict[str, Any]],
dpi: int = 200,
) -> int:
"""
Save generated annotations to database.
Args:
ann_repo: Annotation repository instance
document_id: Document ID
report: AutoLabelReport as dict
page_annotations: List of page annotation data
dpi: DPI used for rendering images (for coordinate conversion)
Returns:
Number of annotations saved
"""
from shared.fields import FIELD_CLASS_IDS
from backend.web.services.storage_helpers import get_storage_helper
# Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"customer_number": "customer_number",
"payment_line": "payment_line",
}
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
scale = dpi / 72.0
# Get storage helper for image dimensions
storage = get_storage_helper()
# Cache for image dimensions per page
image_dimensions: dict[int, tuple[int, int]] = {}
def get_image_dimensions(page_no: int) -> tuple[int, int] | None:
"""Get image dimensions for a page (1-indexed)."""
if page_no in image_dimensions:
return image_dimensions[page_no]
# Get dimensions from storage helper
dims = storage.get_admin_image_dimensions(document_id, page_no)
if dims:
image_dimensions[page_no] = dims
return dims
return None
annotation_count = 0
# Get field results from report (list of dicts)
field_results = report.get("field_results", [])
for field_info in field_results:
if not field_info.get("matched"):
continue
csv_field_name = field_info.get("field_name", "")
# Map CSV field name to internal field name
field_name = CSV_TO_INTERNAL_FIELD.get(csv_field_name, csv_field_name)
# Get class_id from field name
class_id = FIELD_CLASS_IDS.get(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {csv_field_name} -> {field_name}")
continue
# Get bbox info (list: [x, y, x2, y2] in PDF points - 72 DPI)
bbox = field_info.get("bbox", [])
if not bbox or len(bbox) < 4:
continue
# Convert PDF points (72 DPI) to pixel coordinates (at configured DPI)
pdf_x1, pdf_y1, pdf_x2, pdf_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
x1 = pdf_x1 * scale
y1 = pdf_y1 * scale
x2 = pdf_x2 * scale
y2 = pdf_y2 * scale
bbox_width = x2 - x1
bbox_height = y2 - y1
# Get page number (convert to 1-indexed)
page_no = field_info.get("page_no", 0) + 1
# Get image dimensions for normalization
dims = get_image_dimensions(page_no)
if dims:
img_width, img_height = dims
# Calculate normalized coordinates
x_center = (x1 + x2) / 2 / img_width
y_center = (y1 + y2) / 2 / img_height
width = bbox_width / img_width
height = bbox_height / img_height
else:
# Fallback: use pixel coordinates as-is for normalization
# (will be slightly off but better than /1000)
logger.warning(f"Could not get image dimensions for page {page_no}, using estimates")
# Estimate A4 at configured DPI: 595 x 842 points * scale
estimated_width = 595 * scale
estimated_height = 842 * scale
x_center = (x1 + x2) / 2 / estimated_width
y_center = (y1 + y2) / 2 / estimated_height
width = bbox_width / estimated_width
height = bbox_height / estimated_height
# Create annotation
try:
ann_repo.create(
document_id=document_id,
page_number=page_no,
class_id=class_id,
class_name=field_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=int(x1),
bbox_y=int(y1),
bbox_width=int(bbox_width),
bbox_height=int(bbox_height),
text_value=field_info.get("matched_text"),
confidence=field_info.get("score"),
source="auto",
)
annotation_count += 1
logger.info(f"Saved annotation for {field_name}: bbox=({int(x1)}, {int(y1)}, {int(bbox_width)}, {int(bbox_height)})")
except Exception as e:
logger.warning(f"Failed to save annotation for {field_name}: {e}")
return annotation_count
def run_pending_autolabel_batch(
batch_size: int = 10,
output_dir: Path | None = None,
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
) -> dict[str, Any]:
"""
Process a batch of pending auto-label documents.
Args:
batch_size: Number of documents to process
output_dir: Output directory for temp files
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
Returns:
Summary of processing results
"""
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
documents = get_pending_autolabel_documents(limit=batch_size)
results = {
"total": len(documents),
"successful": 0,
"failed": 0,
"documents": [],
}
for doc in documents:
result = process_document_autolabel(
document=doc,
output_dir=output_dir,
doc_repo=doc_repo,
ann_repo=ann_repo,
)
doc_result = {
"document_id": str(doc.document_id),
"success": result.get("success", False),
}
if result.get("success"):
results["successful"] += 1
else:
results["failed"] += 1
doc_result["error"] = result.get("error") or "Unknown error"
results["documents"].append(doc_result)
return results
def save_manual_annotations_to_document_db(
document: AdminDocument,
annotations: list,
) -> dict[str, Any]:
"""
Save manual annotations to PostgreSQL documents and field_results tables.
Called when user marks a document as 'labeled' from the web UI.
This ensures manually labeled documents are also tracked in the same
database as auto-labeled documents for consistency.
Args:
document: AdminDocument instance
annotations: List of AdminAnnotation instances
Returns:
Dict with success status and details
"""
from datetime import datetime
document_id = str(document.document_id)
# Build pdf_path using raw_pdfs base path (same as auto-label)
storage = get_storage_helper()
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
if raw_pdfs_dir is None:
return {
"success": False,
"document_id": document_id,
"error": "Storage not configured for local access",
}
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
# Build report dict compatible with DocumentDB.save_document()
field_results = []
fields_total = len(annotations)
fields_matched = 0
for ann in annotations:
# All manual annotations are considered "matched" since user verified them
field_result = {
"field_name": ann.class_name,
"csv_value": ann.text_value or "", # Manual annotations may not have CSV value
"matched": True,
"score": ann.confidence or 1.0, # Manual = high confidence
"matched_text": ann.text_value,
"candidate_used": "manual",
"bbox": [ann.bbox_x, ann.bbox_y, ann.bbox_x + ann.bbox_width, ann.bbox_y + ann.bbox_height],
"page_no": ann.page_number - 1, # Convert to 0-indexed
"context_keywords": [],
"error": None,
}
field_results.append(field_result)
fields_matched += 1
# Determine PDF type
pdf_type = "unknown"
if pdf_path.exists():
try:
from shared.pdf import PDFDocument
with PDFDocument(pdf_path) as pdf_doc:
tokens = list(pdf_doc.extract_text_tokens(0))
pdf_type = "scanned" if len(tokens) < 10 else "text"
except Exception as e:
logger.warning(f"Could not determine PDF type: {e}")
# Build report
report = {
"document_id": document_id,
"pdf_path": str(pdf_path),
"pdf_type": pdf_type,
"success": fields_matched > 0,
"total_pages": document.page_count,
"fields_matched": fields_matched,
"fields_total": fields_total,
"annotations_generated": fields_matched,
"processing_time_ms": 0, # Manual labeling - no processing time
"timestamp": datetime.utcnow().isoformat(),
"errors": [],
"field_results": field_results,
# Extended fields (from CSV if available)
"split": None,
"customer_number": document.csv_field_values.get("customer_number") if document.csv_field_values else None,
"supplier_name": document.csv_field_values.get("supplier_name") if document.csv_field_values else None,
"supplier_organisation_number": document.csv_field_values.get("supplier_organisation_number") if document.csv_field_values else None,
"supplier_accounts": document.csv_field_values.get("supplier_accounts") if document.csv_field_values else None,
}
# Save to PostgreSQL DocumentDB
try:
doc_db = get_document_db()
doc_db.save_document(report)
logger.info(f"Saved manual annotations to DocumentDB for {document_id}: {fields_matched} fields")
return {
"success": True,
"document_id": document_id,
"fields_saved": fields_matched,
"message": f"Saved {fields_matched} annotations to DocumentDB",
}
except Exception as e:
logger.error(f"Failed to save manual annotations to DocumentDB: {e}", exc_info=True)
return {
"success": False,
"document_id": document_id,
"error": str(e),
}

View File

@@ -0,0 +1,217 @@
"""
Document Service for storage-backed file operations.
Provides a unified interface for document upload, download, and serving
using the storage abstraction layer.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from uuid import uuid4
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
@dataclass
class DocumentResult:
"""Result of document operation."""
id: str
file_path: str
filename: str | None = None
class DocumentService:
"""Service for document file operations using storage backend.
Provides upload, download, and URL generation for documents and images.
"""
# Storage path prefixes
DOCUMENTS_PREFIX = "documents"
IMAGES_PREFIX = "images"
def __init__(
self,
storage_backend: "StorageBackend",
admin_db: Any | None = None,
) -> None:
"""Initialize document service.
Args:
storage_backend: Storage backend for file operations.
admin_db: Optional AdminDB instance for database operations.
"""
self._storage = storage_backend
self._admin_db = admin_db
def upload_document(
self,
content: bytes,
filename: str,
dataset_id: str | None = None,
document_id: str | None = None,
) -> DocumentResult:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename.
dataset_id: Optional dataset ID for organization.
document_id: Optional document ID (generated if not provided).
Returns:
DocumentResult with ID and storage path.
"""
if document_id is None:
document_id = str(uuid4())
# Extract extension from filename
ext = ""
if "." in filename:
ext = "." + filename.rsplit(".", 1)[-1].lower()
# Build logical path
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
# Upload via storage backend
self._storage.upload_bytes(content, remote_path, overwrite=True)
return DocumentResult(
id=document_id,
file_path=remote_path,
filename=filename,
)
def download_document(self, remote_path: str) -> bytes:
"""Download a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
Document content as bytes.
"""
return self._storage.download_bytes(remote_path)
def get_document_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a document.
Args:
remote_path: Logical path to the document.
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for document access.
"""
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def document_exists(self, remote_path: str) -> bool:
"""Check if a document exists in storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document exists.
"""
return self._storage.exists(remote_path)
def delete_document_files(self, remote_path: str) -> bool:
"""Delete a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document was deleted.
"""
return self._storage.delete(remote_path)
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Logical path where image was stored.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
self._storage.upload_bytes(content, remote_path, overwrite=True)
return remote_path
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a page image.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for image access.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image from storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.download_bytes(remote_path)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document ID.
Returns:
Number of images deleted.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
image_paths = self._storage.list_files(prefix)
deleted_count = 0
for path in image_paths:
if self._storage.delete(path):
deleted_count += 1
return deleted_count
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document ID.
Returns:
List of image paths.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
return self._storage.list_files(prefix)

View File

@@ -0,0 +1,360 @@
"""
Inference Service
Business logic for invoice field extraction.
"""
from __future__ import annotations
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
class InferenceService:
"""
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
if self._is_initialized:
return
logger.info("Initializing inference service...")
start_time = time.time()
try:
from backend.pipeline.pipeline import InferencePipeline
from backend.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
)
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
# If no pages rendered
return None

View File

@@ -0,0 +1,830 @@
"""
Storage helpers for web services.
Provides convenience functions for common storage operations,
wrapping the storage backend with proper path handling using prefixes.
"""
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
from shared.storage import PREFIXES, get_storage_backend
from shared.storage.local import LocalStorageBackend
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_default_storage() -> "StorageBackend":
"""Get the default storage backend.
Returns:
Configured StorageBackend instance.
"""
return get_storage_backend()
class StorageHelper:
"""Helper class for storage operations with prefixes.
Provides high-level operations for document storage, including
upload, download, and URL generation with proper path prefixes.
"""
def __init__(self, storage: "StorageBackend | None" = None) -> None:
"""Initialize storage helper.
Args:
storage: Storage backend to use. If None, creates default.
"""
self._storage = storage or get_default_storage()
@property
def storage(self) -> "StorageBackend":
"""Get the underlying storage backend."""
return self._storage
# Document operations
def upload_document(
self,
content: bytes,
filename: str,
document_id: str | None = None,
) -> tuple[str, str]:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename (used for extension).
document_id: Optional document ID. Generated if not provided.
Returns:
Tuple of (document_id, storage_path).
"""
if document_id is None:
document_id = str(uuid4())
ext = Path(filename).suffix.lower() or ".pdf"
path = PREFIXES.document_path(document_id, ext)
self._storage.upload_bytes(content, path, overwrite=True)
return document_id, path
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
"""Download a document from storage.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
Document content as bytes.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.download_bytes(path)
def get_document_url(
self,
document_id: str,
extension: str = ".pdf",
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a document.
Args:
document_id: Document identifier.
extension: File extension.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.get_presigned_url(path, expires_in_seconds)
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
"""Check if a document exists.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document exists.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.exists(path)
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
"""Delete a document.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document was deleted.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.delete(path)
# Image operations
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = PREFIXES.image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
# Upload staging operations
def save_upload(
self,
content: bytes,
filename: str,
subfolder: str | None = None,
) -> str:
"""Save a file to upload staging area.
Args:
content: File content as bytes.
filename: Filename to save as.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path where file was saved.
"""
path = PREFIXES.upload_path(filename, subfolder)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
"""Get a file from upload staging area.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
File content as bytes.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.download_bytes(path)
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
"""Delete a file from upload staging area.
Args:
filename: Filename to delete.
subfolder: Optional subfolder.
Returns:
True if file was deleted.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.delete(path)
# Result operations
def save_result(self, content: bytes, filename: str) -> str:
"""Save a result file.
Args:
content: File content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.result_path(filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_result(self, filename: str) -> bytes:
"""Get a result file.
Args:
filename: Filename to retrieve.
Returns:
File content as bytes.
"""
path = PREFIXES.result_path(filename)
return self._storage.download_bytes(path)
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
"""Get presigned URL for a result file.
Args:
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.result_path(filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
def result_exists(self, filename: str) -> bool:
"""Check if a result file exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = PREFIXES.result_path(filename)
return self._storage.exists(path)
def delete_result(self, filename: str) -> bool:
"""Delete a result file.
Args:
filename: Filename to delete.
Returns:
True if file was deleted.
"""
path = PREFIXES.result_path(filename)
return self._storage.delete(path)
# Export operations
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
"""Save an export file.
Args:
content: File content as bytes.
export_id: Export identifier.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.export_path(export_id, filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_export_url(
self,
export_id: str,
filename: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an export file.
Args:
export_id: Export identifier.
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.export_path(export_id, filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
# Admin image operations
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
"""Get the storage path for an admin image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Storage path like "admin_images/doc123/page_1.png"
"""
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
def save_admin_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save an admin page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = self.get_admin_image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
"""Download an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_admin_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
"""Check if an admin page image exists.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
True if image exists.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.exists(path)
def list_admin_images(self, document_id: str) -> list[str]:
"""List all admin images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
def delete_admin_images(self, document_id: str) -> int:
"""Delete all admin images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def get_admin_image_local_path(
self, document_id: str, page_num: int
) -> Path | None:
"""Get the local filesystem path for an admin image.
This method is useful for serving files via FileResponse.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
# Cloud storage - cannot get local path
return None
remote_path = self.get_admin_image_path(document_id, page_num)
try:
full_path = self._storage._get_full_path(remote_path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_admin_image_dimensions(
self, document_id: str, page_num: int
) -> tuple[int, int] | None:
"""Get the dimensions (width, height) of an admin image.
This method is useful for normalizing bounding box coordinates.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Tuple of (width, height) if image exists, None otherwise.
"""
from PIL import Image
# Try local path first for efficiency
local_path = self.get_admin_image_local_path(document_id, page_num)
if local_path is not None:
with Image.open(local_path) as img:
return img.size
# Fall back to downloading for cloud storage
if not self.admin_image_exists(document_id, page_num):
return None
try:
import io
image_bytes = self.get_admin_image(document_id, page_num)
with Image.open(io.BytesIO(image_bytes)) as img:
return img.size
except Exception:
return None
# Raw PDF operations (legacy compatibility)
def save_raw_pdf(self, content: bytes, filename: str) -> str:
"""Save a raw PDF for auto-labeling pipeline.
Args:
content: PDF content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_raw_pdf(self, filename: str) -> bytes:
"""Get a raw PDF from storage.
Args:
filename: Filename to retrieve.
Returns:
PDF content as bytes.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.download_bytes(path)
def raw_pdf_exists(self, filename: str) -> bool:
"""Check if a raw PDF exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.exists(path)
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a raw PDF.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = f"{PREFIXES.RAW_PDFS}/{filename}"
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_raw_pdf_path(self, filename: str) -> str:
"""Get the storage path for a raw PDF (not the local filesystem path).
Args:
filename: Filename.
Returns:
Storage path like "raw_pdfs/filename.pdf"
"""
return f"{PREFIXES.RAW_PDFS}/{filename}"
# Result local path operations
def get_result_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a result file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.result_path(filename)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_results_base_path(self) -> Path | None:
"""Get the base directory path for results (local storage only).
Used for mounting static file directories.
Returns:
Path to results directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Upload local path operations
def get_upload_local_path(
self, filename: str, subfolder: str | None = None
) -> Path | None:
"""Get the local filesystem path for an upload file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.upload_path(filename, subfolder)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
"""Get the base directory path for uploads (local storage only).
Args:
subfolder: Optional subfolder (e.g., "async").
Returns:
Path to uploads directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
if subfolder:
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
else:
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
"""Check if an upload file exists.
Args:
filename: Filename to check.
subfolder: Optional subfolder.
Returns:
True if file exists.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.exists(path)
# Dataset operations
def get_datasets_base_path(self) -> Path | None:
"""Get the base directory path for datasets (local storage only).
Returns:
Path to datasets directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_admin_images_base_path(self) -> Path | None:
"""Get the base directory path for admin images (local storage only).
Returns:
Path to admin_images directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_raw_pdfs_base_path(self) -> Path | None:
"""Get the base directory path for raw PDFs (local storage only).
Returns:
Path to raw_pdfs directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_autolabel_output_path(self) -> Path | None:
"""Get the directory path for autolabel output (local storage only).
Returns:
Path to autolabel_output directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
# Use a subfolder under results for autolabel output
base_path = self._storage._get_full_path("autolabel_output")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_training_data_path(self) -> Path | None:
"""Get the directory path for training data exports (local storage only).
Returns:
Path to training directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path("training")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_exports_base_path(self) -> Path | None:
"""Get the base directory path for exports (local storage only).
Returns:
Path to exports directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Default instance for convenience
_default_helper: StorageHelper | None = None
def get_storage_helper() -> StorageHelper:
"""Get the default storage helper instance.
Creates the helper on first call with default storage backend.
Returns:
Default StorageHelper instance.
"""
global _default_helper
if _default_helper is None:
_default_helper = StorageHelper()
return _default_helper

View File

@@ -0,0 +1,24 @@
"""
Background Task Queues
Worker queues for asynchronous and batch processing.
"""
from backend.web.workers.async_queue import AsyncTaskQueue, AsyncTask
from backend.web.workers.batch_queue import (
BatchTaskQueue,
BatchTask,
init_batch_queue,
shutdown_batch_queue,
get_batch_queue,
)
__all__ = [
"AsyncTaskQueue",
"AsyncTask",
"BatchTaskQueue",
"BatchTask",
"init_batch_queue",
"shutdown_batch_queue",
"get_batch_queue",
]

Some files were not shown because too many files have changed in this diff Show More