restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install inference package
COPY packages/inference /app/packages/inference
RUN pip install --no-cache-dir -e /app/packages/inference
# Copy frontend (if needed)
COPY frontend /app/frontend
WORKDIR /app/packages/inference
EXPOSE 8000
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]

View File

View File

@@ -0,0 +1,105 @@
"""Trigger training jobs on Azure Container Instances."""
import logging
import os
logger = logging.getLogger(__name__)
# Azure SDK is optional; only needed if using ACI trigger
try:
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
EnvironmentVariable,
GpuResource,
ResourceRequests,
ResourceRequirements,
)
_AZURE_SDK_AVAILABLE = True
except ImportError:
_AZURE_SDK_AVAILABLE = False
def start_training_container(task_id: str) -> str | None:
"""
Start an Azure Container Instance for a training task.
Returns the container group name if successful, None otherwise.
Requires environment variables:
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
"""
if not _AZURE_SDK_AVAILABLE:
logger.warning(
"Azure SDK not installed. Install azure-mgmt-containerinstance "
"and azure-identity to use ACI trigger."
)
return None
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
image = os.environ.get(
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
)
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
location = os.environ.get("AZURE_LOCATION", "eastus")
if not subscription_id:
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
return None
credential = DefaultAzureCredential()
client = ContainerInstanceManagementClient(credential, subscription_id)
container_name = f"training-{task_id[:8]}"
env_vars = [
EnvironmentVariable(name="TASK_ID", value=task_id),
]
# Pass DB connection securely
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
val = os.environ.get(var, "")
if val:
env_vars.append(EnvironmentVariable(name=var, value=val))
db_password = os.environ.get("DB_PASSWORD", "")
if db_password:
env_vars.append(
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
)
container = Container(
name=container_name,
image=image,
resources=ResourceRequirements(
requests=ResourceRequests(
cpu=4,
memory_in_gb=16,
gpu=GpuResource(count=1, sku=gpu_sku),
)
),
environment_variables=env_vars,
command=[
"python",
"run_training.py",
"--task-id",
task_id,
],
)
group = ContainerGroup(
location=location,
containers=[container],
os_type="Linux",
restart_policy="Never",
)
logger.info("Creating ACI container group: %s", container_name)
client.container_groups.begin_create_or_update(
resource_group, container_name, group
)
return container_name

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Inference CLI
Runs inference on new PDFs to extract invoice data.
"""
import argparse
import json
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
def main():
parser = argparse.ArgumentParser(
description='Extract invoice data from PDFs using trained model'
)
parser.add_argument(
'--model', '-m',
required=True,
help='Path to trained YOLO model (.pt file)'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Input PDF file or directory'
)
parser.add_argument(
'--output', '-o',
help='Output JSON file (default: stdout)'
)
parser.add_argument(
'--confidence',
type=float,
default=0.5,
help='Detection confidence threshold (default: 0.5)'
)
parser.add_argument(
'--dpi',
type=int,
default=DEFAULT_DPI,
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
)
parser.add_argument(
'--no-fallback',
action='store_true',
help='Disable fallback OCR'
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
)
parser.add_argument(
'--gpu',
action='store_true',
help='Use GPU'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
args = parser.parse_args()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
sys.exit(1)
# Get input files
input_path = Path(args.input)
if input_path.is_file():
pdf_files = [input_path]
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
from inference.pipeline import InferencePipeline
# Initialize pipeline
pipeline = InferencePipeline(
model_path=model_path,
confidence_threshold=args.confidence,
ocr_lang=args.lang,
use_gpu=args.gpu,
dpi=args.dpi,
enable_fallback=not args.no_fallback
)
# Process files
results = []
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
if result.fallback_used:
print(f" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
# Output results
if len(results) == 1:
output = results[0]
else:
output = results
json_output = json.dumps(output, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
else:
print(json_output)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,159 @@
"""
Web Server CLI
Command-line interface for starting the web server.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
from shared.config import DEFAULT_DPI
def setup_logging(debug: bool = False) -> None:
"""Configure logging."""
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Start the Invoice Field Extraction web server",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to listen on",
)
parser.add_argument(
"--model",
"-m",
type=Path,
default=Path("runs/train/invoice_fields/weights/best.pt"),
help="Path to YOLO model weights",
)
parser.add_argument(
"--confidence",
type=float,
default=0.5,
help="Detection confidence threshold",
)
parser.add_argument(
"--dpi",
type=int,
default=DEFAULT_DPI,
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="Disable GPU acceleration",
)
parser.add_argument(
"--reload",
action="store_true",
help="Enable auto-reload for development",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode",
)
return parser.parse_args()
def main() -> None:
"""Main entry point."""
args = parse_args()
setup_logging(debug=args.debug)
logger = logging.getLogger(__name__)
# Validate model path
if not args.model.exists():
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info("=" * 60)
logger.info("Invoice Field Extraction Web Server")
logger.info("=" * 60)
logger.info(f"Model: {args.model}")
logger.info(f"Confidence threshold: {args.confidence}")
logger.info(f"GPU enabled: {not args.no_gpu}")
logger.info(f"Server: http://{args.host}:{args.port}")
logger.info("=" * 60)
# Create config
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
config = AppConfig(
model=ModelConfig(
model_path=args.model,
confidence_threshold=args.confidence,
use_gpu=not args.no_gpu,
dpi=args.dpi,
),
server=ServerConfig(
host=args.host,
port=args.port,
debug=args.debug,
reload=args.reload,
workers=args.workers,
),
storage=StorageConfig(),
)
# Create and run app
import uvicorn
from inference.web.app import create_app
app = create_app(config)
uvicorn.run(
app,
host=config.server.host,
port=config.server.port,
reload=config.server.reload,
workers=config.server.workers if not config.server.reload else 1,
log_level="debug" if config.server.debug else "info",
)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,407 @@
"""
Admin API SQLModel Database Models
Defines the database schema for admin document management, annotations, and training tasks.
Includes batch upload support, training document links, and annotation history.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
# =============================================================================
# CSV to Field Class Mapping
# =============================================================================
CSV_TO_CLASS_MAPPING: dict[str, int] = {
"InvoiceNumber": 0, # invoice_number
"InvoiceDate": 1, # invoice_date
"InvoiceDueDate": 2, # invoice_due_date
"OCR": 3, # ocr_number
"Bankgiro": 4, # bankgiro
"Plusgiro": 5, # plusgiro
"Amount": 6, # amount
"supplier_organisation_number": 7, # supplier_organisation_number
# 8: payment_line (derived from OCR/Bankgiro/Amount)
"customer_number": 9, # customer_number
}
# =============================================================================
# Core Models
# =============================================================================
class AdminToken(SQLModel, table=True):
"""Admin authentication token."""
__tablename__ = "admin_tokens"
token: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
expires_at: datetime | None = Field(default=None)
class AdminDocument(SQLModel, table=True):
"""Document uploaded for labeling/annotation."""
__tablename__ = "admin_documents"
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
file_path: str = Field(max_length=512) # Path to stored file
page_count: int = Field(default=1)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, auto_labeling, labeled, exported
auto_label_status: str | None = Field(default=None, max_length=20)
# Auto-label status: running, completed, failed
auto_label_error: str | None = Field(default=None)
# v2: Upload source tracking
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
# When auto-label was queued
annotation_lock_until: datetime | None = Field(default=None)
# Lock for manual annotation while auto-label runs
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class AdminAnnotation(SQLModel, table=True):
"""Annotation for a document (bounding box + label)."""
__tablename__ = "admin_annotations"
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
page_number: int = Field(default=1) # 1-indexed
class_id: int # 0-9 for invoice fields
class_name: str = Field(max_length=50) # e.g., "invoice_number"
# Bounding box (normalized 0-1 coordinates)
x_center: float
y_center: float
width: float
height: float
# Original pixel coordinates (for display)
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
# OCR extracted text (if available)
text_value: str | None = Field(default=None)
confidence: float | None = Field(default=None)
# Source: manual, auto, imported
source: str = Field(default="manual", max_length=20, index=True)
# v2: Verification fields
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
# v2: Override tracking
override_source: str | None = Field(default=None, max_length=20)
# If this annotation overrides another: 'auto' or 'imported'
original_annotation_id: UUID | None = Field(default=None)
# Reference to the annotation this overrides
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingTask(SQLModel, table=True):
"""Training/fine-tuning task."""
__tablename__ = "training_tasks"
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, scheduled, running, completed, failed, cancelled
task_type: str = Field(default="train", max_length=20)
# Task type: train, finetune
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training configuration
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Schedule settings
scheduled_at: datetime | None = Field(default=None)
cron_expression: str | None = Field(default=None, max_length=50)
is_recurring: bool = Field(default=False)
# Execution details
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
error_message: str | None = Field(default=None)
# Result metrics
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
model_path: str | None = Field(default=None, max_length=512)
# v2: Document count and extracted metrics
document_count: int = Field(default=0)
# Count of documents used in training
metrics_mAP: float | None = Field(default=None, index=True)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
# Extracted metrics for easy querying
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingLog(SQLModel, table=True):
"""Training log entry."""
__tablename__ = "training_logs"
log_id: int | None = Field(default=None, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
level: str = Field(max_length=20) # INFO, WARNING, ERROR
message: str
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# =============================================================================
# Batch Upload Models (v2)
# =============================================================================
class BatchUpload(SQLModel, table=True):
"""Batch upload of multiple documents via ZIP file."""
__tablename__ = "batch_uploads"
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255) # ZIP filename
file_size: int
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
status: str = Field(default="processing", max_length=20, index=True)
# Status: processing, completed, partial, failed
total_files: int = Field(default=0)
processed_files: int = Field(default=0)
# Number of files processed so far
successful_files: int = Field(default=0)
failed_files: int = Field(default=0)
csv_filename: str | None = Field(default=None, max_length=255)
# CSV file used for auto-labeling
csv_row_count: int | None = Field(default=None)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: datetime | None = Field(default=None)
class BatchUploadFile(SQLModel, table=True):
"""Individual file within a batch upload."""
__tablename__ = "batch_upload_files"
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
filename: str = Field(max_length=255) # PDF filename within ZIP
document_id: UUID | None = Field(default=None)
# Link to created AdminDocument (if successful)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, processing, completed, failed, skipped
error_message: str | None = Field(default=None)
annotation_count: int = Field(default=0)
# Number of annotations created for this file
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# CSV row data for this file (if available)
created_at: datetime = Field(default_factory=datetime.utcnow)
processed_at: datetime | None = Field(default=None)
# =============================================================================
# Training Document Link (v2)
# =============================================================================
class TrainingDataset(SQLModel, table=True):
"""Training dataset containing selected documents with train/val/test splits."""
__tablename__ = "training_datasets"
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, training, archived, failed
train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1)
seed: int = Field(default=42)
total_documents: int = Field(default=0)
total_images: int = Field(default=0)
total_annotations: int = Field(default=0)
dataset_path: str | None = Field(default=None, max_length=512)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class DatasetDocument(SQLModel, table=True):
"""Junction table linking datasets to documents with split assignment."""
__tablename__ = "dataset_documents"
id: UUID = Field(default_factory=uuid4, primary_key=True)
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
split: str = Field(max_length=10) # train, val, test
page_count: int = Field(default=0)
annotation_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingDocumentLink(SQLModel, table=True):
"""Junction table linking training tasks to documents."""
__tablename__ = "training_document_links"
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of annotations at training time (includes count, verified count, etc.)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================
class AnnotationHistory(SQLModel, table=True):
"""History of annotation changes (for override tracking)."""
__tablename__ = "annotation_history"
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
# Change action: created, updated, deleted, override
action: str = Field(max_length=20, index=True)
# Previous value (for updates/deletes)
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# New value (for creates/updates)
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Change metadata
changed_by: str | None = Field(default=None, max_length=255)
# User/token who made the change
change_reason: str | None = Field(default=None)
# Optional reason for change
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Field class mapping (same as src/cli/train.py)
FIELD_CLASSES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_organisation_number",
8: "payment_line",
9: "customer_number",
}
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
# Read-only models for API responses
class AdminDocumentRead(SQLModel):
"""Admin document response model."""
document_id: UUID
filename: str
file_size: int
content_type: str
page_count: int
status: str
auto_label_status: str | None
auto_label_error: str | None
created_at: datetime
updated_at: datetime
class AdminAnnotationRead(SQLModel):
"""Admin annotation response model."""
annotation_id: UUID
document_id: UUID
page_number: int
class_id: int
class_name: str
x_center: float
y_center: float
width: float
height: float
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
text_value: str | None
confidence: float | None
source: str
created_at: datetime
class TrainingTaskRead(SQLModel):
"""Training task response model."""
task_id: UUID
name: str
description: str | None
status: str
task_type: str
config: dict[str, Any] | None
scheduled_at: datetime | None
is_recurring: bool
started_at: datetime | None
completed_at: datetime | None
error_message: str | None
result_metrics: dict[str, Any] | None
model_path: str | None
dataset_id: UUID | None
created_at: datetime
class TrainingDatasetRead(SQLModel):
"""Training dataset response model."""
dataset_id: UUID
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
created_at: datetime
updated_at: datetime
class DatasetDocumentRead(SQLModel):
"""Dataset document response model."""
id: UUID
dataset_id: UUID
document_id: UUID
split: str
page_count: int
annotation_count: int

View File

@@ -0,0 +1,374 @@
"""
Async Request Database Operations
Database interface for async invoice processing requests using SQLModel.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy import func, text
from sqlmodel import Session, select
from inference.data.database import get_session_context, create_db_and_tables, close_engine
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent
logger = logging.getLogger(__name__)
# Legacy dataclasses for backward compatibility
from dataclasses import dataclass
@dataclass(frozen=True)
class ApiKeyConfig:
"""API key configuration and limits (legacy compatibility)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestDB:
"""Database interface for async processing requests using SQLModel."""
def __init__(self, connection_string: str | None = None) -> None:
# connection_string is kept for backward compatibility but ignored
# SQLModel uses the global engine from database.py
self._initialized = False
def connect(self):
"""Legacy method - returns self for compatibility."""
return self
def close(self) -> None:
"""Close database connections."""
close_engine()
def __enter__(self) -> "AsyncRequestDB":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass # Sessions are managed per-operation
def create_tables(self) -> None:
"""Create async processing tables if they don't exist."""
create_db_and_tables()
self._initialized = True
# ==========================================================================
# API Key Operations
# ==========================================================================
def is_valid_api_key(self, api_key: str) -> bool:
"""Check if API key exists and is active."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
return result is not None and result.is_active is True
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
"""Get API key configuration and limits."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
if result is None:
return None
return ApiKeyConfig(
api_key=result.api_key,
name=result.name,
is_active=result.is_active,
requests_per_minute=result.requests_per_minute,
max_concurrent_jobs=result.max_concurrent_jobs,
max_file_size_mb=result.max_file_size_mb,
)
def create_api_key(
self,
api_key: str,
name: str,
requests_per_minute: int = 10,
max_concurrent_jobs: int = 3,
max_file_size_mb: int = 50,
) -> None:
"""Create a new API key."""
with get_session_context() as session:
existing = session.get(ApiKey, api_key)
if existing:
existing.name = name
existing.requests_per_minute = requests_per_minute
existing.max_concurrent_jobs = max_concurrent_jobs
existing.max_file_size_mb = max_file_size_mb
session.add(existing)
else:
new_key = ApiKey(
api_key=api_key,
name=name,
requests_per_minute=requests_per_minute,
max_concurrent_jobs=max_concurrent_jobs,
max_file_size_mb=max_file_size_mb,
)
session.add(new_key)
def update_api_key_usage(self, api_key: str) -> None:
"""Update API key last used timestamp and increment total requests."""
with get_session_context() as session:
key = session.get(ApiKey, api_key)
if key:
key.last_used_at = datetime.utcnow()
key.total_requests += 1
session.add(key)
# ==========================================================================
# Async Request Operations
# ==========================================================================
def create_request(
self,
api_key: str,
filename: str,
file_size: int,
content_type: str,
expires_at: datetime,
request_id: str | None = None,
) -> str:
"""Create a new async request."""
with get_session_context() as session:
request = AsyncRequest(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
)
if request_id:
request.request_id = UUID(request_id)
session.add(request)
session.flush() # To get the generated ID
return str(request.request_id)
def get_request(self, request_id: str) -> AsyncRequest | None:
"""Get a single async request by ID."""
with get_session_context() as session:
result = session.get(AsyncRequest, UUID(request_id))
if result:
# Detach from session for use outside context
session.expunge(result)
return result
def get_request_by_api_key(
self,
request_id: str,
api_key: str,
) -> AsyncRequest | None:
"""Get a request only if it belongs to the given API key."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.request_id == UUID(request_id),
AsyncRequest.api_key == api_key,
)
result = session.exec(statement).first()
if result:
session.expunge(result)
return result
def update_status(
self,
request_id: str,
status: str,
error_message: str | None = None,
increment_retry: bool = False,
) -> None:
"""Update request status."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = status
if status == "processing":
request.started_at = datetime.utcnow()
if error_message is not None:
request.error_message = error_message
if increment_retry:
request.retry_count += 1
session.add(request)
def complete_request(
self,
request_id: str,
document_id: str,
result: dict[str, Any],
processing_time_ms: float,
visualization_path: str | None = None,
) -> None:
"""Mark request as completed with result."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = "completed"
request.document_id = document_id
request.result = result
request.processing_time_ms = processing_time_ms
request.visualization_path = visualization_path
request.completed_at = datetime.utcnow()
session.add(request)
def get_requests_by_api_key(
self,
api_key: str,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AsyncRequest], int]:
"""Get paginated requests for an API key."""
with get_session_context() as session:
# Count query
count_stmt = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
count_stmt = count_stmt.where(AsyncRequest.status == status)
total = session.exec(count_stmt).one()
# Fetch query
statement = select(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
statement = statement.where(AsyncRequest.status == status)
statement = statement.order_by(AsyncRequest.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
# Detach results from session
for r in results:
session.expunge(r)
return list(results), total
def count_active_jobs(self, api_key: str) -> int:
"""Count active (pending + processing) jobs for an API key."""
with get_session_context() as session:
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key,
AsyncRequest.status.in_(["pending", "processing"]),
)
return session.exec(statement).one()
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
"""Get pending requests ordered by creation time."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.status == "pending"
).order_by(AsyncRequest.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_queue_position(self, request_id: str) -> int | None:
"""Get position of a request in the pending queue."""
with get_session_context() as session:
# Get the request's created_at
request = session.get(AsyncRequest, UUID(request_id))
if not request:
return None
# Count pending requests created before this one
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.status == "pending",
AsyncRequest.created_at < request.created_at,
)
count = session.exec(statement).one()
return count + 1 # 1-based position
# ==========================================================================
# Rate Limit Operations
# ==========================================================================
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
"""Record a rate limit event."""
with get_session_context() as session:
event = RateLimitEvent(
api_key=api_key,
event_type=event_type,
)
session.add(event)
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
"""Count requests in the last N seconds."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
statement = select(func.count()).select_from(RateLimitEvent).where(
RateLimitEvent.api_key == api_key,
RateLimitEvent.event_type == "request",
RateLimitEvent.created_at > cutoff,
)
return session.exec(statement).one()
# ==========================================================================
# Cleanup Operations
# ==========================================================================
def delete_expired_requests(self) -> int:
"""Delete requests that have expired. Returns count of deleted rows."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
expired = session.exec(statement).all()
count = len(expired)
for request in expired:
session.delete(request)
logger.info(f"Deleted {count} expired async requests")
return count
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
"""Delete rate limit events older than N hours."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(hours=hours)
statement = select(RateLimitEvent).where(
RateLimitEvent.created_at < cutoff
)
old_events = session.exec(statement).all()
count = len(old_events)
for event in old_events:
session.delete(event)
return count
def reset_stale_processing_requests(
self,
stale_minutes: int = 10,
max_retries: int = 3,
) -> int:
"""
Reset requests stuck in 'processing' status.
Requests that have been processing for more than stale_minutes
are considered stale. They are either reset to 'pending' (if under
max_retries) or set to 'failed'.
"""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
reset_count = 0
# Find stale processing requests
statement = select(AsyncRequest).where(
AsyncRequest.status == "processing",
AsyncRequest.started_at < cutoff,
)
stale_requests = session.exec(statement).all()
for request in stale_requests:
if request.retry_count < max_retries:
request.status = "pending"
request.started_at = None
else:
request.status = "failed"
request.error_message = "Processing timeout after max retries"
session.add(request)
reset_count += 1
if reset_count > 0:
logger.warning(f"Reset {reset_count} stale processing requests")
return reset_count

View File

@@ -0,0 +1,102 @@
"""
Database Engine and Session Management
Provides SQLModel database engine and session handling.
"""
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine
import sys
from shared.config import get_db_connection_string
logger = logging.getLogger(__name__)
# Global engine instance
_engine = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
connection_string = get_db_connection_string()
# Convert psycopg2 format to SQLAlchemy format
if connection_string.startswith("postgresql://"):
# Already in correct format
pass
elif "host=" in connection_string:
# Convert DSN format to URL format
parts = dict(item.split("=") for item in connection_string.split())
connection_string = (
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
f"/{parts.get('dbname', 'docmaster')}"
)
_engine = create_engine(
connection_string,
echo=False, # Set to True for SQL debugging
pool_pre_ping=True, # Verify connections before use
pool_size=5,
max_overflow=10,
)
return _engine
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
from inference.data.admin_models import ( # noqa: F401
AdminToken,
AdminDocument,
AdminAnnotation,
TrainingTask,
TrainingLog,
)
engine = get_engine()
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
def get_session() -> Session:
"""Get a new database session."""
engine = get_engine()
return Session(engine)
@contextmanager
def get_session_context() -> Generator[Session, None, None]:
"""Context manager for database sessions with auto-commit/rollback."""
session = get_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def close_engine() -> None:
"""Close the database engine and release connections."""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None
logger.info("Database engine closed")
def execute_raw_sql(sql: str) -> None:
"""Execute raw SQL (for migrations)."""
engine = get_engine()
with engine.connect() as conn:
conn.execute(text(sql))
conn.commit()

View File

@@ -0,0 +1,95 @@
"""
SQLModel Database Models
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
class ApiKey(SQLModel, table=True):
"""API key configuration and limits."""
__tablename__ = "api_keys"
api_key: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
requests_per_minute: int = Field(default=10)
max_concurrent_jobs: int = Field(default=3)
max_file_size_mb: int = Field(default=50)
total_requests: int = Field(default=0)
total_processed: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
class AsyncRequest(SQLModel, table=True):
"""Async request record."""
__tablename__ = "async_requests"
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
status: str = Field(default="pending", max_length=20, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
document_id: str | None = Field(default=None, max_length=100)
error_message: str | None = Field(default=None)
retry_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
expires_at: datetime = Field(index=True)
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
processing_time_ms: float | None = Field(default=None)
visualization_path: str | None = Field(default=None, max_length=255)
class RateLimitEvent(SQLModel, table=True):
"""Rate limit event record."""
__tablename__ = "rate_limit_events"
id: int | None = Field(default=None, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
event_type: str = Field(max_length=50)
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Read-only models for responses (without table=True)
class ApiKeyRead(SQLModel):
"""API key response model (read-only)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestRead(SQLModel):
"""Async request response model (read-only)."""
request_id: UUID
api_key: str
status: str
filename: str
file_size: int
content_type: str
document_id: str | None
error_message: str | None
retry_count: int
created_at: datetime
started_at: datetime | None
completed_at: datetime | None
expires_at: datetime
result: dict[str, Any] | None
processing_time_ms: float | None
visualization_path: str | None

View File

@@ -0,0 +1,5 @@
from .pipeline import InferencePipeline, InferenceResult
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']

View File

@@ -0,0 +1,101 @@
"""
Inference Configuration Constants
Centralized configuration values for the inference pipeline.
Extracted from hardcoded values across multiple modules for easier maintenance.
"""
# ============================================================================
# Detection & Model Configuration
# ============================================================================
# YOLO Detection
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
# ============================================================================
# Image Processing Configuration
# ============================================================================
# DPI (Dots Per Inch) for PDF rendering
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
# ============================================================================
# Customer Number Parser Configuration
# ============================================================================
# Pattern confidence scores (higher = more confident)
CUSTOMER_NUMBER_CONFIDENCE = {
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
'compact': 0.75, # Compact format (e.g., "JTY5763")
'generic_base': 0.5, # Base score for generic alphanumeric pattern
}
# Bonus scores for generic pattern matching
CUSTOMER_NUMBER_BONUS = {
'has_dash': 0.2, # Bonus if contains dash
'typical_format': 0.25, # Bonus for format XXX NNN-X
'medium_length': 0.1, # Bonus for length 6-12 characters
}
# Customer number length constraints
CUSTOMER_NUMBER_LENGTH = {
'min': 6, # Minimum length for medium length bonus
'max': 12, # Maximum length for medium length bonus
}
# ============================================================================
# Field Extraction Confidence Scores
# ============================================================================
# Confidence multipliers and base scores
FIELD_CONFIDENCE = {
'pdf_text': 1.0, # PDF text extraction (always accurate)
'payment_line_high': 0.95, # Payment line parsed successfully
'regex_fallback': 0.5, # Regex-based fallback extraction
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
}
# ============================================================================
# Payment Line Validation
# ============================================================================
# Account number length thresholds for type detection
ACCOUNT_TYPE_THRESHOLD = {
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
}
# ============================================================================
# OCR Configuration
# ============================================================================
# Minimum OCR reference number length
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
# ============================================================================
# Pattern Matching
# ============================================================================
# Swedish postal code pattern (to exclude from customer numbers)
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
# ============================================================================
# Usage Notes
# ============================================================================
"""
These constants can be overridden at runtime by passing parameters to
constructors or methods. The values here serve as sensible defaults
based on Swedish invoice processing requirements.
Example:
from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
detector = YOLODetector(
model_path="model.pt",
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
)
"""

View File

@@ -0,0 +1,390 @@
"""
Swedish Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
Uses Strategy Pattern with multiple matching patterns.
Common Swedish customer number formats:
- JTY 576-3
- EMM 256-6
- DWQ 211-X
- FFL 019N
"""
import re
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List
from shared.exceptions import CustomerNumberParseError
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
"""The normalized customer number"""
pattern_name: str
"""Name of the pattern that matched"""
confidence: float
"""Confidence score (0.0 to 1.0)"""
raw_text: str
"""Original text that was matched"""
position: int = 0
"""Position in text where match was found"""
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""
Try to match pattern in text.
Args:
text: Text to search for customer number
Returns:
CustomerNumberMatch if found, None otherwise
"""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""
Format matched groups to standard format.
Args:
match: Regex match object
Returns:
Formatted customer number string
"""
pass
class DashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123-X (with dash)
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with dash format."""
match = self.PATTERN.search(text)
if not match:
return None
# Check if it's not a postal code
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE10643, etc.
return bool(
text.upper().startswith('SE ') and
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
)
class NoDashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123X (no dash)
Examples: Dwq 211X, FFL 019N
Converts to: DWQ 211-X, FFL 019-N
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number without dash."""
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format (add dash)."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CompactFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC123X (compact, no spaces)
Examples: JTY5763, FFL019N
"""
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match compact customer number format."""
upper_text = text.upper()
match = self.PATTERN.search(upper_text)
if not match:
return None
# Filter out SE postal codes
if match.group(1) == 'SE':
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="CompactFormat",
confidence=0.75,
raw_text=match.group(0),
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to ABC123X or ABC123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
if suffix:
return f"{prefix} {number}-{suffix}"
else:
return f"{prefix}{number}"
class GenericAlphanumericPattern(CustomerNumberPattern):
"""
Generic pattern: Letters + numbers + optional dash/letter
Examples: EMM 256-6, ABC 123, FFL 019
"""
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match generic alphanumeric pattern."""
upper_text = text.upper()
all_matches = []
for match in self.PATTERN.finditer(upper_text):
matched_text = match.group(1)
# Filter out pure numbers
if re.match(r'^\d+$', matched_text):
continue
# Filter out Swedish postal codes
if re.match(r'^SE[\s\-]*\d', matched_text):
continue
# Filter out single letter + digit + space + digit (V4 2)
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
continue
# Calculate confidence based on characteristics
confidence = self._calculate_confidence(matched_text)
all_matches.append((confidence, matched_text, match.start()))
if all_matches:
# Return highest confidence match
best = max(all_matches, key=lambda x: x[0])
return CustomerNumberMatch(
value=best[1].strip(),
pattern_name="GenericAlphanumeric",
confidence=best[0],
raw_text=best[1],
position=best[2]
)
return None
def format(self, match: re.Match) -> str:
"""Return matched text as-is (already uppercase)."""
return match.group(1).strip()
def _calculate_confidence(self, text: str) -> float:
"""Calculate confidence score based on text characteristics."""
# Require letters AND digits
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
has_digits = bool(re.search(r'\d', text))
if not (has_letters and has_digits):
return 0.0 # Not a valid customer number
score = 0.5 # Base score
# Bonus for containing dash
if '-' in text:
score += 0.2
# Bonus for typical format XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
score += 0.25
# Bonus for medium length
if 6 <= len(text) <= 12:
score += 0.1
return min(score, 1.0)
class LabeledPattern(CustomerNumberPattern):
"""
Pattern: Explicit label + customer number
Examples:
- "Kundnummer: JTY 576-3"
- "Customer No: EMM 256-6"
"""
PATTERN = re.compile(
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
re.IGNORECASE
)
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with explicit label."""
match = self.PATTERN.search(text)
if not match:
return None
extracted = match.group(1).strip()
# Remove trailing punctuation
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
if extracted and len(extracted) >= 2:
return CustomerNumberMatch(
value=extracted.upper(),
pattern_name="Labeled",
confidence=0.98, # Very high confidence when labeled
raw_text=match.group(0),
position=match.start()
)
return None
def format(self, match: re.Match) -> str:
"""Return matched customer number."""
extracted = match.group(1).strip()
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
"""Initialize parser with patterns ordered by specificity."""
self.patterns: List[CustomerNumberPattern] = [
LabeledPattern(), # Highest priority - explicit label
DashFormatPattern(), # Standard format with dash
NoDashFormatPattern(), # Standard format without dash
CompactFormatPattern(), # Compact format
GenericAlphanumericPattern(), # Fallback generic pattern
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Args:
text: Text to search for customer number
Returns:
Tuple of (customer_number, is_valid, error_message)
"""
if not text or not text.strip():
return None, False, "Empty text"
text = text.strip()
# Try each pattern
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# No matches
if not all_matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
self.logger.debug(
f"Customer number matched: {best_match.value} "
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
)
return best_match.value, True, None
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
Args:
text: Text to search
Returns:
List of CustomerNumberMatch sorted by confidence (descending)
"""
if not text or not text.strip():
return []
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# Sort by confidence (highest first), then by position (later first)
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,261 @@
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Unifies payment line parsing logic that was previously duplicated across multiple modules.
Standard Swedish payment line format:
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example:
# 94228110015950070 # 15658 00 8 > 48666036#14#
This parser handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
"""
import re
import logging
from dataclasses import dataclass
from typing import Optional
from shared.exceptions import PaymentLineParseError
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
"""OCR reference number (payment reference)"""
amount: Optional[str] = None
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
account_number: Optional[str] = None
"""Bankgiro or Plusgiro account number"""
record_type: Optional[str] = None
"""Record type digit (usually '5' or '8' or '9')"""
check_digits: Optional[str] = None
"""Check digits for account validation"""
raw_text: str = ""
"""Original raw text that was parsed"""
is_valid: bool = True
"""Whether parsing was successful"""
error: Optional[str] = None
"""Error message if parsing failed"""
parse_method: str = "unknown"
"""Which parsing pattern was used (for debugging)"""
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
NO_AMOUNT_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Alternative pattern: look for OCR > ACCOUNT# pattern
ALT_PATTERN = re.compile(
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Account only pattern: > ACCOUNT#CHECK#
ACCOUNT_ONLY_PATTERN = re.compile(
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
"""Initialize parser with logger."""
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
Args:
text: Raw payment line text from OCR
Returns:
PaymentLineData with parsed fields or error information
"""
if not text or not text.strip():
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="Empty payment line text",
parse_method="none"
)
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try pattern without amount
match = self.NO_AMOUNT_PATTERN.search(text)
if match:
return self._parse_no_amount_match(match, text)
# Try alternative pattern
match = self.ALT_PATTERN.search(text)
if match:
return self._parse_alt_match(match, text)
# Try account only pattern
match = self.ACCOUNT_ONLY_PATTERN.search(text)
if match:
return self._parse_account_only_match(match, text)
# No match - return error
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="No valid payment line format found",
parse_method="none"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="full"
)
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse pattern match without amount."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="no_amount"
)
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse alternative pattern match."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="alternative"
)
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse account-only pattern match."""
account = self._clean_digits(match.group(1))
check_digits = match.group(2)
return PaymentLineData(
ocr_number="",
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error="Partial payment line (account only)",
parse_method="account_only"
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string (OCR error correction)."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string in standard Swedish payment line format
"""
if not data.is_valid:
return data.raw_text
# Full format with amount
if data.amount and data.record_type:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
# Format without amount
if data.ocr_number and data.account_number:
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
# Account only
if data.account_number:
return f"> {data.account_number}#{data.check_digits}#"
# Fallback
return data.raw_text
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
"""
Format parsed data for FieldExtractor compatibility.
Returns:
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
"""
if not data.is_valid:
return None, False, data.error
formatted = self.format_machine_readable(data)
return formatted, True, data.error

View File

@@ -0,0 +1,498 @@
"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import time
import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
@dataclass
class CrossValidationResult:
"""Result of cross-validation between payment_line and other fields."""
is_valid: bool = False
ocr_match: bool | None = None # None if not comparable
amount_match: bool | None = None
bankgiro_match: bool | None = None
plusgiro_match: bool | None = None
payment_line_ocr: str | None = None
payment_line_amount: str | None = None
payment_line_account: str | None = None
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
details: list[str] = field(default_factory=list)
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
result = {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'supplier_org_number': self.fields.get('supplier_org_number'),
'customer_number': self.fields.get('customer_number'),
'payment_line': self.fields.get('payment_line'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
# Add bboxes if present
if self.bboxes:
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
# Add cross-validation results if present
if self.cross_validation:
result['cross_validation'] = {
'is_valid': self.cross_validation.is_valid,
'ocr_match': self.cross_validation.ocr_match,
'amount_match': self.cross_validation.amount_match,
'bankgiro_match': self.cross_validation.bankgiro_match,
'plusgiro_match': self.cross_validation.plusgiro_match,
'payment_line_ocr': self.cross_validation.payment_line_ocr,
'payment_line_amount': self.cross_validation.payment_line_amount,
'payment_line_account': self.cross_validation.payment_line_account,
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
return result
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
try:
all_detections = []
all_extracted = []
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
parsed = self.payment_line_parser.parse(payment_line)
if not parsed.is_valid:
return None, None, None
return parsed.ocr_number, parsed.amount, parsed.account_number
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
Payment line values take PRIORITY over individually detected fields.
Swedish payment line (Betalningsrad) contains:
- OCR reference number
- Amount (kronor and öre)
- Bankgiro or Plusgiro account number
This method:
1. Parses payment_line to extract OCR, Amount, Account
2. Compares with separately detected fields for validation
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
"""
payment_line = result.fields.get('payment_line')
if not payment_line:
return
cv = CrossValidationResult()
cv.details = []
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values
# OCR: payment_line takes priority
detected_ocr = result.fields.get('OCR')
if cv.payment_line_ocr:
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
if detected_ocr:
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
if cv.ocr_match:
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
else:
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
else:
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
# OVERRIDE: use payment_line OCR
result.fields['OCR'] = cv.payment_line_ocr
result.confidence['OCR'] = 0.95 # High confidence for payment_line
# Amount: payment_line takes priority
detected_amount = result.fields.get('Amount')
if cv.payment_line_amount:
if detected_amount:
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
det_amount = self._normalize_amount_for_compare(str(detected_amount))
cv.amount_match = pl_amount == det_amount
if cv.amount_match:
cv.details.append(f"Amount match: {cv.payment_line_amount}")
else:
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
else:
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
# OVERRIDE: use payment_line Amount
result.fields['Amount'] = cv.payment_line_amount
result.confidence['Amount'] = 0.95
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_bankgiro = result.fields.get('Bankgiro')
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_bankgiro:
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
cv.bankgiro_match = pl_bg_digits == det_bg_digits
if cv.bankgiro_match:
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
else:
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_plusgiro = result.fields.get('Plusgiro')
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_plusgiro:
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
cv.plusgiro_match = pl_pg_digits == det_pg_digits
if cv.plusgiro_match:
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
else:
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Determine overall validity
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
# has both accounts, the other one cannot be matched - this is expected and OK.
# Only count the account type that payment_line actually has.
matches = [cv.ocr_match, cv.amount_match]
# Only include account match if payment_line has that account type
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
matches.append(cv.bankgiro_match)
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
matches.append(cv.plusgiro_match)
valid_matches = [m for m in matches if m is not None]
if valid_matches:
match_count = sum(1 for m in valid_matches if m)
cv.is_valid = match_count >= min(2, len(valid_matches))
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
else:
# No comparison possible
cv.is_valid = True
cv.details.append("No comparison available from payment_line")
result.cross_validation = cv
def _normalize_amount_for_compare(self, amount: str) -> float | None:
"""Normalize amount string to float for comparison."""
try:
# Remove spaces, convert comma to dot
cleaned = amount.replace(' ', '').replace(',', '.')
# Handle Swedish format with space as thousands separator
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
return round(float(cleaned), 2)
except (ValueError, AttributeError):
return None
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from shared.pdf.renderer import render_pdf_to_images
from shared.ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result

View File

@@ -0,0 +1,210 @@
"""
YOLO Detection Module
Runs YOLO model inference for field detection.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
@dataclass
class Detection:
"""Represents a single YOLO detection."""
class_id: int
class_name: str
confidence: float
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
def get_padded_bbox(
self,
padding: float = 0.1,
image_width: float | None = None,
image_height: float | None = None
) -> tuple[float, float, float, float]:
"""Get bbox with padding for OCR extraction."""
pad_x = self.width * padding
pad_y = self.height * padding
x0 = self.x0 - pad_x
y0 = self.y0 - pad_y
x1 = self.x1 + pad_x
y1 = self.y1 + pad_y
if image_width:
x0 = max(0, x0)
x1 = min(image_width, x1)
if image_height:
y0 = max(0, y0)
y1 = min(image_height, y1)
return (x0, y0, x1, y1)
# Class names (must match training configuration)
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number', # Matches training class name
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
# Mapping from class name to field name
CLASS_TO_FIELD = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'ocr_number': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'amount': 'Amount',
'supplier_org_number': 'supplier_org_number',
'customer_number': 'customer_number',
'payment_line': 'payment_line',
}
class YOLODetector:
"""YOLO model wrapper for field detection."""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
iou_threshold: float = 0.45,
device: str = 'auto'
):
"""
Initialize YOLO detector.
Args:
model_path: Path to trained YOLO model (.pt file)
confidence_threshold: Minimum confidence for detections
iou_threshold: IOU threshold for NMS
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
"""
from ultralytics import YOLO
self.model = YOLO(model_path)
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.device = device
def detect(
self,
image: str | Path | np.ndarray,
page_no: int = 0
) -> list[Detection]:
"""
Run detection on an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
Returns:
List of Detection objects
"""
results = self.model.predict(
source=image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
device=self.device,
verbose=False
)
detections = []
for result in results:
boxes = result.boxes
if boxes is None:
continue
for i in range(len(boxes)):
class_id = int(boxes.cls[i])
confidence = float(boxes.conf[i])
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
detections.append(Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=tuple(bbox),
page_no=page_no
))
return detections
def detect_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> dict[int, list[Detection]]:
"""
Run detection on all pages of a PDF.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
Returns:
Dict mapping page number to list of detections
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
results = {}
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
detections = self.detect(image_array, page_no=page_no)
results[page_no] = detections
return results
def get_field_name(self, class_name: str) -> str:
"""Convert class name to field name."""
return CLASS_TO_FIELD.get(class_name, class_name)

View File

@@ -0,0 +1,7 @@
"""
Cross-validation module for verifying field extraction using LLM.
"""
from .llm_validator import LLMValidator
__all__ = ['LLMValidator']

View File

@@ -0,0 +1,748 @@
"""
LLM-based cross-validation for invoice field extraction.
Uses a vision LLM to extract fields from invoice PDFs and compare with
the autolabel results to identify potential errors.
"""
import json
import base64
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI
@dataclass
class LLMExtractionResult:
"""Result of LLM field extraction."""
document_id: str
invoice_number: Optional[str] = None
invoice_date: Optional[str] = None
invoice_due_date: Optional[str] = None
ocr_number: Optional[str] = None
bankgiro: Optional[str] = None
plusgiro: Optional[str] = None
amount: Optional[str] = None
supplier_organisation_number: Optional[str] = None
raw_response: Optional[str] = None
model_used: Optional[str] = None
processing_time_ms: Optional[float] = None
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class LLMValidator:
"""
Cross-validates invoice field extraction using LLM.
Queries documents with failed field matches from the database,
sends the PDF images to an LLM for extraction, and stores
the results for comparison.
"""
# Fields to extract (excluding customer_number as requested)
FIELDS_TO_EXTRACT = [
'InvoiceNumber',
'InvoiceDate',
'InvoiceDueDate',
'OCR',
'Bankgiro',
'Plusgiro',
'Amount',
'supplier_organisation_number',
]
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
{
"invoice_number": "the invoice number/fakturanummer",
"invoice_date": "the invoice date in YYYY-MM-DD format",
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
"ocr_number": "the OCR payment reference number",
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
"plusgiro": "the plusgiro number",
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
}
Rules:
- If a field is not found or not visible, use null
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
- For amounts, extract just the numeric value without currency symbols
- The OCR number is typically a long number used for payment reference
- Look for "Att betala" or "Summa att betala" for the amount
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
Return ONLY the JSON object, no other text."""
def __init__(self, connection_string: str = None, pdf_dir: str = None):
"""
Initialize the validator.
Args:
connection_string: PostgreSQL connection string
pdf_dir: Directory containing PDF files
"""
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS
self.connection_string = connection_string or get_db_connection_string()
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def create_validation_table(self):
"""Create the llm_validation table if it doesn't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS llm_validations (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
-- Extracted fields
invoice_number TEXT,
invoice_date TEXT,
invoice_due_date TEXT,
ocr_number TEXT,
bankgiro TEXT,
plusgiro TEXT,
amount TEXT,
supplier_organisation_number TEXT,
-- Metadata
raw_response TEXT,
model_used TEXT,
processing_time_ms REAL,
error TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
-- Comparison results (populated later)
comparison_results JSONB,
UNIQUE(document_id)
);
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
ON llm_validations(document_id);
""")
conn.commit()
def get_documents_with_failed_matches(
self,
exclude_customer_number: bool = True,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Get documents that have at least one failed field match.
Args:
exclude_customer_number: If True, ignore customer_number failures
limit: Maximum number of documents to return
Returns:
List of document info with failed fields
"""
conn = self.connect()
with conn.cursor() as cursor:
# Find documents with failed matches (excluding customer_number if requested)
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND fr.field_name != 'customer_number'"
query = f"""
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
d.supplier_name, d.split
FROM documents d
JOIN field_results fr ON d.document_id = fr.document_id
WHERE fr.matched = false
AND fr.field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
AND d.document_id NOT IN (
SELECT document_id FROM llm_validations WHERE error IS NULL
)
ORDER BY d.document_id
"""
if limit:
query += f" LIMIT {limit}"
cursor.execute(query)
results = []
for row in cursor.fetchall():
doc_id = row[0]
# Get failed fields for this document
exclude_clause_inner = ""
if exclude_customer_number:
exclude_clause_inner = "AND field_name != 'customer_number'"
cursor.execute(f"""
SELECT field_name, csv_value, score
FROM field_results
WHERE document_id = %s
AND matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause_inner}
""", (doc_id,))
failed_fields = [
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
for r in cursor.fetchall()
]
results.append({
'document_id': doc_id,
'pdf_path': row[1],
'pdf_type': row[2],
'supplier_name': row[3],
'split': row[4],
'failed_fields': failed_fields,
})
return results
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
"""Get statistics about failed matches."""
conn = self.connect()
with conn.cursor() as cursor:
exclude_clause = ""
if exclude_customer_number:
exclude_clause = "AND field_name != 'customer_number'"
# Count by field
cursor.execute(f"""
SELECT field_name, COUNT(*) as cnt
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
GROUP BY field_name
ORDER BY cnt DESC
""")
by_field = {row[0]: row[1] for row in cursor.fetchall()}
# Count documents with failures
cursor.execute(f"""
SELECT COUNT(DISTINCT document_id)
FROM field_results
WHERE matched = false
AND field_name NOT LIKE 'supplier_accounts%%'
{exclude_clause}
""")
doc_count = cursor.fetchone()[0]
# Already validated count
cursor.execute("""
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
""")
validated_count = cursor.fetchone()[0]
return {
'documents_with_failures': doc_count,
'already_validated': validated_count,
'remaining': doc_count - validated_count,
'failures_by_field': by_field,
}
def render_pdf_to_image(
self,
pdf_path: Path,
page_no: int = 0,
dpi: int = DEFAULT_DPI,
max_size_mb: float = 18.0
) -> bytes:
"""
Render a PDF page to PNG image bytes.
Args:
pdf_path: Path to PDF file
page_no: Page number to render (0-indexed)
dpi: Resolution for rendering
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
Returns:
PNG image bytes
"""
import fitz # PyMuPDF
from io import BytesIO
from PIL import Image
doc = fitz.open(pdf_path)
page = doc[page_no]
# Try different DPI values until we get a small enough image
dpi_values = [dpi, 120, 100, 72, 50]
for current_dpi in dpi_values:
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
pix = page.get_pixmap(matrix=mat)
png_bytes = pix.tobytes("png")
size_mb = len(png_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return png_bytes
# If still too large, use JPEG compression
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
pix = page.get_pixmap(matrix=mat)
# Convert to PIL Image and compress as JPEG
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Try different JPEG quality levels
for quality in [85, 70, 50, 30]:
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=quality)
jpeg_bytes = buffer.getvalue()
size_mb = len(jpeg_bytes) / (1024 * 1024)
if size_mb <= max_size_mb:
doc.close()
return jpeg_bytes
doc.close()
# Return whatever we have, let the API handle the error
return jpeg_bytes
def extract_with_openai(
self,
image_bytes: bytes,
model: str = "gpt-4o"
) -> LLMExtractionResult:
"""
Extract fields using OpenAI's vision API (supports Azure OpenAI).
Args:
image_bytes: PNG image bytes
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
Returns:
Extraction result
"""
import openai
import time
start_time = time.time()
# Encode image to base64 and detect format
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
if image_bytes[:4] == b'\x89PNG':
media_type = "image/png"
else:
media_type = "image/jpeg"
# Check for Azure OpenAI configuration
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
if azure_endpoint and azure_api_key:
# Use Azure OpenAI
client = openai.AzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=azure_api_key,
api_version="2024-02-15-preview"
)
model = azure_deployment # Use deployment name for Azure
else:
# Use standard OpenAI
client = openai.OpenAI()
try:
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self.EXTRACTION_PROMPT},
{
"type": "image_url",
"image_url": {
"url": f"data:{media_type};base64,{image_b64}",
"detail": "high"
}
}
]
}
],
max_tokens=1000,
temperature=0,
)
raw_response = response.choices[0].message.content
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
# Try to extract JSON from response (may have markdown code blocks)
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="", # Will be set by caller
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def extract_with_anthropic(
self,
image_bytes: bytes,
model: str = "claude-sonnet-4-20250514"
) -> LLMExtractionResult:
"""
Extract fields using Anthropic's Claude API.
Args:
image_bytes: PNG image bytes
model: Model to use
Returns:
Extraction result
"""
import anthropic
import time
start_time = time.time()
# Encode image to base64
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
client = anthropic.Anthropic()
try:
response = client.messages.create(
model=model,
max_tokens=1000,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_b64,
}
},
{
"type": "text",
"text": self.EXTRACTION_PROMPT
}
]
}
],
)
raw_response = response.content[0].text
processing_time = (time.time() - start_time) * 1000
# Parse JSON response
json_str = raw_response
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0]
elif "```" in json_str:
json_str = json_str.split("```")[1].split("```")[0]
data = json.loads(json_str.strip())
return LLMExtractionResult(
document_id="",
invoice_number=data.get('invoice_number'),
invoice_date=data.get('invoice_date'),
invoice_due_date=data.get('invoice_due_date'),
ocr_number=data.get('ocr_number'),
bankgiro=data.get('bankgiro'),
plusgiro=data.get('plusgiro'),
amount=data.get('amount'),
supplier_organisation_number=data.get('supplier_organisation_number'),
raw_response=raw_response,
model_used=model,
processing_time_ms=processing_time,
)
except json.JSONDecodeError as e:
return LLMExtractionResult(
document_id="",
raw_response=raw_response if 'raw_response' in dir() else None,
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=f"JSON parse error: {str(e)}"
)
except Exception as e:
return LLMExtractionResult(
document_id="",
model_used=model,
processing_time_ms=(time.time() - start_time) * 1000,
error=str(e)
)
def save_validation_result(self, result: LLMExtractionResult):
"""Save extraction result to database."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
INSERT INTO llm_validations (
document_id, invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number, raw_response, model_used,
processing_time_ms, error
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id) DO UPDATE SET
invoice_number = EXCLUDED.invoice_number,
invoice_date = EXCLUDED.invoice_date,
invoice_due_date = EXCLUDED.invoice_due_date,
ocr_number = EXCLUDED.ocr_number,
bankgiro = EXCLUDED.bankgiro,
plusgiro = EXCLUDED.plusgiro,
amount = EXCLUDED.amount,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
raw_response = EXCLUDED.raw_response,
model_used = EXCLUDED.model_used,
processing_time_ms = EXCLUDED.processing_time_ms,
error = EXCLUDED.error,
created_at = NOW()
""", (
result.document_id,
result.invoice_number,
result.invoice_date,
result.invoice_due_date,
result.ocr_number,
result.bankgiro,
result.plusgiro,
result.amount,
result.supplier_organisation_number,
result.raw_response,
result.model_used,
result.processing_time_ms,
result.error,
))
conn.commit()
def validate_document(
self,
doc_id: str,
provider: str = "openai",
model: str = None
) -> LLMExtractionResult:
"""
Validate a single document using LLM.
Args:
doc_id: Document ID
provider: LLM provider ("openai" or "anthropic")
model: Model to use (defaults based on provider)
Returns:
Extraction result
"""
# Get PDF path
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
if not pdf_path.exists():
return LLMExtractionResult(
document_id=doc_id,
error=f"PDF not found: {pdf_path}"
)
# Render first page
try:
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
except Exception as e:
return LLMExtractionResult(
document_id=doc_id,
error=f"Failed to render PDF: {str(e)}"
)
# Extract with LLM
if provider == "openai":
model = model or "gpt-4o"
result = self.extract_with_openai(image_bytes, model)
elif provider == "anthropic":
model = model or "claude-sonnet-4-20250514"
result = self.extract_with_anthropic(image_bytes, model)
else:
return LLMExtractionResult(
document_id=doc_id,
error=f"Unknown provider: {provider}"
)
result.document_id = doc_id
# Save to database
self.save_validation_result(result)
return result
def validate_batch(
self,
limit: int = 10,
provider: str = "openai",
model: str = None,
verbose: bool = True
) -> List[LLMExtractionResult]:
"""
Validate a batch of documents with failed matches.
Args:
limit: Maximum number of documents to validate
provider: LLM provider
model: Model to use
verbose: Print progress
Returns:
List of extraction results
"""
# Get documents to validate
docs = self.get_documents_with_failed_matches(limit=limit)
if verbose:
print(f"Found {len(docs)} documents with failed matches to validate")
results = []
for i, doc in enumerate(docs):
doc_id = doc['document_id']
if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
result = self.validate_document(doc_id, provider, model)
results.append(result)
if verbose:
if result.error:
print(f" ERROR: {result.error}")
else:
print(f" OK ({result.processing_time_ms:.0f}ms)")
return results
def compare_results(self, doc_id: str) -> Dict[str, Any]:
"""
Compare LLM extraction with autolabel results.
Args:
doc_id: Document ID
Returns:
Comparison results
"""
conn = self.connect()
with conn.cursor() as cursor:
# Get autolabel results
cursor.execute("""
SELECT field_name, csv_value, matched, matched_text
FROM field_results
WHERE document_id = %s
""", (doc_id,))
autolabel = {}
for row in cursor.fetchall():
autolabel[row[0]] = {
'csv_value': row[1],
'matched': row[2],
'matched_text': row[3],
}
# Get LLM results
cursor.execute("""
SELECT invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number
FROM llm_validations
WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return {'error': 'No LLM validation found'}
llm = {
'InvoiceNumber': row[0],
'InvoiceDate': row[1],
'InvoiceDueDate': row[2],
'OCR': row[3],
'Bankgiro': row[4],
'Plusgiro': row[5],
'Amount': row[6],
'supplier_organisation_number': row[7],
}
# Compare
comparison = {}
for field in self.FIELDS_TO_EXTRACT:
auto = autolabel.get(field, {})
llm_value = llm.get(field)
comparison[field] = {
'csv_value': auto.get('csv_value'),
'autolabel_matched': auto.get('matched'),
'autolabel_text': auto.get('matched_text'),
'llm_value': llm_value,
'agreement': self._values_match(auto.get('csv_value'), llm_value),
}
return comparison
def _values_match(self, csv_value: str, llm_value: str) -> bool:
"""Check if CSV value matches LLM extracted value."""
if csv_value is None or llm_value is None:
return csv_value == llm_value
# Normalize for comparison
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
return csv_norm == llm_norm

View File

@@ -0,0 +1,9 @@
"""
Web Application Module
Provides REST API and web interface for invoice field extraction.
"""
from .app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,8 @@
"""
Backward compatibility shim for admin_routes.py
DEPRECATED: Import from inference.web.api.v1.admin.documents instead.
"""
from inference.web.api.v1.admin.documents import *
__all__ = ["create_admin_router"]

View File

@@ -0,0 +1,19 @@
"""
Admin API v1
Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
"create_training_router",
]

