restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install inference package
COPY packages/inference /app/packages/inference
RUN pip install --no-cache-dir -e /app/packages/inference
# Copy frontend (if needed)
COPY frontend /app/frontend
WORKDIR /app/packages/inference
EXPOSE 8000
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]

View File

View File

@@ -0,0 +1,105 @@
"""Trigger training jobs on Azure Container Instances."""
import logging
import os
logger = logging.getLogger(__name__)
# Azure SDK is optional; only needed if using ACI trigger
try:
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
EnvironmentVariable,
GpuResource,
ResourceRequests,
ResourceRequirements,
)
_AZURE_SDK_AVAILABLE = True
except ImportError:
_AZURE_SDK_AVAILABLE = False
def start_training_container(task_id: str) -> str | None:
"""
Start an Azure Container Instance for a training task.
Returns the container group name if successful, None otherwise.
Requires environment variables:
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
"""
if not _AZURE_SDK_AVAILABLE:
logger.warning(
"Azure SDK not installed. Install azure-mgmt-containerinstance "
"and azure-identity to use ACI trigger."
)
return None
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
image = os.environ.get(
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
)
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
location = os.environ.get("AZURE_LOCATION", "eastus")
if not subscription_id:
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
return None
credential = DefaultAzureCredential()
client = ContainerInstanceManagementClient(credential, subscription_id)
container_name = f"training-{task_id[:8]}"
env_vars = [
EnvironmentVariable(name="TASK_ID", value=task_id),
]
# Pass DB connection securely
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
val = os.environ.get(var, "")
if val:
env_vars.append(EnvironmentVariable(name=var, value=val))
db_password = os.environ.get("DB_PASSWORD", "")
if db_password:
env_vars.append(
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
)
container = Container(
name=container_name,
image=image,
resources=ResourceRequirements(
requests=ResourceRequests(
cpu=4,
memory_in_gb=16,
gpu=GpuResource(count=1, sku=gpu_sku),
)
),
environment_variables=env_vars,
command=[
"python",
"run_training.py",
"--task-id",
task_id,
],
)
group = ContainerGroup(
location=location,
containers=[container],
os_type="Linux",
restart_policy="Never",
)
logger.info("Creating ACI container group: %s", container_name)
client.container_groups.begin_create_or_update(
resource_group, container_name, group
)
return container_name

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Inference CLI
Runs inference on new PDFs to extract invoice data.
"""
import argparse
import json
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
def main():
parser = argparse.ArgumentParser(
description='Extract invoice data from PDFs using trained model'
)
parser.add_argument(
'--model', '-m',
required=True,
help='Path to trained YOLO model (.pt file)'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Input PDF file or directory'
)
parser.add_argument(
'--output', '-o',
help='Output JSON file (default: stdout)'
)
parser.add_argument(
'--confidence',
type=float,
default=0.5,
help='Detection confidence threshold (default: 0.5)'
)
parser.add_argument(
'--dpi',
type=int,
default=DEFAULT_DPI,
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
)
parser.add_argument(
'--no-fallback',
action='store_true',
help='Disable fallback OCR'
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
)
parser.add_argument(
'--gpu',
action='store_true',
help='Use GPU'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
args = parser.parse_args()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
sys.exit(1)
# Get input files
input_path = Path(args.input)
if input_path.is_file():
pdf_files = [input_path]
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
from inference.pipeline import InferencePipeline
# Initialize pipeline
pipeline = InferencePipeline(
model_path=model_path,
confidence_threshold=args.confidence,
ocr_lang=args.lang,
use_gpu=args.gpu,
dpi=args.dpi,
enable_fallback=not args.no_fallback
)
# Process files
results = []
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
if result.fallback_used:
print(f" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
# Output results
if len(results) == 1:
output = results[0]
else:
output = results
json_output = json.dumps(output, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
else:
print(json_output)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,159 @@
"""
Web Server CLI
Command-line interface for starting the web server.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
from shared.config import DEFAULT_DPI
def setup_logging(debug: bool = False) -> None:
"""Configure logging."""
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Start the Invoice Field Extraction web server",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to listen on",
)
parser.add_argument(
"--model",
"-m",
type=Path,
default=Path("runs/train/invoice_fields/weights/best.pt"),
help="Path to YOLO model weights",
)
parser.add_argument(
"--confidence",
type=float,
default=0.5,
help="Detection confidence threshold",
)
parser.add_argument(
"--dpi",
type=int,
default=DEFAULT_DPI,
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="Disable GPU acceleration",
)
parser.add_argument(
"--reload",
action="store_true",
help="Enable auto-reload for development",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode",
)
return parser.parse_args()
def main() -> None:
"""Main entry point."""
args = parse_args()
setup_logging(debug=args.debug)
logger = logging.getLogger(__name__)
# Validate model path
if not args.model.exists():
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info("=" * 60)
logger.info("Invoice Field Extraction Web Server")
logger.info("=" * 60)
logger.info(f"Model: {args.model}")
logger.info(f"Confidence threshold: {args.confidence}")
logger.info(f"GPU enabled: {not args.no_gpu}")
logger.info(f"Server: http://{args.host}:{args.port}")
logger.info("=" * 60)
# Create config
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
config = AppConfig(
model=ModelConfig(
model_path=args.model,
confidence_threshold=args.confidence,
use_gpu=not args.no_gpu,
dpi=args.dpi,
),
server=ServerConfig(
host=args.host,
port=args.port,
debug=args.debug,
reload=args.reload,
workers=args.workers,
),
storage=StorageConfig(),
)
# Create and run app
import uvicorn
from inference.web.app import create_app
app = create_app(config)
uvicorn.run(
app,
host=config.server.host,
port=config.server.port,
reload=config.server.reload,
workers=config.server.workers if not config.server.reload else 1,
log_level="debug" if config.server.debug else "info",
)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,407 @@
"""
Admin API SQLModel Database Models
Defines the database schema for admin document management, annotations, and training tasks.
Includes batch upload support, training document links, and annotation history.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
# =============================================================================
# CSV to Field Class Mapping
# =============================================================================
CSV_TO_CLASS_MAPPING: dict[str, int] = {
"InvoiceNumber": 0, # invoice_number
"InvoiceDate": 1, # invoice_date
"InvoiceDueDate": 2, # invoice_due_date
"OCR": 3, # ocr_number
"Bankgiro": 4, # bankgiro
"Plusgiro": 5, # plusgiro
"Amount": 6, # amount
"supplier_organisation_number": 7, # supplier_organisation_number
# 8: payment_line (derived from OCR/Bankgiro/Amount)
"customer_number": 9, # customer_number
}
# =============================================================================
# Core Models
# =============================================================================
class AdminToken(SQLModel, table=True):
"""Admin authentication token."""
__tablename__ = "admin_tokens"
token: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
expires_at: datetime | None = Field(default=None)
class AdminDocument(SQLModel, table=True):
"""Document uploaded for labeling/annotation."""
__tablename__ = "admin_documents"
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
file_path: str = Field(max_length=512) # Path to stored file
page_count: int = Field(default=1)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, auto_labeling, labeled, exported
auto_label_status: str | None = Field(default=None, max_length=20)
# Auto-label status: running, completed, failed
auto_label_error: str | None = Field(default=None)
# v2: Upload source tracking
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
# When auto-label was queued
annotation_lock_until: datetime | None = Field(default=None)
# Lock for manual annotation while auto-label runs
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class AdminAnnotation(SQLModel, table=True):
"""Annotation for a document (bounding box + label)."""
__tablename__ = "admin_annotations"
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
page_number: int = Field(default=1) # 1-indexed
class_id: int # 0-9 for invoice fields
class_name: str = Field(max_length=50) # e.g., "invoice_number"
# Bounding box (normalized 0-1 coordinates)
x_center: float
y_center: float
width: float
height: float
# Original pixel coordinates (for display)
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
# OCR extracted text (if available)
text_value: str | None = Field(default=None)
confidence: float | None = Field(default=None)
# Source: manual, auto, imported
source: str = Field(default="manual", max_length=20, index=True)
# v2: Verification fields
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
# v2: Override tracking
override_source: str | None = Field(default=None, max_length=20)
# If this annotation overrides another: 'auto' or 'imported'
original_annotation_id: UUID | None = Field(default=None)
# Reference to the annotation this overrides
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingTask(SQLModel, table=True):
"""Training/fine-tuning task."""
__tablename__ = "training_tasks"
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, scheduled, running, completed, failed, cancelled
task_type: str = Field(default="train", max_length=20)
# Task type: train, finetune
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training configuration
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Schedule settings
scheduled_at: datetime | None = Field(default=None)
cron_expression: str | None = Field(default=None, max_length=50)
is_recurring: bool = Field(default=False)
# Execution details
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
error_message: str | None = Field(default=None)
# Result metrics
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
model_path: str | None = Field(default=None, max_length=512)
# v2: Document count and extracted metrics
document_count: int = Field(default=0)
# Count of documents used in training
metrics_mAP: float | None = Field(default=None, index=True)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
# Extracted metrics for easy querying
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingLog(SQLModel, table=True):
"""Training log entry."""
__tablename__ = "training_logs"
log_id: int | None = Field(default=None, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
level: str = Field(max_length=20) # INFO, WARNING, ERROR
message: str
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# =============================================================================
# Batch Upload Models (v2)
# =============================================================================
class BatchUpload(SQLModel, table=True):
"""Batch upload of multiple documents via ZIP file."""
__tablename__ = "batch_uploads"
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255) # ZIP filename
file_size: int
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
status: str = Field(default="processing", max_length=20, index=True)
# Status: processing, completed, partial, failed
total_files: int = Field(default=0)
processed_files: int = Field(default=0)
# Number of files processed so far
successful_files: int = Field(default=0)
failed_files: int = Field(default=0)
csv_filename: str | None = Field(default=None, max_length=255)
# CSV file used for auto-labeling
csv_row_count: int | None = Field(default=None)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: datetime | None = Field(default=None)
class BatchUploadFile(SQLModel, table=True):
"""Individual file within a batch upload."""
__tablename__ = "batch_upload_files"
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
filename: str = Field(max_length=255) # PDF filename within ZIP
document_id: UUID | None = Field(default=None)
# Link to created AdminDocument (if successful)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, processing, completed, failed, skipped
error_message: str | None = Field(default=None)
annotation_count: int = Field(default=0)
# Number of annotations created for this file
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# CSV row data for this file (if available)
created_at: datetime = Field(default_factory=datetime.utcnow)
processed_at: datetime | None = Field(default=None)
# =============================================================================
# Training Document Link (v2)
# =============================================================================
class TrainingDataset(SQLModel, table=True):
"""Training dataset containing selected documents with train/val/test splits."""
__tablename__ = "training_datasets"
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, training, archived, failed
train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1)
seed: int = Field(default=42)
total_documents: int = Field(default=0)
total_images: int = Field(default=0)
total_annotations: int = Field(default=0)
dataset_path: str | None = Field(default=None, max_length=512)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class DatasetDocument(SQLModel, table=True):
"""Junction table linking datasets to documents with split assignment."""
__tablename__ = "dataset_documents"
id: UUID = Field(default_factory=uuid4, primary_key=True)
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
split: str = Field(max_length=10) # train, val, test
page_count: int = Field(default=0)
annotation_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingDocumentLink(SQLModel, table=True):
"""Junction table linking training tasks to documents."""
__tablename__ = "training_document_links"
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of annotations at training time (includes count, verified count, etc.)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================
class AnnotationHistory(SQLModel, table=True):
"""History of annotation changes (for override tracking)."""
__tablename__ = "annotation_history"
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
# Change action: created, updated, deleted, override
action: str = Field(max_length=20, index=True)
# Previous value (for updates/deletes)
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# New value (for creates/updates)
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Change metadata
changed_by: str | None = Field(default=None, max_length=255)
# User/token who made the change
change_reason: str | None = Field(default=None)
# Optional reason for change
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Field class mapping (same as src/cli/train.py)
FIELD_CLASSES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_organisation_number",
8: "payment_line",
9: "customer_number",
}
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
# Read-only models for API responses
class AdminDocumentRead(SQLModel):
"""Admin document response model."""
document_id: UUID
filename: str
file_size: int
content_type: str
page_count: int
status: str
auto_label_status: str | None
auto_label_error: str | None
created_at: datetime
updated_at: datetime
class AdminAnnotationRead(SQLModel):
"""Admin annotation response model."""
annotation_id: UUID
document_id: UUID
page_number: int
class_id: int
class_name: str
x_center: float
y_center: float
width: float
height: float
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
text_value: str | None
confidence: float | None
source: str
created_at: datetime
class TrainingTaskRead(SQLModel):
"""Training task response model."""
task_id: UUID
name: str
description: str | None
status: str
task_type: str
config: dict[str, Any] | None
scheduled_at: datetime | None
is_recurring: bool
started_at: datetime | None
completed_at: datetime | None
error_message: str | None
result_metrics: dict[str, Any] | None
model_path: str | None
dataset_id: UUID | None
created_at: datetime
class TrainingDatasetRead(SQLModel):
"""Training dataset response model."""
dataset_id: UUID
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
created_at: datetime
updated_at: datetime
class DatasetDocumentRead(SQLModel):
"""Dataset document response model."""
id: UUID
dataset_id: UUID
document_id: UUID
split: str
page_count: int
annotation_count: int

View File

@@ -0,0 +1,374 @@
"""
Async Request Database Operations
Database interface for async invoice processing requests using SQLModel.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy import func, text
from sqlmodel import Session, select
from inference.data.database import get_session_context, create_db_and_tables, close_engine
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent
logger = logging.getLogger(__name__)
# Legacy dataclasses for backward compatibility
from dataclasses import dataclass
@dataclass(frozen=True)
class ApiKeyConfig:
"""API key configuration and limits (legacy compatibility)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestDB:
"""Database interface for async processing requests using SQLModel."""
def __init__(self, connection_string: str | None = None) -> None:
# connection_string is kept for backward compatibility but ignored
# SQLModel uses the global engine from database.py
self._initialized = False
def connect(self):
"""Legacy method - returns self for compatibility."""
return self
def close(self) -> None:
"""Close database connections."""
close_engine()
def __enter__(self) -> "AsyncRequestDB":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass # Sessions are managed per-operation
def create_tables(self) -> None:
"""Create async processing tables if they don't exist."""
create_db_and_tables()
self._initialized = True
# ==========================================================================
# API Key Operations
# ==========================================================================
def is_valid_api_key(self, api_key: str) -> bool:
"""Check if API key exists and is active."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
return result is not None and result.is_active is True
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
"""Get API key configuration and limits."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
if result is None:
return None
return ApiKeyConfig(
api_key=result.api_key,
name=result.name,
is_active=result.is_active,
requests_per_minute=result.requests_per_minute,
max_concurrent_jobs=result.max_concurrent_jobs,
max_file_size_mb=result.max_file_size_mb,
)
def create_api_key(
self,
api_key: str,
name: str,
requests_per_minute: int = 10,
max_concurrent_jobs: int = 3,
max_file_size_mb: int = 50,
) -> None:
"""Create a new API key."""
with get_session_context() as session:
existing = session.get(ApiKey, api_key)
if existing:
existing.name = name
existing.requests_per_minute = requests_per_minute
existing.max_concurrent_jobs = max_concurrent_jobs
existing.max_file_size_mb = max_file_size_mb
session.add(existing)
else:
new_key = ApiKey(
api_key=api_key,
name=name,
requests_per_minute=requests_per_minute,
max_concurrent_jobs=max_concurrent_jobs,
max_file_size_mb=max_file_size_mb,
)
session.add(new_key)
def update_api_key_usage(self, api_key: str) -> None:
"""Update API key last used timestamp and increment total requests."""
with get_session_context() as session:
key = session.get(ApiKey, api_key)
if key:
key.last_used_at = datetime.utcnow()
key.total_requests += 1
session.add(key)
# ==========================================================================
# Async Request Operations
# ==========================================================================
def create_request(
self,
api_key: str,
filename: str,
file_size: int,
content_type: str,
expires_at: datetime,
request_id: str | None = None,
) -> str:
"""Create a new async request."""
with get_session_context() as session:
request = AsyncRequest(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
)
if request_id:
request.request_id = UUID(request_id)
session.add(request)
session.flush() # To get the generated ID
return str(request.request_id)
def get_request(self, request_id: str) -> AsyncRequest | None:
"""Get a single async request by ID."""
with get_session_context() as session:
result = session.get(AsyncRequest, UUID(request_id))
if result:
# Detach from session for use outside context
session.expunge(result)
return result
def get_request_by_api_key(
self,
request_id: str,
api_key: str,
) -> AsyncRequest | None:
"""Get a request only if it belongs to the given API key."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.request_id == UUID(request_id),
AsyncRequest.api_key == api_key,
)
result = session.exec(statement).first()
if result:
session.expunge(result)
return result
def update_status(
self,
request_id: str,
status: str,
error_message: str | None = None,
increment_retry: bool = False,
) -> None:
"""Update request status."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = status
if status == "processing":
request.started_at = datetime.utcnow()
if error_message is not None:
request.error_message = error_message
if increment_retry:
request.retry_count += 1
session.add(request)
def complete_request(
self,
request_id: str,
document_id: str,
result: dict[str, Any],
processing_time_ms: float,
visualization_path: str | None = None,
) -> None:
"""Mark request as completed with result."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = "completed"
request.document_id = document_id
request.result = result
request.processing_time_ms = processing_time_ms
request.visualization_path = visualization_path
request.completed_at = datetime.utcnow()
session.add(request)
def get_requests_by_api_key(
self,
api_key: str,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AsyncRequest], int]:
"""Get paginated requests for an API key."""
with get_session_context() as session:
# Count query
count_stmt = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
count_stmt = count_stmt.where(AsyncRequest.status == status)
total = session.exec(count_stmt).one()
# Fetch query
statement = select(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
statement = statement.where(AsyncRequest.status == status)
statement = statement.order_by(AsyncRequest.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
# Detach results from session
for r in results:
session.expunge(r)
return list(results), total
def count_active_jobs(self, api_key: str) -> int:
"""Count active (pending + processing) jobs for an API key."""
with get_session_context() as session:
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key,
AsyncRequest.status.in_(["pending", "processing"]),
)
return session.exec(statement).one()
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
"""Get pending requests ordered by creation time."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.status == "pending"
).order_by(AsyncRequest.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_queue_position(self, request_id: str) -> int | None:
"""Get position of a request in the pending queue."""
with get_session_context() as session:
# Get the request's created_at
request = session.get(AsyncRequest, UUID(request_id))
if not request:
return None
# Count pending requests created before this one
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.status == "pending",
AsyncRequest.created_at < request.created_at,
)
count = session.exec(statement).one()
return count + 1 # 1-based position
# ==========================================================================
# Rate Limit Operations
# ==========================================================================
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
"""Record a rate limit event."""
with get_session_context() as session:
event = RateLimitEvent(
api_key=api_key,
event_type=event_type,
)
session.add(event)
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
"""Count requests in the last N seconds."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
statement = select(func.count()).select_from(RateLimitEvent).where(
RateLimitEvent.api_key == api_key,
RateLimitEvent.event_type == "request",
RateLimitEvent.created_at > cutoff,
)
return session.exec(statement).one()
# ==========================================================================
# Cleanup Operations
# ==========================================================================
def delete_expired_requests(self) -> int:
"""Delete requests that have expired. Returns count of deleted rows."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
expired = session.exec(statement).all()
count = len(expired)
for request in expired:
session.delete(request)
logger.info(f"Deleted {count} expired async requests")
return count
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
"""Delete rate limit events older than N hours."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(hours=hours)
statement = select(RateLimitEvent).where(
RateLimitEvent.created_at < cutoff
)
old_events = session.exec(statement).all()
count = len(old_events)
for event in old_events:
session.delete(event)
return count
def reset_stale_processing_requests(
self,
stale_minutes: int = 10,
max_retries: int = 3,
) -> int:
"""
Reset requests stuck in 'processing' status.
Requests that have been processing for more than stale_minutes
are considered stale. They are either reset to 'pending' (if under
max_retries) or set to 'failed'.
"""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
reset_count = 0
# Find stale processing requests
statement = select(AsyncRequest).where(
AsyncRequest.status == "processing",
AsyncRequest.started_at < cutoff,
)
stale_requests = session.exec(statement).all()
for request in stale_requests:
if request.retry_count < max_retries:
request.status = "pending"
request.started_at = None
else:
request.status = "failed"
request.error_message = "Processing timeout after max retries"
session.add(request)
reset_count += 1
if reset_count > 0:
logger.warning(f"Reset {reset_count} stale processing requests")
return reset_count

View File

@@ -0,0 +1,102 @@
"""
Database Engine and Session Management
Provides SQLModel database engine and session handling.
"""
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine
import sys
from shared.config import get_db_connection_string
logger = logging.getLogger(__name__)
# Global engine instance
_engine = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
connection_string = get_db_connection_string()
# Convert psycopg2 format to SQLAlchemy format
if connection_string.startswith("postgresql://"):
# Already in correct format
pass
elif "host=" in connection_string:
# Convert DSN format to URL format
parts = dict(item.split("=") for item in connection_string.split())
connection_string = (
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
f"/{parts.get('dbname', 'docmaster')}"
)
_engine = create_engine(
connection_string,
echo=False, # Set to True for SQL debugging
pool_pre_ping=True, # Verify connections before use
pool_size=5,
max_overflow=10,
)
return _engine
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
from inference.data.admin_models import ( # noqa: F401
AdminToken,
AdminDocument,
AdminAnnotation,
TrainingTask,
TrainingLog,
)
engine = get_engine()
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
def get_session() -> Session:
"""Get a new database session."""
engine = get_engine()
return Session(engine)
@contextmanager
def get_session_context() -> Generator[Session, None, None]:
"""Context manager for database sessions with auto-commit/rollback."""
session = get_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def close_engine() -> None:
"""Close the database engine and release connections."""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None
logger.info("Database engine closed")
def execute_raw_sql(sql: str) -> None:
"""Execute raw SQL (for migrations)."""
engine = get_engine()
with engine.connect() as conn:
conn.execute(text(sql))
conn.commit()

View File

@@ -0,0 +1,95 @@
"""
SQLModel Database Models
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
class ApiKey(SQLModel, table=True):
"""API key configuration and limits."""
__tablename__ = "api_keys"
api_key: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
requests_per_minute: int = Field(default=10)
max_concurrent_jobs: int = Field(default=3)
max_file_size_mb: int = Field(default=50)
total_requests: int = Field(default=0)
total_processed: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
class AsyncRequest(SQLModel, table=True):
"""Async request record."""
__tablename__ = "async_requests"
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
status: str = Field(default="pending", max_length=20, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
document_id: str | None = Field(default=None, max_length=100)
error_message: str | None = Field(default=None)
retry_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
expires_at: datetime = Field(index=True)
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
processing_time_ms: float | None = Field(default=None)
visualization_path: str | None = Field(default=None, max_length=255)
class RateLimitEvent(SQLModel, table=True):
"""Rate limit event record."""
__tablename__ = "rate_limit_events"
id: int | None = Field(default=None, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
event_type: str = Field(max_length=50)
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Read-only models for responses (without table=True)
class ApiKeyRead(SQLModel):
"""API key response model (read-only)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestRead(SQLModel):
"""Async request response model (read-only)."""
request_id: UUID
api_key: str
status: str
filename: str
file_size: int
content_type: str
document_id: str | None
error_message: str | None
retry_count: int
created_at: datetime
started_at: datetime | None
completed_at: datetime | None
expires_at: datetime
result: dict[str, Any] | None
processing_time_ms: float | None
visualization_path: str | None

View File

@@ -0,0 +1,5 @@
from .pipeline import InferencePipeline, InferenceResult
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']

View File

@@ -0,0 +1,101 @@
"""
Inference Configuration Constants
Centralized configuration values for the inference pipeline.
Extracted from hardcoded values across multiple modules for easier maintenance.
"""
# ============================================================================
# Detection & Model Configuration
# ============================================================================
# YOLO Detection
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
# ============================================================================
# Image Processing Configuration
# ============================================================================
# DPI (Dots Per Inch) for PDF rendering
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
# ============================================================================
# Customer Number Parser Configuration
# ============================================================================
# Pattern confidence scores (higher = more confident)
CUSTOMER_NUMBER_CONFIDENCE = {
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
'compact': 0.75, # Compact format (e.g., "JTY5763")
'generic_base': 0.5, # Base score for generic alphanumeric pattern
}
# Bonus scores for generic pattern matching
CUSTOMER_NUMBER_BONUS = {
'has_dash': 0.2, # Bonus if contains dash
'typical_format': 0.25, # Bonus for format XXX NNN-X
'medium_length': 0.1, # Bonus for length 6-12 characters
}
# Customer number length constraints
CUSTOMER_NUMBER_LENGTH = {
'min': 6, # Minimum length for medium length bonus
'max': 12, # Maximum length for medium length bonus
}
# ============================================================================
# Field Extraction Confidence Scores
# ============================================================================
# Confidence multipliers and base scores
FIELD_CONFIDENCE = {
'pdf_text': 1.0, # PDF text extraction (always accurate)
'payment_line_high': 0.95, # Payment line parsed successfully
'regex_fallback': 0.5, # Regex-based fallback extraction
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
}
# ============================================================================
# Payment Line Validation
# ============================================================================
# Account number length thresholds for type detection
ACCOUNT_TYPE_THRESHOLD = {
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
}
# ============================================================================
# OCR Configuration
# ============================================================================
# Minimum OCR reference number length
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
# ============================================================================
# Pattern Matching
# ============================================================================
# Swedish postal code pattern (to exclude from customer numbers)
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
# ============================================================================
# Usage Notes
# ============================================================================
"""
These constants can be overridden at runtime by passing parameters to
constructors or methods. The values here serve as sensible defaults
based on Swedish invoice processing requirements.
Example:
from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
detector = YOLODetector(
model_path="model.pt",
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
)
"""

View File

@@ -0,0 +1,390 @@
"""
Swedish Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
Uses Strategy Pattern with multiple matching patterns.
Common Swedish customer number formats:
- JTY 576-3
- EMM 256-6
- DWQ 211-X
- FFL 019N
"""
import re
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List
from shared.exceptions import CustomerNumberParseError
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
"""The normalized customer number"""
pattern_name: str
"""Name of the pattern that matched"""
confidence: float
"""Confidence score (0.0 to 1.0)"""
raw_text: str
"""Original text that was matched"""
position: int = 0
"""Position in text where match was found"""
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""
Try to match pattern in text.
Args:
text: Text to search for customer number
Returns:
CustomerNumberMatch if found, None otherwise
"""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""
Format matched groups to standard format.
Args:
match: Regex match object
Returns:
Formatted customer number string
"""
pass
class DashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123-X (with dash)
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with dash format."""
match = self.PATTERN.search(text)
if not match:
return None
# Check if it's not a postal code
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE10643, etc.
return bool(
text.upper().startswith('SE ') and
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
)
class NoDashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123X (no dash)
Examples: Dwq 211X, FFL 019N
Converts to: DWQ 211-X, FFL 019-N
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number without dash."""
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format (add dash)."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CompactFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC123X (compact, no spaces)
Examples: JTY5763, FFL019N
"""
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match compact customer number format."""
upper_text = text.upper()
match = self.PATTERN.search(upper_text)
if not match:
return None
# Filter out SE postal codes
if match.group(1) == 'SE':
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="CompactFormat",
confidence=0.75,
raw_text=match.group(0),
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to ABC123X or ABC123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
if suffix:
return f"{prefix} {number}-{suffix}"
else:
return f"{prefix}{number}"
class GenericAlphanumericPattern(CustomerNumberPattern):
"""
Generic pattern: Letters + numbers + optional dash/letter
Examples: EMM 256-6, ABC 123, FFL 019
"""
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match generic alphanumeric pattern."""
upper_text = text.upper()
all_matches = []
for match in self.PATTERN.finditer(upper_text):
matched_text = match.group(1)
# Filter out pure numbers
if re.match(r'^\d+$', matched_text):
continue
# Filter out Swedish postal codes
if re.match(r'^SE[\s\-]*\d', matched_text):
continue
# Filter out single letter + digit + space + digit (V4 2)
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
continue
# Calculate confidence based on characteristics
confidence = self._calculate_confidence(matched_text)
all_matches.append((confidence, matched_text, match.start()))
if all_matches:
# Return highest confidence match
best = max(all_matches, key=lambda x: x[0])
return CustomerNumberMatch(
value=best[1].strip(),
pattern_name="GenericAlphanumeric",
confidence=best[0],
raw_text=best[1],
position=best[2]
)
return None
def format(self, match: re.Match) -> str:
"""Return matched text as-is (already uppercase)."""
return match.group(1).strip()
def _calculate_confidence(self, text: str) -> float:
"""Calculate confidence score based on text characteristics."""
# Require letters AND digits
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
has_digits = bool(re.search(r'\d', text))
if not (has_letters and has_digits):
return 0.0 # Not a valid customer number
score = 0.5 # Base score
# Bonus for containing dash
if '-' in text:
score += 0.2
# Bonus for typical format XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
score += 0.25
# Bonus for medium length
if 6 <= len(text) <= 12:
score += 0.1
return min(score, 1.0)
class LabeledPattern(CustomerNumberPattern):
"""
Pattern: Explicit label + customer number
Examples:
- "Kundnummer: JTY 576-3"
- "Customer No: EMM 256-6"
"""
PATTERN = re.compile(
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
re.IGNORECASE
)
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with explicit label."""
match = self.PATTERN.search(text)
if not match:
return None
extracted = match.group(1).strip()
# Remove trailing punctuation
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
if extracted and len(extracted) >= 2:
return CustomerNumberMatch(
value=extracted.upper(),
pattern_name="Labeled",
confidence=0.98, # Very high confidence when labeled
raw_text=match.group(0),
position=match.start()
)
return None
def format(self, match: re.Match) -> str:
"""Return matched customer number."""
extracted = match.group(1).strip()
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
"""Initialize parser with patterns ordered by specificity."""
self.patterns: List[CustomerNumberPattern] = [
LabeledPattern(), # Highest priority - explicit label
DashFormatPattern(), # Standard format with dash
NoDashFormatPattern(), # Standard format without dash
CompactFormatPattern(), # Compact format
GenericAlphanumericPattern(), # Fallback generic pattern
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Args:
text: Text to search for customer number
Returns:
Tuple of (customer_number, is_valid, error_message)
"""
if not text or not text.strip():
return None, False, "Empty text"
text = text.strip()
# Try each pattern
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# No matches
if not all_matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
self.logger.debug(
f"Customer number matched: {best_match.value} "
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
)
return best_match.value, True, None
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
Args:
text: Text to search
Returns:
List of CustomerNumberMatch sorted by confidence (descending)
"""
if not text or not text.strip():
return []
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# Sort by confidence (highest first), then by position (later first)
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,261 @@
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Unifies payment line parsing logic that was previously duplicated across multiple modules.
Standard Swedish payment line format:
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example:
# 94228110015950070 # 15658 00 8 > 48666036#14#
This parser handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
"""
import re
import logging
from dataclasses import dataclass
from typing import Optional
from shared.exceptions import PaymentLineParseError
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
"""OCR reference number (payment reference)"""
amount: Optional[str] = None
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
account_number: Optional[str] = None
"""Bankgiro or Plusgiro account number"""
record_type: Optional[str] = None
"""Record type digit (usually '5' or '8' or '9')"""
check_digits: Optional[str] = None
"""Check digits for account validation"""
raw_text: str = ""
"""Original raw text that was parsed"""
is_valid: bool = True
"""Whether parsing was successful"""
error: Optional[str] = None
"""Error message if parsing failed"""
parse_method: str = "unknown"
"""Which parsing pattern was used (for debugging)"""
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
NO_AMOUNT_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Alternative pattern: look for OCR > ACCOUNT# pattern
ALT_PATTERN = re.compile(
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Account only pattern: > ACCOUNT#CHECK#
ACCOUNT_ONLY_PATTERN = re.compile(
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
"""Initialize parser with logger."""
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
Args:
text: Raw payment line text from OCR
Returns:
PaymentLineData with parsed fields or error information
"""
if not text or not text.strip():
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="Empty payment line text",
parse_method="none"
)
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try pattern without amount
match = self.NO_AMOUNT_PATTERN.search(text)
if match:
return self._parse_no_amount_match(match, text)
# Try alternative pattern
match = self.ALT_PATTERN.search(text)
if match:
return self._parse_alt_match(match, text)
# Try account only pattern
match = self.ACCOUNT_ONLY_PATTERN.search(text)
if match:
return self._parse_account_only_match(match, text)
# No match - return error
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="No valid payment line format found",
parse_method="none"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="full"
)
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse pattern match without amount."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="no_amount"
)
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse alternative pattern match."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="alternative"
)
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse account-only pattern match."""
account = self._clean_digits(match.group(1))
check_digits = match.group(2)
return PaymentLineData(
ocr_number="",
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error="Partial payment line (account only)",
parse_method="account_only"
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string (OCR error correction)."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string in standard Swedish payment line format
"""
if not data.is_valid:
return data.raw_text
# Full format with amount
if data.amount and data.record_type:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
# Format without amount
if data.ocr_number and data.account_number:
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
# Account only
if data.account_number:
return f"> {data.account_number}#{data.check_digits}#"
# Fallback
return data.raw_text
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
"""
Format parsed data for FieldExtractor compatibility.
Returns:
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
"""
if not data.is_valid:
return None, False, data.error
formatted = self.format_machine_readable(data)
return formatted, True, data.error

View File

@@ -0,0 +1,498 @@
"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import time
import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
@dataclass
class CrossValidationResult:
"""Result of cross-validation between payment_line and other fields."""
is_valid: bool = False
ocr_match: bool | None = None # None if not comparable
amount_match: bool | None = None
bankgiro_match: bool | None = None
plusgiro_match: bool | None = None
payment_line_ocr: str | None = None
payment_line_amount: str | None = None
payment_line_account: str | None = None
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
details: list[str] = field(default_factory=list)
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
result = {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'supplier_org_number': self.fields.get('supplier_org_number'),
'customer_number': self.fields.get('customer_number'),
'payment_line': self.fields.get('payment_line'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
# Add bboxes if present
if self.bboxes:
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
# Add cross-validation results if present
if self.cross_validation:
result['cross_validation'] = {
'is_valid': self.cross_validation.is_valid,
'ocr_match': self.cross_validation.ocr_match,
'amount_match': self.cross_validation.amount_match,
'bankgiro_match': self.cross_validation.bankgiro_match,
'plusgiro_match': self.cross_validation.plusgiro_match,
'payment_line_ocr': self.cross_validation.payment_line_ocr,
'payment_line_amount': self.cross_validation.payment_line_amount,
'payment_line_account': self.cross_validation.payment_line_account,
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
return result
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
try:
all_detections = []
all_extracted = []
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
parsed = self.payment_line_parser.parse(payment_line)
if not parsed.is_valid:
return None, None, None
return parsed.ocr_number, parsed.amount, parsed.account_number
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
Payment line values take PRIORITY over individually detected fields.
Swedish payment line (Betalningsrad) contains:
- OCR reference number
- Amount (kronor and öre)
- Bankgiro or Plusgiro account number
This method:
1. Parses payment_line to extract OCR, Amount, Account
2. Compares with separately detected fields for validation
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
"""
payment_line = result.fields.get('payment_line')
if not payment_line:
return
cv = CrossValidationResult()
cv.details = []
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values
# OCR: payment_line takes priority
detected_ocr = result.fields.get('OCR')
if cv.payment_line_ocr:
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
if detected_ocr:
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
if cv.ocr_match:
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
else:
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
else:
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
# OVERRIDE: use payment_line OCR
result.fields['OCR'] = cv.payment_line_ocr
result.confidence['OCR'] = 0.95 # High confidence for payment_line
# Amount: payment_line takes priority
detected_amount = result.fields.get('Amount')
if cv.payment_line_amount:
if detected_amount:
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
det_amount = self._normalize_amount_for_compare(str(detected_amount))
cv.amount_match = pl_amount == det_amount
if cv.amount_match:
cv.details.append(f"Amount match: {cv.payment_line_amount}")
else:
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
else:
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
# OVERRIDE: use payment_line Amount
result.fields['Amount'] = cv.payment_line_amount
result.confidence['Amount'] = 0.95
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_bankgiro = result.fields.get('Bankgiro')
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_bankgiro:
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
cv.bankgiro_match = pl_bg_digits == det_bg_digits
if cv.bankgiro_match:
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
else:
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_plusgiro = result.fields.get('Plusgiro')
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_plusgiro:
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
cv.plusgiro_match = pl_pg_digits == det_pg_digits
if cv.plusgiro_match:
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
else:
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Determine overall validity
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
# has both accounts, the other one cannot be matched - this is expected and OK.
# Only count the account type that payment_line actually has.
matches = [cv.ocr_match, cv.amount_match]
# Only include account match if payment_line has that account type
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
matches.append(cv.bankgiro_match)
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
matches.append(cv.plusgiro_match)
valid_matches = [m for m in matches if m is not None]
if valid_matches:
match_count = sum(1 for m in valid_matches if m)
cv.is_valid = match_count >= min(2, len(valid_matches))
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
else:
# No comparison possible
cv.is_valid = True
cv.details.append("No comparison available from payment_line")
result.cross_validation = cv
def _normalize_amount_for_compare(self, amount: str) -> float | None:
"""Normalize amount string to float for comparison."""
try:
# Remove spaces, convert comma to dot
cleaned = amount.replace(' ', '').replace(',', '.')
# Handle Swedish format with space as thousands separator
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
return round(float(cleaned), 2)
except (ValueError, AttributeError):
return None
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from shared.pdf.renderer import render_pdf_to_images
from shared.ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result

View File

@@ -0,0 +1,210 @@
"""
YOLO Detection Module
Runs YOLO model inference for field detection.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
@dataclass
class Detection:
"""Represents a single YOLO detection."""
class_id: int
class_name: str
confidence: float
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
def get_padded_bbox(
self,
padding: float = 0.1,
image_width: float | None = None,
image_height: float | None = None
) -> tuple[float, float, float, float]:
"""Get bbox with padding for OCR extraction."""
pad_x = self.width * padding
pad_y = self.height * padding
x0 = self.x0 - pad_x
y0 = self.y0 - pad_y
x1 = self.x1 + pad_x
y1 = self.y1 + pad_y
if image_width:
x0 = max(0, x0)
x1 = min(image_width, x1)
if image_height:
y0 = max(0, y0)
y1 = min(image_height, y1)
return (x0, y0, x1, y1)
# Class names (must match training configuration)
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number', # Matches training class name
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
# Mapping from class name to field name
CLASS_TO_FIELD = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'ocr_number': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'amount': 'Amount',
'supplier_org_number': 'supplier_org_number',
'customer_number': 'customer_number',
'payment_line': 'payment_line',
}
class YOLODetector:
"""YOLO model wrapper for field detection."""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
iou_threshold: float = 0.45,
device: str = 'auto'
):
"""
Initialize YOLO detector.
Args:
model_path: Path to trained YOLO model (.pt file)
confidence_threshold: Minimum confidence for detections
iou_threshold: IOU threshold for NMS
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
"""
from ultralytics import YOLO
self.model = YOLO(model_path)
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.device = device
def detect(
self,
image: str | Path | np.ndarray,
page_no: int = 0
) -> list[Detection]:
"""
Run detection on an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
Returns:
List of Detection objects
"""
results = self.model.predict(
source=image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
device=self.device,
verbose=False
)
detections = []
for result in results:
boxes = result.boxes
if boxes is None:
continue
for i in range(len(boxes)):
class_id = int(boxes.cls[i])
confidence = float(boxes.conf[i])
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
detections.append(Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=tuple(bbox),
page_no=page_no
))
return detections
def detect_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> dict[int, list[Detection]]:
"""
Run detection on all pages of a PDF.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
Returns:
Dict mapping page number to list of detections
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
results = {}
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
detections = self.detect(image_array, page_no=page_no)
results[page_no] = detections
return results
def get_field_name(self, class_name: str) -> str:
"""Convert class name to field name."""
return CLASS_TO_FIELD.get(class_name, class_name)

View File

@@ -0,0 +1,7 @@
"""
Cross-validation module for verifying field extraction using LLM.
"""
from .llm_validator import LLMValidator
__all__ = ['LLMValidator']

View File

@@ -0,0 +1,748 @@
"""
LLM-based cross-validation for invoice field extraction.
Uses a vision LLM to extract fields from invoice PDFs and compare with
the autolabel results to identify potential errors.
"""
import json
import base64
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI
@dataclass
class LLMExtractionResult:
"""Result of LLM field extraction."""
document_id: str
invoice_number: Optional[str] = None
invoice_date: Optional[str] = None
invoice_due_date: Optional[str] = None
ocr_number: Optional[str] = None
bankgiro: Optional[str] = None
plusgiro: Optional[str] = None
amount: Optional[str] = None
supplier_organisation_number: Optional[str] = None
raw_response: Optional[str] = None
model_used: Optional[str] = None
processing_time_ms: Optional[float] = None
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class LLMValidator:
"""
Cross-validates invoice field extraction using LLM.
Queries documents with failed field matches from the database,
sends the PDF images to an LLM for extraction, and stores
the results for comparison.
"""
# Fields to extract (excluding customer_number as requested)
FIELDS_TO_EXTRACT = [
'InvoiceNumber',
'InvoiceDate',
'InvoiceDueDate',
'OCR',
'Bankgiro',
'Plusgiro',
'Amount',
'supplier_organisation_number',
]
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
{
"invoice_number": "the invoice number/fakturanummer",
"invoice_date": "the invoice date in YYYY-MM-DD format",
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
"ocr_number": "the OCR payment reference number",
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
"plusgiro": "the plusgiro number",
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
}
Rules:
- If a field is not found or not visible, use null
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
- For amounts, extract just the numeric value without currency symbols
- The OCR number is typically a long number used for payment reference
- Look for "Att betala" or "Summa att betala" for the amount
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
Return ONLY the JSON object, no other text."""
def __init__(self, connection_string: str = None, pdf_dir: str = None):
"""
Initialize the validator.
Args:
connection_string: PostgreSQL connection string
pdf_dir: Directory containing PDF files
"""
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS
self.connection_string = connection_string or get_db_connection_string()
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def create_validation_table(self):
"""Create the llm_validation table if it doesn't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS llm_validations (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
-- Extracted fields
invoice_number TEXT,
invoice_date TEXT,
invoice_due_date TEXT,
ocr_number TEXT,
bankgiro TEXT,
plusgiro TEXT,
amount TEXT,
supplier_organisation_number TEXT,
-- Metadata
raw_response TEXT,
model_used TEXT,
processing_time_ms REAL,
error TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
-- Comparison results (populated later)
comparison_results JSONB,
UNIQUE(document_id)
);
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
ON llm_validations(document_id);
""")
conn.commit()
def get_documents_with_failed_matches(
self,
exclude_customer_number: bool = True,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Get documents that have at least one failed field match.
Args:
exclude_customer_number: If True, ignore customer_number failures
limit: Maximum number of documents to return
Returns:
List of document info with failed fields
"""
conn = self.connect()
with conn.cursor() as cursor:
# Find documents with failed matches (excluding customer_number if requested)
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND fr.field_name != 'customer_number'"
query = f"""
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
d.supplier_name, d.split
FROM documents d
JOIN field_results fr ON d.document_id = fr.document_id
WHERE fr.matched = false
AND fr.field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
AND d.document_id NOT IN (
SELECT document_id FROM llm_validations WHERE error IS NULL
)
ORDER BY d.document_id
"""
if limit:
query += f" LIMIT {limit}"
cursor.execute(query)
results = []
for row in cursor.fetchall():
doc_id = row[0]
# Get failed fields for this document
exclude_clause_inner = ""
if exclude_customer_number:
exclude_clause_inner = "AND field_name != 'customer_number'"
cursor.execute(f"""
SELECT field_name, csv_value, score
FROM field_results
WHERE document_id = %s
AND matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause_inner}
""", (doc_id,))
failed_fields = [
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
for r in cursor.fetchall()
]
results.append({
'document_id': doc_id,
'pdf_path': row[1],
'pdf_type': row[2],
'supplier_name': row[3],
'split': row[4],
'failed_fields': failed_fields,
})
return results
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
"""Get statistics about failed matches."""
conn = self.connect()
with conn.cursor() as cursor:
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND field_name != 'customer_number'"
# Count by field
cursor.execute(f"""
SELECT field_name, COUNT(*) as cnt
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
GROUP BY field_name
ORDER BY cnt DESC
""")
by_field = {row[0]: row[1] for row in cursor.fetchall()}
# Count documents with failures
cursor.execute(f"""
SELECT COUNT(DISTINCT document_id)
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
""")
doc_count = cursor.fetchone()[0]
# Already validated count
cursor.execute("""
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
""")
validated_count = cursor.fetchone()[0]
return {
'documents_with_failures': doc_count,
'already_validated': validated_count,
'remaining': doc_count - validated_count,
'failures_by_field': by_field,
}
def render_pdf_to_image(
self,
pdf_path: Path,
page_no: int = 0,
dpi: int = DEFAULT_DPI,
max_size_mb: float = 18.0
) -> bytes:
"""
Render a PDF page to PNG image bytes.
Args:
pdf_path: Path to PDF file
page_no: Page number to render (0-indexed)
dpi: Resolution for rendering
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
Returns:
PNG image bytes
"""
import fitz # PyMuPDF
from io import BytesIO
from PIL import Image
doc = fitz.open(pdf_path)
page = doc[page_no]
# Try different DPI values until we get a small enough image
dpi_values = [dpi, 120, 100, 72, 50]
for current_dpi in dpi_values:
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
pix = page.get_pixmap(matrix=mat)
png_bytes = pix.tobytes("png")
size_mb = len(png_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return png_bytes
# If still too large, use JPEG compression
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
pix = page.get_pixmap(matrix=mat)
# Convert to PIL Image and compress as JPEG
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Try different JPEG quality levels
for quality in [85, 70, 50, 30]:
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=quality)
jpeg_bytes = buffer.getvalue()
size_mb = len(jpeg_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return jpeg_bytes
doc.close()
# Return whatever we have, let the API handle the error
return jpeg_bytes
def extract_with_openai(
self,
image_bytes: bytes,
model: str = "gpt-4o"
) -> LLMExtractionResult:
"""
Extract fields using OpenAI's vision API (supports Azure OpenAI).
Args:
image_bytes: PNG image bytes
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
Returns:
Extraction result
"""
import openai
import time
start_time = time.time()
# Encode image to base64 and detect format
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
if image_bytes[:4] == b'\x89PNG':
media_type = "image/png"
else:
media_type = "image/jpeg"
# Check for Azure OpenAI configuration
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
if azure_endpoint and azure_api_key:
# Use Azure OpenAI
client = openai.AzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=azure_api_key,
api_version="2024-02-15-preview"
)
model = azure_deployment # Use deployment name for Azure
else:
# Use standard OpenAI
client = openai.OpenAI()
try:
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self.EXTRACTION_PROMPT},
{
"type": "image_url",
"image_url": {
"url": f"data:{media_type};base64,{image_b64}",
"detail": "high"
}
}
]
}
],
max_tokens=1000,
temperature=0,
)
raw_response = response.choices[0].message.content
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
# Try to extract JSON from response (may have markdown code blocks)
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="", # Will be set by caller
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def extract_with_anthropic(
self,
image_bytes: bytes,
model: str = "claude-sonnet-4-20250514"
) -> LLMExtractionResult:
"""
Extract fields using Anthropic's Claude API.
Args:
image_bytes: PNG image bytes
model: Model to use
Returns:
Extraction result
"""
import anthropic
import time
start_time = time.time()
# Encode image to base64
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
client = anthropic.Anthropic()
try:
response = client.messages.create(
model=model,
max_tokens=1000,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_b64,
}
},
{
"type": "text",
"text": self.EXTRACTION_PROMPT
}
]
}
],
)
raw_response = response.content[0].text
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="",
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def save_validation_result(self, result: LLMExtractionResult):
"""Save extraction result to database."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
INSERT INTO llm_validations (
document_id, invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number, raw_response, model_used,
processing_time_ms, error
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id) DO UPDATE SET
invoice_number = EXCLUDED.invoice_number,
invoice_date = EXCLUDED.invoice_date,
invoice_due_date = EXCLUDED.invoice_due_date,
ocr_number = EXCLUDED.ocr_number,
bankgiro = EXCLUDED.bankgiro,
plusgiro = EXCLUDED.plusgiro,
amount = EXCLUDED.amount,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
raw_response = EXCLUDED.raw_response,
model_used = EXCLUDED.model_used,
processing_time_ms = EXCLUDED.processing_time_ms,
error = EXCLUDED.error,
created_at = NOW()
""", (
result.document_id,
result.invoice_number,
result.invoice_date,
result.invoice_due_date,
result.ocr_number,
result.bankgiro,
result.plusgiro,
result.amount,
result.supplier_organisation_number,
result.raw_response,
result.model_used,
result.processing_time_ms,
result.error,
))
conn.commit()
def validate_document(
self,
doc_id: str,
provider: str = "openai",
model: str = None
) -> LLMExtractionResult:
"""
Validate a single document using LLM.
Args:
doc_id: Document ID
provider: LLM provider ("openai" or "anthropic")
model: Model to use (defaults based on provider)
Returns:
Extraction result
"""
# Get PDF path
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
if not pdf_path.exists():
return LLMExtractionResult(
document_id=doc_id,
error=f"PDF not found: {pdf_path}"
)
# Render first page
try:
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
except Exception as e:
return LLMExtractionResult(
document_id=doc_id,
error=f"Failed to render PDF: {str(e)}"
)
# Extract with LLM
if provider == "openai":
model = model or "gpt-4o"
result = self.extract_with_openai(image_bytes, model)
elif provider == "anthropic":
model = model or "claude-sonnet-4-20250514"
result = self.extract_with_anthropic(image_bytes, model)
else:
return LLMExtractionResult(
document_id=doc_id,
error=f"Unknown provider: {provider}"
)
result.document_id = doc_id
# Save to database
self.save_validation_result(result)
return result
def validate_batch(
self,
limit: int = 10,
provider: str = "openai",
model: str = None,
verbose: bool = True
) -> List[LLMExtractionResult]:
"""
Validate a batch of documents with failed matches.
Args:
limit: Maximum number of documents to validate
provider: LLM provider
model: Model to use
verbose: Print progress
Returns:
List of extraction results
"""
# Get documents to validate
docs = self.get_documents_with_failed_matches(limit=limit)
if verbose:
print(f"Found {len(docs)} documents with failed matches to validate")
results = []
for i, doc in enumerate(docs):
doc_id = doc['document_id']
if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
result = self.validate_document(doc_id, provider, model)
results.append(result)
if verbose:
if result.error:
print(f" ERROR: {result.error}")
else:
print(f" OK ({result.processing_time_ms:.0f}ms)")
return results
def compare_results(self, doc_id: str) -> Dict[str, Any]:
"""
Compare LLM extraction with autolabel results.
Args:
doc_id: Document ID
Returns:
Comparison results
"""
conn = self.connect()
with conn.cursor() as cursor:
# Get autolabel results
cursor.execute("""
SELECT field_name, csv_value, matched, matched_text
FROM field_results
WHERE document_id = %s
""", (doc_id,))
autolabel = {}
for row in cursor.fetchall():
autolabel[row[0]] = {
'csv_value': row[1],
'matched': row[2],
'matched_text': row[3],
}
# Get LLM results
cursor.execute("""
SELECT invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number
FROM llm_validations
WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return {'error': 'No LLM validation found'}
llm = {
'InvoiceNumber': row[0],
'InvoiceDate': row[1],
'InvoiceDueDate': row[2],
'OCR': row[3],
'Bankgiro': row[4],
'Plusgiro': row[5],
'Amount': row[6],
'supplier_organisation_number': row[7],
}
# Compare
comparison = {}
for field in self.FIELDS_TO_EXTRACT:
auto = autolabel.get(field, {})
llm_value = llm.get(field)
comparison[field] = {
'csv_value': auto.get('csv_value'),
'autolabel_matched': auto.get('matched'),
'autolabel_text': auto.get('matched_text'),
'llm_value': llm_value,
'agreement': self._values_match(auto.get('csv_value'), llm_value),
}
return comparison
def _values_match(self, csv_value: str, llm_value: str) -> bool:
"""Check if CSV value matches LLM extracted value."""
if csv_value is None or llm_value is None:
return csv_value == llm_value
# Normalize for comparison
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
return csv_norm == llm_norm

View File

@@ -0,0 +1,9 @@
"""
Web Application Module
Provides REST API and web interface for invoice field extraction.
"""
from .app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,8 @@
"""
Backward compatibility shim for admin_routes.py
DEPRECATED: Import from inference.web.api.v1.admin.documents instead.
"""
from inference.web.api.v1.admin.documents import *
__all__ = ["create_admin_router"]

View File

@@ -0,0 +1,19 @@
"""
Admin API v1
Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
"create_training_router",
]