View File

@@ -0,0 +1,644 @@
"""
Admin Annotation API Routes
FastAPI endpoints for annotation management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.autolabel import get_auto_label_service
from inference.web.schemas.admin import (
AnnotationCreate,
AnnotationItem,
AnnotationListResponse,
AnnotationOverrideRequest,
AnnotationOverrideResponse,
AnnotationResponse,
AnnotationSource,
AnnotationUpdate,
AnnotationVerifyRequest,
AnnotationVerifyResponse,
AutoLabelRequest,
AutoLabelResponse,
BoundingBox,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Image storage directory
ADMIN_IMAGES_DIR = Path("data/admin_images")
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_annotation_router() -> APIRouter:
"""Create annotation API router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
# =========================================================================
# Image Endpoints
# =========================================================================
@router.get(
"/{document_id}/images/{page_number}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Get page image",
description="Get the image for a specific page.",
)
async def get_page_image(
document_id: str,
page_number: int,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> FileResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if page_number < 1 or page_number > document.page_count:
raise HTTPException(
status_code=404,
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
)
# Find image file
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=404,
detail=f"Image for page {page_number} not found",
)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
)
# =========================================================================
# Annotation Endpoints
# =========================================================================
@router.get(
"/{document_id}/annotations",
response_model=AnnotationListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="List annotations",
description="Get all annotations for a document.",
)
async def list_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
page_number: Annotated[
int | None,
Query(ge=1, description="Filter by page number"),
] = None,
) -> AnnotationListResponse:
"""List annotations for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id, page_number)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
return AnnotationListResponse(
document_id=document_id,
page_count=document.page_count,
total_annotations=len(annotations),
annotations=annotations,
)
@router.post(
"/{document_id}/annotations",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Create annotation",
description="Create a new annotation for a document.",
)
async def create_annotation(
document_id: str,
request: AnnotationCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Create a new annotation."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if request.page_number > document.page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
)
# Get image dimensions for normalization
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=400,
detail=f"Image for page {request.page_number} not available",
)
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
width = request.bbox.width / image_width
height = request.bbox.height / image_height
# Get class name
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
# Create annotation
annotation_id = db.create_annotation(
document_id=document_id,
page_number=request.page_number,
class_id=request.class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=request.bbox.x,
bbox_y=request.bbox.y,
bbox_width=request.bbox.width,
bbox_height=request.bbox.height,
text_value=request.text_value,
source="manual",
)
# Keep status as pending - user must click "Mark Complete" to finalize
# This allows user to add multiple annotations before saving to PostgreSQL
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation created successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Update annotation",
description="Update an existing annotation.",
)
async def update_annotation(
document_id: str,
annotation_id: str,
request: AnnotationUpdate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Update an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Prepare update data
update_kwargs = {}
if request.class_id is not None:
update_kwargs["class_id"] = request.class_id
update_kwargs["class_name"] = FIELD_CLASSES.get(
request.class_id, f"class_{request.class_id}"
)
if request.text_value is not None:
update_kwargs["text_value"] = request.text_value
if request.bbox is not None:
# Get image dimensions
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
update_kwargs["width"] = request.bbox.width / image_width
update_kwargs["height"] = request.bbox.height / image_height
update_kwargs["bbox_x"] = request.bbox.x
update_kwargs["bbox_y"] = request.bbox.y
update_kwargs["bbox_width"] = request.bbox.width
update_kwargs["bbox_height"] = request.bbox.height
# Update annotation
if update_kwargs:
success = db.update_annotation(annotation_id, **update_kwargs)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to update annotation",
)
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation updated successfully",
)
@router.delete(
"/{document_id}/annotations/{annotation_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Delete annotation",
description="Delete an annotation.",
)
async def delete_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Delete annotation
db.delete_annotation(annotation_id)
return {
"status": "deleted",
"annotation_id": annotation_id,
"message": "Annotation deleted successfully",
}
# =========================================================================
# Auto-Labeling Endpoints
# =========================================================================
@router.post(
"/{document_id}/auto-label",
response_model=AutoLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Trigger auto-labeling",
description="Trigger auto-labeling for a document using field values.",
)
async def trigger_auto_label(
document_id: str,
request: AutoLabelRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AutoLabelResponse:
"""Trigger auto-labeling for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate field values
if not request.field_values:
raise HTTPException(
status_code=400,
detail="At least one field value is required",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=document.file_path,
field_values=request.field_values,
db=db,
replace_existing=request.replace_existing,
)
if result["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
)
return AutoLabelResponse(
document_id=document_id,
status=result["status"],
annotations_created=result["annotations_created"],
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
)
@router.delete(
"/{document_id}/annotations",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete all annotations",
description="Delete all annotations for a document (optionally filter by source).",
)
async def delete_all_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
source: Annotated[
str | None,
Query(description="Filter by source (manual, auto, imported)"),
] = None,
) -> dict:
"""Delete all annotations for a document."""
_validate_uuid(document_id, "document_id")
# Validate source
if source and source not in ("manual", "auto", "imported"):
raise HTTPException(
status_code=400,
detail=f"Invalid source: {source}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete annotations
deleted_count = db.delete_annotations_for_document(document_id, source)
# Update document status if all annotations deleted
remaining = db.get_annotations_for_document(document_id)
if not remaining:
db.update_document_status(document_id, "pending")
return {
"status": "deleted",
"document_id": document_id,
"deleted_count": deleted_count,
"message": f"Deleted {deleted_count} annotations",
}
# =========================================================================
# Phase 5: Annotation Enhancement
# =========================================================================
@router.post(
"/{document_id}/annotations/{annotation_id}/verify",
response_model=AnnotationVerifyResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Verify annotation",
description="Mark an annotation as verified by a human reviewer.",
)
async def verify_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
) -> AnnotationVerifyResponse:
"""Verify an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Verify the annotation
annotation = db.verify_annotation(annotation_id, admin_token)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
return AnnotationVerifyResponse(
annotation_id=annotation_id,
is_verified=annotation.is_verified,
verified_at=annotation.verified_at,
verified_by=annotation.verified_by,
message="Annotation verified successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}/override",
response_model=AnnotationOverrideResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Override annotation",
description="Override an auto-generated annotation with manual corrections.",
)
async def override_annotation(
document_id: str,
annotation_id: str,
request: AnnotationOverrideRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationOverrideResponse:
"""Override an auto-generated annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Build updates dict from request
updates = {}
if request.text_value is not None:
updates["text_value"] = request.text_value
if request.class_id is not None:
updates["class_id"] = request.class_id
# Update class_name if class_id changed
if request.class_id in FIELD_CLASSES:
updates["class_name"] = FIELD_CLASSES[request.class_id]
if request.class_name is not None:
updates["class_name"] = request.class_name
if request.bbox:
# Update bbox fields
if "x" in request.bbox:
updates["bbox_x"] = request.bbox["x"]
if "y" in request.bbox:
updates["bbox_y"] = request.bbox["y"]
if "width" in request.bbox:
updates["bbox_width"] = request.bbox["width"]
if "height" in request.bbox:
updates["bbox_height"] = request.bbox["height"]
if not updates:
raise HTTPException(
status_code=400,
detail="No updates provided. Specify at least one field to update.",
)
# Override the annotation
annotation = db.override_annotation(
annotation_id=annotation_id,
admin_token=admin_token,
change_reason=request.reason,
**updates,
)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Get history to return history_id
history_records = db.get_annotation_history(UUID(annotation_id))
latest_history = history_records[0] if history_records else None
return AnnotationOverrideResponse(
annotation_id=annotation_id,
source=annotation.source,
override_source=annotation.override_source,
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
message="Annotation overridden successfully",
history_id=str(latest_history.history_id) if latest_history else "",
)
return router