View File

@@ -0,0 +1,644 @@
"""
Admin Annotation API Routes
FastAPI endpoints for annotation management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.autolabel import get_auto_label_service
from inference.web.schemas.admin import (
AnnotationCreate,
AnnotationItem,
AnnotationListResponse,
AnnotationOverrideRequest,
AnnotationOverrideResponse,
AnnotationResponse,
AnnotationSource,
AnnotationUpdate,
AnnotationVerifyRequest,
AnnotationVerifyResponse,
AutoLabelRequest,
AutoLabelResponse,
BoundingBox,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Image storage directory
ADMIN_IMAGES_DIR = Path("data/admin_images")
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_annotation_router() -> APIRouter:
"""Create annotation API router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
# =========================================================================
# Image Endpoints
# =========================================================================
@router.get(
"/{document_id}/images/{page_number}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Get page image",
description="Get the image for a specific page.",
)
async def get_page_image(
document_id: str,
page_number: int,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> FileResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if page_number < 1 or page_number > document.page_count:
raise HTTPException(
status_code=404,
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
)
# Find image file
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=404,
detail=f"Image for page {page_number} not found",
)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
)
# =========================================================================
# Annotation Endpoints
# =========================================================================
@router.get(
"/{document_id}/annotations",
response_model=AnnotationListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="List annotations",
description="Get all annotations for a document.",
)
async def list_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
page_number: Annotated[
int | None,
Query(ge=1, description="Filter by page number"),
] = None,
) -> AnnotationListResponse:
"""List annotations for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id, page_number)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
return AnnotationListResponse(
document_id=document_id,
page_count=document.page_count,
total_annotations=len(annotations),
annotations=annotations,
)
@router.post(
"/{document_id}/annotations",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Create annotation",
description="Create a new annotation for a document.",
)
async def create_annotation(
document_id: str,
request: AnnotationCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Create a new annotation."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if request.page_number > document.page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
)
# Get image dimensions for normalization
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=400,
detail=f"Image for page {request.page_number} not available",
)
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
width = request.bbox.width / image_width
height = request.bbox.height / image_height
# Get class name
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
# Create annotation
annotation_id = db.create_annotation(
document_id=document_id,
page_number=request.page_number,
class_id=request.class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=request.bbox.x,
bbox_y=request.bbox.y,
bbox_width=request.bbox.width,
bbox_height=request.bbox.height,
text_value=request.text_value,
source="manual",
)
# Keep status as pending - user must click "Mark Complete" to finalize
# This allows user to add multiple annotations before saving to PostgreSQL
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation created successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Update annotation",
description="Update an existing annotation.",
)
async def update_annotation(
document_id: str,
annotation_id: str,
request: AnnotationUpdate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Update an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Prepare update data
update_kwargs = {}
if request.class_id is not None:
update_kwargs["class_id"] = request.class_id
update_kwargs["class_name"] = FIELD_CLASSES.get(
request.class_id, f"class_{request.class_id}"
)
if request.text_value is not None:
update_kwargs["text_value"] = request.text_value
if request.bbox is not None:
# Get image dimensions
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
update_kwargs["width"] = request.bbox.width / image_width
update_kwargs["height"] = request.bbox.height / image_height
update_kwargs["bbox_x"] = request.bbox.x
update_kwargs["bbox_y"] = request.bbox.y
update_kwargs["bbox_width"] = request.bbox.width
update_kwargs["bbox_height"] = request.bbox.height
# Update annotation
if update_kwargs:
success = db.update_annotation(annotation_id, **update_kwargs)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to update annotation",
)
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation updated successfully",
)
@router.delete(
"/{document_id}/annotations/{annotation_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Delete annotation",
description="Delete an annotation.",
)
async def delete_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Delete annotation
db.delete_annotation(annotation_id)
return {
"status": "deleted",
"annotation_id": annotation_id,
"message": "Annotation deleted successfully",
}
# =========================================================================
# Auto-Labeling Endpoints
# =========================================================================
@router.post(
"/{document_id}/auto-label",
response_model=AutoLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Trigger auto-labeling",
description="Trigger auto-labeling for a document using field values.",
)
async def trigger_auto_label(
document_id: str,
request: AutoLabelRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AutoLabelResponse:
"""Trigger auto-labeling for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate field values
if not request.field_values:
raise HTTPException(
status_code=400,
detail="At least one field value is required",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=document.file_path,
field_values=request.field_values,
db=db,
replace_existing=request.replace_existing,
)
if result["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
)
return AutoLabelResponse(
document_id=document_id,
status=result["status"],
annotations_created=result["annotations_created"],
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
)
@router.delete(
"/{document_id}/annotations",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete all annotations",
description="Delete all annotations for a document (optionally filter by source).",
)
async def delete_all_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
source: Annotated[
str | None,
Query(description="Filter by source (manual, auto, imported)"),
] = None,
) -> dict:
"""Delete all annotations for a document."""
_validate_uuid(document_id, "document_id")
# Validate source
if source and source not in ("manual", "auto", "imported"):
raise HTTPException(
status_code=400,
detail=f"Invalid source: {source}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete annotations
deleted_count = db.delete_annotations_for_document(document_id, source)
# Update document status if all annotations deleted
remaining = db.get_annotations_for_document(document_id)
if not remaining:
db.update_document_status(document_id, "pending")
return {
"status": "deleted",
"document_id": document_id,
"deleted_count": deleted_count,
"message": f"Deleted {deleted_count} annotations",
}
# =========================================================================
# Phase 5: Annotation Enhancement
# =========================================================================
@router.post(
"/{document_id}/annotations/{annotation_id}/verify",
response_model=AnnotationVerifyResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Verify annotation",
description="Mark an annotation as verified by a human reviewer.",
)
async def verify_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
) -> AnnotationVerifyResponse:
"""Verify an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Verify the annotation
annotation = db.verify_annotation(annotation_id, admin_token)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
return AnnotationVerifyResponse(
annotation_id=annotation_id,
is_verified=annotation.is_verified,
verified_at=annotation.verified_at,
verified_by=annotation.verified_by,
message="Annotation verified successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}/override",
response_model=AnnotationOverrideResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Override annotation",
description="Override an auto-generated annotation with manual corrections.",
)
async def override_annotation(
document_id: str,
annotation_id: str,
request: AnnotationOverrideRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationOverrideResponse:
"""Override an auto-generated annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Build updates dict from request
updates = {}
if request.text_value is not None:
updates["text_value"] = request.text_value
if request.class_id is not None:
updates["class_id"] = request.class_id
# Update class_name if class_id changed
if request.class_id in FIELD_CLASSES:
updates["class_name"] = FIELD_CLASSES[request.class_id]
if request.class_name is not None:
updates["class_name"] = request.class_name
if request.bbox:
# Update bbox fields
if "x" in request.bbox:
updates["bbox_x"] = request.bbox["x"]
if "y" in request.bbox:
updates["bbox_y"] = request.bbox["y"]
if "width" in request.bbox:
updates["bbox_width"] = request.bbox["width"]
if "height" in request.bbox:
updates["bbox_height"] = request.bbox["height"]
if not updates:
raise HTTPException(
status_code=400,
detail="No updates provided. Specify at least one field to update.",
)
# Override the annotation
annotation = db.override_annotation(
annotation_id=annotation_id,
admin_token=admin_token,
change_reason=request.reason,
**updates,
)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Get history to return history_id
history_records = db.get_annotation_history(UUID(annotation_id))
latest_history = history_records[0] if history_records else None
return AnnotationOverrideResponse(
annotation_id=annotation_id,
source=annotation.source,
override_source=annotation.override_source,
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
message="Annotation overridden successfully",
history_id=str(latest_history.history_id) if latest_history else "",
)
return router

View File

@@ -0,0 +1,82 @@
"""
Admin Auth Routes
FastAPI endpoints for admin token management.
"""
import logging
import secrets
from datetime import datetime, timedelta
from fastapi import APIRouter
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AdminTokenCreate,
AdminTokenResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def create_auth_router() -> APIRouter:
"""Create admin auth router."""
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
@router.post(
"/token",
response_model=AdminTokenResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
},
summary="Create admin token",
description="Create a new admin authentication token.",
)
async def create_token(
request: AdminTokenCreate,
db: AdminDBDep,
) -> AdminTokenResponse:
"""Create a new admin token."""
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration
expires_at = None
if request.expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
# Create token in database
db.create_admin_token(
token=token,
name=request.name,
expires_at=expires_at,
)
return AdminTokenResponse(
token=token,
name=request.name,
expires_at=expires_at,
message="Admin token created successfully",
)
@router.delete(
"/token",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Revoke admin token",
description="Revoke the current admin token.",
)
async def revoke_token(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Revoke the current admin token."""
db.deactivate_admin_token(admin_token)
return {
"status": "revoked",
"message": "Admin token has been revoked",
}
return router

View File

@@ -0,0 +1,551 @@
"""
Admin Document Routes
FastAPI endpoints for admin document management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.config import DEFAULT_DPI, StorageConfig
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AnnotationItem,
AnnotationSource,
AutoLabelStatus,
BoundingBox,
DocumentDetailResponse,
DocumentItem,
DocumentListResponse,
DocumentStatus,
DocumentStatsResponse,
DocumentUploadResponse,
ModelMetrics,
TrainingHistoryItem,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
# Render at configured DPI for consistency with training
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"""Create admin documents router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
# Directories are created by StorageConfig.__post_init__
allowed_extensions = storage_config.allowed_extensions
@router.post(
"",
response_model=DocumentUploadResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Upload document",
description="Upload a PDF or image document for labeling.",
)
async def upload_document(
admin_token: AdminTokenDep,
db: AdminDBDep,
file: UploadFile = File(..., description="PDF or image file"),
auto_label: Annotated[
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Get page count (for PDF)
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record (token only used for auth, not stored)
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if doc:
doc.file_path = str(file_path)
session.add(doc)
# Convert PDF to images for annotation
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling if requested
auto_label_started = False
if auto_label:
# Auto-labeling will be triggered by a background task
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
auto_label_started = True
return DocumentUploadResponse(
document_id=document_id,
filename=file.filename,
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@router.get(
"",
response_model=DocumentListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List documents",
description="List all documents for the current admin.",
)
async def list_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
upload_source: Annotated[
str | None,
Query(description="Filter by upload source (ui or api)"),
] = None,
has_annotations: Annotated[
bool | None,
Query(description="Filter by annotation presence"),
] = None,
auto_label_status: Annotated[
str | None,
Query(description="Filter by auto-label status"),
] = None,
batch_id: Annotated[
str | None,
Query(description="Filter by batch ID"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> DocumentListResponse:
"""List documents."""
# Validate status
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Validate upload_source
if upload_source and upload_source not in ("ui", "api"):
raise HTTPException(
status_code=400,
detail=f"Invalid upload_source: {upload_source}",
)
# Validate auto_label_status
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid auto_label_status: {auto_label_status}",
)
documents, total = db.get_documents_by_token(
admin_token=admin_token,
status=status,
upload_source=upload_source,
has_annotations=has_annotations,
auto_label_status=auto_label_status,
batch_id=batch_id,
limit=limit,
offset=offset,
)
# Get annotation counts and build items
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# Determine if document can be annotated (not locked)
can_annotate = True
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
from datetime import datetime, timezone
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
items.append(
DocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
file_size=doc.file_size,
page_count=doc.page_count,
status=DocumentStatus(doc.status),
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
)
)
return DocumentListResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/stats",
response_model=DocumentStatsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get document statistics",
description="Get document count by status.",
)
async def get_document_stats(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentStatsResponse:
"""Get document statistics."""
counts = db.count_documents_by_status(admin_token)
return DocumentStatsResponse(
total=sum(counts.values()),
pending=counts.get("pending", 0),
auto_labeling=counts.get("auto_labeling", 0),
labeled=counts.get("labeled", 0),
exported=counts.get("exported", 0),
)
@router.get(
"/{document_id}",
response_model=DocumentDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Get document detail",
description="Get document details with annotations.",
)
async def get_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentDetailResponse:
"""Get document details."""
_validate_uuid(document_id, "document_id")
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
# Generate image URLs
image_urls = []
for page in range(1, document.page_count + 1):
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
# Determine if document can be annotated (not locked)
can_annotate = True
annotation_lock_until = None
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
from datetime import datetime, timezone
annotation_lock_until = document.annotation_lock_until
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
# Get CSV field values if available
csv_field_values = None
if hasattr(document, 'csv_field_values') and document.csv_field_values:
csv_field_values = document.csv_field_values
# Get training history (Phase 5)
training_history = []
training_links = db.get_document_training_tasks(document.document_id)
for link in training_links:
# Get task details
task = db.get_training_task(str(link.task_id))
if task:
# Build metrics
metrics = None
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
training_history.append(
TrainingHistoryItem(
task_id=str(link.task_id),
name=task.name,
trained_at=link.created_at,
model_metrics=metrics,
)
)
return DocumentDetailResponse(
document_id=str(document.document_id),
filename=document.filename,
file_size=document.file_size,
content_type=document.content_type,
page_count=document.page_count,
status=DocumentStatus(document.status),
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
annotations=annotations,
image_urls=image_urls,
training_history=training_history,
created_at=document.created_at,
updated_at=document.updated_at,
)
@router.delete(
"/{document_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete document",
description="Delete a document and its annotations.",
)
async def delete_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete file
file_path = Path(document.file_path)
if file_path.exists():
file_path.unlink()
# Delete images
images_dir = ADMIN_IMAGES_DIR / document_id
if images_dir.exists():
import shutil
shutil.rmtree(images_dir)
# Delete from database
db.delete_document(document_id)
return {
"status": "deleted",
"document_id": document_id,
"message": "Document deleted successfully",
}
@router.patch(
"/{document_id}/status",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document status",
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
)
async def update_document_status(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str,
Query(description="New status"),
],
) -> dict:
"""Update document status.
When status is set to 'labeled', the annotations are automatically
saved to PostgreSQL documents/field_results tables for consistency
with CLI auto-label workflow.
"""
_validate_uuid(document_id, "document_id")
# Validate status
if status not in ("pending", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# If marking as labeled, save annotations to PostgreSQL DocumentDB
db_save_result = None
if status == "labeled":
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document
annotations = db.get_annotations_for_document(document_id)
if annotations:
db_save_result = save_manual_annotations_to_document_db(
document=document,
annotations=annotations,
db=db,
)
db.update_document_status(document_id, status)
response = {
"status": "updated",
"document_id": document_id,
"new_status": status,
"message": "Document status updated",
}
# Include PostgreSQL save result if applicable
if db_save_result:
response["document_db_saved"] = db_save_result.get("success", False)
response["fields_saved"] = db_save_result.get("fields_saved", 0)
return response
return router

View File

@@ -0,0 +1,184 @@
"""
Admin Document Lock Routes
FastAPI endpoints for annotation lock management.
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AnnotationLockRequest,
AnnotationLockResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_locks_router() -> APIRouter:
"""Create annotation locks router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
@router.post(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Document already locked"},
},
summary="Acquire annotation lock",
description="Acquire a lock on a document to prevent concurrent annotation edits.",
)
async def acquire_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Acquire annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to acquire lock
updated_doc = db.acquire_annotation_lock(
document_id=document_id,
admin_token=admin_token,
duration_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Document is already locked. Please try again later.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock acquired for {request.duration_seconds} seconds",
)
@router.delete(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Release annotation lock",
description="Release the annotation lock on a document.",
)
async def release_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
force: Annotated[
bool,
Query(description="Force release (admin override)"),
] = False,
) -> AnnotationLockResponse:
"""Release annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Release lock
updated_doc = db.release_annotation_lock(
document_id=document_id,
admin_token=admin_token,
force=force,
)
if updated_doc is None:
raise HTTPException(
status_code=404,
detail="Failed to release lock",
)
return AnnotationLockResponse(
document_id=document_id,
locked=False,
lock_expires_at=None,
message="Lock released successfully",
)
@router.patch(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
},
summary="Extend annotation lock",
description="Extend an existing annotation lock.",
)
async def extend_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Extend annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to extend lock
updated_doc = db.extend_annotation_lock(
document_id=document_id,
admin_token=admin_token,
additional_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock extended by {request.duration_seconds} seconds",
)
return router