View File

@@ -0,0 +1,82 @@
"""
Admin Auth Routes
FastAPI endpoints for admin token management.
"""
import logging
import secrets
from datetime import datetime, timedelta
from fastapi import APIRouter
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AdminTokenCreate,
AdminTokenResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def create_auth_router() -> APIRouter:
"""Create admin auth router."""
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
@router.post(
"/token",
response_model=AdminTokenResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
},
summary="Create admin token",
description="Create a new admin authentication token.",
)
async def create_token(
request: AdminTokenCreate,
db: AdminDBDep,
) -> AdminTokenResponse:
"""Create a new admin token."""
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration
expires_at = None
if request.expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
# Create token in database
db.create_admin_token(
token=token,
name=request.name,
expires_at=expires_at,
)
return AdminTokenResponse(
token=token,
name=request.name,
expires_at=expires_at,
message="Admin token created successfully",
)
@router.delete(
"/token",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Revoke admin token",
description="Revoke the current admin token.",
)
async def revoke_token(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Revoke the current admin token."""
db.deactivate_admin_token(admin_token)
return {
"status": "revoked",
"message": "Admin token has been revoked",
}
return router

View File

@@ -0,0 +1,551 @@
"""
Admin Document Routes
FastAPI endpoints for admin document management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.config import DEFAULT_DPI, StorageConfig
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AnnotationItem,
AnnotationSource,
AutoLabelStatus,
BoundingBox,
DocumentDetailResponse,
DocumentItem,
DocumentListResponse,
DocumentStatus,
DocumentStatsResponse,
DocumentUploadResponse,
ModelMetrics,
TrainingHistoryItem,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
# Render at configured DPI for consistency with training
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"""Create admin documents router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
# Directories are created by StorageConfig.__post_init__
allowed_extensions = storage_config.allowed_extensions
@router.post(
"",
response_model=DocumentUploadResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Upload document",
description="Upload a PDF or image document for labeling.",
)
async def upload_document(
admin_token: AdminTokenDep,
db: AdminDBDep,
file: UploadFile = File(..., description="PDF or image file"),
auto_label: Annotated[
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Get page count (for PDF)
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record (token only used for auth, not stored)
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if doc:
doc.file_path = str(file_path)
session.add(doc)
# Convert PDF to images for annotation
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling if requested
auto_label_started = False
if auto_label:
# Auto-labeling will be triggered by a background task
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
auto_label_started = True
return DocumentUploadResponse(
document_id=document_id,
filename=file.filename,
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@router.get(
"",
response_model=DocumentListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List documents",
description="List all documents for the current admin.",
)
async def list_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
upload_source: Annotated[
str | None,
Query(description="Filter by upload source (ui or api)"),
] = None,
has_annotations: Annotated[
bool | None,
Query(description="Filter by annotation presence"),
] = None,
auto_label_status: Annotated[
str | None,
Query(description="Filter by auto-label status"),
] = None,
batch_id: Annotated[
str | None,
Query(description="Filter by batch ID"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> DocumentListResponse:
"""List documents."""
# Validate status
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Validate upload_source
if upload_source and upload_source not in ("ui", "api"):
raise HTTPException(
status_code=400,
detail=f"Invalid upload_source: {upload_source}",
)
# Validate auto_label_status
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid auto_label_status: {auto_label_status}",
)
documents, total = db.get_documents_by_token(
admin_token=admin_token,
status=status,
upload_source=upload_source,
has_annotations=has_annotations,
auto_label_status=auto_label_status,
batch_id=batch_id,
limit=limit,
offset=offset,
)
# Get annotation counts and build items
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# Determine if document can be annotated (not locked)
can_annotate = True
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
from datetime import datetime, timezone
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
items.append(
DocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
file_size=doc.file_size,
page_count=doc.page_count,
status=DocumentStatus(doc.status),
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
)
)
return DocumentListResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/stats",
response_model=DocumentStatsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get document statistics",
description="Get document count by status.",
)
async def get_document_stats(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentStatsResponse:
"""Get document statistics."""
counts = db.count_documents_by_status(admin_token)
return DocumentStatsResponse(
total=sum(counts.values()),
pending=counts.get("pending", 0),
auto_labeling=counts.get("auto_labeling", 0),
labeled=counts.get("labeled", 0),
exported=counts.get("exported", 0),
)
@router.get(
"/{document_id}",
response_model=DocumentDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Get document detail",
description="Get document details with annotations.",
)
async def get_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentDetailResponse:
"""Get document details."""
_validate_uuid(document_id, "document_id")
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
# Generate image URLs
image_urls = []
for page in range(1, document.page_count + 1):
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
# Determine if document can be annotated (not locked)
can_annotate = True
annotation_lock_until = None
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
from datetime import datetime, timezone
annotation_lock_until = document.annotation_lock_until
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
# Get CSV field values if available
csv_field_values = None
if hasattr(document, 'csv_field_values') and document.csv_field_values:
csv_field_values = document.csv_field_values
# Get training history (Phase 5)
training_history = []
training_links = db.get_document_training_tasks(document.document_id)
for link in training_links:
# Get task details
task = db.get_training_task(str(link.task_id))
if task:
# Build metrics
metrics = None
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
training_history.append(
TrainingHistoryItem(
task_id=str(link.task_id),
name=task.name,
trained_at=link.created_at,
model_metrics=metrics,
)
)
return DocumentDetailResponse(
document_id=str(document.document_id),
filename=document.filename,
file_size=document.file_size,
content_type=document.content_type,
page_count=document.page_count,
status=DocumentStatus(document.status),
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
annotations=annotations,
image_urls=image_urls,
training_history=training_history,
created_at=document.created_at,
updated_at=document.updated_at,
)
@router.delete(
"/{document_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete document",
description="Delete a document and its annotations.",
)
async def delete_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete file
file_path = Path(document.file_path)
if file_path.exists():
file_path.unlink()
# Delete images
images_dir = ADMIN_IMAGES_DIR / document_id
if images_dir.exists():
import shutil
shutil.rmtree(images_dir)
# Delete from database
db.delete_document(document_id)
return {
"status": "deleted",
"document_id": document_id,
"message": "Document deleted successfully",
}
@router.patch(
"/{document_id}/status",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document status",
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
)
async def update_document_status(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str,
Query(description="New status"),
],
) -> dict:
"""Update document status.
When status is set to 'labeled', the annotations are automatically
saved to PostgreSQL documents/field_results tables for consistency
with CLI auto-label workflow.
"""
_validate_uuid(document_id, "document_id")
# Validate status
if status not in ("pending", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# If marking as labeled, save annotations to PostgreSQL DocumentDB
db_save_result = None
if status == "labeled":
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document
annotations = db.get_annotations_for_document(document_id)
if annotations:
db_save_result = save_manual_annotations_to_document_db(
document=document,
annotations=annotations,
db=db,
)
db.update_document_status(document_id, status)
response = {
"status": "updated",
"document_id": document_id,
"new_status": status,
"message": "Document status updated",
}
# Include PostgreSQL save result if applicable
if db_save_result:
response["document_db_saved"] = db_save_result.get("success", False)
response["fields_saved"] = db_save_result.get("fields_saved", 0)
return response
return router

View File

@@ -0,0 +1,184 @@
"""
Admin Document Lock Routes
FastAPI endpoints for annotation lock management.
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
AnnotationLockRequest,
AnnotationLockResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_locks_router() -> APIRouter:
"""Create annotation locks router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
@router.post(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Document already locked"},
},
summary="Acquire annotation lock",
description="Acquire a lock on a document to prevent concurrent annotation edits.",
)
async def acquire_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Acquire annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to acquire lock
updated_doc = db.acquire_annotation_lock(
document_id=document_id,
admin_token=admin_token,
duration_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Document is already locked. Please try again later.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock acquired for {request.duration_seconds} seconds",
)
@router.delete(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Release annotation lock",
description="Release the annotation lock on a document.",
)
async def release_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
force: Annotated[
bool,
Query(description="Force release (admin override)"),
] = False,
) -> AnnotationLockResponse:
"""Release annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Release lock
updated_doc = db.release_annotation_lock(
document_id=document_id,
admin_token=admin_token,
force=force,
)
if updated_doc is None:
raise HTTPException(
status_code=404,
detail="Failed to release lock",
)
return AnnotationLockResponse(
document_id=document_id,
locked=False,
lock_expires_at=None,
message="Lock released successfully",
)
@router.patch(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
},
summary="Extend annotation lock",
description="Extend an existing annotation lock.",
)
async def extend_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Extend annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to extend lock
updated_doc = db.extend_annotation_lock(
document_id=document_id,
admin_token=admin_token,
additional_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock extended by {request.duration_seconds} seconds",
)
return router

View File

@@ -0,0 +1,28 @@
"""
Admin Training API Routes
FastAPI endpoints for training task management and scheduling.
"""
from fastapi import APIRouter
from ._utils import _validate_uuid
from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
def create_training_router() -> APIRouter:
"""Create training API router."""
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
register_task_routes(router)
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
return router
__all__ = ["create_training_router", "_validate_uuid"]

View File

@@ -0,0 +1,16 @@
"""Shared utilities for training routes."""
from uuid import UUID
from fastapi import HTTPException
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)

View File

@@ -0,0 +1,209 @@
"""Training Dataset Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetDetailResponse,
DatasetDocumentItem,
DatasetListItem,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingStatus,
TrainingTaskResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_dataset_routes(router: APIRouter) -> None:
"""Register dataset endpoints on the router."""
@router.post(
"/datasets",
response_model=DatasetResponse,
summary="Create training dataset",
description="Create a dataset from selected documents with train/val/test splits.",
)
async def create_dataset(
request: DatasetCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
dataset = db.create_dataset(
name=request.name,
description=request.description,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
)
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=request.document_ids,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
admin_images_dir=Path("data/admin_images"),
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return DatasetResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
status="ready",
message="Dataset created successfully",
)
@router.get(
"/datasets",
response_model=DatasetListResponse,
summary="List datasets",
)
async def list_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
return DatasetListResponse(
total=total,
limit=limit,
offset=offset,
datasets=[
DatasetListItem(
dataset_id=str(d.dataset_id),
name=d.name,
description=d.description,
status=d.status,
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
created_at=d.created_at,
)
for d in datasets
],
)
@router.get(
"/datasets/{dataset_id}",
response_model=DatasetDetailResponse,
summary="Get dataset detail",
)
async def get_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetDetailResponse:
"""Get dataset details with document list."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
docs = db.get_dataset_documents(dataset_id)
return DatasetDetailResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
description=dataset.description,
status=dataset.status,
train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio,
seed=dataset.seed,
total_documents=dataset.total_documents,
total_images=dataset.total_images,
total_annotations=dataset.total_annotations,
dataset_path=dataset.dataset_path,
error_message=dataset.error_message,
documents=[
DatasetDocumentItem(
document_id=str(d.document_id),
split=d.split,
page_count=d.page_count,
annotation_count=d.annotation_count,
)
for d in docs
],
created_at=dataset.created_at,
updated_at=dataset.updated_at,
)
@router.delete(
"/datasets/{dataset_id}",
summary="Delete dataset",
)
async def delete_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a dataset and its files."""
import shutil
from pathlib import Path
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.dataset_path:
dataset_dir = Path(dataset.dataset_path)
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
db.delete_dataset(dataset_id)
return {"message": "Dataset deleted"}
@router.post(
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
)
async def train_from_dataset(
dataset_id: str,
request: DatasetTrainRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.status != "ready":
raise HTTPException(
status_code=400,
detail=f"Dataset is not ready (status: {dataset.status})",
)
config_dict = request.config.model_dump()
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
)

View File

@@ -0,0 +1,211 @@
"""Training Documents and Models Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelMetrics,
TrainingDocumentItem,
TrainingDocumentsResponse,
TrainingModelItem,
TrainingModelsResponse,
TrainingStatus,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_document_routes(router: APIRouter) -> None:
"""Register training document and model endpoints on the router."""
@router.get(
"/documents",
response_model=TrainingDocumentsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get documents for training",
description="Get labeled documents available for training with filtering options.",
)
async def get_training_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
] = True,
min_annotation_count: Annotated[
int | None,
Query(ge=1, description="Minimum annotation count"),
] = None,
exclude_used_in_training: Annotated[
bool,
Query(description="Exclude documents already used in training"),
] = False,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
documents, total = db.get_documents_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
min_annotation_count=min_annotation_count,
exclude_used_in_training=exclude_used_in_training,
limit=limit,
offset=offset,
)
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
sources = {"manual": 0, "auto": 0}
for ann in annotations:
if ann.source in sources:
sources[ann.source] += 1
training_links = db.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
)
)
return TrainingDocumentsResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/models/{task_id}/download",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Model not found"},
},
summary="Download trained model",
description="Download trained model weights file.",
)
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
from pathlib import Path
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if not task.model_path:
raise HTTPException(
status_code=404,
detail="Model file not available for this task",
)
model_path = Path(task.model_path)
if not model_path.exists():
raise HTTPException(
status_code=404,
detail="Model file not found on disk",
)
return FileResponse(
path=str(model_path),
media_type="application/octet-stream",
filename=f"{task.name}_model.pt",
)
@router.get(
"/models",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
)
async def get_training_models(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
offset=offset,
)
items = []
for task in tasks:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
download_url = None
if task.model_path and task.status == "completed":
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
items.append(
TrainingModelItem(
task_id=str(task.task_id),
name=task.name,
status=TrainingStatus(task.status),
document_count=task.document_count,
created_at=task.created_at,
completed_at=task.completed_at,
metrics=metrics,
model_path=task.model_path,
download_url=download_url,
)
)
return TrainingModelsResponse(
total=total,
limit=limit,
offset=offset,
models=items,
)

View File

@@ -0,0 +1,121 @@
"""Training Export Endpoints."""
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ExportRequest,
ExportResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def register_export_routes(router: APIRouter) -> None:
"""Register export endpoints on the router."""
@router.post(
"/export",
response_model=ExportResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Export annotations",
description="Export annotations in YOLO format for training.",
)
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ExportResponse:
"""Export annotations for training."""
from pathlib import Path
import shutil
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
status_code=400,
detail=f"Unsupported export format: {request.format}",
)
documents = db.get_labeled_documents_for_export(admin_token)
if not documents:
raise HTTPException(
status_code=400,
detail="No labeled documents available for export",
)
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
total_docs = len(documents)
train_count = int(total_docs * request.split_ratio)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
from inference.data.admin_models import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
(export_dir / "data.yaml").write_text(yaml_content)
return ExportResponse(
status="completed",
export_path=str(export_dir),
total_images=total_images,
total_annotations=total_annotations,
train_count=len(train_docs),
val_count=len(val_docs),
message=f"Exported {total_images} images with {total_annotations} annotations",
)

View File

@@ -0,0 +1,263 @@
"""Training Task Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
TrainingLogItem,
TrainingLogsResponse,
TrainingStatus,
TrainingTaskCreate,
TrainingTaskDetailResponse,
TrainingTaskItem,
TrainingTaskListResponse,
TrainingTaskResponse,
TrainingType,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_task_routes(router: APIRouter) -> None:
"""Register training task endpoints on the router."""
@router.post(
"/tasks",
response_model=TrainingTaskResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Create training task",
description="Create a new training task.",
)
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
config_dict = request.config.model_dump() if request.config else {}
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
description=request.description,
config=config_dict,
scheduled_at=request.scheduled_at,
cron_expression=request.cron_expression,
is_recurring=bool(request.cron_expression),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
message="Training task created successfully",
)
@router.get(
"/tasks",
response_model=TrainingTaskListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List training tasks",
description="List all training tasks.",
)
async def list_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingTaskListResponse:
"""List training tasks."""
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
if status and status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status,
limit=limit,
offset=offset,
)
items = [
TrainingTaskItem(
task_id=str(task.task_id),
name=task.name,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
scheduled_at=task.scheduled_at,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in tasks
]
return TrainingTaskListResponse(
total=total,
limit=limit,
offset=offset,
tasks=items,
)
@router.get(
"/tasks/{task_id}",
response_model=TrainingTaskDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training task detail",
description="Get training task details.",
)
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
return TrainingTaskDetailResponse(
task_id=str(task.task_id),
name=task.name,
description=task.description,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
config=task.config,
scheduled_at=task.scheduled_at,
cron_expression=task.cron_expression,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
error_message=task.error_message,
result_metrics=task.result_metrics,
model_path=task.model_path,
created_at=task.created_at,
)
@router.post(
"/tasks/{task_id}/cancel",
response_model=TrainingTaskResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
},
summary="Cancel training task",
description="Cancel a pending or scheduled training task.",
)
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if task.status not in ("pending", "scheduled"):
raise HTTPException(
status_code=409,
detail=f"Cannot cancel task with status: {task.status}",
)
success = db.cancel_training_task(task_id)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to cancel training task",
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.CANCELLED,
message="Training task cancelled successfully",
)
@router.get(
"/tasks/{task_id}/logs",
response_model=TrainingLogsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training logs",
description="Get training task logs.",
)
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingLogsResponse:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
logs = db.get_training_logs(task_id, limit, offset)
items = [
TrainingLogItem(
level=log.level,
message=log.message,
details=log.details,
created_at=log.created_at,
)
for log in logs
]
return TrainingLogsResponse(
task_id=task_id,
logs=items,
)

View File

@@ -0,0 +1,236 @@
"""
Batch Upload API Routes
Endpoints for batch uploading documents via ZIP files with CSV metadata.
"""
import io
import logging
import zipfile
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse
from inference.data.admin_db import AdminDB
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
@router.post("/upload")
async def upload_batch(
file: UploadFile = File(...),
upload_source: str = Form(default="ui"),
async_mode: bool = Form(default=True),
auto_label: bool = Form(default=True),
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Upload a batch of documents via ZIP file.
The ZIP file can contain:
- Multiple PDF files
- Optional CSV file with field values for auto-labeling
CSV format:
- Required column: DocumentId (matches PDF filename without extension)
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
Args:
file: ZIP file upload
upload_source: Upload source (ui or api)
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch upload result with batch_id and status
"""
if not file.filename.lower().endswith('.zip'):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Check compressed size
if file.size and file.size > MAX_COMPRESSED_SIZE:
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
raise HTTPException(
status_code=400,
detail=f"File size exceeds {max_mb:.0f}MB limit"
)
try:
# Read file content
zip_content = await file.read()
# Additional security validation before processing
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
# Quick validation of ZIP structure
test_zip.testzip()
except zipfile.BadZipFile:
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
if async_mode:
# Async mode: Queue task and return immediately
from uuid import uuid4
batch_id = uuid4()
# Create batch task for background processing
task = BatchTask(
batch_id=batch_id,
admin_token=admin_token,
zip_content=zip_content,
zip_filename=file.filename,
upload_source=upload_source,
auto_label=auto_label,
created_at=datetime.utcnow(),
)
# Submit to queue
queue = get_batch_queue()
if not queue.submit(task):
raise HTTPException(
status_code=503,
detail="Processing queue is full. Please try again later."
)
logger.info(
f"Batch upload queued: batch_id={batch_id}, "
f"filename={file.filename}, async_mode=True"
)
# Return 202 Accepted with batch_id and status URL
return JSONResponse(
status_code=202,
content={
"status": "accepted",
"batch_id": str(batch_id),
"message": "Batch upload queued for processing",
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
"queue_depth": queue.get_queue_depth(),
}
)
else:
# Sync mode: Process immediately and return results
service = BatchUploadService(admin_db)
result = service.process_zip_upload(
admin_token=admin_token,
zip_filename=file.filename,
zip_content=zip_content,
upload_source=upload_source,
)
logger.info(
f"Batch upload completed: batch_id={result.get('batch_id')}, "
f"status={result.get('status')}, files={result.get('successful_files')}"
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing batch upload: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail="Failed to process batch upload. Please contact support."
)
@router.get("/status/{batch_id}")
async def get_batch_status(
batch_id: str,
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Get batch upload status and file processing details.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch status with file processing details
"""
# Validate UUID format
try:
batch_uuid = UUID(batch_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid batch ID format")
# Check batch exists and verify ownership
batch = admin_db.get_batch_upload(batch_uuid)
if not batch:
raise HTTPException(status_code=404, detail="Batch not found")
# CRITICAL: Verify ownership
if batch.admin_token != admin_token:
raise HTTPException(
status_code=403,
detail="You do not have access to this batch"
)
# Now safe to return details
service = BatchUploadService(admin_db)
result = service.get_batch_status(batch_id)
return result
@router.get("/list")
async def list_batch_uploads(
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
limit: int = 50,
offset: int = 0,
) -> dict:
"""List batch uploads for the current admin token.
Args:
admin_token: Admin authentication token
admin_db: Admin database interface
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of batch uploads
"""
# Validate pagination parameters
if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
# Get batch uploads filtered by admin token
batches, total = admin_db.get_batch_uploads_by_token(
admin_token=admin_token,
limit=limit,
offset=offset,
)
return {
"batches": [
{
"batch_id": str(b.batch_id),
"filename": b.filename,
"status": b.status,
"total_files": b.total_files,
"successful_files": b.successful_files,
"failed_files": b.failed_files,
"created_at": b.created_at.isoformat() if b.created_at else None,
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
}
for b in batches
],
"total": total,
"limit": limit,
"offset": offset,
}

View File

@@ -0,0 +1,16 @@
"""
Public API v1
Customer-facing endpoints for inference, async processing, and labeling.
"""
from inference.web.api.v1.public.inference import create_inference_router
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.api.v1.public.labeling import create_labeling_router
__all__ = [
"create_inference_router",
"create_async_router",
"set_async_service",
"create_labeling_router",
]

View File

@@ -0,0 +1,372 @@
"""
Async API Routes
FastAPI endpoints for async invoice processing.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.dependencies import (
ApiKeyDep,
AsyncDBDep,
PollRateLimitDep,
SubmitRateLimitDep,
)
from inference.web.schemas.inference import (
AsyncRequestItem,
AsyncRequestsListResponse,
AsyncResultResponse,
AsyncStatus,
AsyncStatusResponse,
AsyncSubmitResponse,
DetectionResult,
InferenceResult,
)
from inference.web.schemas.common import ErrorResponse
def _validate_request_id(request_id: str) -> None:
"""Validate that request_id is a valid UUID format."""
try:
UUID(request_id)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid request ID format. Must be a valid UUID.",
)
logger = logging.getLogger(__name__)
# Global reference to async processing service (set during app startup)
_async_service = None
def set_async_service(service) -> None:
"""Set the async processing service instance."""
global _async_service
_async_service = service
def get_async_service():
"""Get the async processing service instance."""
if _async_service is None:
raise RuntimeError("AsyncProcessingService not initialized")
return _async_service
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
"""Create async API router."""
router = APIRouter(prefix="/async", tags=["Async Processing"])
@router.post(
"/submit",
response_model=AsyncSubmitResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid API key"},
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
503: {"model": ErrorResponse, "description": "Queue full"},
},
summary="Submit PDF for async processing",
description="Submit a PDF or image file for asynchronous processing. "
"Returns a request_id that can be used to poll for results.",
)
async def submit_document(
api_key: SubmitRateLimitDep,
file: UploadFile = File(..., description="PDF or image file to process"),
) -> AsyncSubmitResponse:
"""Submit a document for async processing."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate file extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Check file size (get from config via service)
service = get_async_service()
max_size = service._async_config.max_file_size_mb * 1024 * 1024
if len(content) > max_size:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: "
f"{service._async_config.max_file_size_mb}MB",
)
# Submit request
result = service.submit_request(
api_key=api_key,
file_content=content,
filename=file.filename,
content_type=file.content_type or "application/octet-stream",
)
if not result.success:
if "queue" in (result.error or "").lower():
raise HTTPException(status_code=503, detail=result.error)
raise HTTPException(status_code=500, detail=result.error)
return AsyncSubmitResponse(
status="accepted",
message="Request submitted for processing",
request_id=result.request_id,
estimated_wait_seconds=result.estimated_wait_seconds,
poll_url=f"/api/v1/async/status/{result.request_id}",
)
@router.get(
"/status/{request_id}",
response_model=AsyncStatusResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get request status",
description="Get the current processing status of an async request.",
)
async def get_status(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncStatusResponse:
"""Get the status of an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Get queue position for pending requests
position = None
if request.status == "pending":
position = db.get_queue_position(request_id)
# Build result URL for completed requests
result_url = None
if request.status == "completed":
result_url = f"/api/v1/async/result/{request_id}"
return AsyncStatusResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
filename=request.filename,
created_at=request.created_at,
started_at=request.started_at,
completed_at=request.completed_at,
position_in_queue=position,
error_message=request.error_message,
result_url=result_url,
)
@router.get(
"/result/{request_id}",
response_model=AsyncResultResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Request not completed"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get extraction results",
description="Get the extraction results for a completed async request.",
)
async def get_result(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncResultResponse:
"""Get the results of a completed async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Check if completed or failed
if request.status not in ("completed", "failed"):
raise HTTPException(
status_code=409,
detail=f"Request not yet completed. Current status: {request.status}",
)
# Build inference result from stored data
inference_result = None
if request.result:
# Convert detections to DetectionResult objects
detections = []
for d in request.result.get("detections", []):
detections.append(DetectionResult(
field=d.get("field", ""),
confidence=d.get("confidence", 0.0),
bbox=d.get("bbox", [0, 0, 0, 0]),
))
inference_result = InferenceResult(
document_id=request.result.get("document_id", str(request.request_id)[:8]),
success=request.result.get("success", False),
document_type=request.result.get("document_type", "invoice"),
fields=request.result.get("fields", {}),
confidence=request.result.get("confidence", {}),
detections=detections,
processing_time_ms=request.processing_time_ms or 0.0,
errors=request.result.get("errors", []),
)
# Build visualization URL
viz_url = None
if request.visualization_path:
viz_url = f"/api/v1/results/{request.visualization_path}"
return AsyncResultResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
processing_time_ms=request.processing_time_ms or 0.0,
result=inference_result,
visualization_url=viz_url,
)
@router.get(
"/requests",
response_model=AsyncRequestsListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
},
summary="List requests",
description="List all async requests for the authenticated API key.",
)
async def list_requests(
api_key: ApiKeyDep,
db: AsyncDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (pending, processing, completed, failed)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Maximum number of results"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Pagination offset"),
] = 0,
) -> AsyncRequestsListResponse:
"""List all requests for the authenticated API key."""
# Validate status filter
if status and status not in ("pending", "processing", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid status filter: {status}. "
"Must be one of: pending, processing, completed, failed",
)
# Get requests from database
requests, total = db.get_requests_by_api_key(
api_key=api_key,
status=status,
limit=limit,
offset=offset,
)
# Convert to response items
items = [
AsyncRequestItem(
request_id=str(r.request_id),
status=AsyncStatus(r.status),
filename=r.filename,
file_size=r.file_size,
created_at=r.created_at,
completed_at=r.completed_at,
)
for r in requests
]
return AsyncRequestsListResponse(
total=total,
limit=limit,
offset=offset,
requests=items,
)
@router.delete(
"/requests/{request_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
},
summary="Cancel/delete request",
description="Cancel a pending request or delete a completed/failed request.",
)
async def delete_request(
request_id: str,
api_key: ApiKeyDep,
db: AsyncDBDep,
) -> dict:
"""Delete or cancel an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Cannot delete processing requests
if request.status == "processing":
raise HTTPException(
status_code=409,
detail="Cannot delete a request that is currently processing",
)
# Delete from database (will cascade delete related records)
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM async_requests WHERE request_id = %s",
(request_id,),
)
conn.commit()
return {
"status": "deleted",
"request_id": request_id,
"message": "Request deleted successfully",
}
return router

View File

@@ -0,0 +1,183 @@
"""
Inference API Routes
FastAPI route definitions for the inference API.
"""
from __future__ import annotations
import logging
import shutil
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from inference.web.schemas.inference import (
DetectionResult,
HealthResponse,
InferenceResponse,
InferenceResult,
)
from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from inference.web.services import InferenceService
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
def create_inference_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with inference endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["inference"])
@router.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Check service health status."""
return HealthResponse(
status="healthy",
model_loaded=inference_service.is_initialized,
gpu_available=inference_service.gpu_available,
version="1.0.0",
)
@router.post(
"/infer",
response_model=InferenceResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
)
async def infer_document(
file: UploadFile = File(..., description="PDF or image file to process"),
) -> InferenceResponse:
"""
Process a document and extract invoice fields.
Accepts PDF or image files (PNG, JPG, JPEG).
Returns extracted field values with confidence scores.
"""
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Generate document ID
doc_id = str(uuid.uuid4())[:8]
# Save uploaded file
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
try:
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save uploaded file",
)
try:
# Process based on file type
if file_ext == ".pdf":
service_result = inference_service.process_pdf(
upload_path, document_id=doc_id
)
else:
service_result = inference_service.process_image(
upload_path, document_id=doc_id
)
# Build response
viz_url = None
if service_result.visualization_path:
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
document_type=service_result.document_type,
fields=service_result.fields,
confidence=service_result.confidence,
detections=[
DetectionResult(**d) for d in service_result.detections
],
processing_time_ms=service_result.processing_time_ms,
visualization_url=viz_url,
errors=service_result.errors,
)
return InferenceResponse(
status="success" if service_result.success else "partial",
message=f"Processed document {doc_id}",
result=inference_result,
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
finally:
# Cleanup uploaded file
upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}")
async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
return FileResponse(
path=file_path,
media_type="image/png",
filename=filename,
)
@router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict:
"""Delete a result file."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
file_path.unlink()
return {"status": "deleted", "filename": filename}
return router

View File