View File

@@ -0,0 +1,28 @@
"""
Admin Training API Routes
FastAPI endpoints for training task management and scheduling.
"""
from fastapi import APIRouter
from ._utils import _validate_uuid
from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
def create_training_router() -> APIRouter:
"""Create training API router."""
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
register_task_routes(router)
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
return router
__all__ = ["create_training_router", "_validate_uuid"]

View File

@@ -0,0 +1,16 @@
"""Shared utilities for training routes."""
from uuid import UUID
from fastapi import HTTPException
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)

View File

@@ -0,0 +1,209 @@
"""Training Dataset Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetDetailResponse,
DatasetDocumentItem,
DatasetListItem,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingStatus,
TrainingTaskResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_dataset_routes(router: APIRouter) -> None:
"""Register dataset endpoints on the router."""
@router.post(
"/datasets",
response_model=DatasetResponse,
summary="Create training dataset",
description="Create a dataset from selected documents with train/val/test splits.",
)
async def create_dataset(
request: DatasetCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
dataset = db.create_dataset(
name=request.name,
description=request.description,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
)
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=request.document_ids,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
admin_images_dir=Path("data/admin_images"),
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return DatasetResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
status="ready",
message="Dataset created successfully",
)
@router.get(
"/datasets",
response_model=DatasetListResponse,
summary="List datasets",
)
async def list_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
return DatasetListResponse(
total=total,
limit=limit,
offset=offset,
datasets=[
DatasetListItem(
dataset_id=str(d.dataset_id),
name=d.name,
description=d.description,
status=d.status,
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
created_at=d.created_at,
)
for d in datasets
],
)
@router.get(
"/datasets/{dataset_id}",
response_model=DatasetDetailResponse,
summary="Get dataset detail",
)
async def get_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetDetailResponse:
"""Get dataset details with document list."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
docs = db.get_dataset_documents(dataset_id)
return DatasetDetailResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
description=dataset.description,
status=dataset.status,
train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio,
seed=dataset.seed,
total_documents=dataset.total_documents,
total_images=dataset.total_images,
total_annotations=dataset.total_annotations,
dataset_path=dataset.dataset_path,
error_message=dataset.error_message,
documents=[
DatasetDocumentItem(
document_id=str(d.document_id),
split=d.split,
page_count=d.page_count,
annotation_count=d.annotation_count,
)
for d in docs
],
created_at=dataset.created_at,
updated_at=dataset.updated_at,
)
@router.delete(
"/datasets/{dataset_id}",
summary="Delete dataset",
)
async def delete_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a dataset and its files."""
import shutil
from pathlib import Path
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.dataset_path:
dataset_dir = Path(dataset.dataset_path)
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
db.delete_dataset(dataset_id)
return {"message": "Dataset deleted"}
@router.post(
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
)
async def train_from_dataset(
dataset_id: str,
request: DatasetTrainRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.status != "ready":
raise HTTPException(
status_code=400,
detail=f"Dataset is not ready (status: {dataset.status})",
)
config_dict = request.config.model_dump()
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
)

View File

@@ -0,0 +1,211 @@
"""Training Documents and Models Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelMetrics,
TrainingDocumentItem,
TrainingDocumentsResponse,
TrainingModelItem,
TrainingModelsResponse,
TrainingStatus,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_document_routes(router: APIRouter) -> None:
"""Register training document and model endpoints on the router."""
@router.get(
"/documents",
response_model=TrainingDocumentsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get documents for training",
description="Get labeled documents available for training with filtering options.",
)
async def get_training_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
] = True,
min_annotation_count: Annotated[
int | None,
Query(ge=1, description="Minimum annotation count"),
] = None,
exclude_used_in_training: Annotated[
bool,
Query(description="Exclude documents already used in training"),
] = False,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
documents, total = db.get_documents_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
min_annotation_count=min_annotation_count,
exclude_used_in_training=exclude_used_in_training,
limit=limit,
offset=offset,
)
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
sources = {"manual": 0, "auto": 0}
for ann in annotations:
if ann.source in sources:
sources[ann.source] += 1
training_links = db.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
)
)
return TrainingDocumentsResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/models/{task_id}/download",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Model not found"},
},
summary="Download trained model",
description="Download trained model weights file.",
)
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
from pathlib import Path
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if not task.model_path:
raise HTTPException(
status_code=404,
detail="Model file not available for this task",
)
model_path = Path(task.model_path)
if not model_path.exists():
raise HTTPException(
status_code=404,
detail="Model file not found on disk",
)
return FileResponse(
path=str(model_path),
media_type="application/octet-stream",
filename=f"{task.name}_model.pt",
)
@router.get(
"/models",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
)
async def get_training_models(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
offset=offset,
)
items = []
for task in tasks:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
download_url = None
if task.model_path and task.status == "completed":
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
items.append(
TrainingModelItem(
task_id=str(task.task_id),
name=task.name,
status=TrainingStatus(task.status),
document_count=task.document_count,
created_at=task.created_at,
completed_at=task.completed_at,
metrics=metrics,
model_path=task.model_path,
download_url=download_url,
)
)
return TrainingModelsResponse(
total=total,
limit=limit,
offset=offset,
models=items,
)

View File

@@ -0,0 +1,121 @@
"""Training Export Endpoints."""
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ExportRequest,
ExportResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def register_export_routes(router: APIRouter) -> None:
"""Register export endpoints on the router."""
@router.post(
"/export",
response_model=ExportResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Export annotations",
description="Export annotations in YOLO format for training.",
)
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ExportResponse:
"""Export annotations for training."""
from pathlib import Path
import shutil
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
status_code=400,
detail=f"Unsupported export format: {request.format}",
)
documents = db.get_labeled_documents_for_export(admin_token)
if not documents:
raise HTTPException(
status_code=400,
detail="No labeled documents available for export",
)
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
total_docs = len(documents)
train_count = int(total_docs * request.split_ratio)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
from inference.data.admin_models import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
(export_dir / "data.yaml").write_text(yaml_content)
return ExportResponse(
status="completed",
export_path=str(export_dir),
total_images=total_images,
total_annotations=total_annotations,
train_count=len(train_docs),
val_count=len(val_docs),
message=f"Exported {total_images} images with {total_annotations} annotations",
)

View File

@@ -0,0 +1,263 @@
"""Training Task Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
TrainingLogItem,
TrainingLogsResponse,
TrainingStatus,
TrainingTaskCreate,
TrainingTaskDetailResponse,
TrainingTaskItem,
TrainingTaskListResponse,
TrainingTaskResponse,
TrainingType,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_task_routes(router: APIRouter) -> None:
"""Register training task endpoints on the router."""
@router.post(
"/tasks",
response_model=TrainingTaskResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Create training task",
description="Create a new training task.",
)
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
config_dict = request.config.model_dump() if request.config else {}
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
description=request.description,
config=config_dict,
scheduled_at=request.scheduled_at,
cron_expression=request.cron_expression,
is_recurring=bool(request.cron_expression),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
message="Training task created successfully",
)
@router.get(
"/tasks",
response_model=TrainingTaskListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List training tasks",
description="List all training tasks.",
)
async def list_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingTaskListResponse:
"""List training tasks."""
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
if status and status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status,
limit=limit,
offset=offset,
)
items = [
TrainingTaskItem(
task_id=str(task.task_id),
name=task.name,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
scheduled_at=task.scheduled_at,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in tasks
]
return TrainingTaskListResponse(
total=total,
limit=limit,
offset=offset,
tasks=items,
)
@router.get(
"/tasks/{task_id}",
response_model=TrainingTaskDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training task detail",
description="Get training task details.",
)
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
return TrainingTaskDetailResponse(
task_id=str(task.task_id),
name=task.name,
description=task.description,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
config=task.config,
scheduled_at=task.scheduled_at,
cron_expression=task.cron_expression,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
error_message=task.error_message,
result_metrics=task.result_metrics,
model_path=task.model_path,
created_at=task.created_at,
)
@router.post(
"/tasks/{task_id}/cancel",
response_model=TrainingTaskResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
},
summary="Cancel training task",
description="Cancel a pending or scheduled training task.",
)
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if task.status not in ("pending", "scheduled"):
raise HTTPException(
status_code=409,
detail=f"Cannot cancel task with status: {task.status}",
)
success = db.cancel_training_task(task_id)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to cancel training task",
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.CANCELLED,
message="Training task cancelled successfully",
)
@router.get(
"/tasks/{task_id}/logs",
response_model=TrainingLogsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training logs",
description="Get training task logs.",
)
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingLogsResponse:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
logs = db.get_training_logs(task_id, limit, offset)
items = [
TrainingLogItem(
level=log.level,
message=log.message,
details=log.details,
created_at=log.created_at,
)
for log in logs
]
return TrainingLogsResponse(
task_id=task_id,
logs=items,
)

View File

@@ -0,0 +1,236 @@
"""
Batch Upload API Routes
Endpoints for batch uploading documents via ZIP files with CSV metadata.
"""
import io
import logging
import zipfile
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse
from inference.data.admin_db import AdminDB
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
@router.post("/upload")
async def upload_batch(
file: UploadFile = File(...),
upload_source: str = Form(default="ui"),
async_mode: bool = Form(default=True),
auto_label: bool = Form(default=True),
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Upload a batch of documents via ZIP file.
The ZIP file can contain:
- Multiple PDF files
- Optional CSV file with field values for auto-labeling
CSV format:
- Required column: DocumentId (matches PDF filename without extension)
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
Args:
file: ZIP file upload
upload_source: Upload source (ui or api)
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch upload result with batch_id and status
"""
if not file.filename.lower().endswith('.zip'):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Check compressed size
if file.size and file.size > MAX_COMPRESSED_SIZE:
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
raise HTTPException(
status_code=400,
detail=f"File size exceeds {max_mb:.0f}MB limit"
)
try:
# Read file content
zip_content = await file.read()
# Additional security validation before processing
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
# Quick validation of ZIP structure
test_zip.testzip()
except zipfile.BadZipFile:
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
if async_mode:
# Async mode: Queue task and return immediately
from uuid import uuid4
batch_id = uuid4()
# Create batch task for background processing
task = BatchTask(
batch_id=batch_id,
admin_token=admin_token,
zip_content=zip_content,
zip_filename=file.filename,
upload_source=upload_source,
auto_label=auto_label,
created_at=datetime.utcnow(),
)
# Submit to queue
queue = get_batch_queue()
if not queue.submit(task):
raise HTTPException(
status_code=503,
detail="Processing queue is full. Please try again later."
)
logger.info(
f"Batch upload queued: batch_id={batch_id}, "
f"filename={file.filename}, async_mode=True"
)
# Return 202 Accepted with batch_id and status URL
return JSONResponse(
status_code=202,
content={
"status": "accepted",
"batch_id": str(batch_id),
"message": "Batch upload queued for processing",
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
"queue_depth": queue.get_queue_depth(),
}
)
else:
# Sync mode: Process immediately and return results
service = BatchUploadService(admin_db)
result = service.process_zip_upload(
admin_token=admin_token,
zip_filename=file.filename,
zip_content=zip_content,
upload_source=upload_source,
)
logger.info(
f"Batch upload completed: batch_id={result.get('batch_id')}, "
f"status={result.get('status')}, files={result.get('successful_files')}"
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing batch upload: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail="Failed to process batch upload. Please contact support."
)
@router.get("/status/{batch_id}")
async def get_batch_status(
batch_id: str,
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Get batch upload status and file processing details.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch status with file processing details
"""
# Validate UUID format
try:
batch_uuid = UUID(batch_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid batch ID format")
# Check batch exists and verify ownership
batch = admin_db.get_batch_upload(batch_uuid)
if not batch:
raise HTTPException(status_code=404, detail="Batch not found")
# CRITICAL: Verify ownership
if batch.admin_token != admin_token:
raise HTTPException(
status_code=403,
detail="You do not have access to this batch"
)
# Now safe to return details
service = BatchUploadService(admin_db)
result = service.get_batch_status(batch_id)
return result
@router.get("/list")
async def list_batch_uploads(
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
limit: int = 50,
offset: int = 0,
) -> dict:
"""List batch uploads for the current admin token.
Args:
admin_token: Admin authentication token
admin_db: Admin database interface
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of batch uploads
"""
# Validate pagination parameters
if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
# Get batch uploads filtered by admin token
batches, total = admin_db.get_batch_uploads_by_token(
admin_token=admin_token,
limit=limit,
offset=offset,
)
return {
"batches": [
{
"batch_id": str(b.batch_id),
"filename": b.filename,
"status": b.status,
"total_files": b.total_files,
"successful_files": b.successful_files,
"failed_files": b.failed_files,
"created_at": b.created_at.isoformat() if b.created_at else None,
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
}
for b in batches
],
"total": total,
"limit": limit,
"offset": offset,
}

View File

@@ -0,0 +1,16 @@
"""
Public API v1
Customer-facing endpoints for inference, async processing, and labeling.
"""
from inference.web.api.v1.public.inference import create_inference_router
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.api.v1.public.labeling import create_labeling_router
__all__ = [
"create_inference_router",
"create_async_router",
"set_async_service",
"create_labeling_router",
]

View File

@@ -0,0 +1,372 @@
"""
Async API Routes
FastAPI endpoints for async invoice processing.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.dependencies import (
ApiKeyDep,
AsyncDBDep,
PollRateLimitDep,
SubmitRateLimitDep,
)
from inference.web.schemas.inference import (
AsyncRequestItem,
AsyncRequestsListResponse,
AsyncResultResponse,
AsyncStatus,
AsyncStatusResponse,
AsyncSubmitResponse,
DetectionResult,
InferenceResult,
)
from inference.web.schemas.common import ErrorResponse
def _validate_request_id(request_id: str) -> None:
"""Validate that request_id is a valid UUID format."""
try:
UUID(request_id)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid request ID format. Must be a valid UUID.",
)
logger = logging.getLogger(__name__)
# Global reference to async processing service (set during app startup)
_async_service = None
def set_async_service(service) -> None:
"""Set the async processing service instance."""
global _async_service
_async_service = service
def get_async_service():
"""Get the async processing service instance."""
if _async_service is None:
raise RuntimeError("AsyncProcessingService not initialized")
return _async_service
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
"""Create async API router."""
router = APIRouter(prefix="/async", tags=["Async Processing"])
@router.post(
"/submit",
response_model=AsyncSubmitResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid API key"},
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
503: {"model": ErrorResponse, "description": "Queue full"},
},
summary="Submit PDF for async processing",
description="Submit a PDF or image file for asynchronous processing. "
"Returns a request_id that can be used to poll for results.",
)
async def submit_document(
api_key: SubmitRateLimitDep,
file: UploadFile = File(..., description="PDF or image file to process"),
) -> AsyncSubmitResponse:
"""Submit a document for async processing."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate file extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Check file size (get from config via service)
service = get_async_service()
max_size = service._async_config.max_file_size_mb * 1024 * 1024
if len(content) > max_size:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: "
f"{service._async_config.max_file_size_mb}MB",
)
# Submit request
result = service.submit_request(
api_key=api_key,
file_content=content,
filename=file.filename,
content_type=file.content_type or "application/octet-stream",
)
if not result.success:
if "queue" in (result.error or "").lower():
raise HTTPException(status_code=503, detail=result.error)
raise HTTPException(status_code=500, detail=result.error)
return AsyncSubmitResponse(
status="accepted",
message="Request submitted for processing",
request_id=result.request_id,
estimated_wait_seconds=result.estimated_wait_seconds,
poll_url=f"/api/v1/async/status/{result.request_id}",
)
@router.get(
"/status/{request_id}",
response_model=AsyncStatusResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get request status",
description="Get the current processing status of an async request.",
)
async def get_status(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncStatusResponse:
"""Get the status of an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Get queue position for pending requests
position = None
if request.status == "pending":
position = db.get_queue_position(request_id)
# Build result URL for completed requests
result_url = None
if request.status == "completed":
result_url = f"/api/v1/async/result/{request_id}"
return AsyncStatusResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
filename=request.filename,
created_at=request.created_at,
started_at=request.started_at,
completed_at=request.completed_at,
position_in_queue=position,
error_message=request.error_message,
result_url=result_url,
)
@router.get(
"/result/{request_id}",
response_model=AsyncResultResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Request not completed"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get extraction results",
description="Get the extraction results for a completed async request.",
)
async def get_result(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncResultResponse:
"""Get the results of a completed async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Check if completed or failed
if request.status not in ("completed", "failed"):
raise HTTPException(
status_code=409,
detail=f"Request not yet completed. Current status: {request.status}",
)
# Build inference result from stored data
inference_result = None
if request.result:
# Convert detections to DetectionResult objects
detections = []
for d in request.result.get("detections", []):
detections.append(DetectionResult(
field=d.get("field", ""),
confidence=d.get("confidence", 0.0),
bbox=d.get("bbox", [0, 0, 0, 0]),
))
inference_result = InferenceResult(
document_id=request.result.get("document_id", str(request.request_id)[:8]),
success=request.result.get("success", False),
document_type=request.result.get("document_type", "invoice"),
fields=request.result.get("fields", {}),
confidence=request.result.get("confidence", {}),
detections=detections,
processing_time_ms=request.processing_time_ms or 0.0,
errors=request.result.get("errors", []),
)
# Build visualization URL
viz_url = None
if request.visualization_path:
viz_url = f"/api/v1/results/{request.visualization_path}"
return AsyncResultResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
processing_time_ms=request.processing_time_ms or 0.0,
result=inference_result,
visualization_url=viz_url,
)
@router.get(
"/requests",
response_model=AsyncRequestsListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
},
summary="List requests",
description="List all async requests for the authenticated API key.",
)
async def list_requests(
api_key: ApiKeyDep,
db: AsyncDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (pending, processing, completed, failed)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Maximum number of results"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Pagination offset"),
] = 0,
) -> AsyncRequestsListResponse:
"""List all requests for the authenticated API key."""
# Validate status filter
if status and status not in ("pending", "processing", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid status filter: {status}. "
"Must be one of: pending, processing, completed, failed",
)
# Get requests from database
requests, total = db.get_requests_by_api_key(
api_key=api_key,
status=status,
limit=limit,
offset=offset,
)
# Convert to response items
items = [
AsyncRequestItem(
request_id=str(r.request_id),
status=AsyncStatus(r.status),
filename=r.filename,
file_size=r.file_size,
created_at=r.created_at,
completed_at=r.completed_at,
)
for r in requests
]
return AsyncRequestsListResponse(
total=total,
limit=limit,
offset=offset,
requests=items,
)
@router.delete(
"/requests/{request_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
},
summary="Cancel/delete request",
description="Cancel a pending request or delete a completed/failed request.",
)
async def delete_request(
request_id: str,
api_key: ApiKeyDep,
db: AsyncDBDep,
) -> dict:
"""Delete or cancel an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Cannot delete processing requests
if request.status == "processing":
raise HTTPException(
status_code=409,
detail="Cannot delete a request that is currently processing",
)
# Delete from database (will cascade delete related records)
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM async_requests WHERE request_id = %s",
(request_id,),
)
conn.commit()
return {
"status": "deleted",
"request_id": request_id,
"message": "Request deleted successfully",
}
return router

View File

@@ -0,0 +1,183 @@
"""
Inference API Routes
FastAPI route definitions for the inference API.
"""
from __future__ import annotations
import logging
import shutil
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from inference.web.schemas.inference import (
DetectionResult,
HealthResponse,
InferenceResponse,
InferenceResult,
)
from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from inference.web.services import InferenceService
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
def create_inference_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with inference endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["inference"])
@router.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Check service health status."""
return HealthResponse(
status="healthy",
model_loaded=inference_service.is_initialized,
gpu_available=inference_service.gpu_available,
version="1.0.0",
)
@router.post(
"/infer",
response_model=InferenceResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
)
async def infer_document(
file: UploadFile = File(..., description="PDF or image file to process"),
) -> InferenceResponse:
"""
Process a document and extract invoice fields.
Accepts PDF or image files (PNG, JPG, JPEG).
Returns extracted field values with confidence scores.
"""
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Generate document ID
doc_id = str(uuid.uuid4())[:8]
# Save uploaded file
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
try:
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save uploaded file",
)
try:
# Process based on file type
if file_ext == ".pdf":
service_result = inference_service.process_pdf(
upload_path, document_id=doc_id
)
else:
service_result = inference_service.process_image(
upload_path, document_id=doc_id
)
# Build response
viz_url = None
if service_result.visualization_path:
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
document_type=service_result.document_type,
fields=service_result.fields,
confidence=service_result.confidence,
detections=[
DetectionResult(**d) for d in service_result.detections
],
processing_time_ms=service_result.processing_time_ms,
visualization_url=viz_url,
errors=service_result.errors,
)
return InferenceResponse(
status="success" if service_result.success else "partial",
message=f"Processed document {doc_id}",
result=inference_result,
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
finally:
# Cleanup uploaded file
upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}")
async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
return FileResponse(
path=file_path,
media_type="image/png",
filename=filename,
)
@router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict:
"""Delete a result file."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
file_path.unlink()
return {"status": "deleted", "filename": filename}
return router

View File