@@ -0,0 +1,203 @@
"""
Labeling API Routes
FastAPI endpoints for pre-labeling documents with expected field values.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from inference.data.admin_db import AdminDB
from inference.web.schemas.labeling import PreLabelResponse
from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from inference.web.services import InferenceService
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
# Storage directory for pre-label uploads (legacy, now uses storage_config)
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def get_admin_db() -> AdminDB:
"""Get admin database instance."""
return AdminDB()
def create_labeling_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with labeling endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["labeling"])
# Ensure upload directory exists
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@router.post(
"/pre-label",
response_model=PreLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
summary="Pre-label document with expected values",
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
)
async def pre_label(
file: UploadFile = File(..., description="PDF or image file to process"),
field_values: str = Form(
...,
description="JSON object with expected field values. "
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
),
db: AdminDB = Depends(get_admin_db),
) -> PreLabelResponse:
"""
Upload a document with expected field values for pre-labeling.
Returns document_id which can be used to retrieve results later.
Example field_values JSON:
```json
{
"InvoiceNumber": "12345",
"Amount": "1500.00",
"Bankgiro": "123-4567",
"OCR": "1234567890"
}
```
"""
# Parse field_values JSON
try:
expected_values = json.loads(field_values)
if not isinstance(expected_values, dict):
raise ValueError("field_values must be a JSON object")
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON in field_values: {e}",
)
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read file",
)
# Get page count for PDF
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record with field_values
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
upload_source="api",
csv_field_values=expected_values,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save file",
)
# Update file path in database
db.update_document_file_path(document_id, str(file_path))
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="pending",
)
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
return PreLabelResponse(document_id=document_id)
return router

View File

@@ -0,0 +1,913 @@
"""
FastAPI Application Factory
Creates and configures the FastAPI application.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from inference.web.services import InferenceService
# Public API imports
from inference.web.api.v1.public import (
create_inference_router,
create_async_router,
set_async_service,
create_labeling_router,
)
# Async processing imports
from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.dependencies import init_dependencies
from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
create_auth_router,
create_documents_router,
create_locks_router,
create_training_router,
)
from inference.web.core.scheduler import start_scheduler, stop_scheduler
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
# Batch upload imports
from inference.web.api.v1.batch.routes import router as batch_upload_router
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
from inference.data.admin_db import AdminDB
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
def create_app(config: AppConfig | None = None) -> FastAPI:
"""
Create and configure FastAPI application.
Args:
config: Application configuration. Uses default if not provided.
Returns:
Configured FastAPI application
"""
config = config or default_config
# Create inference service
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
)
# Create async processing components
async_db = AsyncRequestDB()
rate_limiter = RateLimiter(async_db)
task_queue = AsyncTaskQueue(
max_size=config.async_processing.queue_max_size,
worker_count=config.async_processing.worker_count,
)
async_service = AsyncProcessingService(
inference_service=inference_service,
db=async_db,
queue=task_queue,
rate_limiter=rate_limiter,
async_config=config.async_processing,
storage_config=config.storage,
)
# Initialize dependencies for FastAPI
init_dependencies(async_db, rate_limiter)
set_async_service(async_service)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
logger.info("Inference service ready")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
# Continue anyway - service will retry on first request
# Start async processing service
try:
async_service.start()
logger.info("Async processing service started")
except Exception as e:
logger.error(f"Failed to start async processing: {e}")
# Start batch upload queue
try:
admin_db = AdminDB()
batch_service = BatchUploadService(admin_db)
init_batch_queue(batch_service)
logger.info("Batch upload queue started")
except Exception as e:
logger.error(f"Failed to start batch upload queue: {e}")
# Start training scheduler
try:
start_scheduler()
logger.info("Training scheduler started")
except Exception as e:
logger.error(f"Failed to start training scheduler: {e}")
# Start auto-label scheduler
try:
start_autolabel_scheduler()
logger.info("AutoLabel scheduler started")
except Exception as e:
logger.error(f"Failed to start autolabel scheduler: {e}")
yield
logger.info("Shutting down Invoice Inference API...")
# Stop auto-label scheduler
try:
stop_autolabel_scheduler()
logger.info("AutoLabel scheduler stopped")
except Exception as e:
logger.error(f"Error stopping autolabel scheduler: {e}")
# Stop training scheduler
try:
stop_scheduler()
logger.info("Training scheduler stopped")
except Exception as e:
logger.error(f"Error stopping training scheduler: {e}")
# Stop batch upload queue
try:
shutdown_batch_queue()
logger.info("Batch upload queue stopped")
except Exception as e:
logger.error(f"Error stopping batch upload queue: {e}")
# Stop async processing service
try:
async_service.stop(timeout=30.0)
logger.info("Async processing service stopped")
except Exception as e:
logger.error(f"Error stopping async service: {e}")
# Close database connection
try:
async_db.close()
logger.info("Database connection closed")
except Exception as e:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
app = FastAPI(
title="Invoice Field Extraction API",
description="""
REST API for extracting fields from Swedish invoices.
## Features
- YOLO-based field detection
- OCR text extraction
- Field normalization and validation
- Visualization of detections
## Supported Fields
- InvoiceNumber
- InvoiceDate
- InvoiceDueDate
- OCR (reference number)
- Bankgiro
- Plusgiro
- Amount
- supplier_org_number (Swedish organization number)
- customer_number
- payment_line (machine-readable payment code)
""",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files for results
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
app.mount(
"/static/results",
StaticFiles(directory=str(config.storage.result_dir)),
name="results",
)
# Include public API routes
inference_router = create_inference_router(inference_service, config.storage)
app.include_router(inference_router)
async_router = create_async_router(config.storage.allowed_extensions)
app.include_router(async_router, prefix="/api/v1")
labeling_router = create_labeling_router(inference_service, config.storage)
app.include_router(labeling_router)
# Include admin API routes
auth_router = create_auth_router()
app.include_router(auth_router, prefix="/api/v1")
documents_router = create_documents_router(config.storage)
app.include_router(documents_router, prefix="/api/v1")
locks_router = create_locks_router()
app.include_router(locks_router, prefix="/api/v1")
annotation_router = create_annotation_router()
app.include_router(annotation_router, prefix="/api/v1")
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
"""Serve the web UI."""
return get_html_ui()
return app
def get_html_ui() -> str:
"""Generate HTML UI for the web application."""
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Invoice Field Extraction</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
color: white;
margin-bottom: 30px;
}
header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
}
header p {
opacity: 0.9;
font-size: 1.1rem;
}
.main-content {
display: flex;
flex-direction: column;
gap: 20px;
}
.card {
background: white;
border-radius: 16px;
padding: 24px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
font-size: 1.3rem;
display: flex;
align-items: center;
gap: 10px;
}
.upload-card {
display: flex;
align-items: center;
gap: 20px;
flex-wrap: wrap;
}
.upload-card h2 {
margin-bottom: 0;
white-space: nowrap;
}
.upload-area {
border: 2px dashed #ddd;
border-radius: 10px;
padding: 15px 25px;
text-align: center;
cursor: pointer;
transition: all 0.3s;
background: #fafafa;
flex: 1;
min-width: 200px;
}
.upload-area:hover, .upload-area.dragover {
border-color: #667eea;
background: #f0f4ff;
}
.upload-area.has-file {
border-color: #10b981;
background: #ecfdf5;
}
.upload-icon {
font-size: 24px;
display: inline;
margin-right: 8px;
}
.upload-area p {
color: #666;
margin: 0;
display: inline;
}
.upload-area small {
color: #999;
display: block;
margin-top: 5px;
}
#file-input {
display: none;
}
.file-name {
margin-top: 15px;
padding: 10px 15px;
background: #e0f2fe;
border-radius: 8px;
color: #0369a1;
font-weight: 500;
}
.btn {
display: inline-block;
padding: 12px 24px;
border: none;
border-radius: 10px;
font-size: 0.9rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn-primary:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.loading {
display: none;
align-items: center;
gap: 10px;
}
.loading.active {
display: flex;
}
.spinner {
width: 24px;
height: 24px;
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.results {
display: none;
}
.results.active {
display: block;
}
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
padding-bottom: 15px;
border-bottom: 2px solid #eee;
}
.result-status {
padding: 6px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
}
.result-status.success {
background: #dcfce7;
color: #166534;
}
.result-status.partial {
background: #fef3c7;
color: #92400e;
}
.result-status.error {
background: #fee2e2;
color: #991b1b;
}
.fields-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 12px;
}
.field-item {
padding: 12px;
background: #f8fafc;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.field-item label {
display: block;
font-size: 0.75rem;
color: #64748b;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.field-item .value {
font-size: 1.1rem;
font-weight: 600;
color: #1e293b;
}
.field-item .confidence {
font-size: 0.75rem;
color: #10b981;
margin-top: 2px;
}
.visualization {
margin-top: 20px;
}
.visualization img {
width: 100%;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
}
.processing-time {
text-align: center;
color: #64748b;
font-size: 0.9rem;
margin-top: 15px;
}
.cross-validation {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 10px;
padding: 15px;
margin-top: 20px;
}
.cross-validation h3 {
margin: 0 0 10px 0;
color: #334155;
font-size: 1rem;
}
.cv-status {
font-weight: 600;
padding: 8px 12px;
border-radius: 6px;
margin-bottom: 10px;
display: inline-block;
}
.cv-status.valid {
background: #dcfce7;
color: #166534;
}
.cv-status.invalid {
background: #fef3c7;
color: #92400e;
}
.cv-details {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 10px;
}
.cv-item {
background: white;
border: 1px solid #e2e8f0;
border-radius: 6px;
padding: 6px 12px;
font-size: 0.85rem;
display: flex;
align-items: center;
gap: 6px;
}
.cv-item.match {
border-color: #86efac;
background: #f0fdf4;
}
.cv-item.mismatch {
border-color: #fca5a5;
background: #fef2f2;
}
.cv-icon {
font-weight: bold;
}
.cv-item.match .cv-icon {
color: #16a34a;
}
.cv-item.mismatch .cv-icon {
color: #dc2626;
}
.cv-summary {
margin-top: 10px;
font-size: 0.8rem;
color: #64748b;
}
.error-message {
background: #fee2e2;
color: #991b1b;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
footer {
text-align: center;
color: white;
opacity: 0.8;
margin-top: 30px;
font-size: 0.9rem;
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>📄 Invoice Field Extraction</h1>
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
</header>
<div class="main-content">
<!-- Upload Section - Compact -->
<div class="card upload-card">
<h2>📤 Upload</h2>
<div class="upload-area" id="upload-area">
<span class="upload-icon">📁</span>
<p>Drag & drop or <strong>click to browse</strong></p>
<small>PDF, PNG, JPG (max 50MB)</small>
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
</div>
<div class="file-name" id="file-name" style="display: none;"></div>
<button class="btn btn-primary" id="submit-btn" disabled>
🚀 Extract
</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p>Processing...</p>
</div>
</div>
<!-- Results Section - Full Width -->
<div class="card">
<h2>📊 Extraction Results</h2>
<div id="placeholder" style="text-align: center; padding: 30px; color: #999;">
<div style="font-size: 48px; margin-bottom: 10px;">🔍</div>
<p>Upload a document to see extraction results</p>
</div>
<div class="results" id="results">
<div class="result-header">
<span>Document: <strong id="doc-id"></strong></span>
<span class="result-status" id="result-status"></span>
</div>
<div class="fields-grid" id="fields-grid"></div>
<div class="processing-time" id="processing-time"></div>
<div class="cross-validation" id="cross-validation" style="display: none;"></div>
<div class="error-message" id="error-message" style="display: none;"></div>
<div class="visualization" id="visualization" style="display: none;">
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
<img id="viz-image" src="" alt="Detection visualization">
</div>
</div>
</div>
</div>
<footer>
<p>Powered by ColaCoder</p>
</footer>
</div>
<script>
const uploadArea = document.getElementById('upload-area');
const fileInput = document.getElementById('file-input');
const fileName = document.getElementById('file-name');
const submitBtn = document.getElementById('submit-btn');
const loading = document.getElementById('loading');
const placeholder = document.getElementById('placeholder');
const results = document.getElementById('results');
let selectedFile = null;
// Drag and drop handlers
uploadArea.addEventListener('click', () => fileInput.click());
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFile(files[0]);
}
});
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFile(e.target.files[0]);
}
});
function handleFile(file) {
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
const ext = '.' + file.name.split('.').pop().toLowerCase();
if (!validTypes.includes(ext)) {
alert('Please upload a PDF, PNG, or JPG file.');
return;
}
selectedFile = file;
fileName.textContent = `📎 ${file.name}`;
fileName.style.display = 'block';
uploadArea.classList.add('has-file');
submitBtn.disabled = false;
}
submitBtn.addEventListener('click', async () => {
if (!selectedFile) return;
// Show loading
submitBtn.disabled = true;
loading.classList.add('active');
placeholder.style.display = 'none';
results.classList.remove('active');
try {
const formData = new FormData();
formData.append('file', selectedFile);
const response = await fetch('/api/v1/infer', {
method: 'POST',
body: formData,
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Processing failed');
}
displayResults(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('error-message').textContent = error.message;
document.getElementById('error-message').style.display = 'block';
results.classList.add('active');
} finally {
loading.classList.remove('active');
submitBtn.disabled = false;
}
});
function displayResults(data) {
const result = data.result;
// Document ID
document.getElementById('doc-id').textContent = result.document_id;
// Status
const statusEl = document.getElementById('result-status');
statusEl.textContent = result.success ? 'Success' : 'Partial';
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
// Fields
const fieldsGrid = document.getElementById('fields-grid');
fieldsGrid.innerHTML = '';
const fieldOrder = [
'InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR',
'Amount', 'Bankgiro', 'Plusgiro',
'supplier_org_number', 'customer_number', 'payment_line'
];
fieldOrder.forEach(field => {
const value = result.fields[field];
const confidence = result.confidence[field];
if (value !== null && value !== undefined) {
const fieldDiv = document.createElement('div');
fieldDiv.className = 'field-item';
fieldDiv.innerHTML = `
<label>${formatFieldName(field)}</label>
<div class="value">${value}</div>
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
`;
fieldsGrid.appendChild(fieldDiv);
}
});
// Processing time
document.getElementById('processing-time').textContent =
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
// Cross-validation results
const cvDiv = document.getElementById('cross-validation');
if (result.cross_validation) {
const cv = result.cross_validation;
let cvHtml = '<h3>🔍 Cross-Validation (Payment Line)</h3>';
cvHtml += `<div class="cv-status ${cv.is_valid ? 'valid' : 'invalid'}">`;
cvHtml += cv.is_valid ? '✅ Valid' : '⚠️ Mismatch Detected';
cvHtml += '</div>';
cvHtml += '<div class="cv-details">';
if (cv.payment_line_ocr) {
const matchIcon = cv.ocr_match === true ? '' : (cv.ocr_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.ocr_match === true ? 'match' : (cv.ocr_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> OCR: ${cv.payment_line_ocr}</div>`;
}
if (cv.payment_line_amount) {
const matchIcon = cv.amount_match === true ? '' : (cv.amount_match === false ? '' : '');
cvHtml += `<div class="cv-item ${cv.amount_match === true ? 'match' : (cv.amount_match === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> Amount: ${cv.payment_line_amount}</div>`;
}
if (cv.payment_line_account) {
const accountType = cv.payment_line_account_type === 'bankgiro' ? 'Bankgiro' : 'Plusgiro';
const matchField = cv.payment_line_account_type === 'bankgiro' ? cv.bankgiro_match : cv.plusgiro_match;
const matchIcon = matchField === true ? '' : (matchField === false ? '' : '');
cvHtml += `<div class="cv-item ${matchField === true ? 'match' : (matchField === false ? 'mismatch' : '')}">`;
cvHtml += `<span class="cv-icon">${matchIcon}</span> ${accountType}: ${cv.payment_line_account}</div>`;
}
cvHtml += '</div>';
if (cv.details && cv.details.length > 0) {
cvHtml += '<div class="cv-summary">' + cv.details[cv.details.length - 1] + '</div>';
}
cvDiv.innerHTML = cvHtml;
cvDiv.style.display = 'block';
} else {
cvDiv.style.display = 'none';
}
// Visualization
if (result.visualization_url) {
const vizDiv = document.getElementById('visualization');
const vizImg = document.getElementById('viz-image');
vizImg.src = result.visualization_url;
vizDiv.style.display = 'block';
}
// Errors
if (result.errors && result.errors.length > 0) {
document.getElementById('error-message').textContent = result.errors.join(', ');
document.getElementById('error-message').style.display = 'block';
} else {
document.getElementById('error-message').style.display = 'none';
}
results.classList.add('active');
}
function formatFieldName(name) {
const nameMap = {
'InvoiceNumber': 'Invoice Number',
'InvoiceDate': 'Invoice Date',
'InvoiceDueDate': 'Due Date',
'OCR': 'OCR Reference',
'Amount': 'Amount',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'supplier_org_number': 'Supplier Org Number',
'customer_number': 'Customer Number',
'payment_line': 'Payment Line'
};
return nameMap[name] || name.replace(/([A-Z])/g, ' $1').replace(/_/g, ' ').trim();
}
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,113 @@
"""
Web Application Configuration
Centralized configuration for the web application.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI, PATHS
@dataclass(frozen=True)
class ModelConfig:
"""YOLO model configuration."""
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5
use_gpu: bool = True
dpi: int = DEFAULT_DPI
@dataclass(frozen=True)
class ServerConfig:
"""Server configuration."""
host: str = "0.0.0.0"
port: int = 8000
debug: bool = False
reload: bool = False
workers: int = 1
@dataclass(frozen=True)
class StorageConfig:
"""File storage configuration.
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
and avoids storing duplicate files.
"""
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
admin_images_dir: Path = Path("data/admin_images")
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
self.upload_dir.mkdir(parents=True, exist_ok=True)
self.result_dir.mkdir(parents=True, exist_ok=True)
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
@dataclass(frozen=True)
class AsyncConfig:
"""Async processing configuration."""
# Queue settings
queue_max_size: int = 100
worker_count: int = 1
task_timeout_seconds: int = 300
# Rate limiting defaults
default_requests_per_minute: int = 10
default_max_concurrent_jobs: int = 3
default_min_poll_interval_ms: int = 1000
# Storage
result_retention_days: int = 7
temp_upload_dir: Path = Path("uploads/async")
max_file_size_mb: int = 50
# Cleanup
cleanup_interval_hours: int = 1
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary."""
return cls(
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
)
# Default configuration instance
default_config = AppConfig()

View File

@@ -0,0 +1,28 @@
"""
Core Components
Reusable core functionality: authentication, rate limiting, scheduling.
"""
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
from inference.web.core.rate_limiter import RateLimiter
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from inference.web.core.autolabel_scheduler import (
start_autolabel_scheduler,
stop_autolabel_scheduler,
get_autolabel_scheduler,
)
__all__ = [
"validate_admin_token",
"get_admin_db",
"AdminTokenDep",
"AdminDBDep",
"RateLimiter",
"start_scheduler",
"stop_scheduler",
"get_training_scheduler",
"start_autolabel_scheduler",
"stop_autolabel_scheduler",
"get_autolabel_scheduler",
]

View File

@@ -0,0 +1,60 @@
"""
Admin Authentication
FastAPI dependencies for admin token authentication.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from inference.data.admin_db import AdminDB
from inference.data.database import get_session_context
logger = logging.getLogger(__name__)
# Global AdminDB instance
_admin_db: AdminDB | None = None
def get_admin_db() -> AdminDB:
"""Get the AdminDB instance."""
global _admin_db
if _admin_db is None:
_admin_db = AdminDB()
return _admin_db
def reset_admin_db() -> None:
"""Reset the AdminDB instance (for testing)."""
global _admin_db
_admin_db = None
async def validate_admin_token(
x_admin_token: Annotated[str | None, Header()] = None,
admin_db: AdminDB = Depends(get_admin_db),
) -> str:
"""Validate admin token from header."""
if not x_admin_token:
raise HTTPException(
status_code=401,
detail="Admin token required. Provide X-Admin-Token header.",
)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(
status_code=401,
detail="Invalid or expired admin token.",
)
# Update last used timestamp
admin_db.update_admin_token_usage(x_admin_token)
return x_admin_token
# Type alias for dependency injection
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]

View File

@@ -0,0 +1,153 @@
"""
Auto-Label Scheduler
Background scheduler for processing documents pending auto-labeling.
"""
import logging
import threading
from pathlib import Path
from inference.data.admin_db import AdminDB
from inference.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
)
logger = logging.getLogger(__name__)
class AutoLabelScheduler:
"""Scheduler for auto-labeling tasks."""
def __init__(
self,
check_interval_seconds: int = 10,
batch_size: int = 5,
output_dir: Path | None = None,
):
"""
Initialize auto-label scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
batch_size: Number of documents to process per batch
output_dir: Output directory for temporary files
"""
self._check_interval = check_interval_seconds
self._batch_size = batch_size
self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("AutoLabel scheduler stopped")
@property
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._process_pending_documents()
except Exception as e:
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _process_pending_documents(self) -> None:
"""Check and process pending auto-label documents."""
try:
documents = get_pending_autolabel_documents(
self._db, limit=self._batch_size
)
if not documents:
return
logger.info(f"Processing {len(documents)} pending autolabel documents")
for doc in documents:
if self._stop_event.is_set():
break
try:
result = process_document_autolabel(
document=doc,
db=self._db,
output_dir=self._output_dir,
)
if result.get("success"):
logger.info(
f"AutoLabel completed for document {doc.document_id}"
)
else:
logger.warning(
f"AutoLabel failed for document {doc.document_id}: "
f"{result.get('error', 'Unknown error')}"
)
except Exception as e:
logger.error(
f"Error processing document {doc.document_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
# Global scheduler instance
_autolabel_scheduler: AutoLabelScheduler | None = None
def get_autolabel_scheduler() -> AutoLabelScheduler:
"""Get the auto-label scheduler instance."""
global _autolabel_scheduler
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
return _autolabel_scheduler
def start_autolabel_scheduler() -> None:
"""Start the global auto-label scheduler."""
scheduler = get_autolabel_scheduler()
scheduler.start()
def stop_autolabel_scheduler() -> None:
"""Stop the global auto-label scheduler."""
global _autolabel_scheduler
if _autolabel_scheduler:
_autolabel_scheduler.stop()
_autolabel_scheduler = None

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,340 @@
"""
Admin Training Scheduler
Background scheduler for training tasks using APScheduler.
"""
import logging
import threading
from datetime import datetime
from pathlib import Path
from typing import Any
from inference.data.admin_db import AdminDB
logger = logging.getLogger(__name__)
class TrainingScheduler:
"""Scheduler for training tasks."""
def __init__(
self,
check_interval_seconds: int = 60,
):
"""
Initialize training scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
"""
self._check_interval = check_interval_seconds
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("Training scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("Training scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._check_pending_tasks()
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _check_pending_tasks(self) -> None:
"""Check and execute pending training tasks."""
try:
tasks = self._db.get_pending_training_tasks()
for task in tasks:
task_id = str(task.task_id)
# Check if scheduled time has passed
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
continue
logger.info(f"Starting training task: {task_id}")
try:
dataset_id = getattr(task, "dataset_id", None)
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.update_training_task_status(
task_id=task_id,
status="failed",
error_message=str(e),
)
except Exception as e:
logger.error(f"Error checking pending tasks: {e}")
def _execute_task(
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
) -> None:
"""Execute a training task."""
# Update status to running
self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started")
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
learning_rate = config.get("learning_rate", 0.01)
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
)
else:
export_result = self._export_training_data(task_id)
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project_name=project_name,
)
# Update task with results
self._db.update_training_task_status(
task_id=task_id,
status="completed",
result_metrics=result.get("metrics"),
model_path=result.get("model_path"),
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
import shutil
from inference.data.admin_models import FIELD_CLASSES
# Get all labeled documents
documents = self._db.get_labeled_documents_for_export()
if not documents:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory
export_dir = Path("data/training") / task_id
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
# 80/20 train/val split
total_docs = len(documents)
train_count = int(total_docs * 0.8)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = self._db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
# Copy image
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
# Write YOLO label
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
# Create data.yaml
yaml_path = export_dir / "data.yaml"
yaml_content = f"""path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
yaml_path.write_text(yaml_content)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _run_yolo_training(
self,
task_id: str,
model_name: str,
data_yaml: str,
epochs: int,
batch_size: int,
image_size: int,
learning_rate: float,
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
# Global scheduler instance
_scheduler: TrainingScheduler | None = None
def get_training_scheduler() -> TrainingScheduler:
"""Get the training scheduler instance."""
global _scheduler
if _scheduler is None:
_scheduler = TrainingScheduler()
return _scheduler
def start_scheduler() -> None:
"""Start the global training scheduler."""
scheduler = get_training_scheduler()
scheduler.start()
def stop_scheduler() -> None:
"""Stop the global training scheduler."""
global _scheduler
if _scheduler:
_scheduler.stop()
_scheduler = None

View File

@@ -0,0 +1,133 @@
"""
FastAPI Dependencies
Dependency injection for the async API endpoints.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException, Request
from inference.data.async_request_db import AsyncRequestDB
from inference.web.rate_limiter import RateLimiter
logger = logging.getLogger(__name__)
# Global instances (initialized in app startup)
_async_db: AsyncRequestDB | None = None
_rate_limiter: RateLimiter | None = None
def init_dependencies(db: AsyncRequestDB, rate_limiter: RateLimiter) -> None:
"""Initialize global dependency instances."""
global _async_db, _rate_limiter
_async_db = db
_rate_limiter = rate_limiter
def get_async_db() -> AsyncRequestDB:
"""Get async request database instance."""
if _async_db is None:
raise RuntimeError("AsyncRequestDB not initialized")
return _async_db
def get_rate_limiter() -> RateLimiter:
"""Get rate limiter instance."""
if _rate_limiter is None:
raise RuntimeError("RateLimiter not initialized")
return _rate_limiter
async def verify_api_key(
x_api_key: Annotated[str | None, Header()] = None,
) -> str:
"""
Verify API key exists and is active.
Raises:
HTTPException: 401 if API key is missing or invalid
"""
if not x_api_key:
raise HTTPException(
status_code=401,
detail="X-API-Key header is required",
headers={"WWW-Authenticate": "API-Key"},
)
db = get_async_db()
if not db.is_valid_api_key(x_api_key):
raise HTTPException(
status_code=401,
detail="Invalid or inactive API key",
headers={"WWW-Authenticate": "API-Key"},
)
# Update usage tracking
try:
db.update_api_key_usage(x_api_key)
except Exception as e:
logger.warning(f"Failed to update API key usage: {e}")
return x_api_key
async def check_submit_rate_limit(
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check rate limit before processing submit request.
Raises:
HTTPException: 429 if rate limit exceeded
"""
rate_limiter = get_rate_limiter()
status = rate_limiter.check_submit_limit(api_key)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Rate limit exceeded",
headers=headers,
)
return api_key
async def check_poll_rate_limit(
request: Request,
api_key: Annotated[str, Depends(verify_api_key)],
) -> str:
"""
Check poll rate limit to prevent abuse.
Raises:
HTTPException: 429 if polling too frequently
"""
# Extract request_id from path parameters
request_id = request.path_params.get("request_id")
if not request_id:
return api_key # No request_id, skip poll limit check
rate_limiter = get_rate_limiter()
status = rate_limiter.check_poll_limit(api_key, request_id)
if not status.allowed:
headers = rate_limiter.get_rate_limit_headers(status)
raise HTTPException(
status_code=429,
detail=status.reason or "Polling too frequently",
headers=headers,
)
return api_key
# Type aliases for cleaner route signatures
ApiKeyDep = Annotated[str, Depends(verify_api_key)]
SubmitRateLimitDep = Annotated[str, Depends(check_submit_rate_limit)]
PollRateLimitDep = Annotated[str, Depends(check_poll_rate_limit)]
AsyncDBDep = Annotated[AsyncRequestDB, Depends(get_async_db)]
RateLimiterDep = Annotated[RateLimiter, Depends(get_rate_limiter)]

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

View File

@@ -0,0 +1,11 @@
"""
API Schemas
Pydantic models for request/response validation.
"""
# Import everything from sub-modules for backward compatibility
from inference.web.schemas.common import * # noqa: F401, F403
from inference.web.schemas.admin import * # noqa: F401, F403
from inference.web.schemas.inference import * # noqa: F401, F403
from inference.web.schemas.labeling import * # noqa: F401, F403

View File

@@ -0,0 +1,17 @@
"""
Admin API Request/Response Schemas
Pydantic models for admin API validation and serialization.
"""
from .enums import * # noqa: F401, F403
from .auth import * # noqa: F401, F403
from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse
DocumentDetailResponse.model_rebuild()

View File

@@ -0,0 +1,152 @@
"""Admin Annotation Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .enums import AnnotationSource
class BoundingBox(BaseModel):
"""Bounding box coordinates."""
x: int = Field(..., ge=0, description="X coordinate (pixels)")
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
width: int = Field(..., ge=1, description="Width (pixels)")
height: int = Field(..., ge=1, description="Height (pixels)")
class AnnotationCreate(BaseModel):
"""Request to create an annotation."""
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
text_value: str | None = Field(None, description="Text value (optional)")
class AnnotationUpdate(BaseModel):
"""Request to update an annotation."""
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
bbox: BoundingBox | None = Field(None, description="New bounding box")
text_value: str | None = Field(None, description="New text value")
class AnnotationItem(BaseModel):
"""Single annotation item."""
annotation_id: str = Field(..., description="Annotation UUID")
page_number: int = Field(..., ge=1, description="Page number")
class_id: int = Field(..., ge=0, le=9, description="Class ID")
class_name: str = Field(..., description="Class name")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
normalized_bbox: dict[str, float] = Field(
..., description="Normalized bbox (x_center, y_center, width, height)"
)
text_value: str | None = Field(None, description="Text value")
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
source: AnnotationSource = Field(..., description="Annotation source")
created_at: datetime = Field(..., description="Creation timestamp")
class AnnotationResponse(BaseModel):
"""Response for annotation operation."""
annotation_id: str = Field(..., description="Annotation UUID")
message: str = Field(..., description="Status message")
class AnnotationListResponse(BaseModel):
"""Response for annotation list."""
document_id: str = Field(..., description="Document UUID")
page_count: int = Field(..., ge=1, description="Total pages")
total_annotations: int = Field(..., ge=0, description="Total annotations")
annotations: list[AnnotationItem] = Field(
default_factory=list, description="Annotation list"
)
class AnnotationLockRequest(BaseModel):
"""Request to acquire annotation lock."""
duration_seconds: int = Field(
default=300,
ge=60,
le=3600,
description="Lock duration in seconds (60-3600)",
)
class AnnotationLockResponse(BaseModel):
"""Response for annotation lock operation."""
document_id: str = Field(..., description="Document UUID")
locked: bool = Field(..., description="Whether lock was acquired/released")
lock_expires_at: datetime | None = Field(
None, description="Lock expiration time"
)
message: str = Field(..., description="Status message")
class AutoLabelRequest(BaseModel):
"""Request to trigger auto-labeling."""
field_values: dict[str, str] = Field(
...,
description="Field values to match (e.g., {'invoice_number': '12345'})",
)
replace_existing: bool = Field(
default=False, description="Replace existing auto annotations"
)
class AutoLabelResponse(BaseModel):
"""Response for auto-labeling."""
document_id: str = Field(..., description="Document UUID")
status: str = Field(..., description="Auto-labeling status")
annotations_created: int = Field(
default=0, ge=0, description="Number of annotations created"
)
message: str = Field(..., description="Status message")
class AnnotationVerifyRequest(BaseModel):
"""Request to verify an annotation."""
pass # No body needed, just POST to verify
class AnnotationVerifyResponse(BaseModel):
"""Response for annotation verification."""
annotation_id: str = Field(..., description="Annotation UUID")
is_verified: bool = Field(..., description="Verification status")
verified_at: datetime = Field(..., description="Verification timestamp")
verified_by: str = Field(..., description="Admin token who verified")
message: str = Field(..., description="Status message")
class AnnotationOverrideRequest(BaseModel):
"""Request to override an annotation."""
bbox: dict[str, int] | None = Field(
None, description="Updated bounding box {x, y, width, height}"
)
text_value: str | None = Field(None, description="Updated text value")
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
class_name: str | None = Field(None, description="Updated class name")
reason: str | None = Field(None, description="Reason for override")
class AnnotationOverrideResponse(BaseModel):
"""Response for annotation override."""
annotation_id: str = Field(..., description="Annotation UUID")
source: str = Field(..., description="New source (manual)")
override_source: str | None = Field(None, description="Original source (auto)")
original_annotation_id: str | None = Field(None, description="Original annotation ID")
message: str = Field(..., description="Status message")
history_id: str = Field(..., description="History record UUID")

View File

@@ -0,0 +1,23 @@
"""Admin Auth Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class AdminTokenCreate(BaseModel):
"""Request to create an admin token."""
name: str = Field(..., min_length=1, max_length=255, description="Token name")
expires_in_days: int | None = Field(
None, ge=1, le=365, description="Token expiration in days (optional)"
)
class AdminTokenResponse(BaseModel):
"""Response with created admin token."""
token: str = Field(..., description="Admin token")
name: str = Field(..., description="Token name")
expires_at: datetime | None = Field(None, description="Token expiration time")
message: str = Field(..., description="Status message")

View File

@@ -0,0 +1,85 @@
"""Admin Dataset Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .training import TrainingConfig
class DatasetCreateRequest(BaseModel):
"""Request to create a training dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split")
class DatasetDocumentItem(BaseModel):
"""Document within a dataset."""
document_id: str
split: str
page_count: int
annotation_count: int
class DatasetResponse(BaseModel):
"""Response after creating a dataset."""
dataset_id: str
name: str
status: str
message: str
class DatasetDetailResponse(BaseModel):
"""Detailed dataset info with documents."""
dataset_id: str
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
documents: list[DatasetDocumentItem]
created_at: datetime
updated_at: datetime
class DatasetListItem(BaseModel):
"""Dataset in list view."""
dataset_id: str
name: str
description: str | None
status: str
total_documents: int
total_images: int
total_annotations: int
created_at: datetime
class DatasetListResponse(BaseModel):
"""Paginated dataset list."""
total: int
limit: int
offset: int
datasets: list[DatasetListItem]
class DatasetTrainRequest(BaseModel):
"""Request to start training from a dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
config: TrainingConfig = Field(..., description="Training configuration")

View File

@@ -0,0 +1,103 @@
"""Admin Document Schemas."""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from .enums import AutoLabelStatus, DocumentStatus
if TYPE_CHECKING:
from .annotations import AnnotationItem
from .training import TrainingHistoryItem
class DocumentUploadResponse(BaseModel):
"""Response for document upload."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
message: str = Field(..., description="Status message")
class DocumentItem(BaseModel):
"""Single document in list."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentListResponse(BaseModel):
"""Response for document list."""
total: int = Field(..., ge=0, description="Total documents")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
documents: list[DocumentItem] = Field(
default_factory=list, description="Document list"
)
class DocumentDetailResponse(BaseModel):
"""Response for document detail."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
content_type: str = Field(..., description="MIME type")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
annotation_lock_until: datetime | None = Field(
None, description="Lock expiration time if document is locked"
)
annotations: list["AnnotationItem"] = Field(
default_factory=list, description="Document annotations"
)
image_urls: list[str] = Field(
default_factory=list, description="URLs to page images"
)
training_history: list["TrainingHistoryItem"] = Field(
default_factory=list, description="Training tasks that used this document"
)
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentStatsResponse(BaseModel):
"""Document statistics response."""
total: int = Field(..., ge=0, description="Total documents")
pending: int = Field(default=0, ge=0, description="Pending documents")
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents")

View File

@@ -0,0 +1,46 @@
"""Admin API Enums."""
from enum import Enum
class DocumentStatus(str, Enum):
"""Document status enum."""
PENDING = "pending"
AUTO_LABELING = "auto_labeling"
LABELED = "labeled"
EXPORTED = "exported"
class AutoLabelStatus(str, Enum):
"""Auto-labeling status enum."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class TrainingStatus(str, Enum):
"""Training task status enum."""
PENDING = "pending"
SCHEDULED = "scheduled"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingType(str, Enum):
"""Training task type enum."""
TRAIN = "train"
FINETUNE = "finetune"
class AnnotationSource(str, Enum):
"""Annotation source enum."""
MANUAL = "manual"
AUTO = "auto"
IMPORTED = "imported"

View File

@@ -0,0 +1,202 @@
"""Admin Training Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
project_name: str = Field(
default="invoice_fields", description="Training project name"
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""
name: str = Field(..., min_length=1, max_length=255, description="Task name")
description: str | None = Field(None, max_length=1000, description="Description")
task_type: TrainingType = Field(
default=TrainingType.TRAIN, description="Task type"
)
config: TrainingConfig = Field(
default_factory=TrainingConfig, description="Training configuration"
)
scheduled_at: datetime | None = Field(
None, description="Scheduled execution time"
)
cron_expression: str | None = Field(
None, max_length=50, description="Cron expression for recurring tasks"
)
class TrainingTaskItem(BaseModel):
"""Single training task in list."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskListResponse(BaseModel):
"""Response for training task list."""
total: int = Field(..., ge=0, description="Total tasks")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
class TrainingTaskDetailResponse(BaseModel):
"""Response for training task detail."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
description: str | None = Field(None, description="Description")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
config: dict[str, Any] | None = Field(None, description="Training configuration")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
cron_expression: str | None = Field(None, description="Cron expression")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
error_message: str | None = Field(None, description="Error message")
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
model_path: str | None = Field(None, description="Trained model path")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskResponse(BaseModel):
"""Response for training task operation."""
task_id: str = Field(..., description="Task UUID")
status: TrainingStatus = Field(..., description="Task status")
message: str = Field(..., description="Status message")
class TrainingLogItem(BaseModel):
"""Single training log entry."""
level: str = Field(..., description="Log level")
message: str = Field(..., description="Log message")
details: dict[str, Any] | None = Field(None, description="Additional details")
created_at: datetime = Field(..., description="Timestamp")
class TrainingLogsResponse(BaseModel):
"""Response for training logs."""
task_id: str = Field(..., description="Task UUID")
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
class ExportRequest(BaseModel):
"""Request to export annotations."""
format: str = Field(
default="yolo", description="Export format (yolo, coco, voc)"
)
include_images: bool = Field(
default=True, description="Include images in export"
)
split_ratio: float = Field(
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
)
class ExportResponse(BaseModel):
"""Response for export operation."""
status: str = Field(..., description="Export status")
export_path: str = Field(..., description="Path to exported dataset")
total_images: int = Field(..., ge=0, description="Total images exported")
total_annotations: int = Field(..., ge=0, description="Total annotations")
train_count: int = Field(..., ge=0, description="Training set count")
val_count: int = Field(..., ge=0, description="Validation set count")
message: str = Field(..., description="Status message")
class TrainingDocumentItem(BaseModel):
"""Document item for training page."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Filename")
annotation_count: int = Field(..., ge=0, description="Total annotations")
annotation_sources: dict[str, int] = Field(
..., description="Annotation counts by source (manual, auto)"
)
used_in_training: list[str] = Field(
default_factory=list, description="List of training task IDs that used this document"
)
last_modified: datetime = Field(..., description="Last modification time")
class TrainingDocumentsResponse(BaseModel):
"""Response for GET /admin/training/documents."""
total: int = Field(..., ge=0, description="Total document count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
documents: list[TrainingDocumentItem] = Field(
default_factory=list, description="Documents available for training"
)
class ModelMetrics(BaseModel):
"""Training model metrics."""
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
class TrainingModelItem(BaseModel):
"""Trained model item for model list."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Model name")
status: TrainingStatus = Field(..., description="Training status")
document_count: int = Field(..., ge=0, description="Documents used in training")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")
metrics: ModelMetrics = Field(..., description="Model metrics")
model_path: str | None = Field(None, description="Path to model weights")
download_url: str | None = Field(None, description="Download URL for model")
class TrainingModelsResponse(BaseModel):
"""Response for GET /admin/training/models."""
total: int = Field(..., ge=0, description="Total model count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
models: list[TrainingModelItem] = Field(
default_factory=list, description="Trained models"
)
class TrainingHistoryItem(BaseModel):
"""Training history for a document."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
trained_at: datetime = Field(..., description="Training timestamp")
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")

View File

@@ -0,0 +1,15 @@
"""
Common Schemas
Shared Pydantic models used across multiple API modules.
"""
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")

View File

@@ -0,0 +1,196 @@
"""
API Request/Response Schemas
Pydantic models for API validation and serialization.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
# =============================================================================
# Enums
# =============================================================================
class AsyncStatus(str, Enum):
"""Async request status enum."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
# =============================================================================
# Sync API Schemas (existing)
# =============================================================================
class DetectionResult(BaseModel):
"""Single detection result."""
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
class ExtractedField(BaseModel):
"""Extracted and normalized field value."""
field_name: str = Field(..., description="Field name")
value: str | None = Field(None, description="Extracted value")
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
is_valid: bool = Field(True, description="Whether the value passed validation")
class InferenceResult(BaseModel):
"""Complete inference result for a document."""
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
document_type: str = Field(
default="invoice", description="Document type: 'invoice' or 'letter'"
)
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)
confidence: dict[str, float] = Field(
default_factory=dict, description="Confidence scores per field"
)
detections: list[DetectionResult] = Field(
default_factory=list, description="Raw YOLO detections"
)
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
errors: list[str] = Field(default_factory=list, description="Error messages")
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
status: str = Field(..., description="Response status: success or error")
message: str = Field(..., description="Response message")
result: InferenceResult | None = Field(None, description="Inference result")
class BatchInferenceResponse(BaseModel):
"""API response for batch inference endpoint."""
status: str = Field(..., description="Response status")
message: str = Field(..., description="Response message")
total: int = Field(..., description="Total documents processed")
successful: int = Field(..., description="Number of successful extractions")
results: list[InferenceResult] = Field(
default_factory=list, description="Individual results"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
gpu_available: bool = Field(..., description="Whether GPU is available")
version: str = Field(..., description="API version")
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")
# =============================================================================
# Async API Schemas
# =============================================================================
class AsyncSubmitResponse(BaseModel):
"""Response for async submit endpoint."""
status: str = Field(default="accepted", description="Response status")
message: str = Field(..., description="Response message")
request_id: str = Field(..., description="Unique request identifier (UUID)")
estimated_wait_seconds: int = Field(
..., ge=0, description="Estimated wait time in seconds"
)
poll_url: str = Field(..., description="URL to poll for status updates")
class AsyncStatusResponse(BaseModel):
"""Response for async status endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
created_at: datetime = Field(..., description="Request creation timestamp")
started_at: datetime | None = Field(
None, description="Processing start timestamp"
)
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
position_in_queue: int | None = Field(
None, description="Position in queue (for pending status)"
)
error_message: str | None = Field(
None, description="Error message (for failed status)"
)
result_url: str | None = Field(
None, description="URL to fetch results (for completed status)"
)
class AsyncResultResponse(BaseModel):
"""Response for async result endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Processing status")
processing_time_ms: float = Field(
..., ge=0, description="Total processing time in milliseconds"
)
result: InferenceResult | None = Field(
None, description="Extraction result (when completed)"
)
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
class AsyncRequestItem(BaseModel):
"""Single item in async requests list."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
created_at: datetime = Field(..., description="Request creation timestamp")
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
class AsyncRequestsListResponse(BaseModel):
"""Response for async requests list endpoint."""
total: int = Field(..., ge=0, description="Total number of requests")
limit: int = Field(..., ge=1, description="Maximum items per page")
offset: int = Field(..., ge=0, description="Current offset")
requests: list[AsyncRequestItem] = Field(
default_factory=list, description="List of requests"
)
class RateLimitInfo(BaseModel):
"""Rate limit information (included in headers)."""
limit: int = Field(..., description="Maximum requests per minute")
remaining: int = Field(..., description="Remaining requests in current window")
reset_at: datetime = Field(..., description="Time when limit resets")

View File

@@ -0,0 +1,13 @@
"""
Labeling API Schemas
Pydantic models for pre-labeling and label validation endpoints.
"""
from pydantic import BaseModel, Field
class PreLabelResponse(BaseModel):
"""API response for pre-label endpoint."""
document_id: str = Field(..., description="Document identifier for retrieving results")

View File

@@ -0,0 +1,18 @@
"""
Business Logic Services
Service layer for processing requests and orchestrating data operations.
"""
from inference.web.services.autolabel import AutoLabelService, get_auto_label_service
from inference.web.services.inference import InferenceService
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.services.batch_upload import BatchUploadService
__all__ = [
"AutoLabelService",
"get_auto_label_service",
"InferenceService",
"AsyncProcessingService",
"BatchUploadService",
]

View File

@@ -0,0 +1,383 @@
"""
Async Processing Service
Manages async request lifecycle and background processing.
"""
import logging
import shutil
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from threading import Event, Thread
from typing import TYPE_CHECKING
from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.core.rate_limiter import RateLimiter
if TYPE_CHECKING:
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.services.inference import InferenceService
logger = logging.getLogger(__name__)
@dataclass
class AsyncSubmitResult:
"""Result from async submit operation."""
success: bool
request_id: str | None = None
estimated_wait_seconds: int = 0
error: str | None = None
class AsyncProcessingService:
"""
Manages async request lifecycle and processing.
Coordinates between:
- HTTP endpoints (submit/status/result)
- Background task queue
- Database storage
- Rate limiting
"""
def __init__(
self,
inference_service: "InferenceService",
db: AsyncRequestDB,
queue: AsyncTaskQueue,
rate_limiter: RateLimiter,
async_config: "AsyncConfig",
storage_config: "StorageConfig",
) -> None:
self._inference = inference_service
self._db = db
self._queue = queue
self._rate_limiter = rate_limiter
self._async_config = async_config
self._storage_config = storage_config
# Cleanup thread
self._cleanup_stop_event = Event()
self._cleanup_thread: Thread | None = None
def start(self) -> None:
"""Start the async processing service."""
# Start the task queue with our handler
self._queue.start(self._process_task)
# Start cleanup thread
self._cleanup_stop_event.clear()
self._cleanup_thread = Thread(
target=self._cleanup_loop,
name="async-cleanup",
daemon=True,
)
self._cleanup_thread.start()
logger.info("AsyncProcessingService started")
def stop(self, timeout: float = 30.0) -> None:
"""Stop the async processing service."""
# Stop cleanup thread
self._cleanup_stop_event.set()
if self._cleanup_thread and self._cleanup_thread.is_alive():
self._cleanup_thread.join(timeout=5.0)
# Stop task queue
self._queue.stop(timeout=timeout)
logger.info("AsyncProcessingService stopped")
def submit_request(
self,
api_key: str,
file_content: bytes,
filename: str,
content_type: str,
) -> AsyncSubmitResult:
"""
Submit a new async processing request.
Args:
api_key: API key for the request
file_content: File content as bytes
filename: Original filename
content_type: File content type
Returns:
AsyncSubmitResult with request_id and status
"""
# Generate request ID
request_id = str(uuid.uuid4())
# Save file to temp storage
file_path = self._save_upload(request_id, filename, file_content)
file_size = len(file_content)
try:
# Calculate expiration
expires_at = datetime.utcnow() + timedelta(
days=self._async_config.result_retention_days
)
# Create database record
self._db.create_request(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
request_id=request_id,
)
# Record rate limit event
self._rate_limiter.record_request(api_key)
# Create and queue task
task = AsyncTask(
request_id=request_id,
api_key=api_key,
file_path=file_path,
filename=filename,
created_at=datetime.utcnow(),
)
if not self._queue.submit(task):
# Queue is full
self._db.update_status(
request_id,
"failed",
error_message="Processing queue is full",
)
# Cleanup file
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
request_id=request_id,
error="Processing queue is full. Please try again later.",
)
# Estimate wait time
estimated_wait = self._estimate_wait()
return AsyncSubmitResult(
success=True,
request_id=request_id,
estimated_wait_seconds=estimated_wait,
)
except Exception as e:
logger.error(f"Failed to submit request: {e}", exc_info=True)
# Cleanup file on error
file_path.unlink(missing_ok=True)
return AsyncSubmitResult(
success=False,
# Return generic error message to avoid leaking implementation details
error="Failed to process request. Please try again later.",
)
# Allowed file extensions whitelist
ALLOWED_EXTENSIONS = frozenset({".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"})
def _save_upload(
self,
request_id: str,
filename: str,
content: bytes,
) -> Path:
"""Save uploaded file to temp storage."""
import re
# Extract extension from filename
ext = Path(filename).suffix.lower()
# Validate extension: must be alphanumeric only (e.g., .pdf, .png)
if not ext or not re.match(r'^\.[a-z0-9]+$', ext):
ext = ".pdf"
# Validate against whitelist
if ext not in self.ALLOWED_EXTENSIONS:
ext = ".pdf"
# Create async upload directory
upload_dir = self._async_config.temp_upload_dir
upload_dir.mkdir(parents=True, exist_ok=True)
# Build file path - request_id is a UUID so it's safe
file_path = upload_dir / f"{request_id}{ext}"
# Defense in depth: ensure path is within upload_dir
if not file_path.resolve().is_relative_to(upload_dir.resolve()):
raise ValueError("Invalid file path detected")
file_path.write_bytes(content)
return file_path
def _process_task(self, task: AsyncTask) -> None:
"""
Process a single task (called by worker thread).
This method is called by the AsyncTaskQueue worker threads.
"""
start_time = time.time()
try:
# Update status to processing
self._db.update_status(task.request_id, "processing")
# Ensure file exists
if not task.file_path.exists():
raise FileNotFoundError(f"Upload file not found: {task.file_path}")
# Run inference based on file type
file_ext = task.file_path.suffix.lower()
if file_ext == ".pdf":
result = self._inference.process_pdf(
task.file_path,
document_id=task.request_id[:8],
)
else:
result = self._inference.process_image(
task.file_path,
document_id=task.request_id[:8],
)
# Calculate processing time
processing_time_ms = (time.time() - start_time) * 1000
# Prepare result for storage
result_data = {
"document_id": result.document_id,
"success": result.success,
"document_type": result.document_type,
"fields": result.fields,
"confidence": result.confidence,
"detections": result.detections,
"errors": result.errors,
}
# Get visualization path as string
viz_path = None
if result.visualization_path:
viz_path = str(result.visualization_path.name)
# Store result in database
self._db.complete_request(
request_id=task.request_id,
document_id=result.document_id,
result=result_data,
processing_time_ms=processing_time_ms,
visualization_path=viz_path,
)
logger.info(
f"Task {task.request_id} completed successfully "
f"in {processing_time_ms:.0f}ms"
)
except Exception as e:
logger.error(
f"Task {task.request_id} failed: {e}",
exc_info=True,
)
self._db.update_status(
task.request_id,
"failed",
error_message=str(e),
increment_retry=True,
)
finally:
# Cleanup uploaded file
if task.file_path.exists():
task.file_path.unlink(missing_ok=True)
def _estimate_wait(self) -> int:
"""Estimate wait time based on queue depth."""
queue_depth = self._queue.get_queue_depth()
processing_count = self._queue.get_processing_count()
total_pending = queue_depth + processing_count
# Estimate ~5 seconds per document
avg_processing_time = 5
return total_pending * avg_processing_time
def _cleanup_loop(self) -> None:
"""Background cleanup loop."""
logger.info("Cleanup thread started")
cleanup_interval = self._async_config.cleanup_interval_hours * 3600
while not self._cleanup_stop_event.wait(timeout=cleanup_interval):
try:
self._run_cleanup()
except Exception as e:
logger.error(f"Cleanup failed: {e}", exc_info=True)
logger.info("Cleanup thread stopped")
def _run_cleanup(self) -> None:
"""Run cleanup operations."""
logger.info("Running cleanup...")
# Delete expired requests
deleted_requests = self._db.delete_expired_requests()
# Reset stale processing requests
reset_count = self._db.reset_stale_processing_requests(
stale_minutes=self._async_config.task_timeout_seconds // 60,
max_retries=3,
)
# Cleanup old rate limit events
deleted_events = self._db.cleanup_old_rate_limit_events(hours=1)
# Cleanup old poll timestamps
cleaned_polls = self._rate_limiter.cleanup_poll_timestamps()
# Cleanup rate limiter request windows
self._rate_limiter.cleanup_request_windows()
# Cleanup orphaned upload files
orphan_count = self._cleanup_orphan_files()
logger.info(
f"Cleanup complete: {deleted_requests} expired requests, "
f"{reset_count} stale requests reset, "
f"{deleted_events} rate limit events, "
f"{cleaned_polls} poll timestamps, "
f"{orphan_count} orphan files"
)
def _cleanup_orphan_files(self) -> int:
"""Clean up upload files that don't have matching requests."""
upload_dir = self._async_config.temp_upload_dir
if not upload_dir.exists():
return 0
count = 0
# Files older than 1 hour without matching request are considered orphans
cutoff = time.time() - 3600
for file_path in upload_dir.iterdir():
if not file_path.is_file():
continue
# Check if file is old enough
if file_path.stat().st_mtime > cutoff:
continue
# Extract request_id from filename
request_id = file_path.stem
# Check if request exists in database
request = self._db.get_request(request_id)
if request is None:
file_path.unlink(missing_ok=True)
count += 1
return count

View File

@@ -0,0 +1,335 @@
"""
Admin Auto-Labeling Service
Uses FieldMatcher to automatically create annotations from field values.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken
logger = logging.getLogger(__name__)
class AutoLabelService:
"""Service for automatic document labeling using field matching."""
def __init__(self, ocr_engine: OCREngine | None = None):
"""
Initialize auto-label service.
Args:
ocr_engine: OCR engine instance (creates one if not provided)
"""
self._ocr_engine = ocr_engine
self._field_matcher = FieldMatcher()
@property
def ocr_engine(self) -> OCREngine:
"""Lazy initialization of OCR engine."""
if self._ocr_engine is None:
self._ocr_engine = OCREngine(lang="en")
return self._ocr_engine
def auto_label_document(
self,
document_id: str,
file_path: str,
field_values: dict[str, str],
db: AdminDB,
replace_existing: bool = False,
skip_lock_check: bool = False,
) -> dict[str, Any]:
"""
Auto-label a document using field matching.
Args:
document_id: Document UUID
file_path: Path to document file
field_values: Dict of field_name -> value to match
db: Admin database instance
replace_existing: Whether to replace existing auto annotations
skip_lock_check: Skip annotation lock check (for batch processing)
Returns:
Dict with status and annotation count
"""
try:
# Get document info first
document = db.get_document(document_id)
if document is None:
raise ValueError(f"Document not found: {document_id}")
# Check annotation lock unless explicitly skipped
if not skip_lock_check:
from datetime import datetime, timezone
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
if document.annotation_lock_until > datetime.now(timezone.utc):
raise ValueError(
f"Document is locked for annotation until {document.annotation_lock_until}. "
"Auto-labeling skipped."
)
# Update status to running
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
# Delete existing auto annotations if requested
if replace_existing:
deleted = db.delete_annotations_for_document(
document_id=document_id,
source="auto",
)
logger.info(f"Deleted {deleted} existing auto annotations")
# Process document
path = Path(file_path)
annotations_created = 0
if path.suffix.lower() == ".pdf":
# Process PDF (all pages)
annotations_created = self._process_pdf(
document_id, path, field_values, db
)
else:
# Process single image
annotations_created = self._process_image(
document_id, path, field_values, db, page_number=1
)
# Update document status
status = "labeled" if annotations_created > 0 else "pending"
db.update_document_status(
document_id=document_id,
status=status,
auto_label_status="completed",
)
return {
"status": "completed",
"annotations_created": annotations_created,
}
except Exception as e:
logger.error(f"Auto-labeling failed for {document_id}: {e}")
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"status": "failed",
"error": str(e),
"annotations_created": 0,
}
def _process_pdf(
self,
document_id: str,
pdf_path: Path,
field_values: dict[str, str],
db: AdminDB,
) -> int:
"""Process PDF document and create annotations."""
from shared.pdf.renderer import render_pdf_to_images
import io
total_annotations = 0
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=page_no,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_no + 1, # 1-indexed
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
total_annotations += len(annotations)
return total_annotations
def _process_image(
self,
document_id: str,
image_path: Path,
field_values: dict[str, str],
db: AdminDB,
page_number: int = 1,
) -> int:
"""Process single image and create annotations."""
# Load image
image = Image.open(image_path)
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=0,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_number,
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
return len(annotations)
def _find_annotations(
self,
document_id: str,
tokens: list[OCRToken],
field_values: dict[str, str],
page_number: int,
image_width: int,
image_height: int,
) -> list[dict[str, Any]]:
"""Find annotations for field values using token matching."""
from shared.normalize import normalize_field
annotations = []
for field_name, value in field_values.items():
if not value or not value.strip():
continue
# Map field name to class ID
class_id = self._get_class_id(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {field_name}")
continue
class_name = FIELD_CLASSES[class_id]
# Normalize value
try:
normalized_values = normalize_field(field_name, value)
except Exception as e:
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
normalized_values = [value]
# Find matches
matches = self._field_matcher.find_matches(
tokens=tokens,
field_name=field_name,
normalized_values=normalized_values,
page_no=page_number - 1, # 0-indexed for matcher
)
# Take best match
if matches:
best_match = matches[0]
bbox = best_match.bbox # (x0, y0, x1, y1)
# Calculate normalized coordinates (YOLO format)
x_center = (bbox[0] + bbox[2]) / 2 / image_width
y_center = (bbox[1] + bbox[3]) / 2 / image_height
width = (bbox[2] - bbox[0]) / image_width
height = (bbox[3] - bbox[1]) / image_height
# Pixel coordinates
bbox_x = int(bbox[0])
bbox_y = int(bbox[1])
bbox_width = int(bbox[2] - bbox[0])
bbox_height = int(bbox[3] - bbox[1])
annotations.append({
"document_id": document_id,
"page_number": page_number,
"class_id": class_id,
"class_name": class_name,
"x_center": x_center,
"y_center": y_center,
"width": width,
"height": height,
"bbox_x": bbox_x,
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_value,
"confidence": best_match.score,
"source": "auto",
})
return annotations
def _get_class_id(self, field_name: str) -> int | None:
"""Map field name to class ID."""
# Direct match
if field_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[field_name]
# Handle alternative names
name_mapping = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"PaymentLine": "payment_line",
"customer_number": "customer_number",
}
mapped_name = name_mapping.get(field_name)
if mapped_name and mapped_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[mapped_name]
return None
# Global service instance
_auto_label_service: AutoLabelService | None = None
def get_auto_label_service() -> AutoLabelService:
"""Get the auto-label service instance."""
global _auto_label_service
if _auto_label_service is None:
_auto_label_service = AutoLabelService()
return _auto_label_service
def reset_auto_label_service() -> None:
"""Reset the auto-label service (for testing)."""
global _auto_label_service
_auto_label_service = None

View File

@@ -0,0 +1,548 @@
"""
Batch Upload Service
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
"""
import csv
import io
import logging
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from inference.data.admin_db import AdminDB
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)
# Security limits
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
MAX_FILES_IN_ZIP = 1000
class CSVRowData(BaseModel):
"""Validated CSV row data with security checks."""
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
invoice_number: str | None = Field(None, max_length=255)
invoice_date: str | None = Field(None, max_length=50)
invoice_due_date: str | None = Field(None, max_length=50)
amount: str | None = Field(None, max_length=100)
ocr: str | None = Field(None, max_length=100)
bankgiro: str | None = Field(None, max_length=50)
plusgiro: str | None = Field(None, max_length=50)
customer_number: str | None = Field(None, max_length=255)
supplier_organisation_number: str | None = Field(None, max_length=50)
@field_validator('*', mode='before')
@classmethod
def strip_whitespace(cls, v):
"""Strip whitespace from all string fields."""
if isinstance(v, str):
return v.strip()
return v
@field_validator('*', mode='before')
@classmethod
def reject_suspicious_patterns(cls, v):
"""Reject values with suspicious characters."""
if isinstance(v, str):
# Reject SQL/shell metacharacters and newlines
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
if any(char in v for char in dangerous_chars):
raise ValueError(f"Suspicious characters detected in value")
return v
class BatchUploadService:
"""Service for handling batch uploads of documents via ZIP files."""
def __init__(self, admin_db: AdminDB):
"""Initialize the batch upload service.
Args:
admin_db: Admin database interface
"""
self.admin_db = admin_db
def _safe_extract_filename(self, zip_path: str) -> str:
"""Safely extract filename from ZIP path, preventing path traversal.
Args:
zip_path: Path from ZIP file entry
Returns:
Safe filename
Raises:
ValueError: If path contains traversal attempts or is invalid
"""
# Reject absolute paths
if zip_path.startswith('/') or zip_path.startswith('\\'):
raise ValueError(f"Absolute path rejected: {zip_path}")
# Reject path traversal attempts
if '..' in zip_path:
raise ValueError(f"Path traversal rejected: {zip_path}")
# Reject Windows drive letters
if len(zip_path) >= 2 and zip_path[1] == ':':
raise ValueError(f"Windows path rejected: {zip_path}")
# Get only the basename
safe_name = Path(zip_path).name
if not safe_name or safe_name in ['.', '..']:
raise ValueError(f"Invalid filename: {zip_path}")
# Validate filename doesn't contain suspicious characters
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
raise ValueError(f"Invalid characters in filename: {safe_name}")
return safe_name
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
"""Validate ZIP file against Zip bomb and other attacks.
Args:
zip_file: Opened ZIP file
Raises:
ValueError: If ZIP file is unsafe
"""
total_uncompressed = 0
file_count = 0
for zip_info in zip_file.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILES_IN_ZIP:
raise ValueError(
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
)
# Check individual file size
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
raise ValueError(
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
)
# Accumulate uncompressed size
total_uncompressed += zip_info.file_size
# Check total uncompressed size (Zip bomb protection)
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
raise ValueError(
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
)
# Validate filename safety
try:
self._safe_extract_filename(zip_info.filename)
except ValueError as e:
logger.warning(f"Rejecting malicious ZIP entry: {e}")
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
def process_zip_upload(
self,
admin_token: str,
zip_filename: str,
zip_content: bytes,
upload_source: str = "ui",
) -> dict[str, Any]:
"""Process a ZIP file containing PDFs and optional CSV.
Args:
admin_token: Admin authentication token
zip_filename: Name of the ZIP file
zip_content: ZIP file content as bytes
upload_source: Upload source (ui or api)
Returns:
Dictionary with batch upload results
"""
batch = self.admin_db.create_batch_upload(
admin_token=admin_token,
filename=zip_filename,
file_size=len(zip_content),
upload_source=upload_source,
)
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
# Validate ZIP safety first
self._validate_zip_safety(zip_file)
result = self._process_zip_contents(
batch_id=batch.batch_id,
admin_token=admin_token,
zip_file=zip_file,
)
# Update batch upload status
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status=result["status"],
total_files=result["total_files"],
processed_files=result["processed_files"],
successful_files=result["successful_files"],
failed_files=result["failed_files"],
csv_filename=result.get("csv_filename"),
csv_row_count=result.get("csv_row_count"),
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
**result,
}
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="Invalid ZIP file format",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Invalid ZIP file format",
}
except ValueError as e:
# Security validation errors
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="ZIP file validation failed",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "ZIP file validation failed",
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
self.admin_db.update_batch_upload(
batch_id=batch.batch_id,
status="failed",
error_message="Processing error",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Failed to process batch upload",
}
def _process_zip_contents(
self,
batch_id: UUID,
admin_token: str,
zip_file: zipfile.ZipFile,
) -> dict[str, Any]:
"""Process contents of ZIP file.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
zip_file: Opened ZIP file
Returns:
Processing results dictionary
"""
# Extract file lists
pdf_files = []
csv_file = None
csv_data = {}
for file_info in zip_file.filelist:
if file_info.is_dir():
continue
try:
# Use safe filename extraction
filename = self._safe_extract_filename(file_info.filename)
except ValueError as e:
logger.warning(f"Skipping invalid file: {e}")
continue
if filename.lower().endswith('.pdf'):
pdf_files.append(file_info)
elif filename.lower().endswith('.csv'):
if csv_file is None:
csv_file = file_info
# Parse CSV
csv_data = self._parse_csv_file(zip_file, file_info)
else:
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
if not pdf_files:
return {
"status": "failed",
"total_files": 0,
"processed_files": 0,
"successful_files": 0,
"failed_files": 0,
"error": "No PDF files found in ZIP",
}
# Process each PDF file
total_files = len(pdf_files)
successful_files = 0
failed_files = 0
for pdf_info in pdf_files:
file_record = None
try:
# Use safe filename extraction
filename = self._safe_extract_filename(pdf_info.filename)
# Create batch upload file record
file_record = self.admin_db.create_batch_upload_file(
batch_id=batch_id,
filename=filename,
status="processing",
)
# Get CSV data for this file if available
document_id_base = Path(filename).stem
csv_row_data = csv_data.get(document_id_base)
# Extract PDF content
pdf_content = zip_file.read(pdf_info.filename)
# TODO: Save PDF file and create document
# For now, just mark as completed
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="completed",
csv_row_data=csv_row_data,
processed_at=datetime.utcnow(),
)
successful_files += 1
except ValueError as e:
# Path validation error
logger.warning(f"Skipping invalid file: {e}")
if file_record:
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="failed",
error_message="Invalid filename",
processed_at=datetime.utcnow(),
)
failed_files += 1
except Exception as e:
logger.error(f"Error processing PDF: {e}", exc_info=True)
if file_record:
self.admin_db.update_batch_upload_file(
file_id=file_record.file_id,
status="failed",
error_message="Processing error",
processed_at=datetime.utcnow(),
)
failed_files += 1
# Determine overall status
if failed_files == 0:
status = "completed"
elif successful_files == 0:
status = "failed"
else:
status = "partial"
result = {
"status": status,
"total_files": total_files,
"processed_files": total_files,
"successful_files": successful_files,
"failed_files": failed_files,
}
if csv_file:
result["csv_filename"] = Path(csv_file.filename).name
result["csv_row_count"] = len(csv_data)
return result
def _parse_csv_file(
self,
zip_file: zipfile.ZipFile,
csv_file_info: zipfile.ZipInfo,
) -> dict[str, dict[str, Any]]:
"""Parse CSV file and extract field values with validation.
Args:
zip_file: Opened ZIP file
csv_file_info: CSV file info
Returns:
Dictionary mapping DocumentId to validated field values
"""
# Try multiple encodings
csv_bytes = zip_file.read(csv_file_info.filename)
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
csv_content = None
for encoding in encodings:
try:
csv_content = csv_bytes.decode(encoding)
logger.info(f"CSV decoded with {encoding}")
break
except UnicodeDecodeError:
continue
if csv_content is None:
logger.error("Failed to decode CSV with any encoding")
raise ValueError("Unable to decode CSV file")
csv_reader = csv.DictReader(io.StringIO(csv_content))
result = {}
# Case-insensitive column mapping
field_name_map = {
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
'Amount': ['Amount', 'amount'],
'OCR': ['OCR', 'ocr'],
'Bankgiro': ['Bankgiro', 'bankgiro'],
'Plusgiro': ['Plusgiro', 'plusgiro'],
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
}
for row_num, row in enumerate(csv_reader, start=2):
try:
# Create case-insensitive lookup
row_lower = {k.lower(): v for k, v in row.items()}
# Find DocumentId with case-insensitive matching
document_id = None
for variant in field_name_map['DocumentId']:
if variant.lower() in row_lower:
document_id = row_lower[variant.lower()]
break
if not document_id:
logger.warning(f"Row {row_num}: No DocumentId found")
continue
# Validate using Pydantic model
csv_row_dict = {'document_id': document_id}
# Map CSV field names to model attribute names
csv_to_model_attr = {
'InvoiceNumber': 'invoice_number',
'InvoiceDate': 'invoice_date',
'InvoiceDueDate': 'invoice_due_date',
'Amount': 'amount',
'OCR': 'ocr',
'Bankgiro': 'bankgiro',
'Plusgiro': 'plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for csv_field in field_name_map.keys():
if csv_field == 'DocumentId':
continue
model_attr = csv_to_model_attr.get(csv_field)
if not model_attr:
continue
for variant in field_name_map[csv_field]:
if variant.lower() in row_lower and row_lower[variant.lower()]:
csv_row_dict[model_attr] = row_lower[variant.lower()]
break
# Validate
validated_row = CSVRowData(**csv_row_dict)
# Extract only the fields we care about (map back to CSV field names)
field_values = {}
model_attr_to_csv = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'amount': 'Amount',
'ocr': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for model_attr, csv_field in model_attr_to_csv.items():
value = getattr(validated_row, model_attr, None)
if value and csv_field in CSV_TO_CLASS_MAPPING:
field_values[csv_field] = value
if field_values:
result[document_id] = field_values
except Exception as e:
logger.warning(f"Row {row_num}: Validation error - {e}")
continue
return result
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
"""Get batch upload status.
Args:
batch_id: Batch upload ID
Returns:
Batch status dictionary
"""
batch = self.admin_db.get_batch_upload(UUID(batch_id))
if not batch:
return {
"error": "Batch upload not found",
}
files = self.admin_db.get_batch_upload_files(batch.batch_id)
return {
"batch_id": str(batch.batch_id),
"filename": batch.filename,
"status": batch.status,
"total_files": batch.total_files,
"processed_files": batch.processed_files,
"successful_files": batch.successful_files,
"failed_files": batch.failed_files,
"csv_filename": batch.csv_filename,
"csv_row_count": batch.csv_row_count,
"error_message": batch.error_message,
"created_at": batch.created_at.isoformat() if batch.created_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"files": [
{
"filename": f.filename,
"status": f.status,
"error_message": f.error_message,
"annotation_count": f.annotation_count,
}
for f in files
],
}

View File

@@ -0,0 +1,188 @@
"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from inference.data.admin_models import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(self, db, base_dir: Path):
self._db = db
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._db.update_dataset_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._db.get_documents_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
doc_id = str(doc.document_id)
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._db.add_dataset_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._db.update_dataset_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))

View File

@@ -0,0 +1,531 @@
"""
Database-based Auto-labeling Service
Processes documents with field values stored in the database (csv_field_values).
Used by the pre-label API to create annotations from expected values.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
from shared.data.db import DocumentDB
from inference.web.config import StorageConfig
logger = logging.getLogger(__name__)
# Initialize DocumentDB for saving match reports
_document_db: DocumentDB | None = None
def get_document_db() -> DocumentDB:
"""Get or create DocumentDB instance with connection and tables initialized.
Follows the same pattern as CLI autolabel (src/cli/autolabel.py lines 370-373).
"""
global _document_db
if _document_db is None:
_document_db = DocumentDB()
_document_db.connect()
_document_db.create_tables() # Ensure tables exist
logger.info("Connected to PostgreSQL DocumentDB for match reports")
return _document_db
def convert_csv_field_values_to_row_dict(
document: AdminDocument,
) -> dict[str, Any]:
"""
Convert AdminDocument.csv_field_values to row_dict format for autolabel.
Args:
document: AdminDocument with csv_field_values
Returns:
Dictionary in row_dict format compatible with autolabel_tasks
"""
csv_values = document.csv_field_values or {}
# Build row_dict with DocumentId
row_dict = {
"DocumentId": str(document.document_id),
}
# Map csv_field_values to row_dict format
# csv_field_values uses keys like: InvoiceNumber, InvoiceDate, Amount, OCR, Bankgiro, etc.
# row_dict uses same keys
for key, value in csv_values.items():
if value is not None and value != "":
row_dict[key] = str(value)
return row_dict
def get_pending_autolabel_documents(
db: AdminDB,
limit: int = 10,
) -> list[AdminDocument]:
"""
Get documents pending auto-labeling.
Args:
db: AdminDB instance
limit: Maximum number of documents to return
Returns:
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
"""
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == "auto_labeling",
AdminDocument.auto_label_status == "pending",
).order_by(AdminDocument.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def process_document_autolabel(
document: AdminDocument,
db: AdminDB,
output_dir: Path | None = None,
dpi: int = DEFAULT_DPI,
min_confidence: float = 0.5,
) -> dict[str, Any]:
"""
Process a single document for auto-labeling using csv_field_values.
Args:
document: AdminDocument with csv_field_values and file_path
db: AdminDB instance for updating status
output_dir: Output directory for temp files
dpi: Rendering DPI
min_confidence: Minimum match confidence
Returns:
Result dictionary with success status and annotations
"""
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
from shared.pdf import PDFDocument
document_id = str(document.document_id)
file_path = Path(document.file_path)
if output_dir is None:
output_dir = Path("data/autolabel_output")
output_dir.mkdir(parents=True, exist_ok=True)
# Mark as processing
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
try:
# Check if file exists
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Convert csv_field_values to row_dict
row_dict = convert_csv_field_values_to_row_dict(document)
if len(row_dict) <= 1: # Only has DocumentId
raise ValueError("No field values to match")
# Determine PDF type (text or scanned)
is_scanned = False
with PDFDocument(file_path) as pdf_doc:
# Check if first page has extractable text
tokens = list(pdf_doc.extract_text_tokens(0))
is_scanned = len(tokens) < 10 # Threshold for "no text"
# Build task data
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
# This ensures consistency with CLI autolabel for reprocess_failed.py
storage_config = StorageConfig()
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
task_data = {
"row_dict": row_dict,
"pdf_path": str(pdf_path_for_report),
"output_dir": str(output_dir),
"dpi": dpi,
"min_confidence": min_confidence,
}
# Process based on PDF type
if is_scanned:
result = process_scanned_pdf(task_data)
else:
result = process_text_pdf(task_data)
# Save report to DocumentDB (same as CLI autolabel)
if result.get("report"):
try:
doc_db = get_document_db()
doc_db.save_document(result["report"])
logger.info(f"Saved match report to DocumentDB for {document_id}")
except Exception as e:
logger.warning(f"Failed to save report to DocumentDB: {e}")
# Save annotations to AdminDB
if result.get("success") and result.get("report"):
_save_annotations_to_db(
db=db,
document_id=document_id,
report=result["report"],
page_annotations=result.get("pages", []),
dpi=dpi,
)
# Mark as completed
db.update_document_status(
document_id=document_id,
status="labeled",
auto_label_status="completed",
)
else:
# Mark as failed
errors = result.get("report", {}).get("errors", ["Unknown error"])
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error="; ".join(errors) if errors else "No annotations generated",
)
return result
except Exception as e:
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
# Mark as failed
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"doc_id": document_id,
"success": False,
"error": str(e),
}
def _save_annotations_to_db(
db: AdminDB,
document_id: str,
report: dict[str, Any],
page_annotations: list[dict[str, Any]],
dpi: int = 200,
) -> int:
"""
Save generated annotations to database.
Args:
db: AdminDB instance
document_id: Document ID
report: AutoLabelReport as dict
page_annotations: List of page annotation data
dpi: DPI used for rendering images (for coordinate conversion)
Returns:
Number of annotations saved
"""
from PIL import Image
from inference.data.admin_models import FIELD_CLASS_IDS
# Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"customer_number": "customer_number",
"payment_line": "payment_line",
}
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
scale = dpi / 72.0
# Cache for image dimensions per page
image_dimensions: dict[int, tuple[int, int]] = {}
def get_image_dimensions(page_no: int) -> tuple[int, int] | None:
"""Get image dimensions for a page (1-indexed)."""
if page_no in image_dimensions:
return image_dimensions[page_no]
# Try to load from admin_images
admin_images_dir = Path("data/admin_images") / document_id
image_path = admin_images_dir / f"page_{page_no}.png"
if image_path.exists():
try:
with Image.open(image_path) as img:
dims = img.size # (width, height)
image_dimensions[page_no] = dims
return dims
except Exception as e:
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
return None
annotation_count = 0
# Get field results from report (list of dicts)
field_results = report.get("field_results", [])
for field_info in field_results:
if not field_info.get("matched"):
continue
csv_field_name = field_info.get("field_name", "")
# Map CSV field name to internal field name
field_name = CSV_TO_INTERNAL_FIELD.get(csv_field_name, csv_field_name)
# Get class_id from field name
class_id = FIELD_CLASS_IDS.get(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {csv_field_name} -> {field_name}")
continue
# Get bbox info (list: [x, y, x2, y2] in PDF points - 72 DPI)
bbox = field_info.get("bbox", [])
if not bbox or len(bbox) < 4:
continue
# Convert PDF points (72 DPI) to pixel coordinates (at configured DPI)
pdf_x1, pdf_y1, pdf_x2, pdf_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
x1 = pdf_x1 * scale
y1 = pdf_y1 * scale
x2 = pdf_x2 * scale
y2 = pdf_y2 * scale
bbox_width = x2 - x1
bbox_height = y2 - y1
# Get page number (convert to 1-indexed)
page_no = field_info.get("page_no", 0) + 1
# Get image dimensions for normalization
dims = get_image_dimensions(page_no)
if dims:
img_width, img_height = dims
# Calculate normalized coordinates
x_center = (x1 + x2) / 2 / img_width
y_center = (y1 + y2) / 2 / img_height
width = bbox_width / img_width
height = bbox_height / img_height
else:
# Fallback: use pixel coordinates as-is for normalization
# (will be slightly off but better than /1000)
logger.warning(f"Could not get image dimensions for page {page_no}, using estimates")
# Estimate A4 at configured DPI: 595 x 842 points * scale
estimated_width = 595 * scale
estimated_height = 842 * scale
x_center = (x1 + x2) / 2 / estimated_width
y_center = (y1 + y2) / 2 / estimated_height
width = bbox_width / estimated_width
height = bbox_height / estimated_height
# Create annotation
try:
db.create_annotation(
document_id=document_id,
page_number=page_no,
class_id=class_id,
class_name=field_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=int(x1),
bbox_y=int(y1),
bbox_width=int(bbox_width),
bbox_height=int(bbox_height),
text_value=field_info.get("matched_text"),
confidence=field_info.get("score"),
source="auto",
)
annotation_count += 1
logger.info(f"Saved annotation for {field_name}: bbox=({int(x1)}, {int(y1)}, {int(bbox_width)}, {int(bbox_height)})")
except Exception as e:
logger.warning(f"Failed to save annotation for {field_name}: {e}")
return annotation_count
def run_pending_autolabel_batch(
db: AdminDB | None = None,
batch_size: int = 10,
output_dir: Path | None = None,
) -> dict[str, Any]:
"""
Process a batch of pending auto-label documents.
Args:
db: AdminDB instance (created if None)
batch_size: Number of documents to process
output_dir: Output directory for temp files
Returns:
Summary of processing results
"""
if db is None:
db = AdminDB()
documents = get_pending_autolabel_documents(db, limit=batch_size)
results = {
"total": len(documents),
"successful": 0,
"failed": 0,
"documents": [],
}
for doc in documents:
result = process_document_autolabel(
document=doc,
db=db,
output_dir=output_dir,
)
doc_result = {
"document_id": str(doc.document_id),
"success": result.get("success", False),
}
if result.get("success"):
results["successful"] += 1
else:
results["failed"] += 1
doc_result["error"] = result.get("error") or "Unknown error"
results["documents"].append(doc_result)
return results
def save_manual_annotations_to_document_db(
document: AdminDocument,
annotations: list,
db: AdminDB,
) -> dict[str, Any]:
"""
Save manual annotations to PostgreSQL documents and field_results tables.
Called when user marks a document as 'labeled' from the web UI.
This ensures manually labeled documents are also tracked in the same
database as auto-labeled documents for consistency.
Args:
document: AdminDocument instance
annotations: List of AdminAnnotation instances
db: AdminDB instance
Returns:
Dict with success status and details
"""
from datetime import datetime
document_id = str(document.document_id)
storage_config = StorageConfig()
# Build pdf_path using admin_upload_dir (same as auto-label)
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
# Build report dict compatible with DocumentDB.save_document()
field_results = []
fields_total = len(annotations)
fields_matched = 0
for ann in annotations:
# All manual annotations are considered "matched" since user verified them
field_result = {
"field_name": ann.class_name,
"csv_value": ann.text_value or "", # Manual annotations may not have CSV value
"matched": True,
"score": ann.confidence or 1.0, # Manual = high confidence
"matched_text": ann.text_value,
"candidate_used": "manual",
"bbox": [ann.bbox_x, ann.bbox_y, ann.bbox_x + ann.bbox_width, ann.bbox_y + ann.bbox_height],
"page_no": ann.page_number - 1, # Convert to 0-indexed
"context_keywords": [],
"error": None,
}
field_results.append(field_result)
fields_matched += 1
# Determine PDF type
pdf_type = "unknown"
if pdf_path.exists():
try:
from shared.pdf import PDFDocument
with PDFDocument(pdf_path) as pdf_doc:
tokens = list(pdf_doc.extract_text_tokens(0))
pdf_type = "scanned" if len(tokens) < 10 else "text"
except Exception as e:
logger.warning(f"Could not determine PDF type: {e}")
# Build report
report = {
"document_id": document_id,
"pdf_path": str(pdf_path),
"pdf_type": pdf_type,
"success": fields_matched > 0,
"total_pages": document.page_count,
"fields_matched": fields_matched,
"fields_total": fields_total,
"annotations_generated": fields_matched,
"processing_time_ms": 0, # Manual labeling - no processing time
"timestamp": datetime.utcnow().isoformat(),
"errors": [],
"field_results": field_results,
# Extended fields (from CSV if available)
"split": None,
"customer_number": document.csv_field_values.get("customer_number") if document.csv_field_values else None,
"supplier_name": document.csv_field_values.get("supplier_name") if document.csv_field_values else None,
"supplier_organisation_number": document.csv_field_values.get("supplier_organisation_number") if document.csv_field_values else None,
"supplier_accounts": document.csv_field_values.get("supplier_accounts") if document.csv_field_values else None,
}
# Save to PostgreSQL DocumentDB
try:
doc_db = get_document_db()
doc_db.save_document(report)
logger.info(f"Saved manual annotations to DocumentDB for {document_id}: {fields_matched} fields")
return {
"success": True,
"document_id": document_id,
"fields_saved": fields_matched,
"message": f"Saved {fields_matched} annotations to DocumentDB",
}
except Exception as e:
logger.error(f"Failed to save manual annotations to DocumentDB: {e}", exc_info=True)
return {
"success": False,
"document_id": document_id,
"error": str(e),
}

View File

@@ -0,0 +1,285 @@
"""
Inference Service
Business logic for invoice field extraction.
"""
from __future__ import annotations
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
from PIL import Image
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
class InferenceService:
"""
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
storage_config: Storage configuration
"""
self.model_config = model_config
self.storage_config = storage_config
self._pipeline = None
self._detector = None
self._is_initialized = False
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
if self._is_initialized:
return
logger.info("Initializing inference service...")
start_time = time.time()
try:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
)
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
# If no pages rendered
return None

View File

@@ -0,0 +1,24 @@
"""
Background Task Queues
Worker queues for asynchronous and batch processing.
"""
from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask
from inference.web.workers.batch_queue import (
BatchTaskQueue,
BatchTask,
init_batch_queue,
shutdown_batch_queue,
get_batch_queue,
)
__all__ = [
"AsyncTaskQueue",
"AsyncTask",
"BatchTaskQueue",
"BatchTask",
"init_batch_queue",
"shutdown_batch_queue",
"get_batch_queue",
]

View File

@@ -0,0 +1,181 @@
"""
Async Task Queue
Thread-safe queue for background invoice processing.
"""
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from queue import Empty, Full, Queue
import threading
from threading import Event, Lock, Thread
from typing import Callable
logger = logging.getLogger(__name__)
@dataclass
class AsyncTask:
"""Task queued for background processing."""
request_id: str
api_key: str
file_path: Path
filename: str
created_at: datetime = field(default_factory=datetime.utcnow)
priority: int = 0 # Lower = higher priority (not implemented yet)
class AsyncTaskQueue:
"""Thread-safe queue for async invoice processing."""
def __init__(
self,
max_size: int = 100,
worker_count: int = 1,
) -> None:
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
self._workers: list[Thread] = []
self._stop_event = Event()
self._worker_count = worker_count
self._lock = Lock()
self._processing: set[str] = set() # Currently processing request_ids
self._task_handler: Callable[[AsyncTask], None] | None = None
self._started = False
def start(self, task_handler: Callable[[AsyncTask], None]) -> None:
"""Start background worker threads."""
if self._started:
logger.warning("AsyncTaskQueue already started")
return
self._task_handler = task_handler
self._stop_event.clear()
for i in range(self._worker_count):
worker = Thread(
target=self._worker_loop,
name=f"async-worker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started async worker thread: {worker.name}")
self._started = True
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
def stop(self, timeout: float = 30.0) -> None:
"""Gracefully stop all workers."""
if not self._started:
return
logger.info("Stopping AsyncTaskQueue...")
self._stop_event.set()
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout / self._worker_count)
if worker.is_alive():
logger.warning(f"Worker {worker.name} did not stop gracefully")
self._workers.clear()
self._started = False
logger.info("AsyncTaskQueue stopped")
def submit(self, task: AsyncTask) -> bool:
"""
Submit a task to the queue.
Returns:
True if task was queued, False if queue is full
"""
try:
self._queue.put_nowait(task)
logger.info(f"Task {task.request_id} queued for processing")
return True
except Full:
logger.warning(f"Queue full, task {task.request_id} rejected")
return False
def get_queue_depth(self) -> int:
"""Get current number of tasks in queue."""
return self._queue.qsize()
def get_processing_count(self) -> int:
"""Get number of tasks currently being processed."""
with self._lock:
return len(self._processing)
def is_processing(self, request_id: str) -> bool:
"""Check if a specific request is currently being processed."""
with self._lock:
return request_id in self._processing
@property
def is_running(self) -> bool:
"""Check if the queue is running."""
return self._started and not self._stop_event.is_set()
def _worker_loop(self) -> None:
"""Worker loop that processes tasks from queue."""
thread_name = threading.current_thread().name
logger.info(f"Worker {thread_name} started")
while not self._stop_event.is_set():
try:
# Block for up to 1 second waiting for tasks
task = self._queue.get(timeout=1.0)
except Empty:
continue
try:
with self._lock:
self._processing.add(task.request_id)
logger.info(
f"Worker {thread_name} processing task {task.request_id}"
)
start_time = time.time()
if self._task_handler:
self._task_handler(task)
elapsed = time.time() - start_time
logger.info(
f"Worker {thread_name} completed task {task.request_id} "
f"in {elapsed:.2f}s"
)
except Exception as e:
logger.error(
f"Worker {thread_name} failed to process task "
f"{task.request_id}: {e}",
exc_info=True,
)
finally:
with self._lock:
self._processing.discard(task.request_id)
self._queue.task_done()
logger.info(f"Worker {thread_name} stopped")
def wait_for_completion(self, timeout: float | None = None) -> bool:
"""
Wait for all queued tasks to complete.
Args:
timeout: Maximum time to wait in seconds
Returns:
True if all tasks completed, False if timeout
"""
try:
self._queue.join()
return True
except Exception:
return False

View File

@@ -0,0 +1,225 @@
"""
Batch Upload Processing Queue
Background queue for async batch upload processing.
"""
import logging
import threading
from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Full, Empty
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
@dataclass
class BatchTask:
"""Task for batch upload processing."""
batch_id: UUID
admin_token: str
zip_content: bytes
zip_filename: str
upload_source: str
auto_label: bool
created_at: datetime
class BatchTaskQueue:
"""Thread-safe queue for async batch upload processing."""
def __init__(self, max_size: int = 20, worker_count: int = 2):
"""Initialize the batch task queue.
Args:
max_size: Maximum queue size
worker_count: Number of worker threads
"""
self._queue: Queue[BatchTask] = Queue(maxsize=max_size)
self._workers: list[threading.Thread] = []
self._stop_event = threading.Event()
self._worker_count = worker_count
self._batch_service: Any | None = None
self._running = False
self._lock = threading.Lock()
def start(self, batch_service: Any) -> None:
"""Start worker threads with batch service.
Args:
batch_service: BatchUploadService instance for processing
"""
with self._lock:
if self._running:
logger.warning("Batch queue already running")
return
self._batch_service = batch_service
self._stop_event.clear()
self._running = True
# Start worker threads
for i in range(self._worker_count):
worker = threading.Thread(
target=self._worker_loop,
name=f"BatchWorker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started {self._worker_count} batch workers")
def stop(self, timeout: float = 30.0) -> None:
"""Stop all worker threads gracefully.
Args:
timeout: Maximum time to wait for workers to finish
"""
with self._lock:
if not self._running:
return
logger.info("Stopping batch queue...")
self._stop_event.set()
self._running = False
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout)
self._workers.clear()
logger.info("Batch queue stopped")
def submit(self, task: BatchTask) -> bool:
"""Submit a batch task to the queue.
Args:
task: Batch task to process
Returns:
True if task was queued, False if queue is full
"""
try:
self._queue.put(task, block=False)
logger.info(f"Queued batch task: batch_id={task.batch_id}")
return True
except Full:
logger.warning(f"Queue full, rejected task: batch_id={task.batch_id}")
return False
def get_queue_depth(self) -> int:
"""Get the number of pending tasks in queue.
Returns:
Number of tasks waiting to be processed
"""
return self._queue.qsize()
@property
def is_running(self) -> bool:
"""Check if queue is running.
Returns:
True if queue is active
"""
return self._running
def _worker_loop(self) -> None:
"""Worker thread main loop."""
worker_name = threading.current_thread().name
logger.info(f"{worker_name} started")
while not self._stop_event.is_set():
try:
# Get task with timeout to check stop event periodically
task = self._queue.get(timeout=1.0)
self._process_task(task)
self._queue.task_done()
except Empty:
# No tasks, continue loop to check stop event
continue
except Exception as e:
logger.error(f"{worker_name} error processing task: {e}", exc_info=True)
logger.info(f"{worker_name} stopped")
def _process_task(self, task: BatchTask) -> None:
"""Process a single batch task.
Args:
task: Batch task to process
"""
if self._batch_service is None:
logger.error("Batch service not initialized, cannot process task")
return
logger.info(
f"Processing batch task: batch_id={task.batch_id}, "
f"filename={task.zip_filename}"
)
try:
# Process the batch upload using the service
result = self._batch_service.process_zip_upload(
admin_token=task.admin_token,
zip_filename=task.zip_filename,
zip_content=task.zip_content,
upload_source=task.upload_source,
)
logger.info(
f"Batch task completed: batch_id={task.batch_id}, "
f"status={result.get('status')}, "
f"successful_files={result.get('successful_files')}, "
f"failed_files={result.get('failed_files')}"
)
except Exception as e:
logger.error(
f"Error processing batch task {task.batch_id}: {e}",
exc_info=True,
)
# Global batch queue instance
_batch_queue: BatchTaskQueue | None = None
_queue_lock = threading.Lock()
def get_batch_queue() -> BatchTaskQueue:
"""Get or create the global batch queue instance.
Returns:
Batch task queue instance
"""
global _batch_queue
if _batch_queue is None:
with _queue_lock:
if _batch_queue is None:
_batch_queue = BatchTaskQueue(max_size=20, worker_count=2)
return _batch_queue
def init_batch_queue(batch_service: Any) -> None:
"""Initialize and start the batch queue.
Args:
batch_service: BatchUploadService instance
"""
queue = get_batch_queue()
if not queue.is_running:
queue.start(batch_service)
def shutdown_batch_queue() -> None:
"""Shutdown the batch queue gracefully."""
global _batch_queue
if _batch_queue is not None:
_batch_queue.stop()