@@ -0,0 +1,203 @@
"""
Labeling API Routes
FastAPI endpoints for pre-labeling documents with expected field values.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from inference.data.admin_db import AdminDB
from inference.web.schemas.labeling import PreLabelResponse
from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from inference.web.services import InferenceService
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
# Storage directory for pre-label uploads (legacy, now uses storage_config)
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def get_admin_db() -> AdminDB:
"""Get admin database instance."""
return AdminDB()
def create_labeling_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with labeling endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["labeling"])
# Ensure upload directory exists
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@router.post(
"/pre-label",
response_model=PreLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
summary="Pre-label document with expected values",
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
)
async def pre_label(
file: UploadFile = File(..., description="PDF or image file to process"),
field_values: str = Form(
...,
description="JSON object with expected field values. "
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
),
db: AdminDB = Depends(get_admin_db),
) -> PreLabelResponse:
"""
Upload a document with expected field values for pre-labeling.
Returns document_id which can be used to retrieve results later.
Example field_values JSON:
```json
{
"InvoiceNumber": "12345",
"Amount": "1500.00",
"Bankgiro": "123-4567",
"OCR": "1234567890"
}
```
"""
# Parse field_values JSON
try:
expected_values = json.loads(field_values)
if not isinstance(expected_values, dict):
raise ValueError("field_values must be a JSON object")
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON in field_values: {e}",
)
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read file",
)
# Get page count for PDF
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record with field_values
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
upload_source="api",
csv_field_values=expected_values,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save file",
)
# Update file path in database
db.update_document_file_path(document_id, str(file_path))
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="pending",
)
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
return PreLabelResponse(document_id=document_id)
return router

View File

@@ -0,0 +1,913 @@
"""
FastAPI Application Factory
Creates and configures the FastAPI application.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from inference.web.services import InferenceService
# Public API imports
from inference.web.api.v1.public import (
create_inference_router,
create_async_router,
set_async_service,
create_labeling_router,
)
# Async processing imports
from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.dependencies import init_dependencies
from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
create_auth_router,
create_documents_router,
create_locks_router,
create_training_router,
)
from inference.web.core.scheduler import start_scheduler, stop_scheduler
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
# Batch upload imports
from inference.web.api.v1.batch.routes import router as batch_upload_router
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
from inference.data.admin_db import AdminDB
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
def create_app(config: AppConfig | None = None) -> FastAPI:
"""
Create and configure FastAPI application.
Args:
config: Application configuration. Uses default if not provided.
Returns:
Configured FastAPI application
"""
config = config or default_config
# Create inference service
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
)
# Create async processing components
async_db = AsyncRequestDB()
rate_limiter = RateLimiter(async_db)
task_queue = AsyncTaskQueue(
max_size=config.async_processing.queue_max_size,
worker_count=config.async_processing.worker_count,
)
async_service = AsyncProcessingService(
inference_service=inference_service,
db=async_db,
queue=task_queue,
rate_limiter=rate_limiter,
async_config=config.async_processing,
storage_config=config.storage,
)
# Initialize dependencies for FastAPI
init_dependencies(async_db, rate_limiter)
set_async_service(async_service)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
logger.info("Inference service ready")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
# Continue anyway - service will retry on first request
# Start async processing service
try:
async_service.start()
logger.info("Async processing service started")
except Exception as e:
logger.error(f"Failed to start async processing: {e}")
# Start batch upload queue
try:
admin_db = AdminDB()
batch_service = BatchUploadService(admin_db)
init_batch_queue(batch_service)
logger.info("Batch upload queue started")
except Exception as e:
logger.error(f"Failed to start batch upload queue: {e}")
# Start training scheduler
try:
start_scheduler()
logger.info("Training scheduler started")
except Exception as e:
logger.error(f"Failed to start training scheduler: {e}")
# Start auto-label scheduler
try:
start_autolabel_scheduler()
logger.info("AutoLabel scheduler started")
except Exception as e:
logger.error(f"Failed to start autolabel scheduler: {e}")
yield
logger.info("Shutting down Invoice Inference API...")
# Stop auto-label scheduler
try:
stop_autolabel_scheduler()
logger.info("AutoLabel scheduler stopped")
except Exception as e:
logger.error(f"Error stopping autolabel scheduler: {e}")
# Stop training scheduler
try:
stop_scheduler()
logger.info("Training scheduler stopped")
except Exception as e:
logger.error(f"Error stopping training scheduler: {e}")
# Stop batch upload queue
try:
shutdown_batch_queue()
logger.info("Batch upload queue stopped")
except Exception as e:
logger.error(f"Error stopping batch upload queue: {e}")
# Stop async processing service
try:
async_service.stop(timeout=30.0)
logger.info("Async processing service stopped")
except Exception as e:
logger.error(f"Error stopping async service: {e}")
# Close database connection
try:
async_db.close()
logger.info("Database connection closed")
except Exception as e:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
app = FastAPI(
title="Invoice Field Extraction API",
description="""
REST API for extracting fields from Swedish invoices.
## Features
- YOLO-based field detection
- OCR text extraction
- Field normalization and validation
- Visualization of detections
## Supported Fields
- InvoiceNumber
- InvoiceDate
- InvoiceDueDate
- OCR (reference number)
- Bankgiro
- Plusgiro
- Amount
- supplier_org_number (Swedish organization number)
- customer_number
- payment_line (machine-readable payment code)
""",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files for results
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
app.mount(
"/static/results",
StaticFiles(directory=str(config.storage.result_dir)),
name="results",
)
# Include public API routes
inference_router = create_inference_router(inference_service, config.storage)
app.include_router(inference_router)
async_router = create_async_router(config.storage.allowed_extensions)
app.include_router(async_router, prefix="/api/v1")
labeling_router = create_labeling_router(inference_service, config.storage)
app.include_router(labeling_router)
# Include admin API routes
auth_router = create_auth_router()
app.include_router(auth_router, prefix="/api/v1")
documents_router = create_documents_router(config.storage)
app.include_router(documents_router, prefix="/api/v1")
locks_router = create_locks_router()
app.include_router(locks_router, prefix="/api/v1")
annotation_router = create_annotation_router()
app.include_router(annotation_router, prefix="/api/v1")
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
"""Serve the web UI."""
return get_html_ui()
return app
def get_html_ui() -> str:
"""Generate HTML UI for the web application."""
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Invoice Field Extraction</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
color: white;
margin-bottom: 30px;
}
header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
}
header p {
opacity: 0.9;
font-size: 1.1rem;
}
.main-content {
display: flex;
flex-direction: column;
gap: 20px;
}
.card {
background: white;
border-radius: 16px;
padding: 24px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
font-size: 1.3rem;
display: flex;
align-items: center;
gap: 10px;
}
.upload-card {
display: flex;
align-items: center;
gap: 20px;
flex-wrap: wrap;
}
.upload-card h2 {
margin-bottom: 0;
white-space: nowrap;
}
.upload-area {
border: 2px dashed #ddd;
border-radius: 10px;
padding: 15px 25px;
text-align: center;
cursor: pointer;
transition: all 0.3s;
background: #fafafa;
flex: 1;
min-width: 200px;
}
.upload-area:hover, .upload-area.dragover {
border-color: #667eea;
background: #f0f4ff;
}
.upload-area.has-file {
border-color: #10b981;
background: #ecfdf5;
}
.upload-icon {
font-size: 24px;
display: inline;
margin-right: 8px;
}
.upload-area p {
color: #666;
margin: 0;
display: inline;
}
.upload-area small {
color: #999;
display: block;
margin-top: 5px;
}
#file-input {
display: none;
}
.file-name {
margin-top: 15px;
padding: 10px 15px;
background: #e0f2fe;
border-radius: 8px;
color: #0369a1;
font-weight: 500;
}
.btn {
display: inline-block;
padding: 12px 24px;
border: none;
border-radius: 10px;
font-size: 0.9rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn-primary:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.loading {
display: none;
align-items: center;
gap: 10px;
}
.loading.active {
display: flex;
}
.spinner {
width: 24px;
height: 24px;
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.results {
display: none;
}
.results.active {
display: block;
}
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
padding-bottom: 15px;
border-bottom: 2px solid #eee;
}
.result-status {
padding: 6px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
}
.result-status.success {
background: #dcfce7;
color: #166534;
}
.result-status.partial {
background: #fef3c7;
color: #92400e;
}
.result-status.error {
background: #fee2e2;
color: #991b1b;
}
.fields-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 12px;
}
.field-item {
padding: 12px;
background: #f8fafc;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.field-item label {
display: block;
font-size: 0.75rem;
color: #64748b;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.field-item .value {
font-size: 1.1rem;
font-weight: 600;
color: #1e293b;
}
.field-item .confidence {
font-size: 0.75rem;
color: #10b981;
margin-top: 2px;
}
.visualization {
margin-top: 20px;
}
.visualization img {
width: 100%;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
}
.processing-time {
text-align: center;
color: #64748b;
font-size: 0.9rem;
margin-top: 15px;
}
.cross-validation {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 10px;
padding: 15px;
margin-top: 20px;
}
.cross-validation h3 {
margin: 0 0 10px 0;
color: #334155;
font-size: 1rem;
}
.cv-status {
font-weight: 600;
padding: 8px 12px;
border-radius: 6px;
margin-bottom: 10px;
display: inline-block;
}
.cv-status.valid {
background: #dcfce7;
color: #166534;
}
.cv-status.invalid {
background: #fef3c7;
color: #92400e;
}
.cv-details {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 10px;
}
.cv-item {
background: white;
border: 1px solid #e2e8f0;
border-radius: 6px;
padding: 6px 12px;
font-size: 0.85rem;
display: flex;
align-items: center;
gap: 6px;
}
.cv-item.match {
border-color: #86efac;
background: #f0fdf4;
}
.cv-item.mismatch {
border-color: #fca5a5;
background: #fef2f2;
}
.cv-icon {
font-weight: bold;
}
.cv-item.match .cv-icon {
color: #16a34a;
}
.cv-item.mismatch .cv-icon {
color: #dc2626;
}
.cv-summary {
margin-top: 10px;
font-size: 0.8rem;
color: #64748b;
}
.error-message {
background: #fee2e2;
color: #991b1b;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
footer {
text-align: center;
color: white;
opacity: 0.8;
margin-top: 30px;
font-size: 0.9rem;
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>📄 Invoice Field Extraction</h1>
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
</header>
<div class="main-content">
<!-- Upload Section - Compact -->
<div class="card upload-card">
<h2>📤 Upload</h2>
<div class="upload-area" id="upload-area">
<span class="upload-icon">📁</span>
<p>Drag & drop or <strong>click to browse</strong></p>
<small>PDF, PNG, JPG (max 50MB)</small>
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
</div>
<div class="file-name" id="file-name" style="display: none;"></div>
<button class="btn btn-primary" id="submit-btn" disabled>
🚀 Extract
</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p>Processing...</p>
</div>
</div>
<!-- Results Section - Full Width -->
<div class="card">
<h2>📊 Extraction Results</h2>
<div id="placeholder" style="text-align: center; padding: 30px; color: #999;">
<div style="font-size: 48px; margin-bottom: 10px;">🔍</div>
<p>Upload a document to see extraction results</p>
</div>
<div class="results" id="results">
<div class="result-header">
<span>Document: <strong id="doc-id"></strong></span>
<span class="result-status" id="result-status"></span>
</div>
<div class="fields-grid" id="fields-grid"></div>
<div class="processing-time" id="processing-time"></div>
<div class="cross-validation" id="cross-validation" style="display: none;"></div>
<div class="error-message" id="error-message" style="display: none;"></div>
<div class="visualization" id="visualization" style="display: none;">
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
<img id="viz-image" src="" alt="Detection visualization">
</div>
</div>
</div>
</div>
<footer>
<p>Powered by ColaCoder</p>
</footer>
</div>
<script>
const uploadArea = document.getElementById('upload-area');
const fileInput = document.getElementById('file-input');
const fileName = document.getElementById('file-name');
const submitBtn = document.getElementById('submit-btn');
const loading = document.getElementById('loading');
const placeholder = document.getElementById('placeholder');
const results = document.getElementById('results');
let selectedFile = null;
// Drag and drop handlers
uploadArea.addEventListener('click', () => fileInput.click());
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFile(files[0]);
}
});
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFile(e.target.files[0]);
}
});
function handleFile(file) {
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
const ext = '.' + file.name.split('.').pop().toLowerCase();
if (!validTypes.includes(ext)) {
alert('Please upload a PDF, PNG, or JPG file.');
return;
}
selectedFile = file;
fileName.textContent = `📎 ${file.name}`;
fileName.style.display = 'block';
uploadArea.classList.add('has-file');
submitBtn.disabled = false;
}
submitBtn.addEventListener('click', async () => {
if (!selectedFile) return;
// Show loading
submitBtn.disabled = true;
loading.classList.add('active');
placeholder.style.display = 'none';
results.classList.remove('active');
try {
const formData = new FormData();
formData.append('file', selectedFile);
const response = await fetch('/api/v1/infer', {
method: 'POST',
body: formData,
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Processing failed');
}
displayResults(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('error-message').textContent = error.message;
document.getElementById('error-message').style.display = 'block';
results.classList.add('active');
} finally {
loading.classList.remove('active');
submitBtn.disabled = false;
}
});
function displayResults(data) {
const result = data.result;
// Document ID
document.getElementById('doc-id').textContent = result.document_id;
// Status
const statusEl = document.getElementById('result-status');
statusEl.textContent = result.success ? 'Success' : 'Partial';
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
// Fields
const fieldsGrid = document.getElementById('fields-grid');
fieldsGrid.innerHTML = '';
const fieldOrder = [
'InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR',
'Amount', 'Bankgiro', 'Plusgiro',
'supplier_org_number', 'customer_number', 'payment_line'
];
fieldOrder.forEach(field => {
const value = result.fields[field];
const confidence = result.confidence[field];
if (value !== null && value !== undefined) {
const fieldDiv = document.createElement('div');
fieldDiv.className = 'field-item';
fieldDiv.innerHTML = `
<label>${formatFieldName(field)}</label>
<div class="value">${value}</div>
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
`;
fieldsGrid.appendChild(fieldDiv);
}
});
// Processing time
document.getElementById('processing-time').textContent =
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
// Cross-validation results
const cvDiv = document.getElementById('cross-validation');
if (result.cross_validation) {
const cv = result.cross_validation;
let cvHtml = '<h3>🔍 Cross-Validation (Payment Line)</h3>';
cvHtml += `<div class="cv-status ${cv.is_valid ? 'valid' : 'invalid'}">`;
cvHtml += cv.is_valid ? '✅ Valid' : '⚠️ Mismatch Detected';
cvHtml += '</div>';
cvHtml += '<div class="cv-details">';
if (cv.payment_line_ocr) {
const matchIcon = cv.ocr_match === true ? '' : (cv.ocr_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.ocr_match === true ? 'match' : (cv.ocr_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> OCR: ${cv.payment_line_ocr}</div>`;
}
if (cv.payment_line_amount) {
const matchIcon = cv.amount_match === true ? '' : (cv.amount_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.amount_match === true ? 'match' : (cv.amount_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> Amount: ${cv.payment_line_amount}</div>`;
}
if (cv.payment_line_account) {
const accountType = cv.payment_line_account_type === 'bankgiro' ? 'Bankgiro' : 'Plusgiro';
const matchField = cv.payment_line_account_type === 'bankgiro' ? cv.bankgiro_match : cv.plusgiro_match;
const matchIcon = matchField === true ? '' : (matchField === false ? '' : '');
cvHtml += `<div class="cv-item ${matchField === true ? 'match' : (matchField === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> ${accountType}: ${cv.payment_line_account}</div>`;
}
cvHtml += '</div>';
if (cv.details && cv.details.length > 0) {
cvHtml += '<div class="cv-summary">' + cv.details[cv.details.length - 1] + '</div>';
}
cvDiv.innerHTML = cvHtml;
cvDiv.style.display = 'block';
} else {
cvDiv.style.display = 'none';
}
// Visualization
if (result.visualization_url) {
const vizDiv = document.getElementById('visualization');
const vizImg = document.getElementById('viz-image');
vizImg.src = result.visualization_url;
vizDiv.style.display = 'block';
}
// Errors
if (result.errors && result.errors.length > 0) {
document.getElementById('error-message').textContent = result.errors.join(', ');
document.getElementById('error-message').style.display = 'block';
} else {
document.getElementById('error-message').style.display = 'none';
}
results.classList.add('active');
}
function formatFieldName(name) {
const nameMap = {
'InvoiceNumber': 'Invoice Number',
'InvoiceDate': 'Invoice Date',
'InvoiceDueDate': 'Due Date',
'OCR': 'OCR Reference',
'Amount': 'Amount',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'supplier_org_number': 'Supplier Org Number',
'customer_number': 'Customer Number',
'payment_line': 'Payment Line'
};
return nameMap[name] || name.replace(/([A-Z])/g, ' $1').replace(/_/g, ' ').trim();
}
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,113 @@
"""
Web Application Configuration
Centralized configuration for the web application.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI, PATHS
@dataclass(frozen=True)
class ModelConfig:
"""YOLO model configuration."""
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5
use_gpu: bool = True
dpi: int = DEFAULT_DPI
@dataclass(frozen=True)
class ServerConfig:
"""Server configuration."""
host: str = "0.0.0.0"
port: int = 8000
debug: bool = False
reload: bool = False
workers: int = 1
@dataclass(frozen=True)
class StorageConfig:
"""File storage configuration.
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
and avoids storing duplicate files.
"""
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
admin_images_dir: Path = Path("data/admin_images")
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
self.upload_dir.mkdir(parents=True, exist_ok=True)
self.result_dir.mkdir(parents=True, exist_ok=True)
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
@dataclass(frozen=True)
class AsyncConfig:
"""Async processing configuration."""
# Queue settings
queue_max_size: int = 100
worker_count: int = 1
task_timeout_seconds: int = 300
# Rate limiting defaults
default_requests_per_minute: int = 10
default_max_concurrent_jobs: int = 3
default_min_poll_interval_ms: int = 1000
# Storage
result_retention_days: int = 7
temp_upload_dir: Path = Path("uploads/async")
max_file_size_mb: int = 50
# Cleanup
cleanup_interval_hours: int = 1
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary."""
return cls(
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
)
# Default configuration instance
default_config = AppConfig()

View File

@@ -0,0 +1,28 @@
"""
Core Components
Reusable core functionality: authentication, rate limiting, scheduling.
"""
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
from inference.web.core.rate_limiter import RateLimiter
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from inference.web.core.autolabel_scheduler import (
start_autolabel_scheduler,
stop_autolabel_scheduler,
get_autolabel_scheduler,
)
__all__ = [
"validate_admin_token",
"get_admin_db",
"AdminTokenDep",
"AdminDBDep",
"RateLimiter",
"start_scheduler",
"stop_scheduler",
"get_training_scheduler",
"start_autolabel_scheduler",
"stop_autolabel_scheduler",
"get_autolabel_scheduler",
]

View File

@@ -0,0 +1,60 @@
"""
Admin Authentication
FastAPI dependencies for admin token authentication.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from inference.data.admin_db import AdminDB
from inference.data.database import get_session_context
logger = logging.getLogger(__name__)
# Global AdminDB instance
_admin_db: AdminDB | None = None
def get_admin_db() -> AdminDB:
"""Get the AdminDB instance."""
global _admin_db
if _admin_db is None:
_admin_db = AdminDB()
return _admin_db
def reset_admin_db() -> None:
"""Reset the AdminDB instance (for testing)."""
global _admin_db
_admin_db = None
async def validate_admin_token(
x_admin_token: Annotated[str | None, Header()] = None,
admin_db: AdminDB = Depends(get_admin_db),
) -> str:
"""Validate admin token from header."""
if not x_admin_token:
raise HTTPException(
status_code=401,
detail="Admin token required. Provide X-Admin-Token header.",
)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(
status_code=401,
detail="Invalid or expired admin token.",
)
# Update last used timestamp
admin_db.update_admin_token_usage(x_admin_token)
return x_admin_token
# Type alias for dependency injection
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]

View File

@@ -0,0 +1,153 @@
"""
Auto-Label Scheduler
Background scheduler for processing documents pending auto-labeling.
"""
import logging
import threading
from pathlib import Path
from inference.data.admin_db import AdminDB
from inference.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
)
logger = logging.getLogger(__name__)
class AutoLabelScheduler:
"""Scheduler for auto-labeling tasks."""
def __init__(
self,
check_interval_seconds: int = 10,
batch_size: int = 5,
output_dir: Path | None = None,
):
"""
Initialize auto-label scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
batch_size: Number of documents to process per batch
output_dir: Output directory for temporary files
"""
self._check_interval = check_interval_seconds
self._batch_size = batch_size
self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("AutoLabel scheduler stopped")
@property
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._process_pending_documents()
except Exception as e:
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _process_pending_documents(self) -> None:
"""Check and process pending auto-label documents."""
try:
documents = get_pending_autolabel_documents(
self._db, limit=self._batch_size
)
if not documents:
return
logger.info(f"Processing {len(documents)} pending autolabel documents")
for doc in documents:
if self._stop_event.is_set():
break
try:
result = process_document_autolabel(
document=doc,
db=self._db,
output_dir=self._output_dir,
)
if result.get("success"):
logger.info(
f"AutoLabel completed for document {doc.document_id}"
)
else:
logger.warning(
f"AutoLabel failed for document {doc.document_id}: "
f"{result.get('error', 'Unknown error')}"
)
except Exception as e:
logger.error(
f"Error processing document {doc.document_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
# Global scheduler instance
_autolabel_scheduler: AutoLabelScheduler | None = None
def get_autolabel_scheduler() -> AutoLabelScheduler:
"""Get the auto-label scheduler instance."""
global _autolabel_scheduler
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
return _autolabel_scheduler
def start_autolabel_scheduler() -> None:
"""Start the global auto-label scheduler."""
scheduler = get_autolabel_scheduler()
scheduler.start()
def stop_autolabel_scheduler() -> None:
"""Stop the global auto-label scheduler."""
global _autolabel_scheduler
if _autolabel_scheduler:
_autolabel_scheduler.stop()
_autolabel_scheduler = None

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,340 @@
"""
Admin Training Scheduler
Background scheduler for training tasks using APScheduler.
"""
import logging
import threading
from datetime import datetime
from pathlib import Path
from typing import Any
from inference.data.admin_db import AdminDB
logger = logging.getLogger(__name__)
class TrainingScheduler:
"""Scheduler for training tasks."""
def __init__(
self,
check_interval_seconds: int = 60,
):
"""
Initialize training scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
"""
self._check_interval = check_interval_seconds
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("Training scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("Training scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._check_pending_tasks()
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _check_pending_tasks(self) -> None:
"""Check and execute pending training tasks."""
try:
tasks = self._db.get_pending_training_tasks()
for task in tasks:
task_id = str(task.task_id)
# Check if scheduled time has passed
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
continue
logger.info(f"Starting training task: {task_id}")
try:
dataset_id = getattr(task, "dataset_id", None)
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.update_training_task_status(
task_id=task_id,
status="failed",
error_message=str(e),
)
except Exception as e:
logger.error(f"Error checking pending tasks: {e}")
def _execute_task(
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
) -> None:
"""Execute a training task."""
# Update status to running
self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started")
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
learning_rate = config.get("learning_rate", 0.01)
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
)
else:
export_result = self._export_training_data(task_id)
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project_name=project_name,
)
# Update task with results
self._db.update_training_task_status(
task_id=task_id,
status="completed",
result_metrics=result.get("metrics"),
model_path=result.get("model_path"),
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
import shutil
from inference.data.admin_models import FIELD_CLASSES
# Get all labeled documents
documents = self._db.get_labeled_documents_for_export()
if not documents:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory
export_dir = Path("data/training") / task_id
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
# 80/20 train/val split
total_docs = len(documents)
train_count = int(total_docs * 0.8)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = self._db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
# Copy image
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
# Write YOLO label
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
# Create data.yaml
yaml_path = export_dir / "data.yaml"
yaml_content = f"""path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
yaml_path.write_text(yaml_content)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _run_yolo_training(
self,
task_id: str,
model_name: str,
data_yaml: str,
epochs: int,
batch_size: int,
image_size: int,
learning_rate: float,
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
# Global scheduler instance
_scheduler: TrainingScheduler | None = None
def get_training_scheduler() -> TrainingScheduler:
"""Get the training scheduler instance."""
global _scheduler
if _scheduler is None:
_scheduler = TrainingScheduler()
return _scheduler
def start_scheduler() -> None:
"""Start the global training scheduler."""
scheduler = get_training_scheduler()
scheduler.start()
def stop_scheduler() -> None:
"""Stop the global training scheduler."""
global _scheduler
if _scheduler:
_scheduler.stop()
_scheduler = None

View File

@@ -0,0 +1,133 @@
"""
FastAPI Dependencies
Dependency injection for the async API endpoints.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException, Request
from inference.data.async_request_db import AsyncRequestDB
from inference.web.rate_limiter import RateLimiter
logger = logging.getLogger(__name__)
# Global instances (initialized in app startup)
_async_db: AsyncRequestDB | None = None
_rate_limiter: RateLimiter | None = None
def init_dependencies(db: AsyncRequestDB, rate_limiter: RateLimiter) -> None:
"""Initialize global dependency instances."""
global _async_db, _rate_limiter
_async_db = db
_rate_limiter = rate_limiter
def get_async_db() -> AsyncRequestDB:
"""Get async request database instance."""
if _async_db is None:
raise RuntimeError("AsyncRequestDB not initialized")
return _async_db
def get_rate_limiter() -> RateLimiter:
"""Get rate limiter instance."""
if _rate_limiter is None:
raise RuntimeError("RateLimiter not initialized")
return _rate_limiter
async def verify_api_key(
x_api_key: Annotated[str | None, Header()] = None,
) -> str:
"""
Verify API key exists and is active.
Raises:
HTTPException: 401 if API key is missing or invalid
"""
if not x_api_key:
raise HTTPException(
status_code=401,
detail="X-API-Key header is required",
headers={"WWW-Authenticate": "API-Key"},
)
db = get_async_db()
if not db.is_valid_api_key(x_api_key):
raise HTTPException(
status_code=401,
detail="Invalid or inactive API key",
headers={"WWW-Authenticate": "API-Key"},
)
# Update usage tracking
try:
db.update_api_key_usage(x_api_key)
except Exception as e:
logger.warning(f"Failed to update API key usage: {e}")
return x_api_key
async def check_submit_rate_limit(
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check rate limit before processing submit request.
Raises:
HTTPException: 429 if rate limit exceeded
"""
rate_limiter = get_rate_limiter()
status = rate_limiter.check_submit_limit(api_key)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Rate limit exceeded",
headers=headers,
)
return api_key
async def check_poll_rate_limit(
request: Request,
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check poll rate limit to prevent abuse.
Raises:
HTTPException: 429 if polling too frequently
"""
# Extract request_id from path parameters
request_id = request.path_params.get("request_id")
if not request_id:
return api_key # No request_id, skip poll limit check
rate_limiter = get_rate_limiter()
status = rate_limiter.check_poll_limit(api_key, request_id)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Polling too frequently",
headers=headers,
)
return api_key
# Type aliases for cleaner route signatures
ApiKeyDep = Annotated[str, Depends(verify_api_key)]
SubmitRateLimitDep = Annotated[str, Depends(check_submit_rate_limit)]
PollRateLimitDep = Annotated[str, Depends(check_poll_rate_limit)]
AsyncDBDep = Annotated[AsyncRequestDB, Depends(get_async_db)]
RateLimiterDep = Annotated[RateLimiter, Depends(get_rate_limiter)]

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,11 @@
"""
API Schemas
Pydantic models for request/response validation.
"""
# Import everything from sub-modules for backward compatibility
from inference.web.schemas.common import * # noqa: F401, F403
from inference.web.schemas.admin import * # noqa: F401, F403
from inference.web.schemas.inference import * # noqa: F401, F403
from inference.web.schemas.labeling import * # noqa: F401, F403

View File

@@ -0,0 +1,17 @@
"""
Admin API Request/Response Schemas
Pydantic models for admin API validation and serialization.
"""
from .enums import * # noqa: F401, F403
from .auth import * # noqa: F401, F403
from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse
DocumentDetailResponse.model_rebuild()

View File

@@ -0,0 +1,152 @@
"""Admin Annotation Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .enums import AnnotationSource
class BoundingBox(BaseModel):
"""Bounding box coordinates."""
x: int = Field(..., ge=0, description="X coordinate (pixels)")
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
width: int = Field(..., ge=1, description="Width (pixels)")
height: int = Field(..., ge=1, description="Height (pixels)")
class AnnotationCreate(BaseModel):
"""Request to create an annotation."""
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
text_value: str | None = Field(None, description="Text value (optional)")
class AnnotationUpdate(BaseModel):
"""Request to update an annotation."""
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
bbox: BoundingBox | None = Field(None, description="New bounding box")
text_value: str | None = Field(None, description="New text value")
class AnnotationItem(BaseModel):
"""Single annotation item."""
annotation_id: str = Field(..., description="Annotation UUID")
page_number: int = Field(..., ge=1, description="Page number")
class_id: int = Field(..., ge=0, le=9, description="Class ID")
class_name: str = Field(..., description="Class name")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
normalized_bbox: dict[str, float] = Field(
..., description="Normalized bbox (x_center, y_center, width, height)"
)
text_value: str | None = Field(None, description="Text value")
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
source: AnnotationSource = Field(..., description="Annotation source")
created_at: datetime = Field(..., description="Creation timestamp")
class AnnotationResponse(BaseModel):
"""Response for annotation operation."""
annotation_id: str = Field(..., description="Annotation UUID")
message: str = Field(..., description="Status message")
class AnnotationListResponse(BaseModel):
"""Response for annotation list."""
document_id: str = Field(..., description="Document UUID")
page_count: int = Field(..., ge=1, description="Total pages")
total_annotations: int = Field(..., ge=0, description="Total annotations")
annotations: list[AnnotationItem] = Field(
default_factory=list, description="Annotation list"
)
class AnnotationLockRequest(BaseModel):
"""Request to acquire annotation lock."""
duration_seconds: int = Field(
default=300,
ge=60,
le=3600,
description="Lock duration in seconds (60-3600)",
)
class AnnotationLockResponse(BaseModel):
"""Response for annotation lock operation."""
document_id: str = Field(..., description="Document UUID")
locked: bool = Field(..., description="Whether lock was acquired/released")
lock_expires_at: datetime | None = Field(
None, description="Lock expiration time"
)
message: str = Field(..., description="Status message")
class AutoLabelRequest(BaseModel):
"""Request to trigger auto-labeling."""
field_values: dict[str, str] = Field(
...,
description="Field values to match (e.g., {'invoice_number': '12345'})",
)
replace_existing: bool = Field(
default=False, description="Replace existing auto annotations"
)
class AutoLabelResponse(BaseModel):
"""Response for auto-labeling."""
document_id: str = Field(..., description="Document UUID")
status: str = Field(..., description="Auto-labeling status")
annotations_created: int = Field(
default=0, ge=0, description="Number of annotations created"
)
message: str = Field(..., description="Status message")
class AnnotationVerifyRequest(BaseModel):
"""Request to verify an annotation."""
pass # No body needed, just POST to verify
class AnnotationVerifyResponse(BaseModel):
"""Response for annotation verification."""
annotation_id: str = Field(..., description="Annotation UUID")
is_verified: bool = Field(..., description="Verification status")
verified_at: datetime = Field(..., description="Verification timestamp")
verified_by: str = Field(..., description="Admin token who verified")
message: str = Field(..., description="Status message")
class AnnotationOverrideRequest(BaseModel):
"""Request to override an annotation."""
bbox: dict[str, int] | None = Field(
None, description="Updated bounding box {x, y, width, height}"
)
text_value: str | None = Field(None, description="Updated text value")
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
class_name: str | None = Field(None, description="Updated class name")
reason: str | None = Field(None, description="Reason for override")
class AnnotationOverrideResponse(BaseModel):
"""Response for annotation override."""
annotation_id: str = Field(..., description="Annotation UUID")
source: str = Field(..., description="New source (manual)")
override_source: str | None = Field(None, description="Original source (auto)")
original_annotation_id: str | None = Field(None, description="Original annotation ID")
message: str = Field(..., description="Status message")
history_id: str = Field(..., description="History record UUID")

View File

@@ -0,0 +1,23 @@
"""Admin Auth Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class AdminTokenCreate(BaseModel):
"""Request to create an admin token."""
name: str = Field(..., min_length=1, max_length=255, description="Token name")
expires_in_days: int | None = Field(
None, ge=1, le=365, description="Token expiration in days (optional)"
)
class AdminTokenResponse(BaseModel):
"""Response with created admin token."""
token: str = Field(..., description="Admin token")
name: str = Field(..., description="Token name")
expires_at: datetime | None = Field(None, description="Token expiration time")
message: str = Field(..., description="Status message")

View File

@@ -0,0 +1,85 @@
"""Admin Dataset Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .training import TrainingConfig
class DatasetCreateRequest(BaseModel):
"""Request to create a training dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split")
class DatasetDocumentItem(BaseModel):
"""Document within a dataset."""
document_id: str
split: str
page_count: int
annotation_count: int
class DatasetResponse(BaseModel):
"""Response after creating a dataset."""
dataset_id: str
name: str
status: str
message: str
class DatasetDetailResponse(BaseModel):
"""Detailed dataset info with documents."""
dataset_id: str
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
documents: list[DatasetDocumentItem]
created_at: datetime
updated_at: datetime
class DatasetListItem(BaseModel):
"""Dataset in list view."""
dataset_id: str
name: str
description: str | None
status: str
total_documents: int
total_images: int
total_annotations: int
created_at: datetime
class DatasetListResponse(BaseModel):
"""Paginated dataset list."""
total: int
limit: int
offset: int
datasets: list[DatasetListItem]
class DatasetTrainRequest(BaseModel):
"""Request to start training from a dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
config: TrainingConfig = Field(..., description="Training configuration")

View File

@@ -0,0 +1,103 @@
"""Admin Document Schemas."""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from .enums import AutoLabelStatus, DocumentStatus
if TYPE_CHECKING:
from .annotations import AnnotationItem
from .training import TrainingHistoryItem
class DocumentUploadResponse(BaseModel):
"""Response for document upload."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
message: str = Field(..., description="Status message")
class DocumentItem(BaseModel):
"""Single document in list."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentListResponse(BaseModel):
"""Response for document list."""
total: int = Field(..., ge=0, description="Total documents")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
documents: list[DocumentItem] = Field(
default_factory=list, description="Document list"
)
class DocumentDetailResponse(BaseModel):
"""Response for document detail."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
content_type: str = Field(..., description="MIME type")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
annotation_lock_until: datetime | None = Field(
None, description="Lock expiration time if document is locked"
)
annotations: list["AnnotationItem"] = Field(
default_factory=list, description="Document annotations"
)
image_urls: list[str] = Field(
default_factory=list, description="URLs to page images"
)
training_history: list["TrainingHistoryItem"] = Field(
default_factory=list, description="Training tasks that used this document"
)
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentStatsResponse(BaseModel):
"""Document statistics response."""
total: int = Field(..., ge=0, description="Total documents")
pending: int = Field(default=0, ge=0, description="Pending documents")
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents")

View File

@@ -0,0 +1,46 @@
"""Admin API Enums."""
from enum import Enum
class DocumentStatus(str, Enum):
"""Document status enum."""
PENDING = "pending"
AUTO_LABELING = "auto_labeling"
LABELED = "labeled"
EXPORTED = "exported"
class AutoLabelStatus(str, Enum):
"""Auto-labeling status enum."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class TrainingStatus(str, Enum):
"""Training task status enum."""
PENDING = "pending"
SCHEDULED = "scheduled"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingType(str, Enum):
"""Training task type enum."""
TRAIN = "train"
FINETUNE = "finetune"
class AnnotationSource(str, Enum):
"""Annotation source enum."""
MANUAL = "manual"
AUTO = "auto"
IMPORTED = "imported"

View File

@@ -0,0 +1,202 @@
"""Admin Training Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
project_name: str = Field(
default="invoice_fields", description="Training project name"
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""
name: str = Field(..., min_length=1, max_length=255, description="Task name")
description: str | None = Field(None, max_length=1000, description="Description")
task_type: TrainingType = Field(
default=TrainingType.TRAIN, description="Task type"
)
config: TrainingConfig = Field(
default_factory=TrainingConfig, description="Training configuration"
)
scheduled_at: datetime | None = Field(
None, description="Scheduled execution time"
)
cron_expression: str | None = Field(
None, max_length=50, description="Cron expression for recurring tasks"
)
class TrainingTaskItem(BaseModel):
"""Single training task in list."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskListResponse(BaseModel):
"""Response for training task list."""
total: int = Field(..., ge=0, description="Total tasks")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
class TrainingTaskDetailResponse(BaseModel):
"""Response for training task detail."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
description: str | None = Field(None, description="Description")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
config: dict[str, Any] | None = Field(None, description="Training configuration")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
cron_expression: str | None = Field(None, description="Cron expression")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
error_message: str | None = Field(None, description="Error message")
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
model_path: str | None = Field(None, description="Trained model path")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskResponse(BaseModel):
"""Response for training task operation."""
task_id: str = Field(..., description="Task UUID")
status: TrainingStatus = Field(..., description="Task status")
message: str = Field(..., description="Status message")
class TrainingLogItem(BaseModel):
"""Single training log entry."""
level: str = Field(..., description="Log level")
message: str = Field(..., description="Log message")
details: dict[str, Any] | None = Field(None, description="Additional details")
created_at: datetime = Field(..., description="Timestamp")
class TrainingLogsResponse(BaseModel):
"""Response for training logs."""
task_id: str = Field(..., description="Task UUID")
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
class ExportRequest(BaseModel):
"""Request to export annotations."""
format: str = Field(
default="yolo", description="Export format (yolo, coco, voc)"
)
include_images: bool = Field(
default=True, description="Include images in export"
)
split_ratio: float = Field(
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
)
class ExportResponse(BaseModel):
"""Response for export operation."""
status: str = Field(..., description="Export status")
export_path: str = Field(..., description="Path to exported dataset")
total_images: int = Field(..., ge=0, description="Total images exported")
total_annotations: int = Field(..., ge=0, description="Total annotations")
train_count: int = Field(..., ge=0, description="Training set count")
val_count: int = Field(..., ge=0, description="Validation set count")
message: str = Field(..., description="Status message")
class TrainingDocumentItem(BaseModel):
"""Document item for training page."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Filename")
annotation_count: int = Field(..., ge=0, description="Total annotations")
annotation_sources: dict[str, int] = Field(
..., description="Annotation counts by source (manual, auto)"
)
used_in_training: list[str] = Field(
default_factory=list, description="List of training task IDs that used this document"
)
last_modified: datetime = Field(..., description="Last modification time")
class TrainingDocumentsResponse(BaseModel):
"""Response for GET /admin/training/documents."""
total: int = Field(..., ge=0, description="Total document count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
documents: list[TrainingDocumentItem] = Field(
default_factory=list, description="Documents available for training"
)
class ModelMetrics(BaseModel):
"""Training model metrics."""
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
class TrainingModelItem(BaseModel):
"""Trained model item for model list."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Model name")
status: TrainingStatus = Field(..., description="Training status")
document_count: int = Field(..., ge=0, description="Documents used in training")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")
metrics: ModelMetrics = Field(..., description="Model metrics")
model_path: str | None = Field(None, description="Path to model weights")
download_url: str | None = Field(None, description="Download URL for model")
class TrainingModelsResponse(BaseModel):
"""Response for GET /admin/training/models."""
total: int = Field(..., ge=0, description="Total model count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
models: list[TrainingModelItem] = Field(
default_factory=list, description="Trained models"
)
class TrainingHistoryItem(BaseModel):
"""Training history for a document."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
trained_at: datetime = Field(..., description="Training timestamp")
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")

View File

@@ -0,0 +1,15 @@
"""
Common Schemas
Shared Pydantic models used across multiple API modules.
"""
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")

View File

@@ -0,0 +1,196 @@
"""
API Request/Response Schemas
Pydantic models for API validation and serialization.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
# =============================================================================
# Enums
# =============================================================================
class AsyncStatus(str, Enum):
"""Async request status enum."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
# =============================================================================
# Sync API Schemas (existing)
# =============================================================================
class DetectionResult(BaseModel):
"""Single detection result."""
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
class ExtractedField(BaseModel):
"""Extracted and normalized field value."""
field_name: str = Field(..., description="Field name")
value: str | None = Field(None, description="Extracted value")
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
is_valid: bool = Field(True, description="Whether the value passed validation")
class InferenceResult(BaseModel):
"""Complete inference result for a document."""
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
document_type: str = Field(
default="invoice", description="Document type: 'invoice' or 'letter'"
)
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)
confidence: dict[str, float] = Field(
default_factory=dict, description="Confidence scores per field"
)
detections: list[DetectionResult] = Field(
default_factory=list, description="Raw YOLO detections"
)
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
errors: list[str] = Field(default_factory=list, description="Error messages")
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
status: str = Field(..., description="Response status: success or error")
message: str = Field(..., description="Response message")
result: InferenceResult | None = Field(None, description="Inference result")
class BatchInferenceResponse(BaseModel):
"""API response for batch inference endpoint."""
status: str = Field(..., description="Response status")
message: str = Field(..., description="Response message")
total: int = Field(..., description="Total documents processed")
successful: int = Field(..., description="Number of successful extractions")
results: list[InferenceResult] = Field(
default_factory=list, description="Individual results"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
gpu_available: bool = Field(..., description="Whether GPU is available")
version: str = Field(..., description="API version")
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")
# =============================================================================
# Async API Schemas
# =============================================================================
class AsyncSubmitResponse(BaseModel):
"""Response for async submit endpoint."""
status: str = Field(default="accepted", description="Response status")
message: str = Field(..., description="Response message")
request_id: str = Field(..., description="Unique request identifier (UUID)")
estimated_wait_seconds: int = Field(
..., ge=0, description="Estimated wait time in seconds"
)
poll_url: str = Field(..., description="URL to poll for status updates")
class AsyncStatusResponse(BaseModel):
"""Response for async status endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
created_at: datetime = Field(..., description="Request creation timestamp")
started_at: datetime | None = Field(
None, description="Processing start timestamp"
)
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
position_in_queue: int | None = Field(
None, description="Position in queue (for pending status)"
)
error_message: str | None = Field(
None, description="Error message (for failed status)"
)
result_url: str | None = Field(
None, description="URL to fetch results (for completed status)"
)
class AsyncResultResponse(BaseModel):
"""Response for async result endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Processing status")
processing_time_ms: float = Field(
..., ge=0, description="Total processing time in milliseconds"
)
result: InferenceResult | None = Field(
None, description="Extraction result (when completed)"
)
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
class AsyncRequestItem(BaseModel):
"""Single item in async requests list."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
created_at: datetime = Field(..., description="Request creation timestamp")
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
class AsyncRequestsListResponse(BaseModel):
"""Response for async requests list endpoint."""
total: int = Field(..., ge=0, description="Total number of requests")
limit: int = Field(..., ge=1, description="Maximum items per page")
offset: int = Field(..., ge=0, description="Current offset")
requests: list[AsyncRequestItem] = Field(
default_factory=list, description="List of requests"
)
class RateLimitInfo(BaseModel):
"""Rate limit information (included in headers)."""
limit: int = Field(..., description="Maximum requests per minute")
remaining: int = Field(..., description="Remaining requests in current window")
reset_at: datetime = Field(..., description="Time when limit resets")

View File

@@ -0,0 +1,13 @@
"""
Labeling API Schemas
Pydantic models for pre-labeling and label validation endpoints.
"""
from pydantic import BaseModel, Field
class PreLabelResponse(BaseModel):
"""API response for pre-label endpoint."""
document_id: str = Field(..., description="Document identifier for retrieving results")

View File

@@ -0,0 +1,18 @@
"""
Business Logic Services
Service layer for processing requests and orchestrating data operations.
"""
from inference.web.services.autolabel import AutoLabelService, get_auto_label_service
from inference.web.services.inference import InferenceService
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.services.batch_upload import BatchUploadService
__all__ = [
"AutoLabelService",
"get_auto_label_service",
"InferenceService",
"AsyncProcessingService",
"BatchUploadService",
]

View File

@@ -0,0 +1,383 @@
"""
Async Processing Service
Manages async request lifecycle and background processing.
"""
import logging
import shutil
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from threading import Event, Thread
from typing import TYPE_CHECKING
from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.core.rate_limiter import RateLimiter
if TYPE_CHECKING:
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.services.inference import InferenceService
logger = logging.getLogger(__name__)
@dataclass
class AsyncSubmitResult:
"""Result from async submit operation."""
success: bool
request_id: str | None = None
estimated_wait_seconds: int = 0
error: str | None = None
class AsyncProcessingService:
"""
Manages async request lifecycle and processing.
Coordinates between:
- HTTP endpoints (submit/status/result)
- Background task queue
- Database storage
- Rate limiting
"""
def __init__(
self,
inference_service: "InferenceService",
db: AsyncRequestDB,
queue: AsyncTaskQueue,
rate_limiter: RateLimiter,
async_config: "AsyncConfig",
storage_config: "StorageConfig",
) -> None:
self._inference = inference_service
self._db = db
self._queue = queue
self._rate_limiter = rate_limiter
self._async_config = async_config
self._storage_config = storage_config
# Cleanup thread
self._cleanup_stop_event = Event()
self._cleanup_thread: Thread | None = None
def start(self) -> None:
"""Start the async processing service."""
# Start the task queue with our handler
self._queue.start(self._process_task)
# Start cleanup thread
self._cleanup_stop_event.clear()
self._cleanup_thread = Thread(
target=self._cleanup_loop,
name="async-cleanup",
daemon=True,
)
self._cleanup_thread.start()
logger.info("AsyncProcessingService started")
def stop(self, timeout: float = 30.0) -> None:
"""Stop the async processing service."""
# Stop cleanup thread
self._cleanup_stop_event.set()
if self._cleanup_thread and self._cleanup_thread.is_alive():
self._cleanup_thread.join(timeout=5.0)
# Stop task queue
self._queue.stop(timeout=timeout)
logger.info("AsyncProcessingService stopped")
def submit_request(
self,
api_key: str,
file_content: bytes,
filename: str,
content_type: str,
) -> AsyncSubmitResult:
"""
Submit a new async processing request.
Args:
api_key: API key for the request
file_content: File content as bytes
filename: Original filename
content_type: File content type
Returns:
AsyncSubmitResult with request_id and status
"""
# Generate request ID
request_id = str(uuid.uuid4())
# Save file to temp storage
file_path = self._save_upload(request_id, filename, file_content)
file_size = len(file_content)
try:
# Calculate expiration
expires_at = datetime.utcnow() + timedelta(
days=self._async_config.result_retention_days
)
# Create database record
self._db.create_request(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
request_id=request_id,
)
# Record rate limit event
self._rate_limiter.record_request(api_key)
# Create and queue task
task = AsyncTask(
request_id=request_id,
api_key=api_key,
file_path=file_path,
filename=filename,
created_at=datetime.utcnow(),
)
if not self._queue.submit(task):
# Queue is full
self._db.update_status(
request_id,
"failed",
error_message="Processing queue is full",
)
# Cleanup file
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
request_id=request_id,
error="Processing queue is full. Please try again later.",
)
# Estimate wait time
estimated_wait = self._estimate_wait()
return AsyncSubmitResult(
success=True,
request_id=request_id,
estimated_wait_seconds=estimated_wait,
)
except Exception as e:
logger.error(f"Failed to submit request: {e}", exc_info=True)
# Cleanup file on error
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
# Return generic error message to avoid leaking implementation details
error="Failed to process request. Please try again later.",
)
# Allowed file extensions whitelist
ALLOWED_EXTENSIONS = frozenset({".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"})
def _save_upload(
self,
request_id: str,
filename: str,
content: bytes,
) -> Path:
"""Save uploaded file to temp storage."""
import re
# Extract extension from filename
ext = Path(filename).suffix.lower()
# Validate extension: must be alphanumeric only (e.g., .pdf, .png)
if not ext or not re.match(r'^\.[a-z0-9]+$', ext):
ext = ".pdf"
# Validate against whitelist
if ext not in self.ALLOWED_EXTENSIONS:
ext = ".pdf"
# Create async upload directory
upload_dir = self._async_config.temp_upload_dir
upload_dir.mkdir(parents=True, exist_ok=True)
# Build file path - request_id is a UUID so it's safe
file_path = upload_dir / f"{request_id}{ext}"
# Defense in depth: ensure path is within upload_dir
if not file_path.resolve().is_relative_to(upload_dir.resolve()):
raise ValueError("Invalid file path detected")
file_path.write_bytes(content)
return file_path
def _process_task(self, task: AsyncTask) -> None:
"""
Process a single task (called by worker thread).
This method is called by the AsyncTaskQueue worker threads.
"""
start_time = time.time()
try:
# Update status to processing
self._db.update_status(task.request_id, "processing")
# Ensure file exists
if not task.file_path.exists():
raise FileNotFoundError(f"Upload file not found: {task.file_path}")
# Run inference based on file type
file_ext = task.file_path.suffix.lower()
if file_ext == ".pdf":
result = self._inference.process_pdf(
task.file_path,
document_id=task.request_id[:8],
)
else:
result = self._inference.process_image(
task.file_path,
document_id=task.request_id[:8],
)
# Calculate processing time
processing_time_ms = (time.time() - start_time) * 1000
# Prepare result for storage
result_data = {
"document_id": result.document_id,
"success": result.success,
"document_type": result.document_type,
"fields": result.fields,
"confidence": result.confidence,
"detections": result.detections,
"errors": result.errors,
}
# Get visualization path as string
viz_path = None
if result.visualization_path:
viz_path = str(result.visualization_path.name)
# Store result in database
self._db.complete_request(
request_id=task.request_id,
document_id=result.document_id,
result=result_data,
processing_time_ms=processing_time_ms,
visualization_path=viz_path,
)
logger.info(
f"Task {task.request_id} completed successfully "
f"in {processing_time_ms:.0f}ms"
)
except Exception as e:
logger.error(
f"Task {task.request_id} failed: {e}",
exc_info=True,
)
self._db.update_status(
task.request_id,
"failed",
error_message=str(e),
increment_retry=True,
)
finally:
# Cleanup uploaded file
if task.file_path.exists():
task.file_path.unlink(missing_ok=True)
def _estimate_wait(self) -> int:
"""Estimate wait time based on queue depth."""
queue_depth = self._queue.get_queue_depth()
processing_count = self._queue.get_processing_count()
total_pending = queue_depth + processing_count
# Estimate ~5 seconds per document
avg_processing_time = 5
return total_pending * avg_processing_time
def _cleanup_loop(self) -> None:
"""Background cleanup loop."""
logger.info("Cleanup thread started")
cleanup_interval = self._async_config.cleanup_interval_hours * 3600
while not self._cleanup_stop_event.wait(timeout=cleanup_interval):
try:
self._run_cleanup()
except Exception as e:
logger.error(f"Cleanup failed: {e}", exc_info=True)
logger.info("Cleanup thread stopped")
def _run_cleanup(self) -> None:
"""Run cleanup operations."""
logger.info("Running cleanup...")
# Delete expired requests
deleted_requests = self._db.delete_expired_requests()
# Reset stale processing requests
reset_count = self._db.reset_stale_processing_requests(
stale_minutes=self._async_config.task_timeout_seconds // 60,
max_retries=3,
)
# Cleanup old rate limit events
deleted_events = self._db.cleanup_old_rate_limit_events(hours=1)
# Cleanup old poll timestamps
cleaned_polls = self._rate_limiter.cleanup_poll_timestamps()
# Cleanup rate limiter request windows
self._rate_limiter.cleanup_request_windows()
# Cleanup orphaned upload files
orphan_count = self._cleanup_orphan_files()
logger.info(
f"Cleanup complete: {deleted_requests} expired requests, "
f"{reset_count} stale requests reset, "
f"{deleted_events} rate limit events, "
f"{cleaned_polls} poll timestamps, "
f"{orphan_count} orphan files"
)
def _cleanup_orphan_files(self) -> int:
"""Clean up upload files that don't have matching requests."""
upload_dir = self._async_config.temp_upload_dir
if not upload_dir.exists():
return 0
count = 0
# Files older than 1 hour without matching request are considered orphans
cutoff = time.time() - 3600
for file_path in upload_dir.iterdir():
if not file_path.is_file():
continue
# Check if file is old enough
if file_path.stat().st_mtime > cutoff:
continue
# Extract request_id from filename
request_id = file_path.stem
# Check if request exists in database
request = self._db.get_request(request_id)
if request is None:
file_path.unlink(missing_ok=True)
count += 1
return count

View File

@@ -0,0 +1,335 @@
"""
Admin Auto-Labeling Service
Uses FieldMatcher to automatically create annotations from field values.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken
logger = logging.getLogger(__name__)
class AutoLabelService:
"""Service for automatic document labeling using field matching."""
def __init__(self, ocr_engine: OCREngine | None = None):
"""
Initialize auto-label service.
Args:
ocr_engine: OCR engine instance (creates one if not provided)
"""
self._ocr_engine = ocr_engine
self._field_matcher = FieldMatcher()
@property
def ocr_engine(self) -> OCREngine:
"""Lazy initialization of OCR engine."""
if self._ocr_engine is None:
self._ocr_engine = OCREngine(lang="en")
return self._ocr_engine
def auto_label_document(
self,
document_id: str,
file_path: str,
field_values: dict[str, str],
db: AdminDB,
replace_existing: bool = False,
skip_lock_check: bool = False,
) -> dict[str, Any]:
"""
Auto-label a document using field matching.
Args:
document_id: Document UUID
file_path: Path to document file
field_values: Dict of field_name -> value to match
db: Admin database instance
replace_existing: Whether to replace existing auto annotations
skip_lock_check: Skip annotation lock check (for batch processing)
Returns:
Dict with status and annotation count
"""
try:
# Get document info first
document = db.get_document(document_id)
if document is None:
raise ValueError(f"Document not found: {document_id}")
# Check annotation lock unless explicitly skipped
if not skip_lock_check:
from datetime import datetime, timezone
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
if document.annotation_lock_until > datetime.now(timezone.utc):
raise ValueError(
f"Document is locked for annotation until {document.annotation_lock_until}. "
"Auto-labeling skipped."
)
# Update status to running
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
# Delete existing auto annotations if requested
if replace_existing:
deleted = db.delete_annotations_for_document(
document_id=document_id,
source="auto",
)
logger.info(f"Deleted {deleted} existing auto annotations")
# Process document
path = Path(file_path)
annotations_created = 0
if path.suffix.lower() == ".pdf":
# Process PDF (all pages)
annotations_created = self._process_pdf(
document_id, path, field_values, db
)
else:
# Process single image
annotations_created = self._process_image(
document_id, path, field_values, db, page_number=1
)
# Update document status
status = "labeled" if annotations_created > 0 else "pending"
db.update_document_status(
document_id=document_id,
status=status,
auto_label_status="completed",
)
return {
"status": "completed",
"annotations_created": annotations_created,
}
except Exception as e:
logger.error(f"Auto-labeling failed for {document_id}: {e}")
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"status": "failed",
"error": str(e),
"annotations_created": 0,
}
def _process_pdf(
self,
document_id: str,
pdf_path: Path,
field_values: dict[str, str],
db: AdminDB,
) -> int:
"""Process PDF document and create annotations."""
from shared.pdf.renderer import render_pdf_to_images
import io
total_annotations = 0
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=page_no,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_no + 1, # 1-indexed
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
total_annotations += len(annotations)
return total_annotations
def _process_image(
self,
document_id: str,
image_path: Path,
field_values: dict[str, str],
db: AdminDB,
page_number: int = 1,
) -> int:
"""Process single image and create annotations."""
# Load image
image = Image.open(image_path)
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=0,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_number,
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
return len(annotations)
def _find_annotations(
self,
document_id: str,
tokens: list[OCRToken],
field_values: dict[str, str],
page_number: int,
image_width: int,
image_height: int,
) -> list[dict[str, Any]]:
"""Find annotations for field values using token matching."""
from shared.normalize import normalize_field
annotations = []
for field_name, value in field_values.items():
if not value or not value.strip():
continue
# Map field name to class ID
class_id = self._get_class_id(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {field_name}")
continue
class_name = FIELD_CLASSES[class_id]
# Normalize value
try:
normalized_values = normalize_field(field_name, value)
except Exception as e:
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
normalized_values = [value]
# Find matches
matches = self._field_matcher.find_matches(
tokens=tokens,
field_name=field_name,
normalized_values=normalized_values,
page_no=page_number - 1, # 0-indexed for matcher
)
# Take best match
if matches:
best_match = matches[0]
bbox = best_match.bbox # (x0, y0, x1, y1)
# Calculate normalized coordinates (YOLO format)
x_center = (bbox[0] + bbox[2]) / 2 / image_width
y_center = (bbox[1] + bbox[3]) / 2 / image_height
width = (bbox[2] - bbox[0]) / image_width
height = (bbox[3] - bbox[1]) / image_height
# Pixel coordinates
bbox_x = int(bbox[0])
bbox_y = int(bbox[1])
bbox_width = int(bbox[2] - bbox[0])
bbox_height = int(bbox[3] - bbox[1])
annotations.append({
"document_id": document_id,
"page_number": page_number,
"class_id": class_id,
"class_name": class_name,
"x_center": x_center,
"y_center": y_center,
"width": width,
"height": height,
"bbox_x": bbox_x,
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_value,
"confidence": best_match.score,
"source": "auto",
})
return annotations
def _get_class_id(self, field_name: str) -> int | None:
"""Map field name to class ID."""
# Direct match
if field_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[field_name]
# Handle alternative names
name_mapping = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"PaymentLine": "payment_line",
"customer_number": "customer_number",
}
mapped_name = name_mapping.get(field_name)
if mapped_name and mapped_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[mapped_name]
return None
# Global service instance
_auto_label_service: AutoLabelService | None = None
def get_auto_label_service() -> AutoLabelService:
"""Get the auto-label service instance."""
global _auto_label_service
if _auto_label_service is None:
_auto_label_service = AutoLabelService()
return _auto_label_service
def reset_auto_label_service() -> None:
"""Reset the auto-label service (for testing)."""
global _auto_label_service
_auto_label_service = None

View File

@@ -0,0 +1,548 @@
"""
Batch Upload Service
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
"""
import csv
import io
import logging
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from inference.data.admin_db import AdminDB
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)
# Security limits
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
MAX_FILES_IN_ZIP = 1000
class CSVRowData(BaseModel):
"""Validated CSV row data with security checks."""
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
invoice_number: str | None = Field(None, max_length=255)
invoice_date: str | None = Field(None, max_length=50)
invoice_due_date: str | None = Field(None, max_length=50)
amount: str | None = Field(None, max_length=100)
ocr: str | None = Field(None, max_length=100)
bankgiro: str | None = Field(None, max_length=50)
plusgiro: str | None = Field(None, max_length=50)
customer_number: str | None = Field(None, max_length=255)
supplier_organisation_number: str | None = Field(None, max_length=50)
@field_validator('*', mode='before')
@classmethod
def strip_whitespace(cls, v):
"""Strip whitespace from all string fields."""
if isinstance(v, str):
return v.strip()
return v
@field_validator('*', mode='before')
@classmethod
def reject_suspicious_patterns(cls, v):
"""Reject values with suspicious characters."""
if isinstance(v, str):
# Reject SQL/shell metacharacters and newlines
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
if any(char in v for char in dangerous_chars):
raise ValueError(f"Suspicious characters detected in value")
return v
class BatchUploadService:
"""Service for handling batch uploads of documents via ZIP files."""
def __init__(self, admin_db: AdminDB):
"""Initialize the batch upload service.
Args:
admin_db: Admin database interface
"""
self.admin_db = admin_db
def _safe_extract_filename(self, zip_path: str) -> str:
"""Safely extract filename from ZIP path, preventing path traversal.
Args:
zip_path: Path from ZIP file entry
Returns:
Safe filename
Raises:
ValueError: If path contains traversal attempts or is invalid
"""
# Reject absolute paths
if zip_path.startswith('/') or zip_path.startswith('\\'):
raise ValueError(f"Absolute path rejected: {zip_path}")
# Reject path traversal attempts
if '..' in zip_path:
raise ValueError(f"Path traversal rejected: {zip_path}")
# Reject Windows drive letters
if len(zip_path) >= 2 and zip_path[1] == ':':
raise ValueError(f"Windows path rejected: {zip_path}")
# Get only the basename
safe_name = Path(zip_path).name
if not safe_name or safe_name in ['.', '..']:
raise ValueError(f"Invalid filename: {zip_path}")
# Validate filename doesn't contain suspicious characters
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
raise ValueError(f"Invalid characters in filename: {safe_name}")
return safe_name
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
"""Validate ZIP file against Zip bomb and other attacks.
Args:
zip_file: Opened ZIP file
Raises:
ValueError: If ZIP file is unsafe
"""
total_uncompressed = 0
file_count = 0
for zip_info in zip_file.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILES_IN_ZIP:
raise ValueError(
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
)
# Check individual file size
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
raise ValueError(
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
)
# Accumulate uncompressed size
total_uncompressed += zip_info.file_size
# Check total uncompressed size (Zip bomb protection)
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
raise ValueError(
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
)
# Validate filename safety
try:
self._safe_extract_filename(zip_info.filename)
except ValueError as e:
logger.warning(f"Rejecting malicious ZIP entry: {e}")
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
def process_zip_upload(
self,
admin_token: str,
zip_filename: str,
zip_content: bytes,
upload_source: str = "ui",
) -> dict[str, Any]:
"""Process a ZIP file containing PDFs and optional CSV.
Args:
admin_token: Admin authentication token
zip_filename: Name of the ZIP file
zip_content: ZIP file content as bytes
upload_source: Upload source (ui or api)
Returns:
Dictionary with batch upload results
"""
batch = self.admin_db.create_batch_upload(
admin_token=admin_token,
filename=zip_filename,
file_size=len(zip_content),
upload_source=upload_source,
)
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
# Validate ZIP safety first
self._validate_zip_safety(zip_file)
result = self._process_zip_contents(
batch_id=batch.batch_id,
admin_token=admin_token,
zip_file=zip_file,
)
# Update batch upload status
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status=result["status"],
total_files=result["total_files"],
processed_files=result["processed_files"],
successful_files=result["successful_files"],
failed_files=result["failed_files"],
csv_filename=result.get("csv_filename"),
csv_row_count=result.get("csv_row_count"),
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
**result,
}
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="Invalid ZIP file format",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Invalid ZIP file format",
}
except ValueError as e:
# Security validation errors
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="ZIP file validation failed",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "ZIP file validation failed",
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="Processing error",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Failed to process batch upload",
}
def _process_zip_contents(
self,
batch_id: UUID,
admin_token: str,
zip_file: zipfile.ZipFile,
) -> dict[str, Any]:
"""Process contents of ZIP file.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
zip_file: Opened ZIP file
Returns:
Processing results dictionary
"""
# Extract file lists
pdf_files = []
csv_file = None
csv_data = {}
for file_info in zip_file.filelist:
if file_info.is_dir():
continue
try:
# Use safe filename extraction
filename = self._safe_extract_filename(file_info.filename)
except ValueError as e:
logger.warning(f"Skipping invalid file: {e}")
continue
if filename.lower().endswith('.pdf'):
pdf_files.append(file_info)
elif filename.lower().endswith('.csv'):
if csv_file is None:
csv_file = file_info
# Parse CSV
csv_data = self._parse_csv_file(zip_file, file_info)
else:
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
if not pdf_files:
return {
"status": "failed",
"total_files": 0,
"processed_files": 0,
"successful_files": 0,
"failed_files": 0,
"error": "No PDF files found in ZIP",
}
# Process each PDF file
total_files = len(pdf_files)
successful_files = 0
failed_files = 0
for pdf_info in pdf_files:
file_record = None
try:
# Use safe filename extraction
filename = self._safe_extract_filename(pdf_info.filename)
# Create batch upload file record
file_record = self.admin_db.create_batch_upload_file(
batch_id=batch_id,
filename=filename,
status="processing",
)
# Get CSV data for this file if available
document_id_base = Path(filename).stem
csv_row_data = csv_data.get(document_id_base)
# Extract PDF content
pdf_content = zip_file.read(pdf_info.filename)
# TODO: Save PDF file and create document
# For now, just mark as completed
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="completed",
csv_row_data=csv_row_data,
processed_at=datetime.utcnow(),
)
successful_files += 1
except ValueError as e:
# Path validation error
logger.warning(f"Skipping invalid file: {e}")
if file_record:
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="failed",
error_message="Invalid filename",
processed_at=datetime.utcnow(),
)
failed_files += 1
except Exception as e:
logger.error(f"Error processing PDF: {e}", exc_info=True)
if file_record:
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="failed",
error_message="Processing error",
processed_at=datetime.utcnow(),
)
failed_files += 1
# Determine overall status
if failed_files == 0:
status = "completed"
elif successful_files == 0:
status = "failed"
else:
status = "partial"
result = {
"status": status,
"total_files": total_files,
"processed_files": total_files,
"successful_files": successful_files,
"failed_files": failed_files,
}
if csv_file:
result["csv_filename"] = Path(csv_file.filename).name
result["csv_row_count"] = len(csv_data)
return result
def _parse_csv_file(
self,
zip_file: zipfile.ZipFile,
csv_file_info: zipfile.ZipInfo,
) -> dict[str, dict[str, Any]]:
"""Parse CSV file and extract field values with validation.
Args:
zip_file: Opened ZIP file
csv_file_info: CSV file info
Returns:
Dictionary mapping DocumentId to validated field values
"""
# Try multiple encodings
csv_bytes = zip_file.read(csv_file_info.filename)
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
csv_content = None
for encoding in encodings:
try:
csv_content = csv_bytes.decode(encoding)
logger.info(f"CSV decoded with {encoding}")
break
except UnicodeDecodeError:
continue
if csv_content is None:
logger.error("Failed to decode CSV with any encoding")
raise ValueError("Unable to decode CSV file")
csv_reader = csv.DictReader(io.StringIO(csv_content))
result = {}
# Case-insensitive column mapping
field_name_map = {
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
'Amount': ['Amount', 'amount'],
'OCR': ['OCR', 'ocr'],
'Bankgiro': ['Bankgiro', 'bankgiro'],
'Plusgiro': ['Plusgiro', 'plusgiro'],
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
}
for row_num, row in enumerate(csv_reader, start=2):
try:
# Create case-insensitive lookup
row_lower = {k.lower(): v for k, v in row.items()}
# Find DocumentId with case-insensitive matching
document_id = None
for variant in field_name_map['DocumentId']:
if variant.lower() in row_lower:
document_id = row_lower[variant.lower()]
break
if not document_id:
logger.warning(f"Row {row_num}: No DocumentId found")
continue
# Validate using Pydantic model
csv_row_dict = {'document_id': document_id}
# Map CSV field names to model attribute names
csv_to_model_attr = {
'InvoiceNumber': 'invoice_number',
'InvoiceDate': 'invoice_date',
'InvoiceDueDate': 'invoice_due_date',
'Amount': 'amount',
'OCR': 'ocr',
'Bankgiro': 'bankgiro',
'Plusgiro': 'plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for csv_field in field_name_map.keys():
if csv_field == 'DocumentId':
continue
model_attr = csv_to_model_attr.get(csv_field)
if not model_attr:
continue
for variant in field_name_map[csv_field]:
if variant.lower() in row_lower and row_lower[variant.lower()]:
csv_row_dict[model_attr] = row_lower[variant.lower()]
break
# Validate
validated_row = CSVRowData(**csv_row_dict)
# Extract only the fields we care about (map back to CSV field names)
field_values = {}
model_attr_to_csv = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'amount': 'Amount',
'ocr': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for model_attr, csv_field in model_attr_to_csv.items():
value = getattr(validated_row, model_attr, None)
if value and csv_field in CSV_TO_CLASS_MAPPING:
field_values[csv_field] = value
if field_values:
result[document_id] = field_values
except Exception as e:
logger.warning(f"Row {row_num}: Validation error - {e}")
continue
return result
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
"""Get batch upload status.
Args:
batch_id: Batch upload ID
Returns:
Batch status dictionary
"""
batch = self.admin_db.get_batch_upload(UUID(batch_id))
if not batch:
return {
"error": "Batch upload not found",
}
files = self.admin_db.get_batch_upload_files(batch.batch_id)
return {
"batch_id": str(batch.batch_id),
"filename": batch.filename,
"status": batch.status,
"total_files": batch.total_files,
"processed_files": batch.processed_files,
"successful_files": batch.successful_files,
"failed_files": batch.failed_files,
"csv_filename": batch.csv_filename,
"csv_row_count": batch.csv_row_count,
"error_message": batch.error_message,
"created_at": batch.created_at.isoformat() if batch.created_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"files": [
{
"filename": f.filename,
"status": f.status,
"error_message": f.error_message,
"annotation_count": f.annotation_count,
}
for f in files
],
}

View File

@@ -0,0 +1,188 @@
"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from inference.data.admin_models import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(self, db, base_dir: Path):
self._db = db
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._db.update_dataset_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._db.get_documents_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
doc_id = str(doc.document_id)
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._db.add_dataset_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._db.update_dataset_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))

View File

@@ -0,0 +1,531 @@
"""
Database-based Auto-labeling Service
Processes documents with field values stored in the database (csv_field_values).
Used by the pre-label API to create annotations from expected values.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
from shared.data.db import DocumentDB
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
# Initialize DocumentDB for saving match reports
_document_db: DocumentDB | None = None
def get_document_db() -> DocumentDB:
"""Get or create DocumentDB instance with connection and tables initialized.
Follows the same pattern as CLI autolabel (src/cli/autolabel.py lines 370-373).
"""
global _document_db
if _document_db is None:
_document_db = DocumentDB()
_document_db.connect()
_document_db.create_tables() # Ensure tables exist
logger.info("Connected to PostgreSQL DocumentDB for match reports")
return _document_db
def convert_csv_field_values_to_row_dict(
document: AdminDocument,
) -> dict[str, Any]:
"""
Convert AdminDocument.csv_field_values to row_dict format for autolabel.
Args:
document: AdminDocument with csv_field_values
Returns:
Dictionary in row_dict format compatible with autolabel_tasks
"""
csv_values = document.csv_field_values or {}
# Build row_dict with DocumentId
row_dict = {
"DocumentId": str(document.document_id),
}
# Map csv_field_values to row_dict format
# csv_field_values uses keys like: InvoiceNumber, InvoiceDate, Amount, OCR, Bankgiro, etc.
# row_dict uses same keys
for key, value in csv_values.items():
if value is not None and value != "":
row_dict[key] = str(value)
return row_dict
def get_pending_autolabel_documents(
db: AdminDB,
limit: int = 10,
) -> list[AdminDocument]:
"""
Get documents pending auto-labeling.
Args:
db: AdminDB instance
limit: Maximum number of documents to return
Returns:
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
"""
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == "auto_labeling",
AdminDocument.auto_label_status == "pending",
).order_by(AdminDocument.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def process_document_autolabel(
document: AdminDocument,
db: AdminDB,
output_dir: Path | None = None,
dpi: int = DEFAULT_DPI,
min_confidence: float = 0.5,
) -> dict[str, Any]:
"""
Process a single document for auto-labeling using csv_field_values.
Args:
document: AdminDocument with csv_field_values and file_path
db: AdminDB instance for updating status
output_dir: Output directory for temp files
dpi: Rendering DPI
min_confidence: Minimum match confidence
Returns:
Result dictionary with success status and annotations
"""
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
from shared.pdf import PDFDocument
document_id = str(document.document_id)
file_path = Path(document.file_path)
if output_dir is None:
output_dir = Path("data/autolabel_output")
output_dir.mkdir(parents=True, exist_ok=True)
# Mark as processing
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
try:
# Check if file exists
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Convert csv_field_values to row_dict
row_dict = convert_csv_field_values_to_row_dict(document)
if len(row_dict) <= 1: # Only has DocumentId
raise ValueError("No field values to match")
# Determine PDF type (text or scanned)
is_scanned = False
with PDFDocument(file_path) as pdf_doc:
# Check if first page has extractable text
tokens = list(pdf_doc.extract_text_tokens(0))
is_scanned = len(tokens) < 10 # Threshold for "no text"
# Build task data
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
# This ensures consistency with CLI autolabel for reprocess_failed.py
storage_config = StorageConfig()
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
task_data = {
"row_dict": row_dict,
"pdf_path": str(pdf_path_for_report),
"output_dir": str(output_dir),
"dpi": dpi,
"min_confidence": min_confidence,
}
# Process based on PDF type
if is_scanned:
result = process_scanned_pdf(task_data)
else:
result = process_text_pdf(task_data)
# Save report to DocumentDB (same as CLI autolabel)
if result.get("report"):
try:
doc_db = get_document_db()
doc_db.save_document(result["report"])
logger.info(f"Saved match report to DocumentDB for {document_id}")
except Exception as e:
logger.warning(f"Failed to save report to DocumentDB: {e}")
# Save annotations to AdminDB
if result.get("success") and result.get("report"):
_save_annotations_to_db(
db=db,
document_id=document_id,
report=result["report"],
page_annotations=result.get("pages", []),
dpi=dpi,
)
# Mark as completed
db.update_document_status(
document_id=document_id,
status="labeled",
auto_label_status="completed",
)
else:
# Mark as failed
errors = result.get("report", {}).get("errors", ["Unknown error"])
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error="; ".join(errors) if errors else "No annotations generated",
)
return result
except Exception as e:
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
# Mark as failed
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"doc_id": document_id,
"success": False,
"error": str(e),
}
def _save_annotations_to_db(
db: AdminDB,
document_id: str,
report: dict[str, Any],
page_annotations: list[dict[str, Any]],
dpi: int = 200,
) -> int:
"""
Save generated annotations to database.
Args:
db: AdminDB instance
document_id: Document ID
report: AutoLabelReport as dict
page_annotations: List of page annotation data
dpi: DPI used for rendering images (for coordinate conversion)
Returns:
Number of annotations saved
"""
from PIL import Image
from inference.data.admin_models import FIELD_CLASS_IDS
# Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"customer_number": "customer_number",
"payment_line": "payment_line",
}
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
scale = dpi / 72.0
# Cache for image dimensions per page
image_dimensions: dict[int, tuple[int, int]] = {}
def get_image_dimensions(page_no: int) -> tuple[int, int] | None:
"""Get image dimensions for a page (1-indexed)."""
if page_no in image_dimensions:
return image_dimensions[page_no]
# Try to load from admin_images
admin_images_dir = Path("data/admin_images") / document_id
image_path = admin_images_dir / f"page_{page_no}.png"
if image_path.exists():
try:
with Image.open(image_path) as img:
dims = img.size # (width, height)
image_dimensions[page_no] = dims
return dims
except Exception as e:
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
return None
annotation_count = 0
# Get field results from report (list of dicts)
field_results = report.get("field_results", [])
for field_info in field_results:
if not field_info.get("matched"):
continue
csv_field_name = field_info.get("field_name", "")
# Map CSV field name to internal field name
field_name = CSV_TO_INTERNAL_FIELD.get(csv_field_name, csv_field_name)
# Get class_id from field name
class_id = FIELD_CLASS_IDS.get(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {csv_field_name} -> {field_name}")
continue
# Get bbox info (list: [x, y, x2, y2] in PDF points - 72 DPI)
bbox = field_info.get("bbox", [])
if not bbox or len(bbox) < 4:
continue
# Convert PDF points (72 DPI) to pixel coordinates (at configured DPI)
pdf_x1, pdf_y1, pdf_x2, pdf_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
x1 = pdf_x1 * scale
y1 = pdf_y1 * scale
x2 = pdf_x2 * scale
y2 = pdf_y2 * scale
bbox_width = x2 - x1
bbox_height = y2 - y1
# Get page number (convert to 1-indexed)
page_no = field_info.get("page_no", 0) + 1
# Get image dimensions for normalization
dims = get_image_dimensions(page_no)
if dims:
img_width, img_height = dims
# Calculate normalized coordinates
x_center = (x1 + x2) / 2 / img_width
y_center = (y1 + y2) / 2 / img_height
width = bbox_width / img_width
height = bbox_height / img_height
else:
# Fallback: use pixel coordinates as-is for normalization
# (will be slightly off but better than /1000)
logger.warning(f"Could not get image dimensions for page {page_no}, using estimates")
# Estimate A4 at configured DPI: 595 x 842 points * scale
estimated_width = 595 * scale
estimated_height = 842 * scale
x_center = (x1 + x2) / 2 / estimated_width
y_center = (y1 + y2) / 2 / estimated_height
width = bbox_width / estimated_width
height = bbox_height / estimated_height
# Create annotation
try:
db.create_annotation(
document_id=document_id,
page_number=page_no,
class_id=class_id,
class_name=field_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=int(x1),
bbox_y=int(y1),
bbox_width=int(bbox_width),
bbox_height=int(bbox_height),
text_value=field_info.get("matched_text"),
confidence=field_info.get("score"),
source="auto",
)
annotation_count += 1
logger.info(f"Saved annotation for {field_name}: bbox=({int(x1)}, {int(y1)}, {int(bbox_width)}, {int(bbox_height)})")
except Exception as e:
logger.warning(f"Failed to save annotation for {field_name}: {e}")
return annotation_count
def run_pending_autolabel_batch(
db: AdminDB | None = None,
batch_size: int = 10,
output_dir: Path | None = None,
) -> dict[str, Any]:
"""
Process a batch of pending auto-label documents.
Args:
db: AdminDB instance (created if None)
batch_size: Number of documents to process
output_dir: Output directory for temp files
Returns:
Summary of processing results
"""
if db is None:
db = AdminDB()
documents = get_pending_autolabel_documents(db, limit=batch_size)
results = {
"total": len(documents),
"successful": 0,
"failed": 0,
"documents": [],
}
for doc in documents:
result = process_document_autolabel(
document=doc,
db=db,
output_dir=output_dir,
)
doc_result = {
"document_id": str(doc.document_id),
"success": result.get("success", False),
}
if result.get("success"):
results["successful"] += 1
else:
results["failed"] += 1
doc_result["error"] = result.get("error") or "Unknown error"
results["documents"].append(doc_result)
return results
def save_manual_annotations_to_document_db(
document: AdminDocument,
annotations: list,
db: AdminDB,
) -> dict[str, Any]:
"""
Save manual annotations to PostgreSQL documents and field_results tables.
Called when user marks a document as 'labeled' from the web UI.
This ensures manually labeled documents are also tracked in the same
database as auto-labeled documents for consistency.
Args:
document: AdminDocument instance
annotations: List of AdminAnnotation instances
db: AdminDB instance
Returns:
Dict with success status and details
"""
from datetime import datetime
document_id = str(document.document_id)
storage_config = StorageConfig()
# Build pdf_path using admin_upload_dir (same as auto-label)
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
# Build report dict compatible with DocumentDB.save_document()
field_results = []
fields_total = len(annotations)
fields_matched = 0
for ann in annotations:
# All manual annotations are considered "matched" since user verified them
field_result = {
"field_name": ann.class_name,
"csv_value": ann.text_value or "", # Manual annotations may not have CSV value
"matched": True,
"score": ann.confidence or 1.0, # Manual = high confidence
"matched_text": ann.text_value,
"candidate_used": "manual",
"bbox": [ann.bbox_x, ann.bbox_y, ann.bbox_x + ann.bbox_width, ann.bbox_y + ann.bbox_height],
"page_no": ann.page_number - 1, # Convert to 0-indexed
"context_keywords": [],
"error": None,
}
field_results.append(field_result)
fields_matched += 1
# Determine PDF type
pdf_type = "unknown"
if pdf_path.exists():
try:
from shared.pdf import PDFDocument
with PDFDocument(pdf_path) as pdf_doc:
tokens = list(pdf_doc.extract_text_tokens(0))
pdf_type = "scanned" if len(tokens) < 10 else "text"
except Exception as e:
logger.warning(f"Could not determine PDF type: {e}")
# Build report
report = {
"document_id": document_id,
"pdf_path": str(pdf_path),
"pdf_type": pdf_type,
"success": fields_matched > 0,
"total_pages": document.page_count,
"fields_matched": fields_matched,
"fields_total": fields_total,
"annotations_generated": fields_matched,
"processing_time_ms": 0, # Manual labeling - no processing time
"timestamp": datetime.utcnow().isoformat(),
"errors": [],
"field_results": field_results,
# Extended fields (from CSV if available)
"split": None,
"customer_number": document.csv_field_values.get("customer_number") if document.csv_field_values else None,
"supplier_name": document.csv_field_values.get("supplier_name") if document.csv_field_values else None,
"supplier_organisation_number": document.csv_field_values.get("supplier_organisation_number") if document.csv_field_values else None,
"supplier_accounts": document.csv_field_values.get("supplier_accounts") if document.csv_field_values else None,
}
# Save to PostgreSQL DocumentDB
try:
doc_db = get_document_db()
doc_db.save_document(report)
logger.info(f"Saved manual annotations to DocumentDB for {document_id}: {fields_matched} fields")
return {
"success": True,
"document_id": document_id,
"fields_saved": fields_matched,
"message": f"Saved {fields_matched} annotations to DocumentDB",
}
except Exception as e:
logger.error(f"Failed to save manual annotations to DocumentDB: {e}", exc_info=True)
return {
"success": False,
"document_id": document_id,
"error": str(e),
}

View File

@@ -0,0 +1,285 @@
"""
Inference Service
Business logic for invoice field extraction.
"""
from __future__ import annotations
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
from PIL import Image
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
class InferenceService:
"""
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
storage_config: Storage configuration
"""
self.model_config = model_config
self.storage_config = storage_config
self._pipeline = None
self._detector = None
self._is_initialized = False
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
if self._is_initialized:
return
logger.info("Initializing inference service...")
start_time = time.time()
try:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
)
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
# If no pages rendered
return None

View File

@@ -0,0 +1,24 @@
"""
Background Task Queues
Worker queues for asynchronous and batch processing.
"""
from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask
from inference.web.workers.batch_queue import (
BatchTaskQueue,
BatchTask,
init_batch_queue,
shutdown_batch_queue,
get_batch_queue,
)
__all__ = [
"AsyncTaskQueue",
"AsyncTask",
"BatchTaskQueue",
"BatchTask",
"init_batch_queue",
"shutdown_batch_queue",
"get_batch_queue",
]

View File

@@ -0,0 +1,181 @@
"""
Async Task Queue
Thread-safe queue for background invoice processing.
"""
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from queue import Empty, Full, Queue
import threading
from threading import Event, Lock, Thread
from typing import Callable
logger = logging.getLogger(__name__)
@dataclass
class AsyncTask:
"""Task queued for background processing."""
request_id: str
api_key: str
file_path: Path
filename: str
created_at: datetime = field(default_factory=datetime.utcnow)
priority: int = 0 # Lower = higher priority (not implemented yet)
class AsyncTaskQueue:
"""Thread-safe queue for async invoice processing."""
def __init__(
self,
max_size: int = 100,
worker_count: int = 1,
) -> None:
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
self._workers: list[Thread] = []
self._stop_event = Event()
self._worker_count = worker_count
self._lock = Lock()
self._processing: set[str] = set() # Currently processing request_ids
self._task_handler: Callable[[AsyncTask], None] | None = None
self._started = False
def start(self, task_handler: Callable[[AsyncTask], None]) -> None:
"""Start background worker threads."""
if self._started:
logger.warning("AsyncTaskQueue already started")
return
self._task_handler = task_handler
self._stop_event.clear()
for i in range(self._worker_count):
worker = Thread(
target=self._worker_loop,
name=f"async-worker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started async worker thread: {worker.name}")
self._started = True
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
def stop(self, timeout: float = 30.0) -> None:
"""Gracefully stop all workers."""
if not self._started:
return
logger.info("Stopping AsyncTaskQueue...")
self._stop_event.set()
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout / self._worker_count)
if worker.is_alive():
logger.warning(f"Worker {worker.name} did not stop gracefully")
self._workers.clear()
self._started = False
logger.info("AsyncTaskQueue stopped")
def submit(self, task: AsyncTask) -> bool:
"""
Submit a task to the queue.
Returns:
True if task was queued, False if queue is full
"""
try:
self._queue.put_nowait(task)
logger.info(f"Task {task.request_id} queued for processing")
return True
except Full:
logger.warning(f"Queue full, task {task.request_id} rejected")
return False
def get_queue_depth(self) -> int:
"""Get current number of tasks in queue."""
return self._queue.qsize()
def get_processing_count(self) -> int:
"""Get number of tasks currently being processed."""
with self._lock:
return len(self._processing)
def is_processing(self, request_id: str) -> bool:
"""Check if a specific request is currently being processed."""
with self._lock:
return request_id in self._processing
@property
def is_running(self) -> bool:
"""Check if the queue is running."""
return self._started and not self._stop_event.is_set()
def _worker_loop(self) -> None:
"""Worker loop that processes tasks from queue."""
thread_name = threading.current_thread().name
logger.info(f"Worker {thread_name} started")
while not self._stop_event.is_set():
try:
# Block for up to 1 second waiting for tasks
task = self._queue.get(timeout=1.0)
except Empty:
continue
try:
with self._lock:
self._processing.add(task.request_id)
logger.info(
f"Worker {thread_name} processing task {task.request_id}"
)
start_time = time.time()
if self._task_handler:
self._task_handler(task)
elapsed = time.time() - start_time
logger.info(
f"Worker {thread_name} completed task {task.request_id} "
f"in {elapsed:.2f}s"
)
except Exception as e:
logger.error(
f"Worker {thread_name} failed to process task "
f"{task.request_id}: {e}",
exc_info=True,
)
finally:
with self._lock:
self._processing.discard(task.request_id)
self._queue.task_done()
logger.info(f"Worker {thread_name} stopped")
def wait_for_completion(self, timeout: float | None = None) -> bool:
"""
Wait for all queued tasks to complete.
Args:
timeout: Maximum time to wait in seconds
Returns:
True if all tasks completed, False if timeout
"""
try:
self._queue.join()
return True
except Exception:
return False

View File

@@ -0,0 +1,225 @@
"""
Batch Upload Processing Queue
Background queue for async batch upload processing.
"""
import logging
import threading
from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Full, Empty
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
@dataclass
class BatchTask:
"""Task for batch upload processing."""
batch_id: UUID
admin_token: str
zip_content: bytes
zip_filename: str
upload_source: str
auto_label: bool
created_at: datetime
class BatchTaskQueue:
"""Thread-safe queue for async batch upload processing."""
def __init__(self, max_size: int = 20, worker_count: int = 2):
"""Initialize the batch task queue.
Args:
max_size: Maximum queue size
worker_count: Number of worker threads
"""
self._queue: Queue[BatchTask] = Queue(maxsize=max_size)
self._workers: list[threading.Thread] = []
self._stop_event = threading.Event()
self._worker_count = worker_count
self._batch_service: Any | None = None
self._running = False
self._lock = threading.Lock()
def start(self, batch_service: Any) -> None:
"""Start worker threads with batch service.
Args:
batch_service: BatchUploadService instance for processing
"""
with self._lock:
if self._running:
logger.warning("Batch queue already running")
return
self._batch_service = batch_service
self._stop_event.clear()
self._running = True
# Start worker threads
for i in range(self._worker_count):
worker = threading.Thread(
target=self._worker_loop,
name=f"BatchWorker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started {self._worker_count} batch workers")
def stop(self, timeout: float = 30.0) -> None:
"""Stop all worker threads gracefully.
Args:
timeout: Maximum time to wait for workers to finish
"""
with self._lock:
if not self._running:
return
logger.info("Stopping batch queue...")
self._stop_event.set()
self._running = False
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout)
self._workers.clear()
logger.info("Batch queue stopped")
def submit(self, task: BatchTask) -> bool:
"""Submit a batch task to the queue.
Args:
task: Batch task to process
Returns:
True if task was queued, False if queue is full
"""
try:
self._queue.put(task, block=False)
logger.info(f"Queued batch task: batch_id={task.batch_id}")
return True
except Full:
logger.warning(f"Queue full, rejected task: batch_id={task.batch_id}")
return False
def get_queue_depth(self) -> int:
"""Get the number of pending tasks in queue.
Returns:
Number of tasks waiting to be processed
"""
return self._queue.qsize()
@property
def is_running(self) -> bool:
"""Check if queue is running.
Returns:
True if queue is active
"""
return self._running
def _worker_loop(self) -> None:
"""Worker thread main loop."""
worker_name = threading.current_thread().name
logger.info(f"{worker_name} started")
while not self._stop_event.is_set():
try:
# Get task with timeout to check stop event periodically
task = self._queue.get(timeout=1.0)
self._process_task(task)
self._queue.task_done()
except Empty:
# No tasks, continue loop to check stop event
continue
except Exception as e:
logger.error(f"{worker_name} error processing task: {e}", exc_info=True)
logger.info(f"{worker_name} stopped")
def _process_task(self, task: BatchTask) -> None:
"""Process a single batch task.
Args:
task: Batch task to process
"""
if self._batch_service is None:
logger.error("Batch service not initialized, cannot process task")
return
logger.info(
f"Processing batch task: batch_id={task.batch_id}, "
f"filename={task.zip_filename}"
)
try:
# Process the batch upload using the service
result = self._batch_service.process_zip_upload(
admin_token=task.admin_token,
zip_filename=task.zip_filename,
zip_content=task.zip_content,
upload_source=task.upload_source,
)
logger.info(
f"Batch task completed: batch_id={task.batch_id}, "
f"status={result.get('status')}, "
f"successful_files={result.get('successful_files')}, "
f"failed_files={result.get('failed_files')}"
)
except Exception as e:
logger.error(
f"Error processing batch task {task.batch_id}: {e}",
exc_info=True,
)
# Global batch queue instance
_batch_queue: BatchTaskQueue | None = None
_queue_lock = threading.Lock()
def get_batch_queue() -> BatchTaskQueue:
"""Get or create the global batch queue instance.
Returns:
Batch task queue instance
"""
global _batch_queue
if _batch_queue is None:
with _queue_lock:
if _batch_queue is None:
_batch_queue = BatchTaskQueue(max_size=20, worker_count=2)
return _batch_queue
def init_batch_queue(batch_service: Any) -> None:
"""Initialize and start the batch queue.
Args:
batch_service: BatchUploadService instance
"""
queue = get_batch_queue()
if not queue.is_running:
queue.start(batch_service)
def shutdown_batch_queue() -> None:
"""Shutdown the batch queue gracefully."""
global _batch_queue
if _batch_queue is not None:
_batch_queue.stop()

View File

@@ -0,0 +1,8 @@
-e ../shared
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
sqlmodel>=0.0.22
ultralytics>=8.1.0
httpx>=0.25.0
openai>=1.0.0

View File

@@ -0,0 +1,14 @@
#!/usr/bin/env python
"""
Quick start script for the web server.
Usage:
python run_server.py
python run_server.py --port 8080
python run_server.py --debug --reload
"""
from inference.cli.serve import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,17 @@
from setuptools import setup, find_packages
setup(
name="invoice-inference",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"invoice-shared",
"fastapi>=0.104.0",
"uvicorn[standard]>=0.24.0",
"python-multipart>=0.0.6",
"sqlmodel>=0.0.22",
"ultralytics>=8.1.0",
"httpx>=0.25.0",
],
)