View File

@@ -0,0 +1,8 @@
-e ../shared
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
sqlmodel>=0.0.22
ultralytics>=8.1.0
httpx>=0.25.0
openai>=1.0.0

View File

@@ -0,0 +1,14 @@
#!/usr/bin/env python
"""
Quick start script for the web server.
Usage:
python run_server.py
python run_server.py --port 8080
python run_server.py --debug --reload
"""
from inference.cli.serve import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,17 @@
from setuptools import setup, find_packages
setup(
name="invoice-inference",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"invoice-shared",
"fastapi>=0.104.0",
"uvicorn[standard]>=0.24.0",
"python-multipart>=0.0.6",
"sqlmodel>=0.0.22",
"ultralytics>=8.1.0",
"httpx>=0.25.0",
],
)

View File

@@ -0,0 +1,9 @@
PyMuPDF>=1.23.0
paddleocr>=2.7.0
Pillow>=10.0.0
numpy>=1.24.0
opencv-python>=4.8.0
psycopg2-binary>=2.9.0
python-dotenv>=1.0.0
pyyaml>=6.0
thefuzz>=0.20.0

19
packages/shared/setup.py Normal file
View File

@@ -0,0 +1,19 @@
from setuptools import setup, find_packages
setup(
name="invoice-shared",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"PyMuPDF>=1.23.0",
"paddleocr>=2.7.0",
"Pillow>=10.0.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"psycopg2-binary>=2.9.0",
"python-dotenv>=1.0.0",
"pyyaml>=6.0",
"thefuzz>=0.20.0",
],
)

View File

@@ -0,0 +1,2 @@
# Invoice Master POC v2
# Automatic invoice information extraction system using YOLO + OCR

View File

@@ -0,0 +1,88 @@
"""
Configuration settings for the invoice extraction system.
"""
import os
import platform
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables from .env file at project root
# Walk up from packages/shared/shared/config.py to find project root
_config_dir = Path(__file__).parent
for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]:
_env_path = _candidate / '.env'
if _env_path.exists():
load_dotenv(dotenv_path=_env_path)
break
else:
load_dotenv() # fallback: search cwd and parents
# Global DPI setting - must match training DPI for optimal model performance
DEFAULT_DPI = 150
def _is_wsl() -> bool:
"""Check if running inside WSL (Windows Subsystem for Linux)."""
if platform.system() != 'Linux':
return False
# Check for WSL-specific indicators
if os.environ.get('WSL_DISTRO_NAME'):
return True
try:
with open('/proc/version', 'r') as f:
return 'microsoft' in f.read().lower()
except (FileNotFoundError, PermissionError):
return False
# PostgreSQL Database Configuration
# Now loaded from environment variables for security
DATABASE = {
'host': os.getenv('DB_HOST', '192.168.68.31'),
'port': int(os.getenv('DB_PORT', '5432')),
'database': os.getenv('DB_NAME', 'docmaster'),
'user': os.getenv('DB_USER', 'docmaster'),
'password': os.getenv('DB_PASSWORD'), # No default for security
}
# Validate required configuration
if not DATABASE['password']:
raise ValueError(
"DB_PASSWORD environment variable is not set. "
"Please create a .env file based on .env.example and set DB_PASSWORD."
)
# Connection string for psycopg2
def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows
if _is_wsl():
# WSL: use native Linux filesystem for better I/O performance
PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
'reports_dir': 'reports', # Keep reports in project directory
}
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Auto-labeling Configuration
AUTOLABEL = {
'workers': 2,
'dpi': DEFAULT_DPI,
'min_confidence': 0.5,
'train_ratio': 0.8,
'val_ratio': 0.1,
'test_ratio': 0.1,
'max_records_per_report': 10000,
}

View File

@@ -0,0 +1,3 @@
from .csv_loader import CSVLoader, InvoiceRow
__all__ = ['CSVLoader', 'InvoiceRow']

View File

@@ -0,0 +1,372 @@
"""
CSV Data Loader
Loads and parses structured invoice data from CSV files.
Follows the CSV specification for invoice data.
"""
import csv
from dataclasses import dataclass, field
from datetime import datetime, date
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Any, Iterator
@dataclass
class InvoiceRow:
"""Parsed invoice data row."""
DocumentId: str
InvoiceDate: date | None = None
InvoiceNumber: str | None = None
InvoiceDueDate: date | None = None
OCR: str | None = None
Message: str | None = None
Bankgiro: str | None = None
Plusgiro: str | None = None
Amount: Decimal | None = None
# New fields
split: str | None = None # train/test split indicator
customer_number: str | None = None # Customer number (needs matching)
supplier_name: str | None = None # Supplier name (no matching)
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
supplier_accounts: str | None = None # Supplier accounts (needs matching)
# Raw values for reference
raw_data: dict = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for matching."""
return {
'DocumentId': self.DocumentId,
'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None,
'InvoiceNumber': self.InvoiceNumber,
'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None,
'OCR': self.OCR,
'Bankgiro': self.Bankgiro,
'Plusgiro': self.Plusgiro,
'Amount': str(self.Amount) if self.Amount else None,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
}
def get_field_value(self, field_name: str) -> str | None:
"""Get field value as string for matching."""
value = getattr(self, field_name, None)
if value is None:
return None
if isinstance(value, date):
return value.isoformat()
if isinstance(value, Decimal):
return str(value)
return str(value) if value else None
class CSVLoader:
"""Loads invoice data from CSV files."""
# Expected field mappings (CSV header -> InvoiceRow attribute)
FIELD_MAPPINGS = {
'DocumentId': 'DocumentId',
'InvoiceDate': 'InvoiceDate',
'InvoiceNumber': 'InvoiceNumber',
'InvoiceDueDate': 'InvoiceDueDate',
'OCR': 'OCR',
'Message': 'Message',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'Amount': 'Amount',
# New fields
'split': 'split',
'customer_number': 'customer_number',
'supplier_name': 'supplier_name',
'supplier_organisation_number': 'supplier_organisation_number',
'supplier_accounts': 'supplier_accounts',
}
def __init__(
self,
csv_path: str | Path | list[str | Path],
pdf_dir: str | Path | None = None,
doc_map_path: str | Path | None = None,
encoding: str = 'utf-8'
):
"""
Initialize CSV loader.
Args:
csv_path: Path to CSV file(s). Can be:
- Single path: 'data/file.csv'
- List of paths: ['data/file1.csv', 'data/file2.csv']
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
doc_map_path: Optional path to document mapping CSV
encoding: CSV file encoding (default: utf-8)
"""
# Handle multiple CSV files
if isinstance(csv_path, list):
self.csv_paths = [Path(p) for p in csv_path]
else:
csv_path = Path(csv_path)
# Check if it's a glob pattern (contains * or ?)
if '*' in str(csv_path) or '?' in str(csv_path):
parent = csv_path.parent
pattern = csv_path.name
self.csv_paths = sorted(parent.glob(pattern))
else:
self.csv_paths = [csv_path]
# For backward compatibility
self.csv_path = self.csv_paths[0] if self.csv_paths else None
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
self.encoding = encoding
# Load document mapping if provided
self.doc_map = self._load_doc_map() if self.doc_map_path else {}
def _load_doc_map(self) -> dict[str, str]:
"""Load document ID to filename mapping."""
mapping = {}
if self.doc_map_path and self.doc_map_path.exists():
with open(self.doc_map_path, 'r', encoding=self.encoding) as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row.get('DocumentId', '').strip()
filename = row.get('FileName', '').strip()
if doc_id and filename:
mapping[doc_id] = filename
return mapping
def _parse_date(self, value: str | None) -> date | None:
"""Parse date from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Try different date formats
formats = [
'%Y-%m-%d',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%d/%m/%Y',
'%d.%m.%Y',
'%d-%m-%Y',
'%Y%m%d',
]
for fmt in formats:
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
return None
def _parse_amount(self, value: str | None) -> Decimal | None:
"""Parse monetary amount from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Remove currency symbols and common suffixes
value = value.replace('SEK', '').replace('kr', '').replace(':-', '')
value = value.strip()
# Remove spaces (thousand separators)
value = value.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator (European format)
if ',' in value and '.' not in value:
value = value.replace(',', '.')
elif ',' in value and '.' in value:
# Assume comma is thousands separator, dot is decimal
value = value.replace(',', '')
try:
return Decimal(value)
except InvalidOperation:
return None
def _parse_string(self, value: str | None) -> str | None:
"""Parse string field with cleanup."""
if value is None:
return None
value = value.strip()
return value if value else None
def _get_field(self, row: dict, *keys: str) -> str | None:
"""Get field value trying multiple possible column names."""
for key in keys:
value = row.get(key)
if value is not None:
return value
return None
def _parse_row(self, row: dict) -> InvoiceRow | None:
"""Parse a single CSV row into InvoiceRow."""
doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
if not doc_id:
return None
return InvoiceRow(
DocumentId=doc_id,
InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
Message=self._parse_string(self._get_field(row, 'Message', 'message')),
Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
# New fields
split=self._parse_string(row.get('split')),
customer_number=self._parse_string(row.get('customer_number')),
supplier_name=self._parse_string(row.get('supplier_name')),
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
raw_data=dict(row)
)
def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
"""Iterate over rows from a single CSV file."""
# Handle BOM - try utf-8-sig first to handle BOM correctly
encodings = ['utf-8-sig', self.encoding, 'latin-1']
for enc in encodings:
try:
with open(csv_path, 'r', encoding=enc) as f:
reader = csv.DictReader(f)
for row in reader:
parsed = self._parse_row(row)
if parsed:
yield parsed
return
except UnicodeDecodeError:
continue
raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV(s)."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows from all CSV files."""
seen_doc_ids = set()
for csv_path in self.csv_paths:
if not csv_path.exists():
continue
for row in self._iter_single_csv(csv_path):
# Deduplicate by DocumentId
if row.DocumentId not in seen_doc_ids:
seen_doc_ids.add(row.DocumentId)
yield row
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
"""
Get PDF path for an invoice row.
Uses document mapping if available, otherwise assumes
DocumentId.pdf naming convention.
"""
doc_id = invoice_row.DocumentId
# Check document mapping first
if doc_id in self.doc_map:
filename = self.doc_map[doc_id]
pdf_path = self.pdf_dir / filename
if pdf_path.exists():
return pdf_path
# Try default naming patterns
patterns = [
f"{doc_id}.pdf",
f"{doc_id}.PDF",
f"{doc_id.lower()}.pdf",
f"{doc_id.lower()}.PDF",
f"{doc_id.upper()}.pdf",
f"{doc_id.upper()}.PDF",
]
for pattern in patterns:
pdf_path = self.pdf_dir / pattern
if pdf_path.exists():
return pdf_path
# Try glob patterns for partial matches (both cases)
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"):
return pdf_file
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.PDF"):
return pdf_file
return None
def get_row_by_id(self, doc_id: str) -> InvoiceRow | None:
"""Get a specific row by DocumentId."""
for row in self.iter_rows():
if row.DocumentId == doc_id:
return row
return None
def validate(self) -> list[dict]:
"""
Validate CSV data and return issues.
Returns:
List of validation issues
"""
issues = []
for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1)
# Check required DocumentId
if not row.DocumentId:
issues.append({
'row': i,
'field': 'DocumentId',
'issue': 'Missing required DocumentId'
})
continue
# Check if PDF exists
pdf_path = self.get_pdf_path(row)
if not pdf_path:
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'PDF',
'issue': 'PDF file not found'
})
# Check for at least one matchable field
matchable_fields = [
row.InvoiceNumber,
row.OCR,
row.Bankgiro,
row.Plusgiro,
row.Amount,
row.supplier_organisation_number,
row.supplier_accounts,
]
if not any(matchable_fields):
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'All',
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
})
return issues
def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV(s)."""
loader = CSVLoader(csv_path, pdf_dir)
return loader.load_all()

View File

@@ -0,0 +1,530 @@
"""
Database utilities for autolabel workflow.
"""
import json
import psycopg2
from psycopg2.extras import execute_values
from typing import Set, Dict, Any, Optional
import sys
from pathlib import Path
from shared.config import get_db_connection_string
class DocumentDB:
"""Database interface for document processing status."""
def __init__(self, connection_string: str = None):
self.connection_string = connection_string or get_db_connection_string()
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def create_tables(self):
"""Create database tables if they don't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]',
-- Extended CSV format fields
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def get_successful_doc_ids(self) -> Set[str]:
"""Get all document IDs that have been successfully processed."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = true")
return {row[0] for row in cursor.fetchall()}
def get_failed_doc_ids(self) -> Set[str]:
"""Get all document IDs that failed processing."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = false")
return {row[0] for row in cursor.fetchall()}
def check_document_status(self, doc_id: str) -> Optional[bool]:
"""
Check if a document exists and its success status.
Returns:
True if exists and success=true
False if exists and success=false
None if not exists
"""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT success FROM documents WHERE document_id = %s",
(doc_id,)
)
row = cursor.fetchone()
if row is None:
return None
return row[0]
def check_documents_status_batch(self, doc_ids: list[str]) -> Dict[str, Optional[bool]]:
"""
Batch check document status for multiple IDs.
Returns:
Dict mapping doc_id to status:
True if exists and success=true
False if exists and success=false
(missing from dict if not exists)
"""
if not doc_ids:
return {}
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT document_id, success FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
return {row[0]: row[1] for row in cursor.fetchall()}
def delete_document(self, doc_id: str):
"""Delete a document and its field results (for re-processing)."""
conn = self.connect()
with conn.cursor() as cursor:
# field_results will be cascade deleted
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
conn.commit()
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""Get a single document with its field results."""
conn = self.connect()
with conn.cursor() as cursor:
# Get document
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return None
doc = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
# Get field results
cursor.execute("""
SELECT field_name, csv_value, matched, score, matched_text,
candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = %s
""", (doc_id,))
for fr in cursor.fetchall():
doc['field_results'].append({
'field_name': fr[0],
'csv_value': fr[1],
'matched': fr[2],
'score': fr[3],
'matched_text': fr[4],
'candidate_used': fr[5],
'bbox': fr[6] if isinstance(fr[6], list) else json.loads(fr[6]) if fr[6] else None,
'page_no': fr[7],
'context_keywords': fr[8] if isinstance(fr[8], list) else json.loads(fr[8] or '[]'),
'error': fr[9]
})
return doc
def get_all_documents_summary(self, success_only: bool = False, limit: int = None) -> list[Dict[str, Any]]:
"""Get summary of all documents (without field_results for efficiency)."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total
FROM documents
"""
params = []
if success_only:
query += " WHERE success = true"
query += " ORDER BY timestamp DESC"
if limit:
# Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query, params if params else None)
return [
{
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6]
}
for row in cursor.fetchall()
]
def get_field_stats(self) -> Dict[str, Dict[str, int]]:
"""Get match statistics per field."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name,
COUNT(*) as total,
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched
FROM field_results
GROUP BY field_name
ORDER BY field_name
""")
return {
row[0]: {'total': row[1], 'matched': row[2]}
for row in cursor.fetchall()
}
def get_failed_matches(self, field_name: str = None, limit: int = 100) -> list[Dict[str, Any]]:
"""Get field results that failed to match."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT fr.document_id, fr.field_name, fr.csv_value, fr.error,
d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
"""
params = []
if field_name:
query += " AND fr.field_name = %s"
params.append(field_name)
# Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query, params)
return [
{
'document_id': row[0],
'field_name': row[1],
'csv_value': row[2],
'error': row[3],
'pdf_type': row[4]
}
for row in cursor.fetchall()
]
def get_documents_batch(self, doc_ids: list[str]) -> Dict[str, Dict[str, Any]]:
"""
Get multiple documents with their field results in a single batch query.
This is much more efficient than calling get_document() in a loop.
Args:
doc_ids: List of document IDs to fetch
Returns:
Dict mapping doc_id to document data (with field_results)
"""
if not doc_ids:
return {}
conn = self.connect()
result: Dict[str, Dict[str, Any]] = {}
with conn.cursor() as cursor:
# Batch fetch all documents
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = ANY(%s)
""", (doc_ids,))
for row in cursor.fetchall():
result[row[0]] = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
if not result:
return {}
# Batch fetch all field results for these documents
cursor.execute("""
SELECT document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = ANY(%s)
""", (list(result.keys()),))
for fr in cursor.fetchall():
doc_id = fr[0]
if doc_id in result:
result[doc_id]['field_results'].append({
'field_name': fr[1],
'csv_value': fr[2],
'matched': fr[3],
'score': fr[4],
'matched_text': fr[5],
'candidate_used': fr[6],
'bbox': fr[7] if isinstance(fr[7], list) else json.loads(fr[7]) if fr[7] else None,
'page_no': fr[8],
'context_keywords': fr[9] if isinstance(fr[9], list) else json.loads(fr[9] or '[]'),
'error': fr[10]
})
return result
def save_document(self, report: Dict[str, Any]):
"""Save or update a document and its field results using batch operations."""
conn = self.connect()
doc_id = report.get('document_id')
with conn.cursor() as cursor:
# Delete existing record if any (for update)
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
# Insert document
cursor.execute("""
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
doc_id,
report.get('pdf_path'),
report.get('pdf_type'),
report.get('success'),
report.get('total_pages'),
report.get('fields_matched'),
report.get('fields_total'),
report.get('annotations_generated'),
report.get('processing_time_ms'),
report.get('timestamp'),
json.dumps(report.get('errors', [])),
# New fields
report.get('split'),
report.get('customer_number'),
report.get('supplier_name'),
report.get('supplier_organisation_number'),
report.get('supplier_accounts'),
))
# Batch insert field results using execute_values
field_results = report.get('field_results', [])
if field_results:
field_values = [
(
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
)
for field in field_results
]
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()
def save_documents_batch(self, reports: list[Dict[str, Any]]):
"""Save multiple documents in a batch."""
if not reports:
return
conn = self.connect()
doc_ids = [r['document_id'] for r in reports]
with conn.cursor() as cursor:
# Delete existing records
cursor.execute(
"DELETE FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
# Batch insert documents
doc_values = [
(
r.get('document_id'),
r.get('pdf_path'),
r.get('pdf_type'),
r.get('success'),
r.get('total_pages'),
r.get('fields_matched'),
r.get('fields_total'),
r.get('annotations_generated'),
r.get('processing_time_ms'),
r.get('timestamp'),
json.dumps(r.get('errors', [])),
# New fields
r.get('split'),
r.get('customer_number'),
r.get('supplier_name'),
r.get('supplier_organisation_number'),
r.get('supplier_accounts'),
)
for r in reports
]
execute_values(cursor, """
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s
""", doc_values)
# Batch insert field results
field_values = []
for r in reports:
doc_id = r.get('document_id')
for field in r.get('field_results', []):
field_values.append((
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
))
if field_values:
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()

View File

@@ -0,0 +1,102 @@
"""
Application-specific exceptions for invoice extraction system.
This module defines a hierarchy of custom exceptions to provide better
error handling and debugging capabilities throughout the application.
"""
class InvoiceExtractionError(Exception):
"""Base exception for all invoice extraction errors."""
def __init__(self, message: str, details: dict = None):
"""
Initialize exception with message and optional details.
Args:
message: Human-readable error message
details: Optional dict with additional error context
"""
super().__init__(message)
self.message = message
self.details = details or {}
def __str__(self):
if self.details:
details_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
return f"{self.message} ({details_str})"
return self.message
class PDFProcessingError(InvoiceExtractionError):
"""Error during PDF processing (rendering, conversion)."""
pass
class OCRError(InvoiceExtractionError):
"""Error during OCR processing."""
pass
class ModelInferenceError(InvoiceExtractionError):
"""Error during YOLO model inference."""
pass
class FieldValidationError(InvoiceExtractionError):
"""Error during field validation or normalization."""
def __init__(self, field_name: str, value: str, reason: str, details: dict = None):
"""
Initialize field validation error.
Args:
field_name: Name of the field that failed validation
value: The invalid value
reason: Why validation failed
details: Additional context
"""
message = f"Field '{field_name}' validation failed: {reason}"
super().__init__(message, details)
self.field_name = field_name
self.value = value
self.reason = reason
class DatabaseError(InvoiceExtractionError):
"""Error during database operations."""
pass
class ConfigurationError(InvoiceExtractionError):
"""Error in application configuration."""
pass
class PaymentLineParseError(InvoiceExtractionError):
"""Error parsing Swedish payment line format."""
pass
class CustomerNumberParseError(InvoiceExtractionError):
"""Error parsing Swedish customer number."""
pass
class DataLoadError(InvoiceExtractionError):
"""Error loading data from CSV or other sources."""
pass
class AnnotationError(InvoiceExtractionError):
"""Error generating or processing YOLO annotations."""
pass

View File

@@ -0,0 +1,4 @@
from .field_matcher import FieldMatcher, find_field_matches
from .models import Match, TokenLike
__all__ = ['FieldMatcher', 'Match', 'TokenLike', 'find_field_matches']

View File

@@ -0,0 +1,92 @@
"""
Context keywords for field matching.
"""
from .models import TokenLike
from .token_index import TokenIndex
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
def find_context_keywords(
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str,
context_radius: float,
token_index: TokenIndex | None = None
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
Args:
tokens: List of all tokens
target_token: The token to find context for
field_name: Name of the field
context_radius: Search radius in pixels
token_index: Optional spatial index for efficient lookup
Returns:
Tuple of (found_keywords, boost_score)
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if token_index:
nearby_tokens = token_index.find_nearby(target_token, context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost

View File

@@ -0,0 +1,219 @@
"""
Field Matching Module - Refactored
Matches normalized field values to tokens extracted from documents.
"""
from .models import TokenLike, Match
from .token_index import TokenIndex
from .utils import bbox_overlap
from .strategies import (
ExactMatcher,
ConcatenatedMatcher,
SubstringMatcher,
FuzzyMatcher,
FlexibleDateMatcher,
)
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
# Initialize matching strategies
self.exact_matcher = ExactMatcher(context_radius)
self.concatenated_matcher = ConcatenatedMatcher(context_radius)
self.substring_matcher = SubstringMatcher(context_radius)
self.fuzzy_matcher = FuzzyMatcher(context_radius)
self.flexible_date_matcher = FlexibleDateMatcher(context_radius)
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
# Filter tokens by page and exclude hidden metadata tokens
# Hidden tokens often have bbox with y < 0 or y > page_height
# These are typically PDF metadata stored as invisible text
page_tokens = [
t for t in tokens
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self.exact_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self.concatenated_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self.fuzzy_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in (
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
'Bankgiro', 'Plusgiro', 'supplier_organisation_number',
'supplier_accounts', 'customer_number'
):
substring_matches = self.substring_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
for value in normalized_values:
flexible_matches = self.flexible_date_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

View File

@@ -0,0 +1,875 @@
"""
Field Matching Module
Matches normalized field values to tokens extracted from documents.
"""
from dataclasses import dataclass, field
from typing import Protocol
import re
from functools import cached_property
# Pre-compiled regex patterns (module-level for efficiency)
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
# Filter tokens by page and exclude hidden metadata tokens
# Hidden tokens often have bbox with y < 0 or y > page_height
# These are typically PDF metadata stored as invisible text
page_tokens = [
t for t in tokens
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
flexible_matches = self._find_flexible_date_matches(
page_tokens, normalized_values, field_name
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches
def _find_concatenated_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not self._tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches
def _find_substring_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches
def _find_fuzzy_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find approximate matches for amounts and dates."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = self._parse_amount(token_text)
value_num = self._parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches
def _find_flexible_date_matches(
self,
tokens: list[TokenLike],
normalized_values: list[str],
field_name: str
) -> list[Match]:
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
from datetime import datetime, timedelta
matches = []
# Parse the target date from normalized values
target_date = None
for value in normalized_values:
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
break
except ValueError:
continue
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in _DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= self.min_score_threshold:
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []
def _find_context_keywords(
self,
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def _parse_amount(self, text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
def _bbox_overlap(
self,
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

View File

@@ -0,0 +1,36 @@
"""
Data models for field matching.
"""
from dataclasses import dataclass
from typing import Protocol
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"

View File

@@ -0,0 +1,17 @@
"""
Matching strategies for field matching.
"""
from .exact_matcher import ExactMatcher
from .concatenated_matcher import ConcatenatedMatcher
from .substring_matcher import SubstringMatcher
from .fuzzy_matcher import FuzzyMatcher
from .flexible_date_matcher import FlexibleDateMatcher
__all__ = [
'ExactMatcher',
'ConcatenatedMatcher',
'SubstringMatcher',
'FuzzyMatcher',
'FlexibleDateMatcher',
]

View File

@@ -0,0 +1,42 @@
"""
Base class for matching strategies.
"""
from abc import ABC, abstractmethod
from ..models import TokenLike, Match
from ..token_index import TokenIndex
class BaseMatchStrategy(ABC):
"""Base class for all matching strategies."""
def __init__(self, context_radius: float = 200.0):
"""
Initialize the strategy.
Args:
context_radius: Distance to search for context keywords
"""
self.context_radius = context_radius
@abstractmethod
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""
Find matches for the given value.
Args:
tokens: List of tokens to search
value: Value to find
field_name: Name of the field
token_index: Optional spatial index for efficient lookup
Returns:
List of Match objects
"""
pass

View File

@@ -0,0 +1,73 @@
"""
Concatenated match strategy - finds value by concatenating adjacent tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import WHITESPACE_PATTERN, tokens_on_same_line
class ConcatenatedMatcher(BaseMatchStrategy):
"""Find value by concatenating adjacent tokens."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find concatenated matches."""
matches = []
value_clean = WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = find_context_keywords(
tokens, start_token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches

View File

@@ -0,0 +1,65 @@
"""
Exact match strategy.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import NON_DIGIT_PATTERN
class ExactMatcher(BaseMatchStrategy):
"""Find tokens that exactly match the value."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find exact matches."""
matches = []
value_lower = value.lower()
value_digits = NON_DIGIT_PATTERN.sub('', value) if field_name in (
'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'
) else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif token_index and token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches

View File

@@ -0,0 +1,149 @@
"""
Flexible date match strategy - finds dates with year-month or nearby date matching.
"""
import re
from datetime import datetime
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import DATE_PATTERN
class FlexibleDateMatcher(BaseMatchStrategy):
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find flexible date matches."""
matches = []
# Parse the target date from normalized values
target_date = None
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
except ValueError:
pass
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= 0.5: # Min threshold for flexible matching
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []

View File

@@ -0,0 +1,52 @@
"""
Fuzzy match strategy for amounts and dates.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import parse_amount
class FuzzyMatcher(BaseMatchStrategy):
"""Find approximate matches for amounts and dates."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find fuzzy matches."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = parse_amount(token_text)
value_num = parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches

View File

@@ -0,0 +1,143 @@
"""
Substring match strategy - finds value as substring within longer tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import normalize_dashes
class SubstringMatcher(BaseMatchStrategy):
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find substring matches."""
matches = []
# Supported fields for substring matching
supported_fields = (
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'
)
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = (
'supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts'
)
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches

View File

@@ -0,0 +1,92 @@
"""
Spatial index for fast token lookup.
"""
from .models import TokenLike
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby

View File

@@ -0,0 +1,91 @@
"""
Utility functions for field matching.
"""
import re
# Pre-compiled regex patterns (module-level for efficiency)
DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
WHITESPACE_PATTERN = re.compile(r'\s+')
NON_DIGIT_PATTERN = re.compile(r'\D')
DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return DASH_PATTERN.sub('-', text)
def parse_amount(text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def tokens_on_same_line(token1, token2) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def bbox_overlap(
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0

View File

@@ -0,0 +1,3 @@
from .normalizer import normalize_field, FieldNormalizer
__all__ = ['normalize_field', 'FieldNormalizer']

View File

@@ -0,0 +1,186 @@
"""
Field Normalization Module
Normalizes field values to generate multiple candidate forms for matching.
This module now delegates to individual normalizer modules for each field type.
Each normalizer is a separate, reusable module that can be used independently.
"""
from dataclasses import dataclass
from typing import Callable
from shared.utils.text_cleaner import TextCleaner
# Import individual normalizers
from .normalizers import (
InvoiceNumberNormalizer,
OCRNormalizer,
BankgiroNormalizer,
PlusgiroNormalizer,
AmountNormalizer,
DateNormalizer,
OrganisationNumberNormalizer,
SupplierAccountsNormalizer,
CustomerNumberNormalizer,
)
@dataclass
class NormalizedValue:
"""Represents a normalized value with its variants."""
original: str
variants: list[str]
field_type: str
class FieldNormalizer:
"""
Handles normalization of different invoice field types.
This class now acts as a facade that delegates to individual
normalizer modules. Each field type has its own specialized
normalizer for better modularity and reusability.
"""
# Instantiate individual normalizers
_invoice_number = InvoiceNumberNormalizer()
_ocr_number = OCRNormalizer()
_bankgiro = BankgiroNormalizer()
_plusgiro = PlusgiroNormalizer()
_amount = AmountNormalizer()
_date = DateNormalizer()
_organisation_number = OrganisationNumberNormalizer()
_supplier_accounts = SupplierAccountsNormalizer()
_customer_number = CustomerNumberNormalizer()
# Common Swedish month names for backward compatibility
SWEDISH_MONTHS = DateNormalizer.SWEDISH_MONTHS
@staticmethod
def clean_text(text: str) -> str:
"""
Remove invisible characters and normalize whitespace and dashes.
Delegates to shared TextCleaner for consistency.
"""
return TextCleaner.clean_text(text)
@staticmethod
def normalize_invoice_number(value: str) -> list[str]:
"""
Normalize invoice number.
Delegates to InvoiceNumberNormalizer.
"""
return FieldNormalizer._invoice_number.normalize(value)
@staticmethod
def normalize_ocr_number(value: str) -> list[str]:
"""
Normalize OCR number (Swedish payment reference).
Delegates to OCRNormalizer.
"""
return FieldNormalizer._ocr_number.normalize(value)
@staticmethod
def normalize_bankgiro(value: str) -> list[str]:
"""
Normalize Bankgiro number.
Delegates to BankgiroNormalizer.
"""
return FieldNormalizer._bankgiro.normalize(value)
@staticmethod
def normalize_plusgiro(value: str) -> list[str]:
"""
Normalize Plusgiro number.
Delegates to PlusgiroNormalizer.
"""
return FieldNormalizer._plusgiro.normalize(value)
@staticmethod
def normalize_organisation_number(value: str) -> list[str]:
"""
Normalize Swedish organisation number and generate VAT number variants.
Delegates to OrganisationNumberNormalizer.
"""
return FieldNormalizer._organisation_number.normalize(value)
@staticmethod
def normalize_supplier_accounts(value: str) -> list[str]:
"""
Normalize supplier accounts field.
Delegates to SupplierAccountsNormalizer.
"""
return FieldNormalizer._supplier_accounts.normalize(value)
@staticmethod
def normalize_customer_number(value: str) -> list[str]:
"""
Normalize customer number.
Delegates to CustomerNumberNormalizer.
"""
return FieldNormalizer._customer_number.normalize(value)
@staticmethod
def normalize_amount(value: str) -> list[str]:
"""
Normalize monetary amount.
Delegates to AmountNormalizer.
"""
return FieldNormalizer._amount.normalize(value)
@staticmethod
def normalize_date(value: str) -> list[str]:
"""
Normalize date to YYYY-MM-DD and generate variants.
Delegates to DateNormalizer.
"""
return FieldNormalizer._date.normalize(value)
# Field type to normalizer mapping
NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'InvoiceNumber': FieldNormalizer.normalize_invoice_number,
'OCR': FieldNormalizer.normalize_ocr_number,
'Bankgiro': FieldNormalizer.normalize_bankgiro,
'Plusgiro': FieldNormalizer.normalize_plusgiro,
'Amount': FieldNormalizer.normalize_amount,
'InvoiceDate': FieldNormalizer.normalize_date,
'InvoiceDueDate': FieldNormalizer.normalize_date,
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
'customer_number': FieldNormalizer.normalize_customer_number,
}
def normalize_field(field_name: str, value: str) -> list[str]:
"""
Normalize a field value based on its type.
Args:
field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount')
value: Raw value to normalize
Returns:
List of normalized variants
"""
if value is None or (isinstance(value, str) and not value.strip()):
return []
value = str(value)
normalizer = NORMALIZERS.get(field_name)
if normalizer:
return normalizer(value)
# Default: just clean the text
return [FieldNormalizer.clean_text(value)]

Some files were not shown because too many files have changed in this diff Show More