re-structure
This commit is contained in:
25
packages/backend/Dockerfile
Normal file
25
packages/backend/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install shared package
|
||||
COPY packages/shared /app/packages/shared
|
||||
RUN pip install --no-cache-dir -e /app/packages/shared
|
||||
|
||||
# Install inference package
|
||||
COPY packages/inference /app/packages/inference
|
||||
RUN pip install --no-cache-dir -e /app/packages/inference
|
||||
|
||||
# Copy frontend (if needed)
|
||||
COPY frontend /app/frontend
|
||||
|
||||
WORKDIR /app/packages/inference
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]
|
||||
0
packages/backend/backend/__init__.py
Normal file
0
packages/backend/backend/__init__.py
Normal file
0
packages/backend/backend/azure/__init__.py
Normal file
0
packages/backend/backend/azure/__init__.py
Normal file
105
packages/backend/backend/azure/aci_trigger.py
Normal file
105
packages/backend/backend/azure/aci_trigger.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Trigger training jobs on Azure Container Instances."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Azure SDK is optional; only needed if using ACI trigger
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
|
||||
from azure.mgmt.containerinstance.models import (
|
||||
Container,
|
||||
ContainerGroup,
|
||||
EnvironmentVariable,
|
||||
GpuResource,
|
||||
ResourceRequests,
|
||||
ResourceRequirements,
|
||||
)
|
||||
|
||||
_AZURE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AZURE_SDK_AVAILABLE = False
|
||||
|
||||
|
||||
def start_training_container(task_id: str) -> str | None:
|
||||
"""
|
||||
Start an Azure Container Instance for a training task.
|
||||
|
||||
Returns the container group name if successful, None otherwise.
|
||||
Requires environment variables:
|
||||
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
|
||||
"""
|
||||
if not _AZURE_SDK_AVAILABLE:
|
||||
logger.warning(
|
||||
"Azure SDK not installed. Install azure-mgmt-containerinstance "
|
||||
"and azure-identity to use ACI trigger."
|
||||
)
|
||||
return None
|
||||
|
||||
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
|
||||
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
|
||||
image = os.environ.get(
|
||||
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
|
||||
)
|
||||
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
|
||||
location = os.environ.get("AZURE_LOCATION", "eastus")
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
|
||||
return None
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
client = ContainerInstanceManagementClient(credential, subscription_id)
|
||||
|
||||
container_name = f"training-{task_id[:8]}"
|
||||
|
||||
env_vars = [
|
||||
EnvironmentVariable(name="TASK_ID", value=task_id),
|
||||
]
|
||||
|
||||
# Pass DB connection securely
|
||||
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
|
||||
val = os.environ.get(var, "")
|
||||
if val:
|
||||
env_vars.append(EnvironmentVariable(name=var, value=val))
|
||||
|
||||
db_password = os.environ.get("DB_PASSWORD", "")
|
||||
if db_password:
|
||||
env_vars.append(
|
||||
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
|
||||
)
|
||||
|
||||
container = Container(
|
||||
name=container_name,
|
||||
image=image,
|
||||
resources=ResourceRequirements(
|
||||
requests=ResourceRequests(
|
||||
cpu=4,
|
||||
memory_in_gb=16,
|
||||
gpu=GpuResource(count=1, sku=gpu_sku),
|
||||
)
|
||||
),
|
||||
environment_variables=env_vars,
|
||||
command=[
|
||||
"python",
|
||||
"run_training.py",
|
||||
"--task-id",
|
||||
task_id,
|
||||
],
|
||||
)
|
||||
|
||||
group = ContainerGroup(
|
||||
location=location,
|
||||
containers=[container],
|
||||
os_type="Linux",
|
||||
restart_policy="Never",
|
||||
)
|
||||
|
||||
logger.info("Creating ACI container group: %s", container_name)
|
||||
client.container_groups.begin_create_or_update(
|
||||
resource_group, container_name, group
|
||||
)
|
||||
|
||||
return container_name
|
||||
0
packages/backend/backend/cli/__init__.py
Normal file
0
packages/backend/backend/cli/__init__.py
Normal file
141
packages/backend/backend/cli/infer.py
Normal file
141
packages/backend/backend/cli/infer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Inference CLI
|
||||
|
||||
Runs inference on new PDFs to extract invoice data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract invoice data from PDFs using trained model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
required=True,
|
||||
help='Path to trained YOLO model (.pt file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input', '-i',
|
||||
required=True,
|
||||
help='Input PDF file or directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
help='Output JSON file (default: stdout)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--confidence',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Detection confidence threshold (default: 0.5)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-fallback',
|
||||
action='store_true',
|
||||
help='Disable fallback OCR'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lang',
|
||||
default='en',
|
||||
help='OCR language (default: en)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu',
|
||||
action='store_true',
|
||||
help='Use GPU'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
pdf_files = [input_path]
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
|
||||
from backend.pipeline import InferencePipeline
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = InferencePipeline(
|
||||
model_path=model_path,
|
||||
confidence_threshold=args.confidence,
|
||||
ocr_lang=args.lang,
|
||||
use_gpu=args.gpu,
|
||||
dpi=args.dpi,
|
||||
enable_fallback=not args.no_fallback
|
||||
)
|
||||
|
||||
# Process files
|
||||
results = []
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
output = results[0]
|
||||
else:
|
||||
output = results
|
||||
|
||||
json_output = json.dumps(output, indent=2, ensure_ascii=False)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
159
packages/backend/backend/cli/serve.py
Normal file
159
packages/backend/backend/cli/serve.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Web Server CLI
|
||||
|
||||
Command-line interface for starting the web server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Start the Invoice Field Extraction web server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to listen on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=Path,
|
||||
default=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
help="Path to YOLO model weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Detection confidence threshold",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-gpu",
|
||||
action="store_true",
|
||||
help="Disable GPU acceleration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
setup_logging(debug=args.debug)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validate model path
|
||||
if not args.model.exists():
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Invoice Field Extraction Web Server")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"Confidence threshold: {args.confidence}")
|
||||
logger.info(f"GPU enabled: {not args.no_gpu}")
|
||||
logger.info(f"Server: http://{args.host}:{args.port}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from backend.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
use_gpu=not args.no_gpu,
|
||||
dpi=args.dpi,
|
||||
),
|
||||
server=ServerConfig(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
file=FileConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
import uvicorn
|
||||
from backend.web.app import create_app
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
reload=config.server.reload,
|
||||
workers=config.server.workers if not config.server.reload else 1,
|
||||
log_level="debug" if config.server.debug else "info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
packages/backend/backend/data/__init__.py
Normal file
0
packages/backend/backend/data/__init__.py
Normal file
437
packages/backend/backend/data/admin_models.py
Normal file
437
packages/backend/backend/data/admin_models.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
Admin API SQLModel Database Models
|
||||
|
||||
Defines the database schema for admin document management, annotations, and training tasks.
|
||||
Includes batch upload support, training document links, and annotation history.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AdminToken(SQLModel, table=True):
|
||||
"""Admin authentication token."""
|
||||
|
||||
__tablename__ = "admin_tokens"
|
||||
|
||||
token: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AdminDocument(SQLModel, table=True):
|
||||
"""Document uploaded for labeling/annotation."""
|
||||
|
||||
__tablename__ = "admin_documents"
|
||||
|
||||
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
file_path: str = Field(max_length=512) # Path to stored file
|
||||
page_count: int = Field(default=1)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, auto_labeling, labeled, exported
|
||||
auto_label_status: str | None = Field(default=None, max_length=20)
|
||||
# Auto-label status: running, completed, failed
|
||||
auto_label_error: str | None = Field(default=None)
|
||||
# v2: Upload source tracking
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
batch_id: UUID | None = Field(default=None, index=True)
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||
# User-defined grouping key for document organization
|
||||
category: str = Field(default="invoice", max_length=100, index=True)
|
||||
# Document category for training different models (e.g., invoice, letter, receipt)
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
# When auto-label was queued
|
||||
annotation_lock_until: datetime | None = Field(default=None)
|
||||
# Lock for manual annotation while auto-label runs
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class AdminAnnotation(SQLModel, table=True):
|
||||
"""Annotation for a document (bounding box + label)."""
|
||||
|
||||
__tablename__ = "admin_annotations"
|
||||
|
||||
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
page_number: int = Field(default=1) # 1-indexed
|
||||
class_id: int # 0-9 for invoice fields
|
||||
class_name: str = Field(max_length=50) # e.g., "invoice_number"
|
||||
# Bounding box (normalized 0-1 coordinates)
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
# Original pixel coordinates (for display)
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
# OCR extracted text (if available)
|
||||
text_value: str | None = Field(default=None)
|
||||
confidence: float | None = Field(default=None)
|
||||
# Source: manual, auto, imported
|
||||
source: str = Field(default="manual", max_length=20, index=True)
|
||||
# v2: Verification fields
|
||||
is_verified: bool = Field(default=False, index=True)
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
verified_by: str | None = Field(default=None, max_length=255)
|
||||
# v2: Override tracking
|
||||
override_source: str | None = Field(default=None, max_length=20)
|
||||
# If this annotation overrides another: 'auto' or 'imported'
|
||||
original_annotation_id: UUID | None = Field(default=None)
|
||||
# Reference to the annotation this overrides
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingTask(SQLModel, table=True):
|
||||
"""Training/fine-tuning task."""
|
||||
|
||||
__tablename__ = "training_tasks"
|
||||
|
||||
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, scheduled, running, completed, failed, cancelled
|
||||
task_type: str = Field(default="train", max_length=20)
|
||||
# Task type: train, finetune
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
# Training configuration
|
||||
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Schedule settings
|
||||
scheduled_at: datetime | None = Field(default=None)
|
||||
cron_expression: str | None = Field(default=None, max_length=50)
|
||||
is_recurring: bool = Field(default=False)
|
||||
# Execution details
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
# Result metrics
|
||||
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
model_path: str | None = Field(default=None, max_length=512)
|
||||
# v2: Document count and extracted metrics
|
||||
document_count: int = Field(default=0)
|
||||
# Count of documents used in training
|
||||
metrics_mAP: float | None = Field(default=None, index=True)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
# Extracted metrics for easy querying
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingLog(SQLModel, table=True):
|
||||
"""Training log entry."""
|
||||
|
||||
__tablename__ = "training_logs"
|
||||
|
||||
log_id: int | None = Field(default=None, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
level: str = Field(max_length=20) # INFO, WARNING, ERROR
|
||||
message: str
|
||||
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Batch Upload Models (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BatchUpload(SQLModel, table=True):
|
||||
"""Batch upload of multiple documents via ZIP file."""
|
||||
|
||||
__tablename__ = "batch_uploads"
|
||||
|
||||
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255) # ZIP filename
|
||||
file_size: int
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
status: str = Field(default="processing", max_length=20, index=True)
|
||||
# Status: processing, completed, partial, failed
|
||||
total_files: int = Field(default=0)
|
||||
processed_files: int = Field(default=0)
|
||||
# Number of files processed so far
|
||||
successful_files: int = Field(default=0)
|
||||
failed_files: int = Field(default=0)
|
||||
csv_filename: str | None = Field(default=None, max_length=255)
|
||||
# CSV file used for auto-labeling
|
||||
csv_row_count: int | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class BatchUploadFile(SQLModel, table=True):
|
||||
"""Individual file within a batch upload."""
|
||||
|
||||
__tablename__ = "batch_upload_files"
|
||||
|
||||
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
|
||||
filename: str = Field(max_length=255) # PDF filename within ZIP
|
||||
document_id: UUID | None = Field(default=None)
|
||||
# Link to created AdminDocument (if successful)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, processing, completed, failed, skipped
|
||||
error_message: str | None = Field(default=None)
|
||||
annotation_count: int = Field(default=0)
|
||||
# Number of annotations created for this file
|
||||
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# CSV row data for this file (if available)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
processed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Training Document Link (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingDataset(SQLModel, table=True):
|
||||
"""Training dataset containing selected documents with train/val/test splits."""
|
||||
|
||||
__tablename__ = "training_datasets"
|
||||
|
||||
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="building", max_length=20, index=True)
|
||||
# Status: building, ready, trained, archived, failed
|
||||
training_status: str | None = Field(default=None, max_length=20, index=True)
|
||||
# Training status: pending, scheduled, running, completed, failed, cancelled
|
||||
active_training_task_id: UUID | None = Field(default=None, index=True)
|
||||
train_ratio: float = Field(default=0.8)
|
||||
val_ratio: float = Field(default=0.1)
|
||||
seed: int = Field(default=42)
|
||||
total_documents: int = Field(default=0)
|
||||
total_images: int = Field(default=0)
|
||||
total_annotations: int = Field(default=0)
|
||||
dataset_path: str | None = Field(default=None, max_length=512)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class DatasetDocument(SQLModel, table=True):
|
||||
"""Junction table linking datasets to documents with split assignment."""
|
||||
|
||||
__tablename__ = "dataset_documents"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
split: str = Field(max_length=10) # train, val, test
|
||||
page_count: int = Field(default=0)
|
||||
annotation_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingDocumentLink(SQLModel, table=True):
|
||||
"""Junction table linking training tasks to documents."""
|
||||
|
||||
__tablename__ = "training_document_links"
|
||||
|
||||
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of annotations at training time (includes count, verified count, etc.)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Version Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ModelVersion(SQLModel, table=True):
|
||||
"""Model version for inference deployment."""
|
||||
|
||||
__tablename__ = "model_versions"
|
||||
|
||||
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
version: str = Field(max_length=50, index=True)
|
||||
# Semantic version e.g., "1.0.0", "2.1.0"
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
model_path: str = Field(max_length=512)
|
||||
# Path to the model weights file
|
||||
status: str = Field(default="inactive", max_length=20, index=True)
|
||||
# Status: active, inactive, archived
|
||||
is_active: bool = Field(default=False, index=True)
|
||||
# Only one version can be active at a time for inference
|
||||
|
||||
# Training association
|
||||
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
|
||||
# Training metrics
|
||||
metrics_mAP: float | None = Field(default=None)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
document_count: int = Field(default=0)
|
||||
# Number of documents used in training
|
||||
|
||||
# Training configuration snapshot
|
||||
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of epochs, batch_size, etc.
|
||||
|
||||
# File info
|
||||
file_size: int | None = Field(default=None)
|
||||
# Model file size in bytes
|
||||
|
||||
# Timestamps
|
||||
trained_at: datetime | None = Field(default=None)
|
||||
# When training completed
|
||||
activated_at: datetime | None = Field(default=None)
|
||||
# When this version was last activated
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AnnotationHistory(SQLModel, table=True):
|
||||
"""History of annotation changes (for override tracking)."""
|
||||
|
||||
__tablename__ = "annotation_history"
|
||||
|
||||
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
# Change action: created, updated, deleted, override
|
||||
action: str = Field(max_length=20, index=True)
|
||||
# Previous value (for updates/deletes)
|
||||
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# New value (for creates/updates)
|
||||
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Change metadata
|
||||
changed_by: str | None = Field(default=None, max_length=255)
|
||||
# User/token who made the change
|
||||
change_reason: str | None = Field(default=None)
|
||||
# Optional reason for change
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
# Read-only models for API responses
|
||||
class AdminDocumentRead(SQLModel):
|
||||
"""Admin document response model."""
|
||||
|
||||
document_id: UUID
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
page_count: int
|
||||
status: str
|
||||
auto_label_status: str | None
|
||||
auto_label_error: str | None
|
||||
category: str = "invoice"
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AdminAnnotationRead(SQLModel):
|
||||
"""Admin annotation response model."""
|
||||
|
||||
annotation_id: UUID
|
||||
document_id: UUID
|
||||
page_number: int
|
||||
class_id: int
|
||||
class_name: str
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
text_value: str | None
|
||||
confidence: float | None
|
||||
source: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingTaskRead(SQLModel):
|
||||
"""Training task response model."""
|
||||
|
||||
task_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
task_type: str
|
||||
config: dict[str, Any] | None
|
||||
scheduled_at: datetime | None
|
||||
is_recurring: bool
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
error_message: str | None
|
||||
result_metrics: dict[str, Any] | None
|
||||
model_path: str | None
|
||||
dataset_id: UUID | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingDatasetRead(SQLModel):
|
||||
"""Training dataset response model."""
|
||||
|
||||
dataset_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetDocumentRead(SQLModel):
|
||||
"""Dataset document response model."""
|
||||
|
||||
id: UUID
|
||||
dataset_id: UUID
|
||||
document_id: UUID
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
374
packages/backend/backend/data/async_request_db.py
Normal file
374
packages/backend/backend/data/async_request_db.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Async Request Database Operations
|
||||
|
||||
Database interface for async invoice processing requests using SQLModel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, text
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from backend.data.database import get_session_context, create_db_and_tables, close_engine
|
||||
from backend.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Legacy dataclasses for backward compatibility
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiKeyConfig:
|
||||
"""API key configuration and limits (legacy compatibility)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestDB:
|
||||
"""Database interface for async processing requests using SQLModel."""
|
||||
|
||||
def __init__(self, connection_string: str | None = None) -> None:
|
||||
# connection_string is kept for backward compatibility but ignored
|
||||
# SQLModel uses the global engine from database.py
|
||||
self._initialized = False
|
||||
|
||||
def connect(self):
|
||||
"""Legacy method - returns self for compatibility."""
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connections."""
|
||||
close_engine()
|
||||
|
||||
def __enter__(self) -> "AsyncRequestDB":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass # Sessions are managed per-operation
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create async processing tables if they don't exist."""
|
||||
create_db_and_tables()
|
||||
self._initialized = True
|
||||
|
||||
# ==========================================================================
|
||||
# API Key Operations
|
||||
# ==========================================================================
|
||||
|
||||
def is_valid_api_key(self, api_key: str) -> bool:
|
||||
"""Check if API key exists and is active."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
return result is not None and result.is_active is True
|
||||
|
||||
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
|
||||
"""Get API key configuration and limits."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
if result is None:
|
||||
return None
|
||||
return ApiKeyConfig(
|
||||
api_key=result.api_key,
|
||||
name=result.name,
|
||||
is_active=result.is_active,
|
||||
requests_per_minute=result.requests_per_minute,
|
||||
max_concurrent_jobs=result.max_concurrent_jobs,
|
||||
max_file_size_mb=result.max_file_size_mb,
|
||||
)
|
||||
|
||||
def create_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
name: str,
|
||||
requests_per_minute: int = 10,
|
||||
max_concurrent_jobs: int = 3,
|
||||
max_file_size_mb: int = 50,
|
||||
) -> None:
|
||||
"""Create a new API key."""
|
||||
with get_session_context() as session:
|
||||
existing = session.get(ApiKey, api_key)
|
||||
if existing:
|
||||
existing.name = name
|
||||
existing.requests_per_minute = requests_per_minute
|
||||
existing.max_concurrent_jobs = max_concurrent_jobs
|
||||
existing.max_file_size_mb = max_file_size_mb
|
||||
session.add(existing)
|
||||
else:
|
||||
new_key = ApiKey(
|
||||
api_key=api_key,
|
||||
name=name,
|
||||
requests_per_minute=requests_per_minute,
|
||||
max_concurrent_jobs=max_concurrent_jobs,
|
||||
max_file_size_mb=max_file_size_mb,
|
||||
)
|
||||
session.add(new_key)
|
||||
|
||||
def update_api_key_usage(self, api_key: str) -> None:
|
||||
"""Update API key last used timestamp and increment total requests."""
|
||||
with get_session_context() as session:
|
||||
key = session.get(ApiKey, api_key)
|
||||
if key:
|
||||
key.last_used_at = datetime.utcnow()
|
||||
key.total_requests += 1
|
||||
session.add(key)
|
||||
|
||||
# ==========================================================================
|
||||
# Async Request Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
api_key: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
expires_at: datetime,
|
||||
request_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new async request."""
|
||||
with get_session_context() as session:
|
||||
request = AsyncRequest(
|
||||
api_key=api_key,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
if request_id:
|
||||
request.request_id = UUID(request_id)
|
||||
session.add(request)
|
||||
session.flush() # To get the generated ID
|
||||
return str(request.request_id)
|
||||
|
||||
def get_request(self, request_id: str) -> AsyncRequest | None:
|
||||
"""Get a single async request by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AsyncRequest, UUID(request_id))
|
||||
if result:
|
||||
# Detach from session for use outside context
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_request_by_api_key(
|
||||
self,
|
||||
request_id: str,
|
||||
api_key: str,
|
||||
) -> AsyncRequest | None:
|
||||
"""Get a request only if it belongs to the given API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.request_id == UUID(request_id),
|
||||
AsyncRequest.api_key == api_key,
|
||||
)
|
||||
result = session.exec(statement).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
increment_retry: bool = False,
|
||||
) -> None:
|
||||
"""Update request status."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = status
|
||||
if status == "processing":
|
||||
request.started_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
request.error_message = error_message
|
||||
if increment_retry:
|
||||
request.retry_count += 1
|
||||
session.add(request)
|
||||
|
||||
def complete_request(
|
||||
self,
|
||||
request_id: str,
|
||||
document_id: str,
|
||||
result: dict[str, Any],
|
||||
processing_time_ms: float,
|
||||
visualization_path: str | None = None,
|
||||
) -> None:
|
||||
"""Mark request as completed with result."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = "completed"
|
||||
request.document_id = document_id
|
||||
request.result = result
|
||||
request.processing_time_ms = processing_time_ms
|
||||
request.visualization_path = visualization_path
|
||||
request.completed_at = datetime.utcnow()
|
||||
session.add(request)
|
||||
|
||||
def get_requests_by_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AsyncRequest], int]:
|
||||
"""Get paginated requests for an API key."""
|
||||
with get_session_context() as session:
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(AsyncRequest.status == status)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Fetch query
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
statement = statement.where(AsyncRequest.status == status)
|
||||
statement = statement.order_by(AsyncRequest.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
# Detach results from session
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def count_active_jobs(self, api_key: str) -> int:
|
||||
"""Count active (pending + processing) jobs for an API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key,
|
||||
AsyncRequest.status.in_(["pending", "processing"]),
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
|
||||
"""Get pending requests ordered by creation time."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending"
|
||||
).order_by(AsyncRequest.created_at).limit(limit)
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_queue_position(self, request_id: str) -> int | None:
|
||||
"""Get position of a request in the pending queue."""
|
||||
with get_session_context() as session:
|
||||
# Get the request's created_at
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if not request:
|
||||
return None
|
||||
|
||||
# Count pending requests created before this one
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending",
|
||||
AsyncRequest.created_at < request.created_at,
|
||||
)
|
||||
count = session.exec(statement).one()
|
||||
return count + 1 # 1-based position
|
||||
|
||||
# ==========================================================================
|
||||
# Rate Limit Operations
|
||||
# ==========================================================================
|
||||
|
||||
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
|
||||
"""Record a rate limit event."""
|
||||
with get_session_context() as session:
|
||||
event = RateLimitEvent(
|
||||
api_key=api_key,
|
||||
event_type=event_type,
|
||||
)
|
||||
session.add(event)
|
||||
|
||||
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
|
||||
"""Count requests in the last N seconds."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
|
||||
statement = select(func.count()).select_from(RateLimitEvent).where(
|
||||
RateLimitEvent.api_key == api_key,
|
||||
RateLimitEvent.event_type == "request",
|
||||
RateLimitEvent.created_at > cutoff,
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
# ==========================================================================
|
||||
# Cleanup Operations
|
||||
# ==========================================================================
|
||||
|
||||
def delete_expired_requests(self) -> int:
|
||||
"""Delete requests that have expired. Returns count of deleted rows."""
|
||||
with get_session_context() as session:
|
||||
now = datetime.utcnow()
|
||||
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
|
||||
expired = session.exec(statement).all()
|
||||
count = len(expired)
|
||||
for request in expired:
|
||||
session.delete(request)
|
||||
logger.info(f"Deleted {count} expired async requests")
|
||||
return count
|
||||
|
||||
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
|
||||
"""Delete rate limit events older than N hours."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
statement = select(RateLimitEvent).where(
|
||||
RateLimitEvent.created_at < cutoff
|
||||
)
|
||||
old_events = session.exec(statement).all()
|
||||
count = len(old_events)
|
||||
for event in old_events:
|
||||
session.delete(event)
|
||||
return count
|
||||
|
||||
def reset_stale_processing_requests(
|
||||
self,
|
||||
stale_minutes: int = 10,
|
||||
max_retries: int = 3,
|
||||
) -> int:
|
||||
"""
|
||||
Reset requests stuck in 'processing' status.
|
||||
|
||||
Requests that have been processing for more than stale_minutes
|
||||
are considered stale. They are either reset to 'pending' (if under
|
||||
max_retries) or set to 'failed'.
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
|
||||
reset_count = 0
|
||||
|
||||
# Find stale processing requests
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "processing",
|
||||
AsyncRequest.started_at < cutoff,
|
||||
)
|
||||
stale_requests = session.exec(statement).all()
|
||||
|
||||
for request in stale_requests:
|
||||
if request.retry_count < max_retries:
|
||||
request.status = "pending"
|
||||
request.started_at = None
|
||||
else:
|
||||
request.status = "failed"
|
||||
request.error_message = "Processing timeout after max retries"
|
||||
session.add(request)
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.warning(f"Reset {reset_count} stale processing requests")
|
||||
return reset_count
|
||||
318
packages/backend/backend/data/database.py
Normal file
318
packages/backend/backend/data/database.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Database Engine and Session Management
|
||||
|
||||
Provides SQLModel database engine and session handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
import sys
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global engine instance
|
||||
_engine = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Get or create the database engine."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
connection_string = get_db_connection_string()
|
||||
# Convert psycopg2 format to SQLAlchemy format
|
||||
if connection_string.startswith("postgresql://"):
|
||||
# Already in correct format
|
||||
pass
|
||||
elif "host=" in connection_string:
|
||||
# Convert DSN format to URL format
|
||||
parts = dict(item.split("=") for item in connection_string.split())
|
||||
connection_string = (
|
||||
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
|
||||
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
|
||||
f"/{parts.get('dbname', 'docmaster')}"
|
||||
)
|
||||
|
||||
_engine = create_engine(
|
||||
connection_string,
|
||||
echo=False, # Set to True for SQL debugging
|
||||
pool_pre_ping=True, # Verify connections before use
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
"""Run database migrations for new columns."""
|
||||
engine = get_engine()
|
||||
|
||||
migrations = [
|
||||
# Migration 004: Training datasets tables and dataset_id on training_tasks
|
||||
(
|
||||
"training_datasets_tables",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||
seed INTEGER NOT NULL DEFAULT 42,
|
||||
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||
total_images INTEGER NOT NULL DEFAULT 0,
|
||||
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||
dataset_path VARCHAR(512),
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"dataset_documents_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||
split VARCHAR(10) NOT NULL,
|
||||
page_count INTEGER NOT NULL DEFAULT 0,
|
||||
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(dataset_id, document_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_dataset_id",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 005: Add group_key to admin_documents
|
||||
(
|
||||
"admin_documents_group_key",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||
""",
|
||||
),
|
||||
# Migration 006: Model versions table
|
||||
(
|
||||
"model_versions_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS model_versions (
|
||||
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
version VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
model_path VARCHAR(512) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
task_id UUID REFERENCES training_tasks(task_id),
|
||||
dataset_id UUID REFERENCES training_datasets(dataset_id),
|
||||
metrics_mAP FLOAT,
|
||||
metrics_precision FLOAT,
|
||||
metrics_recall FLOAT,
|
||||
document_count INTEGER NOT NULL DEFAULT 0,
|
||||
training_config JSONB,
|
||||
file_size BIGINT,
|
||||
trained_at TIMESTAMP WITH TIME ZONE,
|
||||
activated_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 009: Add category to admin_documents
|
||||
(
|
||||
"admin_documents_category",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||
""",
|
||||
),
|
||||
# Migration 010: Add training_status and active_training_task_id to training_datasets
|
||||
(
|
||||
"training_datasets_training_status",
|
||||
"""
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
|
||||
""",
|
||||
),
|
||||
# Migration 010b: Update existing datasets with completed training to 'trained' status
|
||||
(
|
||||
"training_datasets_update_trained_status",
|
||||
"""
|
||||
UPDATE training_datasets d
|
||||
SET status = 'trained'
|
||||
WHERE d.status = 'ready'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM training_tasks t
|
||||
WHERE t.dataset_id = d.dataset_id
|
||||
AND t.status = 'completed'
|
||||
);
|
||||
""",
|
||||
),
|
||||
# Migration 007: Add extra columns to training_tasks
|
||||
(
|
||||
"training_tasks_name",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
|
||||
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
|
||||
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_description",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_admin_token",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_task_type",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_recurring",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_metrics",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_updated_at",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||
""",
|
||||
),
|
||||
# Migration 008: Fix model_versions foreign key constraints
|
||||
(
|
||||
"model_versions_fk_fix",
|
||||
"""
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_dataset_id_fkey
|
||||
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_task_id_fkey
|
||||
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
|
||||
""",
|
||||
),
|
||||
# Migration 006b: Ensure only one active model at a time
|
||||
(
|
||||
"model_versions_single_active",
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
|
||||
ON model_versions(is_active) WHERE is_active = TRUE;
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
for name, sql in migrations:
|
||||
try:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
logger.info(f"Migration '{name}' applied successfully")
|
||||
except Exception as e:
|
||||
# Log but don't fail - column may already exist
|
||||
logger.debug(f"Migration '{name}' skipped or failed: {e}")
|
||||
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from backend.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
from backend.data.admin_models import ( # noqa: F401
|
||||
AdminToken,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
TrainingTask,
|
||||
TrainingLog,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
# Run migrations for new columns
|
||||
run_migrations()
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""Get a new database session."""
|
||||
engine = get_engine()
|
||||
return Session(engine)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_context() -> Generator[Session, None, None]:
|
||||
"""Context manager for database sessions with auto-commit/rollback."""
|
||||
session = get_session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def close_engine() -> None:
|
||||
"""Close the database engine and release connections."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
logger.info("Database engine closed")
|
||||
|
||||
|
||||
def execute_raw_sql(sql: str) -> None:
|
||||
"""Execute raw SQL (for migrations)."""
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
95
packages/backend/backend/data/models.py
Normal file
95
packages/backend/backend/data/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
SQLModel Database Models
|
||||
|
||||
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
class ApiKey(SQLModel, table=True):
|
||||
"""API key configuration and limits."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
api_key: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
requests_per_minute: int = Field(default=10)
|
||||
max_concurrent_jobs: int = Field(default=3)
|
||||
max_file_size_mb: int = Field(default=50)
|
||||
total_requests: int = Field(default=0)
|
||||
total_processed: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AsyncRequest(SQLModel, table=True):
|
||||
"""Async request record."""
|
||||
|
||||
__tablename__ = "async_requests"
|
||||
|
||||
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
document_id: str | None = Field(default=None, max_length=100)
|
||||
error_message: str | None = Field(default=None)
|
||||
retry_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field(index=True)
|
||||
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
processing_time_ms: float | None = Field(default=None)
|
||||
visualization_path: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class RateLimitEvent(SQLModel, table=True):
|
||||
"""Rate limit event record."""
|
||||
|
||||
__tablename__ = "rate_limit_events"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
event_type: str = Field(max_length=50)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Read-only models for responses (without table=True)
|
||||
class ApiKeyRead(SQLModel):
|
||||
"""API key response model (read-only)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestRead(SQLModel):
|
||||
"""Async request response model (read-only)."""
|
||||
|
||||
request_id: UUID
|
||||
api_key: str
|
||||
status: str
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
document_id: str | None
|
||||
error_message: str | None
|
||||
retry_count: int
|
||||
created_at: datetime
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
expires_at: datetime
|
||||
result: dict[str, Any] | None
|
||||
processing_time_ms: float | None
|
||||
visualization_path: str | None
|
||||
26
packages/backend/backend/data/repositories/__init__.py
Normal file
26
packages/backend/backend/data/repositories/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Repository Pattern Implementation
|
||||
|
||||
Provides domain-specific repository classes to replace the monolithic AdminDB.
|
||||
Each repository handles a single domain following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
from backend.data.repositories.token_repository import TokenRepository
|
||||
from backend.data.repositories.document_repository import DocumentRepository
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
"TokenRepository",
|
||||
"DocumentRepository",
|
||||
"AnnotationRepository",
|
||||
"TrainingTaskRepository",
|
||||
"DatasetRepository",
|
||||
"ModelVersionRepository",
|
||||
"BatchUploadRepository",
|
||||
]
|
||||
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Annotation Repository
|
||||
|
||||
Handles annotation operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationRepository(BaseRepository[AdminAnnotation]):
|
||||
"""Repository for annotation management.
|
||||
|
||||
Handles:
|
||||
- Annotation CRUD operations
|
||||
- Batch annotation creation
|
||||
- Annotation verification
|
||||
- Annotation override tracking
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
class_id: int,
|
||||
class_name: str,
|
||||
x_center: float,
|
||||
y_center: float,
|
||||
width: float,
|
||||
height: float,
|
||||
bbox_x: int,
|
||||
bbox_y: int,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
text_value: str | None = None,
|
||||
confidence: float | None = None,
|
||||
source: str = "manual",
|
||||
) -> str:
|
||||
"""Create a new annotation.
|
||||
|
||||
Returns:
|
||||
Annotation ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = AdminAnnotation(
|
||||
document_id=UUID(document_id),
|
||||
page_number=page_number,
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=bbox_x,
|
||||
bbox_y=bbox_y,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
text_value=text_value,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
session.add(annotation)
|
||||
session.flush()
|
||||
return str(annotation.annotation_id)
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
annotations: list[dict[str, Any]],
|
||||
) -> list[str]:
|
||||
"""Create multiple annotations in a batch.
|
||||
|
||||
Args:
|
||||
annotations: List of annotation data dicts
|
||||
|
||||
Returns:
|
||||
List of annotation IDs
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
ids = []
|
||||
for ann_data in annotations:
|
||||
annotation = AdminAnnotation(
|
||||
document_id=UUID(ann_data["document_id"]),
|
||||
page_number=ann_data.get("page_number", 1),
|
||||
class_id=ann_data["class_id"],
|
||||
class_name=ann_data["class_name"],
|
||||
x_center=ann_data["x_center"],
|
||||
y_center=ann_data["y_center"],
|
||||
width=ann_data["width"],
|
||||
height=ann_data["height"],
|
||||
bbox_x=ann_data["bbox_x"],
|
||||
bbox_y=ann_data["bbox_y"],
|
||||
bbox_width=ann_data["bbox_width"],
|
||||
bbox_height=ann_data["bbox_height"],
|
||||
text_value=ann_data.get("text_value"),
|
||||
confidence=ann_data.get("confidence"),
|
||||
source=ann_data.get("source", "auto"),
|
||||
)
|
||||
session.add(annotation)
|
||||
session.flush()
|
||||
ids.append(str(annotation.annotation_id))
|
||||
return ids
|
||||
|
||||
def get(self, annotation_id: str) -> AdminAnnotation | None:
|
||||
"""Get an annotation by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_for_document(
|
||||
self,
|
||||
document_id: str,
|
||||
page_number: int | None = None,
|
||||
) -> list[AdminAnnotation]:
|
||||
"""Get all annotations for a document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
if page_number is not None:
|
||||
statement = statement.where(AdminAnnotation.page_number == page_number)
|
||||
statement = statement.order_by(AdminAnnotation.class_id)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def update(
|
||||
self,
|
||||
annotation_id: str,
|
||||
x_center: float | None = None,
|
||||
y_center: float | None = None,
|
||||
width: float | None = None,
|
||||
height: float | None = None,
|
||||
bbox_x: int | None = None,
|
||||
bbox_y: int | None = None,
|
||||
bbox_width: int | None = None,
|
||||
bbox_height: int | None = None,
|
||||
text_value: str | None = None,
|
||||
class_id: int | None = None,
|
||||
class_name: str | None = None,
|
||||
) -> bool:
|
||||
"""Update an annotation.
|
||||
|
||||
Returns:
|
||||
True if updated, False if not found
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if annotation:
|
||||
if x_center is not None:
|
||||
annotation.x_center = x_center
|
||||
if y_center is not None:
|
||||
annotation.y_center = y_center
|
||||
if width is not None:
|
||||
annotation.width = width
|
||||
if height is not None:
|
||||
annotation.height = height
|
||||
if bbox_x is not None:
|
||||
annotation.bbox_x = bbox_x
|
||||
if bbox_y is not None:
|
||||
annotation.bbox_y = bbox_y
|
||||
if bbox_width is not None:
|
||||
annotation.bbox_width = bbox_width
|
||||
if bbox_height is not None:
|
||||
annotation.bbox_height = bbox_height
|
||||
if text_value is not None:
|
||||
annotation.text_value = text_value
|
||||
if class_id is not None:
|
||||
annotation.class_id = class_id
|
||||
if class_name is not None:
|
||||
annotation.class_name = class_name
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
session.add(annotation)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, annotation_id: str) -> bool:
|
||||
"""Delete an annotation."""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if annotation:
|
||||
session.delete(annotation)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_for_document(
|
||||
self,
|
||||
document_id: str,
|
||||
source: str | None = None,
|
||||
) -> int:
|
||||
"""Delete all annotations for a document.
|
||||
|
||||
Returns:
|
||||
Count of deleted annotations
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
if source:
|
||||
statement = statement.where(AdminAnnotation.source == source)
|
||||
annotations = session.exec(statement).all()
|
||||
count = len(annotations)
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
session.commit()
|
||||
return count
|
||||
|
||||
def verify(
|
||||
self,
|
||||
annotation_id: str,
|
||||
admin_token: str,
|
||||
) -> AdminAnnotation | None:
|
||||
"""Mark an annotation as verified."""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if not annotation:
|
||||
return None
|
||||
|
||||
annotation.is_verified = True
|
||||
annotation.verified_at = datetime.utcnow()
|
||||
annotation.verified_by = admin_token
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(annotation)
|
||||
session.commit()
|
||||
session.refresh(annotation)
|
||||
session.expunge(annotation)
|
||||
return annotation
|
||||
|
||||
def override(
|
||||
self,
|
||||
annotation_id: str,
|
||||
admin_token: str,
|
||||
change_reason: str | None = None,
|
||||
**updates: Any,
|
||||
) -> AdminAnnotation | None:
|
||||
"""Override an auto-generated annotation.
|
||||
|
||||
Creates a history record and updates the annotation.
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if not annotation:
|
||||
return None
|
||||
|
||||
previous_value = {
|
||||
"class_id": annotation.class_id,
|
||||
"class_name": annotation.class_name,
|
||||
"bbox": {
|
||||
"x": annotation.bbox_x,
|
||||
"y": annotation.bbox_y,
|
||||
"width": annotation.bbox_width,
|
||||
"height": annotation.bbox_height,
|
||||
},
|
||||
"normalized": {
|
||||
"x_center": annotation.x_center,
|
||||
"y_center": annotation.y_center,
|
||||
"width": annotation.width,
|
||||
"height": annotation.height,
|
||||
},
|
||||
"text_value": annotation.text_value,
|
||||
"confidence": annotation.confidence,
|
||||
"source": annotation.source,
|
||||
}
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(annotation, key):
|
||||
setattr(annotation, key, value)
|
||||
|
||||
if annotation.source == "auto":
|
||||
annotation.override_source = "auto"
|
||||
annotation.source = "manual"
|
||||
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
session.add(annotation)
|
||||
|
||||
history = AnnotationHistory(
|
||||
annotation_id=UUID(annotation_id),
|
||||
document_id=annotation.document_id,
|
||||
action="override",
|
||||
previous_value=previous_value,
|
||||
new_value=updates,
|
||||
changed_by=admin_token,
|
||||
change_reason=change_reason,
|
||||
)
|
||||
session.add(history)
|
||||
|
||||
session.commit()
|
||||
session.refresh(annotation)
|
||||
session.expunge(annotation)
|
||||
return annotation
|
||||
|
||||
def create_history(
|
||||
self,
|
||||
annotation_id: UUID,
|
||||
document_id: UUID,
|
||||
action: str,
|
||||
previous_value: dict[str, Any] | None = None,
|
||||
new_value: dict[str, Any] | None = None,
|
||||
changed_by: str | None = None,
|
||||
change_reason: str | None = None,
|
||||
) -> AnnotationHistory:
|
||||
"""Create an annotation history record."""
|
||||
with get_session_context() as session:
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
document_id=document_id,
|
||||
action=action,
|
||||
previous_value=previous_value,
|
||||
new_value=new_value,
|
||||
changed_by=changed_by,
|
||||
change_reason=change_reason,
|
||||
)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
session.refresh(history)
|
||||
session.expunge(history)
|
||||
return history
|
||||
|
||||
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
|
||||
"""Get history for a specific annotation."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AnnotationHistory).where(
|
||||
AnnotationHistory.annotation_id == annotation_id
|
||||
).order_by(AnnotationHistory.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
|
||||
"""Get all annotation history for a document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AnnotationHistory).where(
|
||||
AnnotationHistory.document_id == document_id
|
||||
).order_by(AnnotationHistory.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
75
packages/backend/backend/data/repositories/base.py
Normal file
75
packages/backend/backend/data/repositories/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Base Repository
|
||||
|
||||
Provides common functionality for all repositories.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator, TypeVar, Generic
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseRepository(ABC, Generic[T]):
|
||||
"""Base class for all repositories.
|
||||
|
||||
Provides:
|
||||
- Session management via context manager
|
||||
- Logging infrastructure
|
||||
- Common query patterns
|
||||
- Utility methods for datetime and UUID handling
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _session(self) -> Generator[Session, None, None]:
|
||||
"""Get a database session with auto-commit/rollback."""
|
||||
with get_session_context() as session:
|
||||
yield session
|
||||
|
||||
def _expunge(self, session: Session, entity: T) -> T:
|
||||
"""Detach entity from session for safe return."""
|
||||
session.expunge(entity)
|
||||
return entity
|
||||
|
||||
def _expunge_all(self, session: Session, entities: list[T]) -> list[T]:
|
||||
"""Detach multiple entities from session."""
|
||||
for entity in entities:
|
||||
session.expunge(entity)
|
||||
return entities
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime.
|
||||
|
||||
Use this instead of datetime.utcnow() which is deprecated in Python 3.12+.
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _validate_uuid(value: str, field_name: str = "id") -> UUID:
|
||||
"""Validate and convert string to UUID.
|
||||
|
||||
Args:
|
||||
value: String to convert to UUID
|
||||
field_name: Name of field for error message
|
||||
|
||||
Returns:
|
||||
Validated UUID
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid UUID
|
||||
"""
|
||||
try:
|
||||
return UUID(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid {field_name}: {value}") from e
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Batch Upload Repository
|
||||
|
||||
Handles batch upload operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchUploadRepository(BaseRepository[BatchUpload]):
|
||||
"""Repository for batch upload management.
|
||||
|
||||
Handles:
|
||||
- Batch upload CRUD operations
|
||||
- Batch file tracking
|
||||
- Progress monitoring
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
admin_token: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
upload_source: str = "ui",
|
||||
) -> BatchUpload:
|
||||
"""Create a new batch upload record."""
|
||||
with get_session_context() as session:
|
||||
batch = BatchUpload(
|
||||
admin_token=admin_token,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
upload_source=upload_source,
|
||||
)
|
||||
session.add(batch)
|
||||
session.commit()
|
||||
session.refresh(batch)
|
||||
session.expunge(batch)
|
||||
return batch
|
||||
|
||||
def get(self, batch_id: UUID) -> BatchUpload | None:
|
||||
"""Get batch upload by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(BatchUpload, batch_id)
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update batch upload fields."""
|
||||
with get_session_context() as session:
|
||||
batch = session.get(BatchUpload, batch_id)
|
||||
if batch:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(batch, key):
|
||||
setattr(batch, key, value)
|
||||
session.add(batch)
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
filename: str,
|
||||
**kwargs: Any,
|
||||
) -> BatchUploadFile:
|
||||
"""Create a batch upload file record."""
|
||||
with get_session_context() as session:
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename=filename,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(file_record)
|
||||
session.commit()
|
||||
session.refresh(file_record)
|
||||
session.expunge(file_record)
|
||||
return file_record
|
||||
|
||||
def update_file(
|
||||
self,
|
||||
file_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update batch upload file fields."""
|
||||
with get_session_context() as session:
|
||||
file_record = session.get(BatchUploadFile, file_id)
|
||||
if file_record:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(file_record, key):
|
||||
setattr(file_record, key, value)
|
||||
session.add(file_record)
|
||||
|
||||
def get_files(self, batch_id: UUID) -> list[BatchUploadFile]:
|
||||
"""Get all files for a batch upload."""
|
||||
with get_session_context() as session:
|
||||
statement = select(BatchUploadFile).where(
|
||||
BatchUploadFile.batch_id == batch_id
|
||||
).order_by(BatchUploadFile.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[BatchUpload], int]:
|
||||
"""Get paginated batch uploads."""
|
||||
with get_session_context() as session:
|
||||
count_stmt = select(func.count()).select_from(BatchUpload)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(BatchUpload).order_by(
|
||||
BatchUpload.created_at.desc()
|
||||
).offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
216
packages/backend/backend/data/repositories/dataset_repository.py
Normal file
216
packages/backend/backend/data/repositories/dataset_repository.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Dataset Repository
|
||||
|
||||
Handles training dataset operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRepository(BaseRepository[TrainingDataset]):
|
||||
"""Repository for training dataset management.
|
||||
|
||||
Handles:
|
||||
- Dataset CRUD operations
|
||||
- Dataset status management
|
||||
- Dataset document linking
|
||||
- Training status tracking
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> TrainingDataset:
|
||||
"""Create a new training dataset."""
|
||||
with get_session_context() as session:
|
||||
dataset = TrainingDataset(
|
||||
name=name,
|
||||
description=description,
|
||||
train_ratio=train_ratio,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed,
|
||||
)
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
session.refresh(dataset)
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get(self, dataset_id: str | UUID) -> TrainingDataset | None:
|
||||
"""Get a dataset by ID."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if dataset:
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TrainingDataset], int]:
|
||||
"""List datasets with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(TrainingDataset)
|
||||
count_query = select(func.count()).select_from(TrainingDataset)
|
||||
if status:
|
||||
query = query.where(TrainingDataset.status == status)
|
||||
count_query = count_query.where(TrainingDataset.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
datasets = session.exec(
|
||||
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for d in datasets:
|
||||
session.expunge(d)
|
||||
return list(datasets), total
|
||||
|
||||
def get_active_training_tasks(
|
||||
self, dataset_ids: list[str]
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Get active training tasks for datasets.
|
||||
|
||||
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
|
||||
"""
|
||||
if not dataset_ids:
|
||||
return {}
|
||||
|
||||
valid_uuids = []
|
||||
for d in dataset_ids:
|
||||
try:
|
||||
valid_uuids.append(UUID(d))
|
||||
except ValueError:
|
||||
logger.warning("Invalid UUID in get_active_training_tasks: %s", d)
|
||||
continue
|
||||
|
||||
if not valid_uuids:
|
||||
return {}
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.dataset_id.in_(valid_uuids),
|
||||
TrainingTask.status.in_(["pending", "scheduled", "running"]),
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
return {
|
||||
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
|
||||
for t in results
|
||||
}
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
total_documents: int | None = None,
|
||||
total_images: int | None = None,
|
||||
total_annotations: int | None = None,
|
||||
dataset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Update dataset status and optional totals."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.status = status
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
dataset.error_message = error_message
|
||||
if total_documents is not None:
|
||||
dataset.total_documents = total_documents
|
||||
if total_images is not None:
|
||||
dataset.total_images = total_images
|
||||
if total_annotations is not None:
|
||||
dataset.total_annotations = total_annotations
|
||||
if dataset_path is not None:
|
||||
dataset.dataset_path = dataset_path
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def update_training_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
training_status: str | None,
|
||||
active_training_task_id: str | UUID | None = None,
|
||||
update_main_status: bool = False,
|
||||
) -> None:
|
||||
"""Update dataset training status."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.training_status = training_status
|
||||
dataset.active_training_task_id = (
|
||||
UUID(str(active_training_task_id)) if active_training_task_id else None
|
||||
)
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
if update_main_status and training_status == "completed":
|
||||
dataset.status = "trained"
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Batch insert documents into a dataset.
|
||||
|
||||
Each dict: {document_id, split, page_count, annotation_count}
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
for doc in documents:
|
||||
dd = DatasetDocument(
|
||||
dataset_id=UUID(str(dataset_id)),
|
||||
document_id=UUID(str(doc["document_id"])),
|
||||
split=doc["split"],
|
||||
page_count=doc.get("page_count", 0),
|
||||
annotation_count=doc.get("annotation_count", 0),
|
||||
)
|
||||
session.add(dd)
|
||||
session.commit()
|
||||
|
||||
def get_documents(self, dataset_id: str | UUID) -> list[DatasetDocument]:
|
||||
"""Get all documents in a dataset."""
|
||||
with get_session_context() as session:
|
||||
results = session.exec(
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def delete(self, dataset_id: str | UUID) -> bool:
|
||||
"""Delete a dataset and its document links."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return False
|
||||
# Delete associated document links first
|
||||
doc_links = session.exec(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == UUID(str(dataset_id))
|
||||
)
|
||||
).all()
|
||||
for link in doc_links:
|
||||
session.delete(link)
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Document Repository
|
||||
|
||||
Handles document operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentRepository(BaseRepository[AdminDocument]):
|
||||
"""Repository for document management.
|
||||
|
||||
Handles:
|
||||
- Document CRUD operations
|
||||
- Document status management
|
||||
- Document filtering and pagination
|
||||
- Document category management
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
file_path: str,
|
||||
page_count: int = 1,
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
category: str = "invoice",
|
||||
admin_token: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new document record.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
file_size: File size in bytes
|
||||
content_type: MIME type
|
||||
file_path: Storage path
|
||||
page_count: Number of pages
|
||||
upload_source: Upload source (ui/api)
|
||||
csv_field_values: CSV field values for reference
|
||||
group_key: User-defined grouping key
|
||||
category: Document category
|
||||
admin_token: Deprecated, kept for compatibility
|
||||
|
||||
Returns:
|
||||
Document ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = AdminDocument(
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
file_path=file_path,
|
||||
page_count=page_count,
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
return str(document.document_id)
|
||||
|
||||
def get(self, document_id: str) -> AdminDocument | None:
|
||||
"""Get a document by ID.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
|
||||
Returns:
|
||||
AdminDocument if found, None otherwise
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AdminDocument, UUID(document_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_by_token(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
) -> AdminDocument | None:
|
||||
"""Get a document by ID. Token parameter is deprecated."""
|
||||
return self.get(document_id)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str | None = None,
|
||||
upload_source: str | None = None,
|
||||
has_annotations: bool | None = None,
|
||||
auto_label_status: str | None = None,
|
||||
batch_id: str | None = None,
|
||||
category: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
"""Get paginated documents with optional filters.
|
||||
|
||||
Args:
|
||||
admin_token: Deprecated, kept for compatibility
|
||||
status: Filter by status
|
||||
upload_source: Filter by upload source
|
||||
has_annotations: Filter by annotation presence
|
||||
auto_label_status: Filter by auto-label status
|
||||
batch_id: Filter by batch ID
|
||||
category: Filter by category
|
||||
limit: Page size
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
Tuple of (documents, total_count)
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
where_clauses = []
|
||||
|
||||
if status:
|
||||
where_clauses.append(AdminDocument.status == status)
|
||||
if upload_source:
|
||||
where_clauses.append(AdminDocument.upload_source == upload_source)
|
||||
if auto_label_status:
|
||||
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
||||
if batch_id:
|
||||
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
||||
if category:
|
||||
where_clauses.append(AdminDocument.category == category)
|
||||
|
||||
count_stmt = select(func.count()).select_from(AdminDocument)
|
||||
if where_clauses:
|
||||
count_stmt = count_stmt.where(*where_clauses)
|
||||
|
||||
if has_annotations is not None:
|
||||
if has_annotations:
|
||||
count_stmt = (
|
||||
count_stmt
|
||||
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.group_by(AdminDocument.document_id)
|
||||
)
|
||||
else:
|
||||
count_stmt = (
|
||||
count_stmt
|
||||
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.where(AdminAnnotation.annotation_id.is_(None))
|
||||
)
|
||||
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(AdminDocument)
|
||||
if where_clauses:
|
||||
statement = statement.where(*where_clauses)
|
||||
|
||||
if has_annotations is not None:
|
||||
if has_annotations:
|
||||
statement = (
|
||||
statement
|
||||
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.group_by(AdminDocument.document_id)
|
||||
)
|
||||
else:
|
||||
statement = (
|
||||
statement
|
||||
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.where(AdminAnnotation.annotation_id.is_(None))
|
||||
)
|
||||
|
||||
statement = statement.order_by(AdminDocument.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
document_id: str,
|
||||
status: str,
|
||||
auto_label_status: str | None = None,
|
||||
auto_label_error: str | None = None,
|
||||
) -> None:
|
||||
"""Update document status.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
status: New status
|
||||
auto_label_status: Auto-label status
|
||||
auto_label_error: Auto-label error message
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.status = status
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
if auto_label_status is not None:
|
||||
document.auto_label_status = auto_label_status
|
||||
if auto_label_error is not None:
|
||||
document.auto_label_error = auto_label_error
|
||||
session.add(document)
|
||||
|
||||
def update_file_path(self, document_id: str, file_path: str) -> None:
|
||||
"""Update document file path."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.file_path = file_path
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
|
||||
def update_group_key(self, document_id: str, group_key: str | None) -> bool:
|
||||
"""Update document group key."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.group_key = group_key
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_category(self, document_id: str, category: str) -> AdminDocument | None:
|
||||
"""Update document category."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.category = category
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
session.refresh(document)
|
||||
return document
|
||||
return None
|
||||
|
||||
def delete(self, document_id: str) -> bool:
|
||||
"""Delete a document and its annotations.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
ann_stmt = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
annotations = session.exec(ann_stmt).all()
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
session.delete(document)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_categories(self) -> list[str]:
|
||||
"""Get list of unique document categories."""
|
||||
with get_session_context() as session:
|
||||
statement = (
|
||||
select(AdminDocument.category)
|
||||
.distinct()
|
||||
.order_by(AdminDocument.category)
|
||||
)
|
||||
categories = session.exec(statement).all()
|
||||
return [c for c in categories if c is not None]
|
||||
|
||||
def get_labeled_for_export(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
) -> list[AdminDocument]:
|
||||
"""Get all labeled documents ready for export."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == "labeled"
|
||||
)
|
||||
if admin_token:
|
||||
statement = statement.where(AdminDocument.admin_token == admin_token)
|
||||
statement = statement.order_by(AdminDocument.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def count_by_status(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""Count documents by status."""
|
||||
with get_session_context() as session:
|
||||
statement = select(
|
||||
AdminDocument.status,
|
||||
func.count(AdminDocument.document_id),
|
||||
).group_by(AdminDocument.status)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return {status: count for status, count in results}
|
||||
|
||||
def get_by_ids(self, document_ids: list[str]) -> list[AdminDocument]:
|
||||
"""Get documents by list of IDs."""
|
||||
with get_session_context() as session:
|
||||
uuids = [UUID(str(did)) for did in document_ids]
|
||||
results = session.exec(
|
||||
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_for_training(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str = "labeled",
|
||||
has_annotations: bool = True,
|
||||
min_annotation_count: int | None = None,
|
||||
exclude_used_in_training: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
"""Get documents suitable for training with filtering."""
|
||||
from backend.data.admin_models import TrainingDocumentLink
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == status,
|
||||
)
|
||||
|
||||
if has_annotations or min_annotation_count:
|
||||
annotation_subq = (
|
||||
select(func.count(AdminAnnotation.annotation_id))
|
||||
.where(AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.correlate(AdminDocument)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
if has_annotations:
|
||||
statement = statement.where(annotation_subq > 0)
|
||||
|
||||
if min_annotation_count:
|
||||
statement = statement.where(annotation_subq >= min_annotation_count)
|
||||
|
||||
if exclude_used_in_training:
|
||||
from sqlalchemy import exists
|
||||
training_subq = exists(
|
||||
select(1)
|
||||
.select_from(TrainingDocumentLink)
|
||||
.where(TrainingDocumentLink.document_id == AdminDocument.document_id)
|
||||
)
|
||||
statement = statement.where(~training_subq)
|
||||
|
||||
count_statement = select(func.count()).select_from(statement.subquery())
|
||||
total = session.exec(count_statement).one()
|
||||
|
||||
statement = statement.order_by(AdminDocument.created_at.desc())
|
||||
statement = statement.limit(limit).offset(offset)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
|
||||
return list(results), total
|
||||
|
||||
def acquire_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
duration_seconds: int = 300,
|
||||
) -> AdminDocument | None:
|
||||
"""Acquire annotation lock for a document."""
|
||||
from datetime import timedelta
|
||||
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
lock_until = doc.annotation_lock_until
|
||||
# Handle PostgreSQL returning offset-naive datetimes
|
||||
if lock_until and lock_until.tzinfo is None:
|
||||
lock_until = lock_until.replace(tzinfo=timezone.utc)
|
||||
if lock_until and lock_until > now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
|
||||
def release_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
force: bool = False,
|
||||
) -> AdminDocument | None:
|
||||
"""Release annotation lock for a document."""
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = None
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
|
||||
def extend_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
additional_seconds: int = 300,
|
||||
) -> AdminDocument | None:
|
||||
"""Extend an existing annotation lock."""
|
||||
from datetime import timedelta
|
||||
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
lock_until = doc.annotation_lock_until
|
||||
# Handle PostgreSQL returning offset-naive datetimes
|
||||
if lock_until and lock_until.tzinfo is None:
|
||||
lock_until = lock_until.replace(tzinfo=timezone.utc)
|
||||
if not lock_until or lock_until <= now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Model Version Repository
|
||||
|
||||
Handles model version operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import ModelVersion
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelVersionRepository(BaseRepository[ModelVersion]):
|
||||
"""Repository for model version management.
|
||||
|
||||
Handles:
|
||||
- Model version CRUD operations
|
||||
- Model activation/deactivation
|
||||
- Active model resolution
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
version: str,
|
||||
name: str,
|
||||
model_path: str,
|
||||
description: str | None = None,
|
||||
task_id: str | UUID | None = None,
|
||||
dataset_id: str | UUID | None = None,
|
||||
metrics_mAP: float | None = None,
|
||||
metrics_precision: float | None = None,
|
||||
metrics_recall: float | None = None,
|
||||
document_count: int = 0,
|
||||
training_config: dict[str, Any] | None = None,
|
||||
file_size: int | None = None,
|
||||
trained_at: datetime | None = None,
|
||||
) -> ModelVersion:
|
||||
"""Create a new model version."""
|
||||
with get_session_context() as session:
|
||||
model = ModelVersion(
|
||||
version=version,
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
description=description,
|
||||
task_id=UUID(str(task_id)) if task_id else None,
|
||||
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=training_config,
|
||||
file_size=file_size,
|
||||
trained_at=trained_at,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Get a model version by ID."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if model:
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ModelVersion], int]:
|
||||
"""List model versions with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(ModelVersion)
|
||||
count_query = select(func.count()).select_from(ModelVersion)
|
||||
if status:
|
||||
query = query.where(ModelVersion.status == status)
|
||||
count_query = count_query.where(ModelVersion.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
models = session.exec(
|
||||
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for m in models:
|
||||
session.expunge(m)
|
||||
return list(models), total
|
||||
|
||||
def get_active(self) -> ModelVersion | None:
|
||||
"""Get the currently active model version for inference."""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def activate(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Activate a model version for inference (deactivates all others)."""
|
||||
with get_session_context() as session:
|
||||
all_versions = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).all()
|
||||
for v in all_versions:
|
||||
v.is_active = False
|
||||
v.status = "inactive"
|
||||
v.updated_at = datetime.utcnow()
|
||||
session.add(v)
|
||||
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = True
|
||||
model.status = "active"
|
||||
model.activated_at = datetime.utcnow()
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def deactivate(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Deactivate a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = False
|
||||
model.status = "inactive"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def update(
|
||||
self,
|
||||
version_id: str | UUID,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> ModelVersion | None:
|
||||
"""Update model version metadata."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if name is not None:
|
||||
model.name = name
|
||||
if description is not None:
|
||||
model.description = description
|
||||
if status is not None:
|
||||
model.status = status
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def archive(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Archive a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if model.is_active:
|
||||
return None
|
||||
model.status = "archived"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def delete(self, version_id: str | UUID) -> bool:
|
||||
"""Delete a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return False
|
||||
if model.is_active:
|
||||
return False
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
return True
|
||||
117
packages/backend/backend/data/repositories/token_repository.py
Normal file
117
packages/backend/backend/data/repositories/token_repository.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Token Repository
|
||||
|
||||
Handles admin token operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from backend.data.admin_models import AdminToken
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenRepository(BaseRepository[AdminToken]):
|
||||
"""Repository for admin token management.
|
||||
|
||||
Handles:
|
||||
- Token validation (active status, expiration)
|
||||
- Token CRUD operations
|
||||
- Usage tracking
|
||||
"""
|
||||
|
||||
def is_valid(self, token: str) -> bool:
|
||||
"""Check if admin token exists and is active.
|
||||
|
||||
Args:
|
||||
token: The token string to validate
|
||||
|
||||
Returns:
|
||||
True if token exists, is active, and not expired
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.get(AdminToken, token)
|
||||
if result is None:
|
||||
return False
|
||||
if not result.is_active:
|
||||
return False
|
||||
if result.expires_at and result.expires_at < self._now():
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, token: str) -> AdminToken | None:
|
||||
"""Get admin token details.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
|
||||
Returns:
|
||||
AdminToken if found, None otherwise
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.get(AdminToken, token)
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def create(
|
||||
self,
|
||||
token: str,
|
||||
name: str,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Create or update an admin token.
|
||||
|
||||
If token exists, updates name, expires_at, and reactivates it.
|
||||
Otherwise creates a new token.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
name: Display name for the token
|
||||
expires_at: Optional expiration datetime
|
||||
"""
|
||||
with self._session() as session:
|
||||
existing = session.get(AdminToken, token)
|
||||
if existing:
|
||||
existing.name = name
|
||||
existing.expires_at = expires_at
|
||||
existing.is_active = True
|
||||
session.add(existing)
|
||||
else:
|
||||
new_token = AdminToken(
|
||||
token=token,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(new_token)
|
||||
|
||||
def update_usage(self, token: str) -> None:
|
||||
"""Update admin token last used timestamp.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
"""
|
||||
with self._session() as session:
|
||||
admin_token = session.get(AdminToken, token)
|
||||
if admin_token:
|
||||
admin_token.last_used_at = self._now()
|
||||
session.add(admin_token)
|
||||
|
||||
def deactivate(self, token: str) -> bool:
|
||||
"""Deactivate an admin token.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
|
||||
Returns:
|
||||
True if token was deactivated, False if not found
|
||||
"""
|
||||
with self._session() as session:
|
||||
admin_token = session.get(AdminToken, token)
|
||||
if admin_token:
|
||||
admin_token.is_active = False
|
||||
session.add(admin_token)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Training Task Repository
|
||||
|
||||
Handles training task operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTaskRepository(BaseRepository[TrainingTask]):
|
||||
"""Repository for training task management.
|
||||
|
||||
Handles:
|
||||
- Training task CRUD operations
|
||||
- Task status management
|
||||
- Training logs
|
||||
- Training document links
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
admin_token: str,
|
||||
name: str,
|
||||
task_type: str = "train",
|
||||
description: str | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
scheduled_at: datetime | None = None,
|
||||
cron_expression: str | None = None,
|
||||
is_recurring: bool = False,
|
||||
dataset_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new training task.
|
||||
|
||||
Returns:
|
||||
Task ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
task = TrainingTask(
|
||||
admin_token=admin_token,
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
description=description,
|
||||
config=config,
|
||||
scheduled_at=scheduled_at,
|
||||
cron_expression=cron_expression,
|
||||
is_recurring=is_recurring,
|
||||
status="scheduled" if scheduled_at else "pending",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
session.add(task)
|
||||
session.flush()
|
||||
return str(task.task_id)
|
||||
|
||||
def get(self, task_id: str) -> TrainingTask | None:
|
||||
"""Get a training task by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(TrainingTask, UUID(task_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_by_token(
|
||||
self,
|
||||
task_id: str,
|
||||
admin_token: str | None = None,
|
||||
) -> TrainingTask | None:
|
||||
"""Get a training task by ID. Token parameter is deprecated."""
|
||||
return self.get(task_id)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TrainingTask], int]:
|
||||
"""Get paginated training tasks."""
|
||||
with get_session_context() as session:
|
||||
count_stmt = select(func.count()).select_from(TrainingTask)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(TrainingTask.status == status)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(TrainingTask)
|
||||
if status:
|
||||
statement = statement.where(TrainingTask.status == status)
|
||||
statement = statement.order_by(TrainingTask.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def get_pending(self) -> list[TrainingTask]:
|
||||
"""Get pending training tasks ready to run."""
|
||||
with get_session_context() as session:
|
||||
now = datetime.utcnow()
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.status.in_(["pending", "scheduled"]),
|
||||
(TrainingTask.scheduled_at == None) | (TrainingTask.scheduled_at <= now),
|
||||
).order_by(TrainingTask.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_running(self) -> TrainingTask | None:
|
||||
"""Get currently running training task.
|
||||
|
||||
Returns:
|
||||
Running task or None if no task is running
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(TrainingTask)
|
||||
.where(TrainingTask.status == "running")
|
||||
.order_by(TrainingTask.started_at.desc())
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
result_metrics: dict[str, Any] | None = None,
|
||||
model_path: str | None = None,
|
||||
) -> None:
|
||||
"""Update training task status."""
|
||||
with get_session_context() as session:
|
||||
task = session.get(TrainingTask, UUID(task_id))
|
||||
if task:
|
||||
task.status = status
|
||||
task.updated_at = datetime.utcnow()
|
||||
if status == "running":
|
||||
task.started_at = datetime.utcnow()
|
||||
elif status in ("completed", "failed"):
|
||||
task.completed_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
task.error_message = error_message
|
||||
if result_metrics is not None:
|
||||
task.result_metrics = result_metrics
|
||||
if model_path is not None:
|
||||
task.model_path = model_path
|
||||
session.add(task)
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
"""Cancel a training task."""
|
||||
with get_session_context() as session:
|
||||
task = session.get(TrainingTask, UUID(task_id))
|
||||
if task and task.status in ("pending", "scheduled"):
|
||||
task.status = "cancelled"
|
||||
task.updated_at = datetime.utcnow()
|
||||
session.add(task)
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_log(
|
||||
self,
|
||||
task_id: str,
|
||||
level: str,
|
||||
message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a training log entry."""
|
||||
with get_session_context() as session:
|
||||
log = TrainingLog(
|
||||
task_id=UUID(task_id),
|
||||
level=level,
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
session.add(log)
|
||||
|
||||
def get_logs(
|
||||
self,
|
||||
task_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[TrainingLog]:
|
||||
"""Get training logs for a task."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingLog).where(
|
||||
TrainingLog.task_id == UUID(task_id)
|
||||
).order_by(TrainingLog.created_at.desc()).offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def create_document_link(
|
||||
self,
|
||||
task_id: UUID,
|
||||
document_id: UUID,
|
||||
annotation_snapshot: dict[str, Any] | None = None,
|
||||
) -> TrainingDocumentLink:
|
||||
"""Create a training document link."""
|
||||
with get_session_context() as session:
|
||||
link = TrainingDocumentLink(
|
||||
task_id=task_id,
|
||||
document_id=document_id,
|
||||
annotation_snapshot=annotation_snapshot,
|
||||
)
|
||||
session.add(link)
|
||||
session.commit()
|
||||
session.refresh(link)
|
||||
session.expunge(link)
|
||||
return link
|
||||
|
||||
def get_document_links(self, task_id: UUID) -> list[TrainingDocumentLink]:
|
||||
"""Get all document links for a training task."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingDocumentLink).where(
|
||||
TrainingDocumentLink.task_id == task_id
|
||||
).order_by(TrainingDocumentLink.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_document_training_tasks(self, document_id: UUID) -> list[TrainingDocumentLink]:
|
||||
"""Get all training tasks that used this document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingDocumentLink).where(
|
||||
TrainingDocumentLink.document_id == document_id
|
||||
).order_by(TrainingDocumentLink.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
5
packages/backend/backend/pipeline/__init__.py
Normal file
5
packages/backend/backend/pipeline/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .pipeline import InferencePipeline, InferenceResult
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor
|
||||
|
||||
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
|
||||
101
packages/backend/backend/pipeline/constants.py
Normal file
101
packages/backend/backend/pipeline/constants.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Inference Configuration Constants
|
||||
|
||||
Centralized configuration values for the inference pipeline.
|
||||
Extracted from hardcoded values across multiple modules for easier maintenance.
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# Detection & Model Configuration
|
||||
# ============================================================================
|
||||
|
||||
# YOLO Detection
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
|
||||
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
|
||||
|
||||
# ============================================================================
|
||||
# Image Processing Configuration
|
||||
# ============================================================================
|
||||
|
||||
# DPI (Dots Per Inch) for PDF rendering
|
||||
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
|
||||
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
|
||||
|
||||
# ============================================================================
|
||||
# Customer Number Parser Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Pattern confidence scores (higher = more confident)
|
||||
CUSTOMER_NUMBER_CONFIDENCE = {
|
||||
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
|
||||
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
|
||||
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
|
||||
'compact': 0.75, # Compact format (e.g., "JTY5763")
|
||||
'generic_base': 0.5, # Base score for generic alphanumeric pattern
|
||||
}
|
||||
|
||||
# Bonus scores for generic pattern matching
|
||||
CUSTOMER_NUMBER_BONUS = {
|
||||
'has_dash': 0.2, # Bonus if contains dash
|
||||
'typical_format': 0.25, # Bonus for format XXX NNN-X
|
||||
'medium_length': 0.1, # Bonus for length 6-12 characters
|
||||
}
|
||||
|
||||
# Customer number length constraints
|
||||
CUSTOMER_NUMBER_LENGTH = {
|
||||
'min': 6, # Minimum length for medium length bonus
|
||||
'max': 12, # Maximum length for medium length bonus
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Field Extraction Confidence Scores
|
||||
# ============================================================================
|
||||
|
||||
# Confidence multipliers and base scores
|
||||
FIELD_CONFIDENCE = {
|
||||
'pdf_text': 1.0, # PDF text extraction (always accurate)
|
||||
'payment_line_high': 0.95, # Payment line parsed successfully
|
||||
'regex_fallback': 0.5, # Regex-based fallback extraction
|
||||
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Payment Line Validation
|
||||
# ============================================================================
|
||||
|
||||
# Account number length thresholds for type detection
|
||||
ACCOUNT_TYPE_THRESHOLD = {
|
||||
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
|
||||
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# OCR Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Minimum OCR reference number length
|
||||
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
|
||||
|
||||
# ============================================================================
|
||||
# Pattern Matching
|
||||
# ============================================================================
|
||||
|
||||
# Swedish postal code pattern (to exclude from customer numbers)
|
||||
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
|
||||
|
||||
# ============================================================================
|
||||
# Usage Notes
|
||||
# ============================================================================
|
||||
"""
|
||||
These constants can be overridden at runtime by passing parameters to
|
||||
constructors or methods. The values here serve as sensible defaults
|
||||
based on Swedish invoice processing requirements.
|
||||
|
||||
Example:
|
||||
from backend.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
|
||||
|
||||
detector = YOLODetector(
|
||||
model_path="model.pt",
|
||||
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
|
||||
)
|
||||
"""
|
||||
390
packages/backend/backend/pipeline/customer_number_parser.py
Normal file
390
packages/backend/backend/pipeline/customer_number_parser.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Swedish Customer Number Parser
|
||||
|
||||
Handles extraction and normalization of Swedish customer numbers.
|
||||
Uses Strategy Pattern with multiple matching patterns.
|
||||
|
||||
Common Swedish customer number formats:
|
||||
- JTY 576-3
|
||||
- EMM 256-6
|
||||
- DWQ 211-X
|
||||
- FFL 019N
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
from shared.exceptions import CustomerNumberParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomerNumberMatch:
|
||||
"""Customer number match result."""
|
||||
|
||||
value: str
|
||||
"""The normalized customer number"""
|
||||
|
||||
pattern_name: str
|
||||
"""Name of the pattern that matched"""
|
||||
|
||||
confidence: float
|
||||
"""Confidence score (0.0 to 1.0)"""
|
||||
|
||||
raw_text: str
|
||||
"""Original text that was matched"""
|
||||
|
||||
position: int = 0
|
||||
"""Position in text where match was found"""
|
||||
|
||||
|
||||
class CustomerNumberPattern(ABC):
|
||||
"""Abstract base for customer number patterns."""
|
||||
|
||||
@abstractmethod
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""
|
||||
Try to match pattern in text.
|
||||
|
||||
Args:
|
||||
text: Text to search for customer number
|
||||
|
||||
Returns:
|
||||
CustomerNumberMatch if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""
|
||||
Format matched groups to standard format.
|
||||
|
||||
Args:
|
||||
match: Regex match object
|
||||
|
||||
Returns:
|
||||
Formatted customer number string
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DashFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC 123-X (with dash)
|
||||
|
||||
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number with dash format."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Check if it's not a postal code
|
||||
full_match = match.group(0)
|
||||
if self._is_postal_code(full_match):
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="DashFormat",
|
||||
confidence=0.95,
|
||||
raw_text=full_match,
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to standard ABC 123-X format."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
|
||||
def _is_postal_code(self, text: str) -> bool:
|
||||
"""Check if text looks like Swedish postal code."""
|
||||
# SE 106 43, SE10643, etc.
|
||||
return bool(
|
||||
text.upper().startswith('SE ') and
|
||||
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
|
||||
)
|
||||
|
||||
|
||||
class NoDashFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC 123X (no dash)
|
||||
|
||||
Examples: Dwq 211X, FFL 019N
|
||||
Converts to: DWQ 211-X, FFL 019-N
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number without dash."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Exclude postal codes
|
||||
full_match = match.group(0)
|
||||
if self._is_postal_code(full_match):
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="NoDashFormat",
|
||||
confidence=0.90,
|
||||
raw_text=full_match,
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to standard ABC 123-X format (add dash)."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
|
||||
def _is_postal_code(self, text: str) -> bool:
|
||||
"""Check if text looks like Swedish postal code."""
|
||||
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
|
||||
|
||||
|
||||
class CompactFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC123X (compact, no spaces)
|
||||
|
||||
Examples: JTY5763, FFL019N
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match compact customer number format."""
|
||||
upper_text = text.upper()
|
||||
match = self.PATTERN.search(upper_text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Filter out SE postal codes
|
||||
if match.group(1) == 'SE':
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="CompactFormat",
|
||||
confidence=0.75,
|
||||
raw_text=match.group(0),
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to ABC123X or ABC123-X format."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
|
||||
if suffix:
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
else:
|
||||
return f"{prefix}{number}"
|
||||
|
||||
|
||||
class GenericAlphanumericPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Generic pattern: Letters + numbers + optional dash/letter
|
||||
|
||||
Examples: EMM 256-6, ABC 123, FFL 019
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match generic alphanumeric pattern."""
|
||||
upper_text = text.upper()
|
||||
|
||||
all_matches = []
|
||||
for match in self.PATTERN.finditer(upper_text):
|
||||
matched_text = match.group(1)
|
||||
|
||||
# Filter out pure numbers
|
||||
if re.match(r'^\d+$', matched_text):
|
||||
continue
|
||||
|
||||
# Filter out Swedish postal codes
|
||||
if re.match(r'^SE[\s\-]*\d', matched_text):
|
||||
continue
|
||||
|
||||
# Filter out single letter + digit + space + digit (V4 2)
|
||||
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
|
||||
continue
|
||||
|
||||
# Calculate confidence based on characteristics
|
||||
confidence = self._calculate_confidence(matched_text)
|
||||
|
||||
all_matches.append((confidence, matched_text, match.start()))
|
||||
|
||||
if all_matches:
|
||||
# Return highest confidence match
|
||||
best = max(all_matches, key=lambda x: x[0])
|
||||
return CustomerNumberMatch(
|
||||
value=best[1].strip(),
|
||||
pattern_name="GenericAlphanumeric",
|
||||
confidence=best[0],
|
||||
raw_text=best[1],
|
||||
position=best[2]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Return matched text as-is (already uppercase)."""
|
||||
return match.group(1).strip()
|
||||
|
||||
def _calculate_confidence(self, text: str) -> float:
|
||||
"""Calculate confidence score based on text characteristics."""
|
||||
# Require letters AND digits
|
||||
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
|
||||
has_digits = bool(re.search(r'\d', text))
|
||||
|
||||
if not (has_letters and has_digits):
|
||||
return 0.0 # Not a valid customer number
|
||||
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Bonus for containing dash
|
||||
if '-' in text:
|
||||
score += 0.2
|
||||
|
||||
# Bonus for typical format XXX NNN-X
|
||||
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
|
||||
score += 0.25
|
||||
|
||||
# Bonus for medium length
|
||||
if 6 <= len(text) <= 12:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
|
||||
class LabeledPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: Explicit label + customer number
|
||||
|
||||
Examples:
|
||||
- "Kundnummer: JTY 576-3"
|
||||
- "Customer No: EMM 256-6"
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
|
||||
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number with explicit label."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
extracted = match.group(1).strip()
|
||||
# Remove trailing punctuation
|
||||
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
|
||||
|
||||
if extracted and len(extracted) >= 2:
|
||||
return CustomerNumberMatch(
|
||||
value=extracted.upper(),
|
||||
pattern_name="Labeled",
|
||||
confidence=0.98, # Very high confidence when labeled
|
||||
raw_text=match.group(0),
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Return matched customer number."""
|
||||
extracted = match.group(1).strip()
|
||||
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
|
||||
|
||||
|
||||
class CustomerNumberParser:
|
||||
"""Parser for Swedish customer numbers."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize parser with patterns ordered by specificity."""
|
||||
self.patterns: List[CustomerNumberPattern] = [
|
||||
LabeledPattern(), # Highest priority - explicit label
|
||||
DashFormatPattern(), # Standard format with dash
|
||||
NoDashFormatPattern(), # Standard format without dash
|
||||
CompactFormatPattern(), # Compact format
|
||||
GenericAlphanumericPattern(), # Fallback generic pattern
|
||||
]
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
|
||||
"""
|
||||
Parse customer number from text.
|
||||
|
||||
Args:
|
||||
text: Text to search for customer number
|
||||
|
||||
Returns:
|
||||
Tuple of (customer_number, is_valid, error_message)
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return None, False, "Empty text"
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Try each pattern
|
||||
all_matches: List[CustomerNumberMatch] = []
|
||||
for pattern in self.patterns:
|
||||
match = pattern.match(text)
|
||||
if match:
|
||||
all_matches.append(match)
|
||||
|
||||
# No matches
|
||||
if not all_matches:
|
||||
return None, False, "No customer number found"
|
||||
|
||||
# Return highest confidence match
|
||||
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
|
||||
self.logger.debug(
|
||||
f"Customer number matched: {best_match.value} "
|
||||
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
|
||||
)
|
||||
return best_match.value, True, None
|
||||
|
||||
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
|
||||
"""
|
||||
Find all customer numbers in text.
|
||||
|
||||
Useful for cases with multiple potential matches.
|
||||
|
||||
Args:
|
||||
text: Text to search
|
||||
|
||||
Returns:
|
||||
List of CustomerNumberMatch sorted by confidence (descending)
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
all_matches: List[CustomerNumberMatch] = []
|
||||
for pattern in self.patterns:
|
||||
match = pattern.match(text)
|
||||
if match:
|
||||
all_matches.append(match)
|
||||
|
||||
# Sort by confidence (highest first), then by position (later first)
|
||||
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)
|
||||
630
packages/backend/backend/pipeline/field_extractor.py
Normal file
630
packages/backend/backend/pipeline/field_extractor.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""
|
||||
Field Extractor Module
|
||||
|
||||
Extracts and validates field values from detected regions.
|
||||
|
||||
This module is used during inference to extract values from OCR text.
|
||||
It uses shared utilities from shared.utils for text cleaning and validation.
|
||||
|
||||
Enhanced features:
|
||||
- Multi-source fusion with confidence weighting
|
||||
- Smart amount parsing with multiple strategies
|
||||
- Enhanced date format unification
|
||||
- OCR error correction integration
|
||||
|
||||
Refactored to use modular normalizers for each field type.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import Detection
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
# Import new unified parsers
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
from .customer_number_parser import CustomerNumberParser
|
||||
|
||||
# Import normalizers
|
||||
from .normalizers import (
|
||||
BaseNormalizer,
|
||||
NormalizationResult,
|
||||
create_normalizer_registry,
|
||||
EnhancedAmountNormalizer,
|
||||
EnhancedDateNormalizer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
"""Represents an extracted field value."""
|
||||
field_name: str
|
||||
raw_text: str
|
||||
normalized_value: str | None
|
||||
confidence: float
|
||||
detection_confidence: float
|
||||
ocr_confidence: float
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
is_valid: bool = True
|
||||
validation_error: str | None = None
|
||||
# Multi-source fusion fields
|
||||
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
|
||||
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
|
||||
ocr_corrections_applied: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
result = {
|
||||
'field_name': self.field_name,
|
||||
'value': self.normalized_value,
|
||||
'raw_text': self.raw_text,
|
||||
'confidence': self.confidence,
|
||||
'bbox': list(self.bbox),
|
||||
'page_no': self.page_no,
|
||||
'is_valid': self.is_valid,
|
||||
'validation_error': self.validation_error
|
||||
}
|
||||
if self.alternative_values:
|
||||
result['alternatives'] = self.alternative_values
|
||||
if self.extraction_method != 'single':
|
||||
result['extraction_method'] = self.extraction_method
|
||||
return result
|
||||
|
||||
|
||||
class FieldExtractor:
|
||||
"""Extracts field values from detected regions using OCR or PDF text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
bbox_padding: float = 0.1,
|
||||
dpi: int = 300,
|
||||
use_enhanced_parsing: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize field extractor.
|
||||
|
||||
Args:
|
||||
ocr_lang: Language for OCR
|
||||
use_gpu: Whether to use GPU for OCR
|
||||
bbox_padding: Padding to add around bboxes (as fraction)
|
||||
dpi: DPI used for rendering (for coordinate conversion)
|
||||
use_enhanced_parsing: Whether to use enhanced normalizers
|
||||
"""
|
||||
self.ocr_lang = ocr_lang
|
||||
self.use_gpu = use_gpu
|
||||
self.bbox_padding = bbox_padding
|
||||
self.dpi = dpi
|
||||
self._ocr_engine = None # Lazy init
|
||||
self.use_enhanced_parsing = use_enhanced_parsing
|
||||
|
||||
# Initialize new unified parsers
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.customer_number_parser = CustomerNumberParser()
|
||||
|
||||
# Initialize normalizer registry
|
||||
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
|
||||
|
||||
@property
|
||||
def ocr_engine(self):
|
||||
"""Lazy-load OCR engine only when needed."""
|
||||
if self._ocr_engine is None:
|
||||
from shared.ocr import OCREngine
|
||||
self._ocr_engine = OCREngine(lang=self.ocr_lang)
|
||||
return self._ocr_engine
|
||||
|
||||
def extract_from_detection_with_pdf(
|
||||
self,
|
||||
detection: Detection,
|
||||
pdf_tokens: list,
|
||||
image_width: int,
|
||||
image_height: int
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
|
||||
|
||||
Args:
|
||||
detection: Detection object with bbox in pixel coordinates
|
||||
pdf_tokens: List of Token objects from PDF text extraction
|
||||
image_width: Width of rendered image in pixels
|
||||
image_height: Height of rendered image in pixels
|
||||
|
||||
Returns:
|
||||
ExtractedField object
|
||||
"""
|
||||
# Convert detection bbox from pixels to PDF points
|
||||
scale = 72 / self.dpi # points per pixel
|
||||
x0_pdf = detection.bbox[0] * scale
|
||||
y0_pdf = detection.bbox[1] * scale
|
||||
x1_pdf = detection.bbox[2] * scale
|
||||
y1_pdf = detection.bbox[3] * scale
|
||||
|
||||
# Add padding in points
|
||||
pad = 3 # Small padding in points
|
||||
|
||||
# Find tokens that overlap with detection bbox
|
||||
matching_tokens = []
|
||||
for token in pdf_tokens:
|
||||
if token.page_no != detection.page_no:
|
||||
continue
|
||||
tx0, ty0, tx1, ty1 = token.bbox
|
||||
# Check overlap
|
||||
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
|
||||
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
|
||||
# Calculate overlap ratio to prioritize better matches
|
||||
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
|
||||
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
|
||||
if overlap_x > 0 and overlap_y > 0:
|
||||
token_area = (tx1 - tx0) * (ty1 - ty0)
|
||||
overlap_area = overlap_x * overlap_y
|
||||
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
||||
matching_tokens.append((token, overlap_ratio))
|
||||
|
||||
# Sort by overlap ratio and combine text
|
||||
matching_tokens.sort(key=lambda x: -x[1])
|
||||
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
)
|
||||
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized_value,
|
||||
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
|
||||
detection_confidence=detection.confidence,
|
||||
ocr_confidence=1.0, # PDF text is always accurate
|
||||
bbox=detection.bbox,
|
||||
page_no=detection.page_no,
|
||||
is_valid=is_valid,
|
||||
validation_error=validation_error
|
||||
)
|
||||
|
||||
def extract_from_detection(
|
||||
self,
|
||||
detection: Detection,
|
||||
image: np.ndarray | Image.Image
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value from a detection region using OCR.
|
||||
|
||||
Args:
|
||||
detection: Detection object
|
||||
image: Full page image
|
||||
|
||||
Returns:
|
||||
ExtractedField object
|
||||
"""
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
|
||||
# Get padded bbox
|
||||
h, w = image.shape[:2]
|
||||
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
|
||||
|
||||
# Crop region
|
||||
x0, y0, x1, y1 = [int(v) for v in bbox]
|
||||
region = image[y0:y1, x0:x1]
|
||||
|
||||
# Run OCR on region
|
||||
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
||||
|
||||
# Combine all OCR text
|
||||
raw_text = ' '.join(t.text for t in ocr_tokens)
|
||||
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
)
|
||||
|
||||
# Combined confidence
|
||||
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
|
||||
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized_value,
|
||||
confidence=confidence,
|
||||
detection_confidence=detection.confidence,
|
||||
ocr_confidence=ocr_confidence,
|
||||
bbox=detection.bbox,
|
||||
page_no=detection.page_no,
|
||||
is_valid=is_valid,
|
||||
validation_error=validation_error
|
||||
)
|
||||
|
||||
def _normalize_and_validate(
|
||||
self,
|
||||
field_name: str,
|
||||
raw_text: str
|
||||
) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize and validate extracted text for a field.
|
||||
|
||||
Uses modular normalizers for each field type.
|
||||
Falls back to legacy methods for payment_line and customer_number.
|
||||
|
||||
Returns:
|
||||
(normalized_value, is_valid, validation_error)
|
||||
"""
|
||||
text = raw_text.strip()
|
||||
|
||||
if not text:
|
||||
return None, False, "Empty text"
|
||||
|
||||
# Special handling for payment_line and customer_number (use unified parsers)
|
||||
if field_name == 'payment_line':
|
||||
return self._normalize_payment_line(text)
|
||||
|
||||
if field_name == 'customer_number':
|
||||
return self._normalize_customer_number(text)
|
||||
|
||||
# Use normalizer registry for other fields
|
||||
normalizer = self._normalizers.get(field_name)
|
||||
if normalizer:
|
||||
result = normalizer.normalize(text)
|
||||
return result.to_tuple()
|
||||
|
||||
# Fallback for unknown fields
|
||||
return text, True, None
|
||||
|
||||
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize payment line region text using unified PaymentLineParser.
|
||||
|
||||
Extracts the machine-readable payment line format from OCR text.
|
||||
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
|
||||
Examples:
|
||||
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
|
||||
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
|
||||
|
||||
Returns normalized format preserving ALL components including Amount.
|
||||
This allows downstream cross-validation to extract fields properly.
|
||||
"""
|
||||
# Use unified payment line parser
|
||||
return self.payment_line_parser.format_for_field_extractor(
|
||||
self.payment_line_parser.parse(text)
|
||||
)
|
||||
|
||||
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize customer number text using unified CustomerNumberParser.
|
||||
|
||||
Supports various Swedish customer number formats:
|
||||
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
|
||||
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
|
||||
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
|
||||
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
|
||||
"""
|
||||
return self.customer_number_parser.parse(text)
|
||||
|
||||
def extract_all_fields(
|
||||
self,
|
||||
detections: list[Detection],
|
||||
image: np.ndarray | Image.Image
|
||||
) -> list[ExtractedField]:
|
||||
"""
|
||||
Extract fields from all detections.
|
||||
|
||||
Args:
|
||||
detections: List of detections
|
||||
image: Full page image
|
||||
|
||||
Returns:
|
||||
List of ExtractedField objects
|
||||
"""
|
||||
fields = []
|
||||
|
||||
for detection in detections:
|
||||
field = self.extract_from_detection(detection, image)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
@staticmethod
|
||||
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Infer OCR field from InvoiceNumber if not detected.
|
||||
|
||||
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
|
||||
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
|
||||
|
||||
Args:
|
||||
fields: Dict of field_name -> normalized_value
|
||||
|
||||
Returns:
|
||||
Updated fields dict with inferred OCR if applicable
|
||||
"""
|
||||
# If OCR already exists, no need to infer
|
||||
if fields.get('OCR'):
|
||||
return fields
|
||||
|
||||
# If InvoiceNumber exists and is numeric, use it as OCR
|
||||
invoice_number = fields.get('InvoiceNumber')
|
||||
if invoice_number:
|
||||
# Check if it's mostly digits (valid OCR reference)
|
||||
digits_only = re.sub(r'\D', '', invoice_number)
|
||||
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
|
||||
fields['OCR'] = invoice_number
|
||||
|
||||
return fields
|
||||
|
||||
# =========================================================================
|
||||
# Multi-Source Fusion with Confidence Weighting
|
||||
# =========================================================================
|
||||
|
||||
def fuse_multiple_detections(
|
||||
self,
|
||||
extracted_fields: list[ExtractedField]
|
||||
) -> list[ExtractedField]:
|
||||
"""
|
||||
Fuse multiple detections of the same field using confidence-weighted voting.
|
||||
|
||||
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
|
||||
this method selects the best value or combines them intelligently.
|
||||
|
||||
Strategies:
|
||||
1. For numeric fields (Amount, OCR): prefer values that pass validation
|
||||
2. For date fields: prefer values in expected range
|
||||
3. For giro numbers: prefer values with valid Luhn checksum
|
||||
4. General: weighted vote by confidence scores
|
||||
|
||||
Args:
|
||||
extracted_fields: List of all extracted fields (may have duplicates)
|
||||
|
||||
Returns:
|
||||
List with duplicates resolved to single best value per field
|
||||
"""
|
||||
# Group fields by name
|
||||
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
|
||||
for field in extracted_fields:
|
||||
fields_by_name[field.field_name].append(field)
|
||||
|
||||
fused_fields = []
|
||||
|
||||
for field_name, candidates in fields_by_name.items():
|
||||
if len(candidates) == 1:
|
||||
# No fusion needed
|
||||
fused_fields.append(candidates[0])
|
||||
else:
|
||||
# Multiple candidates - fuse them
|
||||
fused = self._fuse_field_candidates(field_name, candidates)
|
||||
fused_fields.append(fused)
|
||||
|
||||
return fused_fields
|
||||
|
||||
def _fuse_field_candidates(
|
||||
self,
|
||||
field_name: str,
|
||||
candidates: list[ExtractedField]
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Fuse multiple candidates for a single field.
|
||||
|
||||
Returns the best candidate with alternatives recorded.
|
||||
"""
|
||||
# Sort by confidence (descending)
|
||||
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
# Collect all unique values with their max confidence
|
||||
value_scores: dict[str, tuple[float, ExtractedField]] = {}
|
||||
for c in sorted_candidates:
|
||||
if c.normalized_value:
|
||||
if c.normalized_value not in value_scores:
|
||||
value_scores[c.normalized_value] = (c.confidence, c)
|
||||
else:
|
||||
# Keep the higher confidence one
|
||||
if c.confidence > value_scores[c.normalized_value][0]:
|
||||
value_scores[c.normalized_value] = (c.confidence, c)
|
||||
|
||||
if not value_scores:
|
||||
# No valid values, return the highest confidence candidate
|
||||
return sorted_candidates[0]
|
||||
|
||||
# Field-specific fusion strategy
|
||||
best_value, best_field = self._select_best_value(field_name, value_scores)
|
||||
|
||||
# Record alternatives
|
||||
alternatives = [
|
||||
(v, score) for v, (score, _) in value_scores.items()
|
||||
if v != best_value
|
||||
]
|
||||
|
||||
# Create fused result
|
||||
result = ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=best_field.raw_text,
|
||||
normalized_value=best_value,
|
||||
confidence=value_scores[best_value][0],
|
||||
detection_confidence=best_field.detection_confidence,
|
||||
ocr_confidence=best_field.ocr_confidence,
|
||||
bbox=best_field.bbox,
|
||||
page_no=best_field.page_no,
|
||||
is_valid=best_field.is_valid,
|
||||
validation_error=best_field.validation_error,
|
||||
alternative_values=alternatives,
|
||||
extraction_method='fused' if len(value_scores) > 1 else 'single'
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _select_best_value(
|
||||
self,
|
||||
field_name: str,
|
||||
value_scores: dict[str, tuple[float, ExtractedField]]
|
||||
) -> tuple[str, ExtractedField]:
|
||||
"""
|
||||
Select the best value for a field using field-specific logic.
|
||||
|
||||
Returns (best_value, best_field)
|
||||
"""
|
||||
items = list(value_scores.items())
|
||||
|
||||
# Field-specific selection
|
||||
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
|
||||
# Prefer values with valid Luhn checksum
|
||||
for value, (score, field) in items:
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return value, field
|
||||
|
||||
elif field_name == 'Amount':
|
||||
# Prefer larger amounts (usually the total, not subtotals)
|
||||
amounts = []
|
||||
for value, (score, field) in items:
|
||||
try:
|
||||
amt = float(value.replace(',', '.'))
|
||||
amounts.append((amt, value, field))
|
||||
except ValueError:
|
||||
continue
|
||||
if amounts:
|
||||
# Return the largest amount
|
||||
amounts.sort(reverse=True)
|
||||
return amounts[0][1], amounts[0][2]
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Prefer dates in reasonable range
|
||||
from datetime import datetime
|
||||
for value, (score, field) in items:
|
||||
try:
|
||||
dt = datetime.strptime(value, '%Y-%m-%d')
|
||||
# Prefer recent dates (within last 2 years and next 1 year)
|
||||
now = datetime.now()
|
||||
if now.year - 2 <= dt.year <= now.year + 1:
|
||||
return value, field
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Default: return highest confidence value
|
||||
best = max(items, key=lambda x: x[1][0])
|
||||
return best[0], best[1][1]
|
||||
|
||||
# =========================================================================
|
||||
# Apply OCR Corrections to Raw Text
|
||||
# =========================================================================
|
||||
|
||||
def apply_ocr_corrections(
|
||||
self,
|
||||
field_name: str,
|
||||
raw_text: str
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Apply OCR corrections to raw text based on field type.
|
||||
|
||||
Returns (corrected_text, list_of_corrections_applied)
|
||||
"""
|
||||
corrections_applied = []
|
||||
|
||||
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
|
||||
# Aggressive correction for numeric fields
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
elif field_name == 'Amount':
|
||||
# Conservative correction for amounts (preserve decimal separators)
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Conservative correction for dates
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
# No correction for other fields
|
||||
return raw_text, []
|
||||
|
||||
# =========================================================================
|
||||
# Extraction with All Enhancements
|
||||
# =========================================================================
|
||||
|
||||
def extract_with_enhancements(
|
||||
self,
|
||||
detection: Detection,
|
||||
pdf_tokens: list,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
use_enhanced_parsing: bool = True
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value with all enhancements enabled.
|
||||
|
||||
Combines:
|
||||
1. OCR error correction
|
||||
2. Enhanced amount/date parsing
|
||||
3. Multi-strategy extraction
|
||||
|
||||
Args:
|
||||
detection: Detection object
|
||||
pdf_tokens: PDF text tokens
|
||||
image_width: Image width in pixels
|
||||
image_height: Image height in pixels
|
||||
use_enhanced_parsing: Whether to use enhanced parsing methods
|
||||
|
||||
Returns:
|
||||
ExtractedField with enhancements applied
|
||||
"""
|
||||
# First, extract using standard method
|
||||
base_result = self.extract_from_detection_with_pdf(
|
||||
detection, pdf_tokens, image_width, image_height
|
||||
)
|
||||
|
||||
if not use_enhanced_parsing:
|
||||
return base_result
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text, corrections = self.apply_ocr_corrections(
|
||||
base_result.field_name, base_result.raw_text
|
||||
)
|
||||
|
||||
# Re-normalize with enhanced methods if corrections were applied
|
||||
if corrections or base_result.normalized_value is None:
|
||||
# Use enhanced normalizers for Amount and Date fields
|
||||
if base_result.field_name == 'Amount':
|
||||
enhanced_normalizer = EnhancedAmountNormalizer()
|
||||
result = enhanced_normalizer.normalize(corrected_text)
|
||||
normalized, is_valid, error = result.to_tuple()
|
||||
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
enhanced_normalizer = EnhancedDateNormalizer()
|
||||
result = enhanced_normalizer.normalize(corrected_text)
|
||||
normalized, is_valid, error = result.to_tuple()
|
||||
else:
|
||||
# Re-run standard normalization with corrected text
|
||||
normalized, is_valid, error = self._normalize_and_validate(
|
||||
base_result.field_name, corrected_text
|
||||
)
|
||||
|
||||
# Update result if we got a better value
|
||||
if normalized and (not base_result.normalized_value or is_valid):
|
||||
base_result.normalized_value = normalized
|
||||
base_result.is_valid = is_valid
|
||||
base_result.validation_error = error
|
||||
base_result.ocr_corrections_applied = corrections
|
||||
if corrections:
|
||||
base_result.extraction_method = 'corrected'
|
||||
|
||||
return base_result
|
||||
60
packages/backend/backend/pipeline/normalizers/__init__.py
Normal file
60
packages/backend/backend/pipeline/normalizers/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Normalizers Package
|
||||
|
||||
Provides field-specific normalizers for invoice data extraction.
|
||||
Each normalizer handles a specific field type's normalization and validation.
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
from .invoice_number import InvoiceNumberNormalizer
|
||||
from .ocr_number import OcrNumberNormalizer
|
||||
from .bankgiro import BankgiroNormalizer
|
||||
from .plusgiro import PlusgiroNormalizer
|
||||
from .amount import AmountNormalizer, EnhancedAmountNormalizer
|
||||
from .date import DateNormalizer, EnhancedDateNormalizer
|
||||
from .supplier_org_number import SupplierOrgNumberNormalizer
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseNormalizer",
|
||||
"NormalizationResult",
|
||||
# Normalizers
|
||||
"InvoiceNumberNormalizer",
|
||||
"OcrNumberNormalizer",
|
||||
"BankgiroNormalizer",
|
||||
"PlusgiroNormalizer",
|
||||
"AmountNormalizer",
|
||||
"EnhancedAmountNormalizer",
|
||||
"DateNormalizer",
|
||||
"EnhancedDateNormalizer",
|
||||
"SupplierOrgNumberNormalizer",
|
||||
]
|
||||
|
||||
|
||||
# Registry of all normalizers by field name
|
||||
def create_normalizer_registry(
|
||||
use_enhanced: bool = False,
|
||||
) -> dict[str, BaseNormalizer]:
|
||||
"""
|
||||
Create a registry mapping field names to normalizer instances.
|
||||
|
||||
Args:
|
||||
use_enhanced: Whether to use enhanced normalizers for amount/date
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field names to normalizer instances
|
||||
"""
|
||||
amount_normalizer = EnhancedAmountNormalizer() if use_enhanced else AmountNormalizer()
|
||||
date_normalizer = EnhancedDateNormalizer() if use_enhanced else DateNormalizer()
|
||||
|
||||
return {
|
||||
"InvoiceNumber": InvoiceNumberNormalizer(),
|
||||
"OCR": OcrNumberNormalizer(),
|
||||
"Bankgiro": BankgiroNormalizer(),
|
||||
"Plusgiro": PlusgiroNormalizer(),
|
||||
"Amount": amount_normalizer,
|
||||
"InvoiceDate": date_normalizer,
|
||||
"InvoiceDueDate": date_normalizer,
|
||||
# Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping)
|
||||
"supplier_organisation_number": SupplierOrgNumberNormalizer(),
|
||||
}
|
||||
185
packages/backend/backend/pipeline/normalizers/amount.py
Normal file
185
packages/backend/backend/pipeline/normalizers/amount.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Amount Normalizer
|
||||
|
||||
Handles normalization and validation of monetary amounts.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class AmountNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes monetary amounts from Swedish invoices.
|
||||
|
||||
Handles various Swedish amount formats:
|
||||
- With decimal: 1 234,56 kr
|
||||
- With SEK suffix: 1234.56 SEK
|
||||
- Multiple amounts (returns the last one, usually the total)
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Amount"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Split by newlines and process line by line to get the last valid amount
|
||||
lines = text.split("\n")
|
||||
|
||||
# Collect all valid amounts from all lines
|
||||
all_amounts: list[float] = []
|
||||
|
||||
# Pattern for Swedish amount format (with decimals)
|
||||
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Find all amounts in this line
|
||||
matches = re.findall(amount_pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
amount_str = match.replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
all_amounts.append(amount)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
if all_amounts:
|
||||
return NormalizationResult.success(f"{all_amounts[-1]:.2f}")
|
||||
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Try to find any decimal number
|
||||
simple_pattern = r"(\d+[,\.]\d{2})"
|
||||
matches = re.findall(simple_pattern, text)
|
||||
if matches:
|
||||
amount_str = matches[-1].replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Last resort: try to find integer amount (no decimals)
|
||||
# Look for patterns like "Amount: 11699" or standalone numbers
|
||||
int_pattern = r"(?:amount|belopp|summa|total)[:\s]*(\d+)"
|
||||
match = re.search(int_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Very last resort: find any standalone number >= 3 digits
|
||||
standalone_pattern = r"\b(\d{3,})\b"
|
||||
matches = re.findall(standalone_pattern, text)
|
||||
if matches:
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse amount: {text}")
|
||||
|
||||
|
||||
class EnhancedAmountNormalizer(AmountNormalizer):
|
||||
"""
|
||||
Enhanced amount parsing with multiple strategies.
|
||||
|
||||
Strategies:
|
||||
1. Pattern matching for Swedish formats
|
||||
2. Context-aware extraction (look for keywords like "Total", "Summa")
|
||||
3. OCR error correction for common digit errors
|
||||
4. Multi-amount handling (prefer last/largest as total)
|
||||
"""
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Strategy 1: Apply OCR corrections first
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Strategy 2: Look for labeled amounts (highest priority)
|
||||
labeled_patterns = [
|
||||
# Swedish patterns
|
||||
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
|
||||
(
|
||||
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
|
||||
0.8,
|
||||
), # Lower priority for VAT
|
||||
# Generic pattern
|
||||
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
|
||||
]
|
||||
|
||||
candidates: list[tuple[float, float, int]] = []
|
||||
for pattern, priority in labeled_patterns:
|
||||
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
|
||||
amount_str = match.group(1).replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if 0 < amount < 10_000_000: # Reasonable range
|
||||
candidates.append((amount, priority, match.start()))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if candidates:
|
||||
# Sort by priority (desc), then by position (later is usually total)
|
||||
candidates.sort(key=lambda x: (-x[1], -x[2]))
|
||||
best_amount = candidates[0][0]
|
||||
return NormalizationResult.success(f"{best_amount:.2f}")
|
||||
|
||||
# Strategy 3: Parse with shared validator
|
||||
cleaned = TextCleaner.normalize_amount_text(corrected_text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Strategy 4: Try to extract any decimal number as fallback
|
||||
decimal_pattern = r"(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})"
|
||||
matches = re.findall(decimal_pattern, corrected_text)
|
||||
if matches:
|
||||
# Clean and parse each match
|
||||
amounts: list[float] = []
|
||||
for m in matches:
|
||||
cleaned_m = m.replace(" ", "").replace(".", "").replace(",", ".")
|
||||
# Handle Swedish format: "1 234,56" -> "1234.56"
|
||||
if "," in m and "." not in m:
|
||||
cleaned_m = m.replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amt = float(cleaned_m)
|
||||
if 0 < amt < 10_000_000:
|
||||
amounts.append(amt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if amounts:
|
||||
# Return the last/largest amount (usually the total)
|
||||
return NormalizationResult.success(f"{max(amounts):.2f}")
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse amount: {text[:50]}")
|
||||
87
packages/backend/backend/pipeline/normalizers/bankgiro.py
Normal file
87
packages/backend/backend/pipeline/normalizers/bankgiro.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Bankgiro Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish Bankgiro numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class BankgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish Bankgiro numbers.
|
||||
|
||||
Bankgiro rules:
|
||||
- 7 or 8 digits only
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
|
||||
|
||||
Display pattern: ^\\d{3,4}-\\d{4}$
|
||||
Normalized pattern: ^\\d{7,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Bankgiro"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
|
||||
# This distinguishes BG from PG which uses X-X format (digits-single digit)
|
||||
bg_matches = re.findall(r"(\d{3,4})-(\d{4})", text)
|
||||
|
||||
if bg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in bg_matches:
|
||||
digits = match[0] + match[1]
|
||||
if len(digits) in (7, 8) and FieldValidators.luhn_checksum(digits):
|
||||
# Valid BG found
|
||||
formatted = self._format_bankgiro(digits)
|
||||
return NormalizationResult.success(formatted)
|
||||
|
||||
# No valid Luhn, use first match
|
||||
digits = bg_matches[0][0] + bg_matches[0][1]
|
||||
if len(digits) in (7, 8):
|
||||
formatted = self._format_bankgiro(digits)
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# Fallback: try to find 7-8 consecutive digits
|
||||
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
|
||||
# to avoid misinterpreting PG as BG
|
||||
pg_format_present = re.search(r"(?<![0-9])\d{1,7}-\d(?!\d)", text)
|
||||
if pg_format_present:
|
||||
return NormalizationResult.failure("No valid Bankgiro found in text")
|
||||
|
||||
digit_match = re.search(r"\b(\d{7,8})\b", text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
if len(digits) in (7, 8):
|
||||
formatted = self._format_bankgiro(digits)
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure("No valid Bankgiro found in text")
|
||||
|
||||
@staticmethod
|
||||
def _format_bankgiro(digits: str) -> str:
|
||||
"""Format Bankgiro number with dash."""
|
||||
if len(digits) == 8:
|
||||
return f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
return f"{digits[:3]}-{digits[3:]}"
|
||||
71
packages/backend/backend/pipeline/normalizers/base.py
Normal file
71
packages/backend/backend/pipeline/normalizers/base.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Base Normalizer Interface
|
||||
|
||||
Defines the contract for all field normalizers.
|
||||
Each normalizer handles a specific field type's normalization and validation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizationResult:
|
||||
"""Result of a normalization operation."""
|
||||
|
||||
value: str | None
|
||||
is_valid: bool
|
||||
error: str | None = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, value: str) -> "NormalizationResult":
|
||||
"""Create a successful result."""
|
||||
return cls(value=value, is_valid=True, error=None)
|
||||
|
||||
@classmethod
|
||||
def success_with_warning(cls, value: str, warning: str) -> "NormalizationResult":
|
||||
"""Create a successful result with a warning."""
|
||||
return cls(value=value, is_valid=True, error=warning)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "NormalizationResult":
|
||||
"""Create a failed result."""
|
||||
return cls(value=None, is_valid=False, error=error)
|
||||
|
||||
def to_tuple(self) -> tuple[str | None, bool, str | None]:
|
||||
"""Convert to legacy tuple format for backward compatibility."""
|
||||
return (self.value, self.is_valid, self.error)
|
||||
|
||||
|
||||
class BaseNormalizer(ABC):
|
||||
"""
|
||||
Abstract base class for field normalizers.
|
||||
|
||||
Each normalizer is responsible for:
|
||||
1. Cleaning and normalizing raw text
|
||||
2. Validating the normalized value
|
||||
3. Returning a standardized result
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def field_name(self) -> str:
|
||||
"""The field name this normalizer handles."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
"""
|
||||
Normalize and validate the input text.
|
||||
|
||||
Args:
|
||||
text: Raw text to normalize
|
||||
|
||||
Returns:
|
||||
NormalizationResult with normalized value or error
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, text: str) -> NormalizationResult:
|
||||
"""Allow using the normalizer as a callable."""
|
||||
return self.normalize(text)
|
||||
200
packages/backend/backend/pipeline/normalizers/date.py
Normal file
200
packages/backend/backend/pipeline/normalizers/date.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Date Normalizer
|
||||
|
||||
Handles normalization and validation of invoice dates.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class DateNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes dates from Swedish invoices.
|
||||
|
||||
Handles various date formats:
|
||||
- 2025-08-29 (ISO format)
|
||||
- 2025.08.29 (dot separator)
|
||||
- 29/08/2025 (European format)
|
||||
- 29.08.2025 (European with dots)
|
||||
- 20250829 (compact format)
|
||||
|
||||
Output format: YYYY-MM-DD (ISO 8601)
|
||||
"""
|
||||
|
||||
# Date patterns with their parsing logic
|
||||
PATTERNS = [
|
||||
# ISO format: 2025-08-29
|
||||
(
|
||||
r"(\d{4})-(\d{1,2})-(\d{1,2})",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
# Dot format: 2025.08.29 (common in Swedish)
|
||||
(
|
||||
r"(\d{4})\.(\d{1,2})\.(\d{1,2})",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
# European slash format: 29/08/2025
|
||||
(
|
||||
r"(\d{1,2})/(\d{1,2})/(\d{4})",
|
||||
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
|
||||
),
|
||||
# European dot format: 29.08.2025
|
||||
(
|
||||
r"(\d{1,2})\.(\d{1,2})\.(\d{4})",
|
||||
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
|
||||
),
|
||||
# Compact format: 20250829
|
||||
(
|
||||
r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Date"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# First, try using shared validator
|
||||
iso_date = FieldValidators.format_date_iso(text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return NormalizationResult.success(iso_date)
|
||||
|
||||
# Fallback: try original patterns for edge cases
|
||||
for pattern, extractor in self.PATTERNS:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
try:
|
||||
year, month, day = extractor(match)
|
||||
# Validate date
|
||||
parsed_date = datetime(year, month, day)
|
||||
# Sanity check: year should be reasonable (2000-2100)
|
||||
if 2000 <= parsed_date.year <= 2100:
|
||||
return NormalizationResult.success(
|
||||
parsed_date.strftime("%Y-%m-%d")
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse date: {text}")
|
||||
|
||||
|
||||
class EnhancedDateNormalizer(DateNormalizer):
|
||||
"""
|
||||
Enhanced date parsing with comprehensive format support.
|
||||
|
||||
Additional support for:
|
||||
- Swedish text: "29 december 2024", "29 dec 2024"
|
||||
- OCR error correction: 2O24-12-29 -> 2024-12-29
|
||||
"""
|
||||
|
||||
# Swedish month names
|
||||
SWEDISH_MONTHS = {
|
||||
"januari": 1,
|
||||
"jan": 1,
|
||||
"februari": 2,
|
||||
"feb": 2,
|
||||
"mars": 3,
|
||||
"mar": 3,
|
||||
"april": 4,
|
||||
"apr": 4,
|
||||
"maj": 5,
|
||||
"juni": 6,
|
||||
"jun": 6,
|
||||
"juli": 7,
|
||||
"jul": 7,
|
||||
"augusti": 8,
|
||||
"aug": 8,
|
||||
"september": 9,
|
||||
"sep": 9,
|
||||
"sept": 9,
|
||||
"oktober": 10,
|
||||
"okt": 10,
|
||||
"november": 11,
|
||||
"nov": 11,
|
||||
"december": 12,
|
||||
"dec": 12,
|
||||
}
|
||||
|
||||
# Extended patterns
|
||||
EXTENDED_PATTERNS = [
|
||||
# ISO format: 2025-08-29, 2025/08/29
|
||||
("ymd", r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"),
|
||||
# Dot format: 2025.08.29
|
||||
("ymd", r"(\d{4})\.(\d{1,2})\.(\d{1,2})"),
|
||||
# European slash: 29/08/2025
|
||||
("dmy", r"(\d{1,2})/(\d{1,2})/(\d{4})"),
|
||||
# European dot: 29.08.2025
|
||||
("dmy", r"(\d{1,2})\.(\d{1,2})\.(\d{4})"),
|
||||
# European dash: 29-08-2025
|
||||
("dmy", r"(\d{1,2})-(\d{1,2})-(\d{4})"),
|
||||
# Compact: 20250829
|
||||
("ymd_compact", r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)"),
|
||||
]
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Try shared validator first
|
||||
iso_date = FieldValidators.format_date_iso(corrected_text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return NormalizationResult.success(iso_date)
|
||||
|
||||
# Try Swedish text date pattern: "29 december 2024" or "29 dec 2024"
|
||||
swedish_pattern = r"(\d{1,2})\s+([a-z\u00e5\u00e4\u00f6]+)\s+(\d{4})"
|
||||
match = re.search(swedish_pattern, corrected_text.lower())
|
||||
if match:
|
||||
day = int(match.group(1))
|
||||
month_name = match.group(2)
|
||||
year = int(match.group(3))
|
||||
if month_name in self.SWEDISH_MONTHS:
|
||||
month = self.SWEDISH_MONTHS[month_name]
|
||||
try:
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Extended patterns
|
||||
for fmt, pattern in self.EXTENDED_PATTERNS:
|
||||
match = re.search(pattern, corrected_text)
|
||||
if match:
|
||||
try:
|
||||
if fmt == "ymd":
|
||||
year = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
day = int(match.group(3))
|
||||
elif fmt == "dmy":
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
elif fmt == "ymd_compact":
|
||||
year = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
day = int(match.group(3))
|
||||
else:
|
||||
continue
|
||||
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse date: {text[:50]}")
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Invoice Number Normalizer
|
||||
|
||||
Handles normalization and validation of invoice numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class InvoiceNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes invoice numbers from Swedish invoices.
|
||||
|
||||
Invoice numbers can be:
|
||||
- Pure digits: 12345678
|
||||
- Alphanumeric: A3861, INV-2024-001, F12345
|
||||
- With separators: 2024/001, 2024-001
|
||||
|
||||
Strategy:
|
||||
1. Look for common invoice number patterns
|
||||
2. Prefer shorter, more specific matches over long digit sequences
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "InvoiceNumber"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
|
||||
# Examples: A3861, F12345, INV001
|
||||
alpha_patterns = [
|
||||
r"\b([A-Z]{1,3}\d{3,10})\b", # A3861, INV12345
|
||||
r"\b(\d{3,10}[A-Z]{1,3})\b", # 12345A
|
||||
r"\b([A-Z]{2,5}[-/]?\d{3,10})\b", # INV-12345, FAK12345
|
||||
]
|
||||
|
||||
for pattern in alpha_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return NormalizationResult.success(match.group(1).upper())
|
||||
|
||||
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
|
||||
year_pattern = r"\b(20\d{2}[-/]\d{3,8})\b"
|
||||
match = re.search(year_pattern, text)
|
||||
if match:
|
||||
return NormalizationResult.success(match.group(1))
|
||||
|
||||
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
|
||||
# This avoids capturing long OCR numbers
|
||||
digit_sequences = re.findall(r"\b(\d{3,10})\b", text)
|
||||
if digit_sequences:
|
||||
# Prefer shorter sequences (more likely to be invoice number)
|
||||
# Also filter out sequences that look like dates (8 digits starting with 20)
|
||||
valid_sequences = []
|
||||
for seq in digit_sequences:
|
||||
# Skip if it looks like a date (YYYYMMDD)
|
||||
if len(seq) == 8 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip if too long (likely OCR number)
|
||||
if len(seq) > 10:
|
||||
continue
|
||||
valid_sequences.append(seq)
|
||||
|
||||
if valid_sequences:
|
||||
# Return shortest valid sequence
|
||||
return NormalizationResult.success(min(valid_sequences, key=len))
|
||||
|
||||
# Fallback: extract all digits if nothing else works
|
||||
digits = re.sub(r"\D", "", text)
|
||||
if len(digits) >= 3:
|
||||
# Limit to first 15 digits to avoid very long sequences
|
||||
return NormalizationResult.success_with_warning(
|
||||
digits[:15], "Fallback extraction"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure(
|
||||
f"Cannot extract invoice number from: {text[:50]}"
|
||||
)
|
||||
37
packages/backend/backend/pipeline/normalizers/ocr_number.py
Normal file
37
packages/backend/backend/pipeline/normalizers/ocr_number.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
OCR Number Normalizer
|
||||
|
||||
Handles normalization and validation of OCR reference numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class OcrNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes OCR (Optical Character Recognition) reference numbers.
|
||||
|
||||
OCR numbers in Swedish payment systems:
|
||||
- Minimum 5 digits
|
||||
- Used for automated payment matching
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "OCR"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
if len(digits) < 5:
|
||||
return NormalizationResult.failure(
|
||||
f"Too few digits for OCR: {len(digits)}"
|
||||
)
|
||||
|
||||
return NormalizationResult.success(digits)
|
||||
90
packages/backend/backend/pipeline/normalizers/plusgiro.py
Normal file
90
packages/backend/backend/pipeline/normalizers/plusgiro.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Plusgiro Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish Plusgiro numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class PlusgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish Plusgiro numbers.
|
||||
|
||||
Plusgiro rules:
|
||||
- 2 to 8 digits
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
|
||||
|
||||
Display pattern: ^\\d{1,7}-\\d$
|
||||
Normalized pattern: ^\\d{2,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Plusgiro"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
|
||||
# This is distinct from BG format which has 4 digits after the dash
|
||||
# Pattern allows spaces within the number like "486 98 63-6"
|
||||
# (?<![0-9]) ensures we don't start from within another number (like BG)
|
||||
pg_matches = re.findall(r"(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)", text)
|
||||
|
||||
if pg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in pg_matches:
|
||||
# Remove spaces from the first part
|
||||
digits = re.sub(r"\s", "", match[0]) + match[1]
|
||||
if 2 <= len(digits) <= 8 and FieldValidators.luhn_checksum(digits):
|
||||
# Valid PG found
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return NormalizationResult.success(formatted)
|
||||
|
||||
# No valid Luhn, use first match with most digits
|
||||
best_match = max(pg_matches, key=lambda m: len(re.sub(r"\s", "", m[0])))
|
||||
digits = re.sub(r"\s", "", best_match[0]) + best_match[1]
|
||||
if 2 <= len(digits) <= 8:
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# If no PG format found, extract all digits and format as PG
|
||||
# This handles cases where the number might be in BG format or raw digits
|
||||
all_digits = re.sub(r"\D", "", text)
|
||||
|
||||
# Try to find a valid 2-8 digit sequence
|
||||
if 2 <= len(all_digits) <= 8:
|
||||
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
|
||||
if FieldValidators.luhn_checksum(all_digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# Try to find any 2-8 digit sequence in text
|
||||
digit_match = re.search(r"\b(\d{2,8})\b", text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure("No valid Plusgiro found in text")
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Supplier Organization Number Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish organization numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class SupplierOrgNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish supplier organization numbers.
|
||||
|
||||
Extracts organization number in format: NNNNNN-NNNN (10 digits)
|
||||
Also handles VAT numbers: SE + 10 digits + 01
|
||||
|
||||
Examples:
|
||||
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
|
||||
'Momsreg.nr SE556123456701' -> '556123-4567'
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "supplier_org_number"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Pattern 1: Standard org number format: NNNNNN-NNNN
|
||||
org_pattern = r"\b(\d{6})-?(\d{4})\b"
|
||||
match = re.search(org_pattern, text)
|
||||
if match:
|
||||
org_num = f"{match.group(1)}-{match.group(2)}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
# Pattern 2: VAT number format: SE + 10 digits + 01
|
||||
vat_pattern = r"SE\s*(\d{10})01"
|
||||
match = re.search(vat_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
# Pattern 3: Just 10 consecutive digits
|
||||
digits_pattern = r"\b(\d{10})\b"
|
||||
match = re.search(digits_pattern, text)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
# Validate: first digit should be 1-9 for Swedish org numbers
|
||||
if digits[0] in "123456789":
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
return NormalizationResult.failure(
|
||||
f"Cannot extract org number from: {text[:100]}"
|
||||
)
|
||||
261
packages/backend/backend/pipeline/payment_line_parser.py
Normal file
261
packages/backend/backend/pipeline/payment_line_parser.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Swedish Payment Line Parser
|
||||
|
||||
Handles parsing and validation of Swedish machine-readable payment lines.
|
||||
Unifies payment line parsing logic that was previously duplicated across multiple modules.
|
||||
|
||||
Standard Swedish payment line format:
|
||||
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
|
||||
Example:
|
||||
# 94228110015950070 # 15658 00 8 > 48666036#14#
|
||||
|
||||
This parser handles common OCR errors:
|
||||
- Spaces in numbers: "12 0 0" → "1200"
|
||||
- Missing symbols: Missing ">"
|
||||
- Spaces in check digits: "#41 #" → "#41#"
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from shared.exceptions import PaymentLineParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaymentLineData:
|
||||
"""Parsed payment line data."""
|
||||
|
||||
ocr_number: str
|
||||
"""OCR reference number (payment reference)"""
|
||||
|
||||
amount: Optional[str] = None
|
||||
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
|
||||
|
||||
account_number: Optional[str] = None
|
||||
"""Bankgiro or Plusgiro account number"""
|
||||
|
||||
record_type: Optional[str] = None
|
||||
"""Record type digit (usually '5' or '8' or '9')"""
|
||||
|
||||
check_digits: Optional[str] = None
|
||||
"""Check digits for account validation"""
|
||||
|
||||
raw_text: str = ""
|
||||
"""Original raw text that was parsed"""
|
||||
|
||||
is_valid: bool = True
|
||||
"""Whether parsing was successful"""
|
||||
|
||||
error: Optional[str] = None
|
||||
"""Error message if parsing failed"""
|
||||
|
||||
parse_method: str = "unknown"
|
||||
"""Which parsing pattern was used (for debugging)"""
|
||||
|
||||
|
||||
class PaymentLineParser:
|
||||
"""Parser for Swedish payment lines with OCR error handling."""
|
||||
|
||||
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
|
||||
FULL_PATTERN = re.compile(
|
||||
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
|
||||
NO_AMOUNT_PATTERN = re.compile(
|
||||
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Alternative pattern: look for OCR > ACCOUNT# pattern
|
||||
ALT_PATTERN = re.compile(
|
||||
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Account only pattern: > ACCOUNT#CHECK#
|
||||
ACCOUNT_ONLY_PATTERN = re.compile(
|
||||
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize parser with logger."""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse(self, text: str) -> PaymentLineData:
|
||||
"""
|
||||
Parse payment line text.
|
||||
|
||||
Handles common OCR errors:
|
||||
- Spaces in numbers: "12 0 0" → "1200"
|
||||
- Missing symbols: Missing ">"
|
||||
- Spaces in check digits: "#41 #" → "#41#"
|
||||
|
||||
Args:
|
||||
text: Raw payment line text from OCR
|
||||
|
||||
Returns:
|
||||
PaymentLineData with parsed fields or error information
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
raw_text=text,
|
||||
is_valid=False,
|
||||
error="Empty payment line text",
|
||||
parse_method="none"
|
||||
)
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Try full pattern with amount
|
||||
match = self.FULL_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_full_match(match, text)
|
||||
|
||||
# Try pattern without amount
|
||||
match = self.NO_AMOUNT_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_no_amount_match(match, text)
|
||||
|
||||
# Try alternative pattern
|
||||
match = self.ALT_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_alt_match(match, text)
|
||||
|
||||
# Try account only pattern
|
||||
match = self.ACCOUNT_ONLY_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_account_only_match(match, text)
|
||||
|
||||
# No match - return error
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
raw_text=text,
|
||||
is_valid=False,
|
||||
error="No valid payment line format found",
|
||||
parse_method="none"
|
||||
)
|
||||
|
||||
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse full pattern match (with amount)."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
kronor = self._clean_digits(match.group(2))
|
||||
ore = match.group(3)
|
||||
record_type = match.group(4)
|
||||
account = self._clean_digits(match.group(5))
|
||||
check_digits = match.group(6)
|
||||
|
||||
amount = f"{kronor}.{ore}"
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=amount,
|
||||
account_number=account,
|
||||
record_type=record_type,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="full"
|
||||
)
|
||||
|
||||
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse pattern match without amount."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
account = self._clean_digits(match.group(2))
|
||||
check_digits = match.group(3)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="no_amount"
|
||||
)
|
||||
|
||||
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse alternative pattern match."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
account = self._clean_digits(match.group(2))
|
||||
check_digits = match.group(3)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="alternative"
|
||||
)
|
||||
|
||||
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse account-only pattern match."""
|
||||
account = self._clean_digits(match.group(1))
|
||||
check_digits = match.group(2)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error="Partial payment line (account only)",
|
||||
parse_method="account_only"
|
||||
)
|
||||
|
||||
def _clean_digits(self, text: str) -> str:
|
||||
"""Remove spaces from digit string (OCR error correction)."""
|
||||
return text.replace(' ', '')
|
||||
|
||||
def format_machine_readable(self, data: PaymentLineData) -> str:
|
||||
"""
|
||||
Format parsed data back to machine-readable format.
|
||||
|
||||
Returns:
|
||||
Formatted string in standard Swedish payment line format
|
||||
"""
|
||||
if not data.is_valid:
|
||||
return data.raw_text
|
||||
|
||||
# Full format with amount
|
||||
if data.amount and data.record_type:
|
||||
kronor, ore = data.amount.split('.')
|
||||
return (
|
||||
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
|
||||
f"{data.account_number}#{data.check_digits}#"
|
||||
)
|
||||
|
||||
# Format without amount
|
||||
if data.ocr_number and data.account_number:
|
||||
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
|
||||
|
||||
# Account only
|
||||
if data.account_number:
|
||||
return f"> {data.account_number}#{data.check_digits}#"
|
||||
|
||||
# Fallback
|
||||
return data.raw_text
|
||||
|
||||
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
|
||||
"""
|
||||
Format parsed data for FieldExtractor compatibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
|
||||
"""
|
||||
if not data.is_valid:
|
||||
return None, False, data.error
|
||||
|
||||
formatted = self.format_machine_readable(data)
|
||||
return formatted, True, data.error
|
||||
499
packages/backend/backend/pipeline/pipeline.py
Normal file
499
packages/backend/backend/pipeline/pipeline.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
Inference Pipeline
|
||||
|
||||
Complete pipeline for extracting invoice data from PDFs.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import time
|
||||
import re
|
||||
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor, ExtractedField
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossValidationResult:
|
||||
"""Result of cross-validation between payment_line and other fields."""
|
||||
is_valid: bool = False
|
||||
ocr_match: bool | None = None # None if not comparable
|
||||
amount_match: bool | None = None
|
||||
bankgiro_match: bool | None = None
|
||||
plusgiro_match: bool | None = None
|
||||
payment_line_ocr: str | None = None
|
||||
payment_line_amount: str | None = None
|
||||
payment_line_account: str | None = None
|
||||
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
|
||||
details: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""Result of invoice processing."""
|
||||
document_id: str | None = None
|
||||
success: bool = False
|
||||
fields: dict[str, Any] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
|
||||
raw_detections: list[Detection] = field(default_factory=list)
|
||||
extracted_fields: list[ExtractedField] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
fallback_used: bool = False
|
||||
cross_validation: CrossValidationResult | None = None
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""Convert to JSON-serializable dictionary."""
|
||||
result = {
|
||||
'DocumentId': self.document_id,
|
||||
'InvoiceNumber': self.fields.get('InvoiceNumber'),
|
||||
'InvoiceDate': self.fields.get('InvoiceDate'),
|
||||
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
|
||||
'OCR': self.fields.get('OCR'),
|
||||
'Bankgiro': self.fields.get('Bankgiro'),
|
||||
'Plusgiro': self.fields.get('Plusgiro'),
|
||||
'Amount': self.fields.get('Amount'),
|
||||
'supplier_org_number': self.fields.get('supplier_org_number'),
|
||||
'customer_number': self.fields.get('customer_number'),
|
||||
'payment_line': self.fields.get('payment_line'),
|
||||
'confidence': self.confidence,
|
||||
'success': self.success,
|
||||
'fallback_used': self.fallback_used
|
||||
}
|
||||
# Add bboxes if present
|
||||
if self.bboxes:
|
||||
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
|
||||
# Add cross-validation results if present
|
||||
if self.cross_validation:
|
||||
result['cross_validation'] = {
|
||||
'is_valid': self.cross_validation.is_valid,
|
||||
'ocr_match': self.cross_validation.ocr_match,
|
||||
'amount_match': self.cross_validation.amount_match,
|
||||
'bankgiro_match': self.cross_validation.bankgiro_match,
|
||||
'plusgiro_match': self.cross_validation.plusgiro_match,
|
||||
'payment_line_ocr': self.cross_validation.payment_line_ocr,
|
||||
'payment_line_amount': self.cross_validation.payment_line_amount,
|
||||
'payment_line_account': self.cross_validation.payment_line_account,
|
||||
'payment_line_account_type': self.cross_validation.payment_line_account_type,
|
||||
'details': self.cross_validation.details,
|
||||
}
|
||||
return result
|
||||
|
||||
def get_field(self, field_name: str) -> tuple[Any, float]:
|
||||
"""Get field value and confidence."""
|
||||
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
||||
|
||||
|
||||
class InferencePipeline:
|
||||
"""
|
||||
Complete inference pipeline for invoice data extraction.
|
||||
|
||||
Pipeline flow:
|
||||
1. PDF -> Image rendering
|
||||
2. YOLO detection of field regions
|
||||
3. OCR extraction from detected regions
|
||||
4. Field normalization and validation
|
||||
5. Fallback to full-page OCR if YOLO fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
dpi: int = 300,
|
||||
enable_fallback: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize inference pipeline.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained YOLO model
|
||||
confidence_threshold: Detection confidence threshold
|
||||
ocr_lang: Language for OCR
|
||||
use_gpu: Whether to use GPU
|
||||
dpi: Resolution for PDF rendering
|
||||
enable_fallback: Enable fallback to full-page OCR
|
||||
"""
|
||||
self.detector = YOLODetector(
|
||||
model_path,
|
||||
confidence_threshold=confidence_threshold,
|
||||
device='cuda' if use_gpu else 'cpu'
|
||||
)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.dpi = dpi
|
||||
self.enable_fallback = enable_fallback
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
document_id: str | None = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Process a PDF and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
"""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = InferenceResult(
|
||||
document_id=document_id or Path(pdf_path).stem
|
||||
)
|
||||
|
||||
try:
|
||||
all_detections = []
|
||||
all_extracted = []
|
||||
|
||||
# Process each page
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
# Convert to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Run YOLO detection
|
||||
detections = self.detector.detect(image_array, page_no=page_no)
|
||||
all_detections.extend(detections)
|
||||
|
||||
# Extract fields from detections
|
||||
for detection in detections:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
all_extracted.append(extracted)
|
||||
|
||||
result.raw_detections = all_detections
|
||||
result.extracted_fields = all_extracted
|
||||
|
||||
# Merge extracted fields (prefer highest confidence)
|
||||
self._merge_fields(result)
|
||||
|
||||
# Fallback if key fields are missing
|
||||
if self.enable_fallback and self._needs_fallback(result):
|
||||
self._run_fallback(pdf_path, result)
|
||||
|
||||
result.success = len(result.fields) > 0
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
for extracted in result.extracted_fields:
|
||||
if not extracted.is_valid or not extracted.normalized_value:
|
||||
continue
|
||||
|
||||
if extracted.field_name not in field_candidates:
|
||||
field_candidates[extracted.field_name] = []
|
||||
field_candidates[extracted.field_name].append(extracted)
|
||||
|
||||
# Select best candidate for each field
|
||||
for field_name, candidates in field_candidates.items():
|
||||
best = max(candidates, key=lambda x: x.confidence)
|
||||
result.fields[field_name] = best.normalized_value
|
||||
result.confidence[field_name] = best.confidence
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
result.bboxes[field_name] = best.bbox
|
||||
|
||||
# Perform cross-validation if payment_line is detected
|
||||
self._cross_validate_payment_line(result)
|
||||
|
||||
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
||||
|
||||
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
|
||||
|
||||
Returns: (ocr, amount, account) tuple
|
||||
"""
|
||||
parsed = self.payment_line_parser.parse(payment_line)
|
||||
|
||||
if not parsed.is_valid:
|
||||
return None, None, None
|
||||
|
||||
return parsed.ocr_number, parsed.amount, parsed.account_number
|
||||
|
||||
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
|
||||
"""
|
||||
Cross-validate payment_line data against other detected fields.
|
||||
Payment line values take PRIORITY over individually detected fields.
|
||||
|
||||
Swedish payment line (Betalningsrad) contains:
|
||||
- OCR reference number
|
||||
- Amount (kronor and öre)
|
||||
- Bankgiro or Plusgiro account number
|
||||
|
||||
This method:
|
||||
1. Parses payment_line to extract OCR, Amount, Account
|
||||
2. Compares with separately detected fields for validation
|
||||
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
|
||||
"""
|
||||
payment_line = result.fields.get('payment_line')
|
||||
if not payment_line:
|
||||
return
|
||||
|
||||
cv = CrossValidationResult()
|
||||
cv.details = []
|
||||
|
||||
# Parse machine-readable payment line format
|
||||
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
|
||||
|
||||
cv.payment_line_ocr = ocr
|
||||
cv.payment_line_amount = amount
|
||||
|
||||
# Determine account type based on digit count
|
||||
if account:
|
||||
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
|
||||
if len(account) >= 7:
|
||||
cv.payment_line_account_type = 'bankgiro'
|
||||
# Format: XXX-XXXX or XXXX-XXXX
|
||||
if len(account) == 7:
|
||||
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
|
||||
else:
|
||||
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
|
||||
else:
|
||||
cv.payment_line_account_type = 'plusgiro'
|
||||
# Format: XXXXXXX-X
|
||||
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
|
||||
|
||||
# Cross-validate and OVERRIDE with payment_line values
|
||||
|
||||
# OCR: payment_line takes priority
|
||||
detected_ocr = result.fields.get('OCR')
|
||||
if cv.payment_line_ocr:
|
||||
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
|
||||
if detected_ocr:
|
||||
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
|
||||
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
|
||||
if cv.ocr_match:
|
||||
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
|
||||
else:
|
||||
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
|
||||
else:
|
||||
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
|
||||
# OVERRIDE: use payment_line OCR
|
||||
result.fields['OCR'] = cv.payment_line_ocr
|
||||
result.confidence['OCR'] = 0.95 # High confidence for payment_line
|
||||
|
||||
# Amount: payment_line takes priority
|
||||
detected_amount = result.fields.get('Amount')
|
||||
if cv.payment_line_amount:
|
||||
if detected_amount:
|
||||
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
|
||||
det_amount = self._normalize_amount_for_compare(str(detected_amount))
|
||||
cv.amount_match = pl_amount == det_amount
|
||||
if cv.amount_match:
|
||||
cv.details.append(f"Amount match: {cv.payment_line_amount}")
|
||||
else:
|
||||
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
|
||||
else:
|
||||
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
|
||||
# OVERRIDE: use payment_line Amount
|
||||
result.fields['Amount'] = cv.payment_line_amount
|
||||
result.confidence['Amount'] = 0.95
|
||||
|
||||
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
||||
detected_bankgiro = result.fields.get('Bankgiro')
|
||||
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
|
||||
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
||||
if detected_bankgiro:
|
||||
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
|
||||
cv.bankgiro_match = pl_bg_digits == det_bg_digits
|
||||
if cv.bankgiro_match:
|
||||
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
|
||||
else:
|
||||
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
|
||||
# Do NOT override - keep detected value
|
||||
|
||||
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
||||
detected_plusgiro = result.fields.get('Plusgiro')
|
||||
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
|
||||
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
||||
if detected_plusgiro:
|
||||
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
|
||||
cv.plusgiro_match = pl_pg_digits == det_pg_digits
|
||||
if cv.plusgiro_match:
|
||||
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
|
||||
else:
|
||||
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
|
||||
# Do NOT override - keep detected value
|
||||
|
||||
# Determine overall validity
|
||||
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
|
||||
# has both accounts, the other one cannot be matched - this is expected and OK.
|
||||
# Only count the account type that payment_line actually has.
|
||||
matches = [cv.ocr_match, cv.amount_match]
|
||||
|
||||
# Only include account match if payment_line has that account type
|
||||
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
|
||||
matches.append(cv.bankgiro_match)
|
||||
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
|
||||
matches.append(cv.plusgiro_match)
|
||||
|
||||
valid_matches = [m for m in matches if m is not None]
|
||||
if valid_matches:
|
||||
match_count = sum(1 for m in valid_matches if m)
|
||||
cv.is_valid = match_count >= min(2, len(valid_matches))
|
||||
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
|
||||
else:
|
||||
# No comparison possible
|
||||
cv.is_valid = True
|
||||
cv.details.append("No comparison available from payment_line")
|
||||
|
||||
result.cross_validation = cv
|
||||
|
||||
def _normalize_amount_for_compare(self, amount: str) -> float | None:
|
||||
"""Normalize amount string to float for comparison."""
|
||||
try:
|
||||
# Remove spaces, convert comma to dot
|
||||
cleaned = amount.replace(' ', '').replace(',', '.')
|
||||
# Handle Swedish format with space as thousands separator
|
||||
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
|
||||
return round(float(cleaned), 2)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _needs_fallback(self, result: InferenceResult) -> bool:
|
||||
"""Check if fallback OCR is needed."""
|
||||
# Check for key fields
|
||||
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
||||
missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
return missing >= 2 # Fallback if 2+ key fields missing
|
||||
|
||||
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||
"""Run full-page OCR fallback."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from shared.ocr import OCREngine
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
result.fallback_used = True
|
||||
ocr_engine = OCREngine()
|
||||
|
||||
try:
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Full page OCR
|
||||
tokens = ocr_engine.extract_from_image(image_array, page_no)
|
||||
full_text = ' '.join(t.text for t in tokens)
|
||||
|
||||
# Try to extract missing fields with regex patterns
|
||||
self._extract_with_patterns(full_text, result)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Fallback OCR error: {e}")
|
||||
|
||||
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
|
||||
"""Extract fields using regex patterns (fallback)."""
|
||||
patterns = {
|
||||
'Amount': [
|
||||
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
||||
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
||||
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
||||
],
|
||||
'InvoiceNumber': [
|
||||
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
||||
],
|
||||
}
|
||||
|
||||
for field_name, field_patterns in patterns.items():
|
||||
if field_name in result.fields:
|
||||
continue
|
||||
|
||||
for pattern in field_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
value = match.group(1).strip()
|
||||
|
||||
# Normalize the value
|
||||
if field_name == 'Amount':
|
||||
value = value.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
value = f"{float(value):.2f}"
|
||||
except ValueError:
|
||||
continue
|
||||
elif field_name == 'Bankgiro':
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 8:
|
||||
value = f"{digits[:4]}-{digits[4:]}"
|
||||
|
||||
result.fields[field_name] = value
|
||||
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
||||
break
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: str | Path,
|
||||
document_id: str | None = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Process a single image (for pre-rendered pages).
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
"""
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = InferenceResult(
|
||||
document_id=document_id or Path(image_path).stem
|
||||
)
|
||||
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
image_array = np.array(image)
|
||||
|
||||
# Run detection
|
||||
detections = self.detector.detect(image_array, page_no=0)
|
||||
result.raw_detections = detections
|
||||
|
||||
# Extract fields
|
||||
for detection in detections:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
result.extracted_fields.append(extracted)
|
||||
|
||||
# Merge fields
|
||||
self._merge_fields(result)
|
||||
result.success = len(result.fields) > 0
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
188
packages/backend/backend/pipeline/yolo_detector.py
Normal file
188
packages/backend/backend/pipeline/yolo_detector.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
YOLO Detection Module
|
||||
|
||||
Runs YOLO model inference for field detection.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
# Import field mappings from single source of truth
|
||||
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""Represents a single YOLO detection."""
|
||||
class_id: int
|
||||
class_name: str
|
||||
confidence: float
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
|
||||
page_no: int = 0
|
||||
|
||||
@property
|
||||
def x0(self) -> float:
|
||||
return self.bbox[0]
|
||||
|
||||
@property
|
||||
def y0(self) -> float:
|
||||
return self.bbox[1]
|
||||
|
||||
@property
|
||||
def x1(self) -> float:
|
||||
return self.bbox[2]
|
||||
|
||||
@property
|
||||
def y1(self) -> float:
|
||||
return self.bbox[3]
|
||||
|
||||
@property
|
||||
def center(self) -> tuple[float, float]:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
@property
|
||||
def width(self) -> float:
|
||||
return self.x1 - self.x0
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
return self.y1 - self.y0
|
||||
|
||||
def get_padded_bbox(
|
||||
self,
|
||||
padding: float = 0.1,
|
||||
image_width: float | None = None,
|
||||
image_height: float | None = None
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""Get bbox with padding for OCR extraction."""
|
||||
pad_x = self.width * padding
|
||||
pad_y = self.height * padding
|
||||
|
||||
x0 = self.x0 - pad_x
|
||||
y0 = self.y0 - pad_y
|
||||
x1 = self.x1 + pad_x
|
||||
y1 = self.y1 + pad_y
|
||||
|
||||
if image_width:
|
||||
x0 = max(0, x0)
|
||||
x1 = min(image_width, x1)
|
||||
if image_height:
|
||||
y0 = max(0, y0)
|
||||
y1 = min(image_height, y1)
|
||||
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
|
||||
# This ensures consistency with the trained YOLO model
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLO model wrapper for field detection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
iou_threshold: float = 0.45,
|
||||
device: str = 'auto'
|
||||
):
|
||||
"""
|
||||
Initialize YOLO detector.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained YOLO model (.pt file)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
iou_threshold: IOU threshold for NMS
|
||||
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(model_path)
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.device = device
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0
|
||||
) -> list[Detection]:
|
||||
"""
|
||||
Run detection on an image.
|
||||
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
|
||||
Returns:
|
||||
List of Detection objects
|
||||
"""
|
||||
results = self.model.predict(
|
||||
source=image,
|
||||
conf=self.confidence_threshold,
|
||||
iou=self.iou_threshold,
|
||||
device=self.device,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is None:
|
||||
continue
|
||||
|
||||
for i in range(len(boxes)):
|
||||
class_id = int(boxes.cls[i])
|
||||
confidence = float(boxes.conf[i])
|
||||
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
|
||||
|
||||
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
|
||||
|
||||
detections.append(Detection(
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
confidence=confidence,
|
||||
bbox=tuple(bbox),
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
return detections
|
||||
|
||||
def detect_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
dpi: int = 300
|
||||
) -> dict[int, list[Detection]]:
|
||||
"""
|
||||
Run detection on all pages of a PDF.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Returns:
|
||||
Dict mapping page number to list of detections
|
||||
"""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
results = {}
|
||||
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
|
||||
# Convert bytes to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
detections = self.detect(image_array, page_no=page_no)
|
||||
results[page_no] = detections
|
||||
|
||||
return results
|
||||
|
||||
def get_field_name(self, class_name: str) -> str:
|
||||
"""Convert class name to field name."""
|
||||
return CLASS_TO_FIELD.get(class_name, class_name)
|
||||
7
packages/backend/backend/validation/__init__.py
Normal file
7
packages/backend/backend/validation/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Cross-validation module for verifying field extraction using LLM.
|
||||
"""
|
||||
|
||||
from .llm_validator import LLMValidator
|
||||
|
||||
__all__ = ['LLMValidator']
|
||||
748
packages/backend/backend/validation/llm_validator.py
Normal file
748
packages/backend/backend/validation/llm_validator.py
Normal file
@@ -0,0 +1,748 @@
|
||||
"""
|
||||
LLM-based cross-validation for invoice field extraction.
|
||||
|
||||
Uses a vision LLM to extract fields from invoice PDFs and compare with
|
||||
the autolabel results to identify potential errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMExtractionResult:
|
||||
"""Result of LLM field extraction."""
|
||||
document_id: str
|
||||
invoice_number: Optional[str] = None
|
||||
invoice_date: Optional[str] = None
|
||||
invoice_due_date: Optional[str] = None
|
||||
ocr_number: Optional[str] = None
|
||||
bankgiro: Optional[str] = None
|
||||
plusgiro: Optional[str] = None
|
||||
amount: Optional[str] = None
|
||||
supplier_organisation_number: Optional[str] = None
|
||||
raw_response: Optional[str] = None
|
||||
model_used: Optional[str] = None
|
||||
processing_time_ms: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class LLMValidator:
|
||||
"""
|
||||
Cross-validates invoice field extraction using LLM.
|
||||
|
||||
Queries documents with failed field matches from the database,
|
||||
sends the PDF images to an LLM for extraction, and stores
|
||||
the results for comparison.
|
||||
"""
|
||||
|
||||
# Fields to extract (excluding customer_number as requested)
|
||||
FIELDS_TO_EXTRACT = [
|
||||
'InvoiceNumber',
|
||||
'InvoiceDate',
|
||||
'InvoiceDueDate',
|
||||
'OCR',
|
||||
'Bankgiro',
|
||||
'Plusgiro',
|
||||
'Amount',
|
||||
'supplier_organisation_number',
|
||||
]
|
||||
|
||||
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
|
||||
|
||||
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
|
||||
|
||||
{
|
||||
"invoice_number": "the invoice number/fakturanummer",
|
||||
"invoice_date": "the invoice date in YYYY-MM-DD format",
|
||||
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
|
||||
"ocr_number": "the OCR payment reference number",
|
||||
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
|
||||
"plusgiro": "the plusgiro number",
|
||||
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
|
||||
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
|
||||
}
|
||||
|
||||
Rules:
|
||||
- If a field is not found or not visible, use null
|
||||
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
|
||||
- For amounts, extract just the numeric value without currency symbols
|
||||
- The OCR number is typically a long number used for payment reference
|
||||
- Look for "Att betala" or "Summa att betala" for the amount
|
||||
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
|
||||
|
||||
Return ONLY the JSON object, no other text."""
|
||||
|
||||
def __init__(self, connection_string: str = None, pdf_dir: str = None):
|
||||
"""
|
||||
Initialize the validator.
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection string
|
||||
pdf_dir: Directory containing PDF files
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS
|
||||
|
||||
self.connection_string = connection_string or get_db_connection_string()
|
||||
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
|
||||
self.conn = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to database."""
|
||||
if self.conn is None:
|
||||
self.conn = psycopg2.connect(self.connection_string)
|
||||
return self.conn
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def create_validation_table(self):
|
||||
"""Create the llm_validation table if it doesn't exist."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS llm_validations (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
-- Extracted fields
|
||||
invoice_number TEXT,
|
||||
invoice_date TEXT,
|
||||
invoice_due_date TEXT,
|
||||
ocr_number TEXT,
|
||||
bankgiro TEXT,
|
||||
plusgiro TEXT,
|
||||
amount TEXT,
|
||||
supplier_organisation_number TEXT,
|
||||
-- Metadata
|
||||
raw_response TEXT,
|
||||
model_used TEXT,
|
||||
processing_time_ms REAL,
|
||||
error TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
-- Comparison results (populated later)
|
||||
comparison_results JSONB,
|
||||
UNIQUE(document_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
|
||||
ON llm_validations(document_id);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def get_documents_with_failed_matches(
|
||||
self,
|
||||
exclude_customer_number: bool = True,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get documents that have at least one failed field match.
|
||||
|
||||
Args:
|
||||
exclude_customer_number: If True, ignore customer_number failures
|
||||
limit: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of document info with failed fields
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Find documents with failed matches (excluding customer_number if requested)
|
||||
exclude_clause = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause = "AND fr.field_name != 'customer_number'"
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
|
||||
d.supplier_name, d.split
|
||||
FROM documents d
|
||||
JOIN field_results fr ON d.document_id = fr.document_id
|
||||
WHERE fr.matched = false
|
||||
AND fr.field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
AND d.document_id NOT IN (
|
||||
SELECT document_id FROM llm_validations WHERE error IS NULL
|
||||
)
|
||||
ORDER BY d.document_id
|
||||
"""
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query)
|
||||
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
doc_id = row[0]
|
||||
|
||||
# Get failed fields for this document
|
||||
exclude_clause_inner = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause_inner = "AND field_name != 'customer_number'"
|
||||
cursor.execute(f"""
|
||||
SELECT field_name, csv_value, score
|
||||
FROM field_results
|
||||
WHERE document_id = %s
|
||||
AND matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause_inner}
|
||||
""", (doc_id,))
|
||||
|
||||
failed_fields = [
|
||||
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
|
||||
for r in cursor.fetchall()
|
||||
]
|
||||
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'supplier_name': row[3],
|
||||
'split': row[4],
|
||||
'failed_fields': failed_fields,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
|
||||
"""Get statistics about failed matches."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
exclude_clause = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause = "AND field_name != 'customer_number'"
|
||||
|
||||
# Count by field
|
||||
cursor.execute(f"""
|
||||
SELECT field_name, COUNT(*) as cnt
|
||||
FROM field_results
|
||||
WHERE matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
GROUP BY field_name
|
||||
ORDER BY cnt DESC
|
||||
""")
|
||||
by_field = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Count documents with failures
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(DISTINCT document_id)
|
||||
FROM field_results
|
||||
WHERE matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
""")
|
||||
doc_count = cursor.fetchone()[0]
|
||||
|
||||
# Already validated count
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
|
||||
""")
|
||||
validated_count = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
'documents_with_failures': doc_count,
|
||||
'already_validated': validated_count,
|
||||
'remaining': doc_count - validated_count,
|
||||
'failures_by_field': by_field,
|
||||
}
|
||||
|
||||
def render_pdf_to_image(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
page_no: int = 0,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
max_size_mb: float = 18.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Render a PDF page to PNG image bytes.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
page_no: Page number to render (0-indexed)
|
||||
dpi: Resolution for rendering
|
||||
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
|
||||
|
||||
Returns:
|
||||
PNG image bytes
|
||||
"""
|
||||
import fitz # PyMuPDF
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[page_no]
|
||||
|
||||
# Try different DPI values until we get a small enough image
|
||||
dpi_values = [dpi, 120, 100, 72, 50]
|
||||
|
||||
for current_dpi in dpi_values:
|
||||
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
png_bytes = pix.tobytes("png")
|
||||
|
||||
size_mb = len(png_bytes) / (1024 * 1024)
|
||||
if size_mb <= max_size_mb:
|
||||
doc.close()
|
||||
return png_bytes
|
||||
|
||||
# If still too large, use JPEG compression
|
||||
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# Convert to PIL Image and compress as JPEG
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
# Try different JPEG quality levels
|
||||
for quality in [85, 70, 50, 30]:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=quality)
|
||||
jpeg_bytes = buffer.getvalue()
|
||||
|
||||
size_mb = len(jpeg_bytes) / (1024 * 1024)
|
||||
if size_mb <= max_size_mb:
|
||||
doc.close()
|
||||
return jpeg_bytes
|
||||
|
||||
doc.close()
|
||||
# Return whatever we have, let the API handle the error
|
||||
return jpeg_bytes
|
||||
|
||||
def extract_with_openai(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
model: str = "gpt-4o"
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Extract fields using OpenAI's vision API (supports Azure OpenAI).
|
||||
|
||||
Args:
|
||||
image_bytes: PNG image bytes
|
||||
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
import openai
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Encode image to base64 and detect format
|
||||
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
|
||||
if image_bytes[:4] == b'\x89PNG':
|
||||
media_type = "image/png"
|
||||
else:
|
||||
media_type = "image/jpeg"
|
||||
|
||||
# Check for Azure OpenAI configuration
|
||||
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
||||
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
|
||||
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
|
||||
|
||||
if azure_endpoint and azure_api_key:
|
||||
# Use Azure OpenAI
|
||||
client = openai.AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=azure_api_key,
|
||||
api_version="2024-02-15-preview"
|
||||
)
|
||||
model = azure_deployment # Use deployment name for Azure
|
||||
else:
|
||||
# Use standard OpenAI
|
||||
client = openai.OpenAI()
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.EXTRACTION_PROMPT},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{media_type};base64,{image_b64}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
max_tokens=1000,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
raw_response = response.choices[0].message.content
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Parse JSON response
|
||||
# Try to extract JSON from response (may have markdown code blocks)
|
||||
json_str = raw_response
|
||||
if "```json" in json_str:
|
||||
json_str = json_str.split("```json")[1].split("```")[0]
|
||||
elif "```" in json_str:
|
||||
json_str = json_str.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(json_str.strip())
|
||||
|
||||
return LLMExtractionResult(
|
||||
document_id="", # Will be set by caller
|
||||
invoice_number=data.get('invoice_number'),
|
||||
invoice_date=data.get('invoice_date'),
|
||||
invoice_due_date=data.get('invoice_due_date'),
|
||||
ocr_number=data.get('ocr_number'),
|
||||
bankgiro=data.get('bankgiro'),
|
||||
plusgiro=data.get('plusgiro'),
|
||||
amount=data.get('amount'),
|
||||
supplier_organisation_number=data.get('supplier_organisation_number'),
|
||||
raw_response=raw_response,
|
||||
model_used=model,
|
||||
processing_time_ms=processing_time,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
raw_response=raw_response if 'raw_response' in dir() else None,
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=f"JSON parse error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def extract_with_anthropic(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
model: str = "claude-sonnet-4-20250514"
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Extract fields using Anthropic's Claude API.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG image bytes
|
||||
model: Model to use
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
import anthropic
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Encode image to base64
|
||||
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": image_b64,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.EXTRACTION_PROMPT
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
raw_response = response.content[0].text
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Parse JSON response
|
||||
json_str = raw_response
|
||||
if "```json" in json_str:
|
||||
json_str = json_str.split("```json")[1].split("```")[0]
|
||||
elif "```" in json_str:
|
||||
json_str = json_str.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(json_str.strip())
|
||||
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
invoice_number=data.get('invoice_number'),
|
||||
invoice_date=data.get('invoice_date'),
|
||||
invoice_due_date=data.get('invoice_due_date'),
|
||||
ocr_number=data.get('ocr_number'),
|
||||
bankgiro=data.get('bankgiro'),
|
||||
plusgiro=data.get('plusgiro'),
|
||||
amount=data.get('amount'),
|
||||
supplier_organisation_number=data.get('supplier_organisation_number'),
|
||||
raw_response=raw_response,
|
||||
model_used=model,
|
||||
processing_time_ms=processing_time,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
raw_response=raw_response if 'raw_response' in dir() else None,
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=f"JSON parse error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def save_validation_result(self, result: LLMExtractionResult):
|
||||
"""Save extraction result to database."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO llm_validations (
|
||||
document_id, invoice_number, invoice_date, invoice_due_date,
|
||||
ocr_number, bankgiro, plusgiro, amount,
|
||||
supplier_organisation_number, raw_response, model_used,
|
||||
processing_time_ms, error
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (document_id) DO UPDATE SET
|
||||
invoice_number = EXCLUDED.invoice_number,
|
||||
invoice_date = EXCLUDED.invoice_date,
|
||||
invoice_due_date = EXCLUDED.invoice_due_date,
|
||||
ocr_number = EXCLUDED.ocr_number,
|
||||
bankgiro = EXCLUDED.bankgiro,
|
||||
plusgiro = EXCLUDED.plusgiro,
|
||||
amount = EXCLUDED.amount,
|
||||
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
|
||||
raw_response = EXCLUDED.raw_response,
|
||||
model_used = EXCLUDED.model_used,
|
||||
processing_time_ms = EXCLUDED.processing_time_ms,
|
||||
error = EXCLUDED.error,
|
||||
created_at = NOW()
|
||||
""", (
|
||||
result.document_id,
|
||||
result.invoice_number,
|
||||
result.invoice_date,
|
||||
result.invoice_due_date,
|
||||
result.ocr_number,
|
||||
result.bankgiro,
|
||||
result.plusgiro,
|
||||
result.amount,
|
||||
result.supplier_organisation_number,
|
||||
result.raw_response,
|
||||
result.model_used,
|
||||
result.processing_time_ms,
|
||||
result.error,
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
def validate_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
provider: str = "openai",
|
||||
model: str = None
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Validate a single document using LLM.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
provider: LLM provider ("openai" or "anthropic")
|
||||
model: Model to use (defaults based on provider)
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
# Get PDF path
|
||||
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
||||
if not pdf_path.exists():
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"PDF not found: {pdf_path}"
|
||||
)
|
||||
|
||||
# Render first page
|
||||
try:
|
||||
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"Failed to render PDF: {str(e)}"
|
||||
)
|
||||
|
||||
# Extract with LLM
|
||||
if provider == "openai":
|
||||
model = model or "gpt-4o"
|
||||
result = self.extract_with_openai(image_bytes, model)
|
||||
elif provider == "anthropic":
|
||||
model = model or "claude-sonnet-4-20250514"
|
||||
result = self.extract_with_anthropic(image_bytes, model)
|
||||
else:
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"Unknown provider: {provider}"
|
||||
)
|
||||
|
||||
result.document_id = doc_id
|
||||
|
||||
# Save to database
|
||||
self.save_validation_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def validate_batch(
|
||||
self,
|
||||
limit: int = 10,
|
||||
provider: str = "openai",
|
||||
model: str = None,
|
||||
verbose: bool = True
|
||||
) -> List[LLMExtractionResult]:
|
||||
"""
|
||||
Validate a batch of documents with failed matches.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of documents to validate
|
||||
provider: LLM provider
|
||||
model: Model to use
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
List of extraction results
|
||||
"""
|
||||
# Get documents to validate
|
||||
docs = self.get_documents_with_failed_matches(limit=limit)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(docs)} documents with failed matches to validate")
|
||||
|
||||
results = []
|
||||
for i, doc in enumerate(docs):
|
||||
doc_id = doc['document_id']
|
||||
|
||||
if verbose:
|
||||
failed_fields = [f['field'] for f in doc['failed_fields']]
|
||||
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
||||
|
||||
result = self.validate_document(doc_id, provider, model)
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
if result.error:
|
||||
print(f" ERROR: {result.error}")
|
||||
else:
|
||||
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
||||
|
||||
return results
|
||||
|
||||
def compare_results(self, doc_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare LLM extraction with autolabel results.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Comparison results
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Get autolabel results
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, matched, matched_text
|
||||
FROM field_results
|
||||
WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
autolabel = {}
|
||||
for row in cursor.fetchall():
|
||||
autolabel[row[0]] = {
|
||||
'csv_value': row[1],
|
||||
'matched': row[2],
|
||||
'matched_text': row[3],
|
||||
}
|
||||
|
||||
# Get LLM results
|
||||
cursor.execute("""
|
||||
SELECT invoice_number, invoice_date, invoice_due_date,
|
||||
ocr_number, bankgiro, plusgiro, amount,
|
||||
supplier_organisation_number
|
||||
FROM llm_validations
|
||||
WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return {'error': 'No LLM validation found'}
|
||||
|
||||
llm = {
|
||||
'InvoiceNumber': row[0],
|
||||
'InvoiceDate': row[1],
|
||||
'InvoiceDueDate': row[2],
|
||||
'OCR': row[3],
|
||||
'Bankgiro': row[4],
|
||||
'Plusgiro': row[5],
|
||||
'Amount': row[6],
|
||||
'supplier_organisation_number': row[7],
|
||||
}
|
||||
|
||||
# Compare
|
||||
comparison = {}
|
||||
for field in self.FIELDS_TO_EXTRACT:
|
||||
auto = autolabel.get(field, {})
|
||||
llm_value = llm.get(field)
|
||||
|
||||
comparison[field] = {
|
||||
'csv_value': auto.get('csv_value'),
|
||||
'autolabel_matched': auto.get('matched'),
|
||||
'autolabel_text': auto.get('matched_text'),
|
||||
'llm_value': llm_value,
|
||||
'agreement': self._values_match(auto.get('csv_value'), llm_value),
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
def _values_match(self, csv_value: str, llm_value: str) -> bool:
|
||||
"""Check if CSV value matches LLM extracted value."""
|
||||
if csv_value is None or llm_value is None:
|
||||
return csv_value == llm_value
|
||||
|
||||
# Normalize for comparison
|
||||
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
|
||||
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
|
||||
|
||||
return csv_norm == llm_norm
|
||||
9
packages/backend/backend/web/__init__.py
Normal file
9
packages/backend/backend/web/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Web Application Module
|
||||
|
||||
Provides REST API and web interface for invoice field extraction.
|
||||
"""
|
||||
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
8
packages/backend/backend/web/admin_routes_new.py
Normal file
8
packages/backend/backend/web/admin_routes_new.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Backward compatibility shim for admin_routes.py
|
||||
|
||||
DEPRECATED: Import from backend.web.api.v1.admin.documents instead.
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import *
|
||||
|
||||
__all__ = ["create_admin_router"]
|
||||
0
packages/backend/backend/web/api/__init__.py
Normal file
0
packages/backend/backend/web/api/__init__.py
Normal file
0
packages/backend/backend/web/api/v1/__init__.py
Normal file
0
packages/backend/backend/web/api/v1/__init__.py
Normal file
21
packages/backend/backend/web/api/v1/admin/__init__.py
Normal file
21
packages/backend/backend/web/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Admin API v1
|
||||
|
||||
Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from backend.web.api.v1.admin.annotations import create_annotation_router
|
||||
from backend.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from backend.web.api.v1.admin.auth import create_auth_router
|
||||
from backend.web.api.v1.admin.documents import create_documents_router
|
||||
from backend.web.api.v1.admin.locks import create_locks_router
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_augmentation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
"create_training_router",
|
||||
]
|
||||
706
packages/backend/backend/web/api/v1/admin/annotations.py
Normal file
706
packages/backend/backend/web/api/v1/admin/annotations.py
Normal file
@@ -0,0 +1,706 @@
|
||||
"""
|
||||
Admin Annotation API Routes
|
||||
|
||||
FastAPI endpoints for annotation management.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from backend.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from backend.web.core.auth import AdminTokenDep
|
||||
from backend.web.services.autolabel import get_auto_label_service
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
from backend.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
AnnotationListResponse,
|
||||
AnnotationOverrideRequest,
|
||||
AnnotationOverrideResponse,
|
||||
AnnotationResponse,
|
||||
AnnotationSource,
|
||||
AnnotationUpdate,
|
||||
AnnotationVerifyRequest,
|
||||
AnnotationVerifyResponse,
|
||||
AutoLabelRequest,
|
||||
AutoLabelResponse,
|
||||
BoundingBox,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global repository instances
|
||||
_doc_repo: DocumentRepository | None = None
|
||||
_ann_repo: AnnotationRepository | None = None
|
||||
|
||||
|
||||
def get_doc_repository() -> DocumentRepository:
|
||||
"""Get the DocumentRepository instance."""
|
||||
global _doc_repo
|
||||
if _doc_repo is None:
|
||||
_doc_repo = DocumentRepository()
|
||||
return _doc_repo
|
||||
|
||||
|
||||
def get_ann_repository() -> AnnotationRepository:
|
||||
"""Get the AnnotationRepository instance."""
|
||||
global _ann_repo
|
||||
if _ann_repo is None:
|
||||
_ann_repo = AnnotationRepository()
|
||||
return _ann_repo
|
||||
|
||||
|
||||
# Type aliases for dependency injection
|
||||
DocRepoDep = Annotated[DocumentRepository, Depends(get_doc_repository)]
|
||||
AnnRepoDep = Annotated[AnnotationRepository, Depends(get_ann_repository)]
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_annotation_router() -> APIRouter:
|
||||
"""Create annotation API router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
|
||||
|
||||
# =========================================================================
|
||||
# Image Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/images/{page_number}",
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {"content": {"image/png": {}}, "description": "Page image"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Get page image",
|
||||
description="Get the image for a specific page.",
|
||||
)
|
||||
async def get_page_image(
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
) -> FileResponse | StreamingResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if page_number < 1 or page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||
)
|
||||
|
||||
# Get storage helper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Check if image exists
|
||||
if not storage.admin_image_exists(document_id, page_number):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Image for page {page_number} not found",
|
||||
)
|
||||
|
||||
# Try to get local path for efficient file serving
|
||||
local_path = storage.get_admin_image_local_path(document_id, page_number)
|
||||
if local_path is not None:
|
||||
return FileResponse(
|
||||
path=str(local_path),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
)
|
||||
|
||||
# Fall back to streaming for cloud storage
|
||||
image_content = storage.get_admin_image(document_id, page_number)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(image_content),
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Annotation Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="List annotations",
|
||||
description="Get all annotations for a document.",
|
||||
)
|
||||
async def list_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
page_number: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Filter by page number"),
|
||||
] = None,
|
||||
) -> AnnotationListResponse:
|
||||
"""List annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = ann_repo.get_for_document(document_id, page_number)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
return AnnotationListResponse(
|
||||
document_id=document_id,
|
||||
page_count=document.page_count,
|
||||
total_annotations=len(annotations),
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Create annotation",
|
||||
description="Create a new annotation for a document.",
|
||||
)
|
||||
async def create_annotation(
|
||||
document_id: str,
|
||||
request: AnnotationCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Create a new annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if request.page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
|
||||
)
|
||||
|
||||
# Get image dimensions for normalization
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {request.page_number} not available",
|
||||
)
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
width = request.bbox.width / image_width
|
||||
height = request.bbox.height / image_height
|
||||
|
||||
# Get class name
|
||||
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
|
||||
|
||||
# Create annotation
|
||||
annotation_id = ann_repo.create(
|
||||
document_id=document_id,
|
||||
page_number=request.page_number,
|
||||
class_id=request.class_id,
|
||||
class_name=class_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=request.bbox.x,
|
||||
bbox_y=request.bbox.y,
|
||||
bbox_width=request.bbox.width,
|
||||
bbox_height=request.bbox.height,
|
||||
text_value=request.text_value,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Keep status as pending - user must click "Mark Complete" to finalize
|
||||
# This allows user to add multiple annotations before saving to PostgreSQL
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation created successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Update annotation",
|
||||
description="Update an existing annotation.",
|
||||
)
|
||||
async def update_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationUpdate,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Update an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = ann_repo.get(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Prepare update data
|
||||
update_kwargs = {}
|
||||
|
||||
if request.class_id is not None:
|
||||
update_kwargs["class_id"] = request.class_id
|
||||
update_kwargs["class_name"] = FIELD_CLASSES.get(
|
||||
request.class_id, f"class_{request.class_id}"
|
||||
)
|
||||
|
||||
if request.text_value is not None:
|
||||
update_kwargs["text_value"] = request.text_value
|
||||
|
||||
if request.bbox is not None:
|
||||
# Get image dimensions
|
||||
storage = get_storage_helper()
|
||||
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
|
||||
if dimensions is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {annotation.page_number} not available",
|
||||
)
|
||||
image_width, image_height = dimensions
|
||||
|
||||
# Calculate normalized coordinates
|
||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
update_kwargs["width"] = request.bbox.width / image_width
|
||||
update_kwargs["height"] = request.bbox.height / image_height
|
||||
update_kwargs["bbox_x"] = request.bbox.x
|
||||
update_kwargs["bbox_y"] = request.bbox.y
|
||||
update_kwargs["bbox_width"] = request.bbox.width
|
||||
update_kwargs["bbox_height"] = request.bbox.height
|
||||
|
||||
# Update annotation
|
||||
if update_kwargs:
|
||||
success = ann_repo.update(annotation_id, **update_kwargs)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update annotation",
|
||||
)
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation updated successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Delete annotation",
|
||||
description="Delete an annotation.",
|
||||
)
|
||||
async def delete_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> dict:
|
||||
"""Delete an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = ann_repo.get(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Delete annotation
|
||||
ann_repo.delete(annotation_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"annotation_id": annotation_id,
|
||||
"message": "Annotation deleted successfully",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Auto-Labeling Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/auto-label",
|
||||
response_model=AutoLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Trigger auto-labeling",
|
||||
description="Trigger auto-labeling for a document using field values.",
|
||||
)
|
||||
async def trigger_auto_label(
|
||||
document_id: str,
|
||||
request: AutoLabelRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AutoLabelResponse:
|
||||
"""Trigger auto-labeling for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate field values
|
||||
if not request.field_values:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one field value is required",
|
||||
)
|
||||
|
||||
# Get the actual file path from storage
|
||||
# document.file_path is a relative storage path like "raw_pdfs/uuid.pdf"
|
||||
storage = get_storage_helper()
|
||||
filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path
|
||||
file_path = storage.get_raw_pdf_local_path(filename)
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cannot find PDF file: {document.file_path}",
|
||||
)
|
||||
|
||||
# Run auto-labeling
|
||||
service = get_auto_label_service()
|
||||
result = service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(file_path),
|
||||
field_values=request.field_values,
|
||||
doc_repo=doc_repo,
|
||||
ann_repo=ann_repo,
|
||||
replace_existing=request.replace_existing,
|
||||
)
|
||||
|
||||
if result["status"] == "failed":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return AutoLabelResponse(
|
||||
document_id=document_id,
|
||||
status=result["status"],
|
||||
annotations_created=result["annotations_created"],
|
||||
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete all annotations",
|
||||
description="Delete all annotations for a document (optionally filter by source).",
|
||||
)
|
||||
async def delete_all_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by source (manual, auto, imported)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Delete all annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate source
|
||||
if source and source not in ("manual", "auto", "imported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid source: {source}",
|
||||
)
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Delete annotations
|
||||
deleted_count = ann_repo.delete_for_document(document_id, source)
|
||||
|
||||
# Update document status if all annotations deleted
|
||||
remaining = ann_repo.get_for_document(document_id)
|
||||
if not remaining:
|
||||
doc_repo.update_status(document_id, "pending")
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"Deleted {deleted_count} annotations",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Phase 5: Annotation Enhancement
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations/{annotation_id}/verify",
|
||||
response_model=AnnotationVerifyResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Verify annotation",
|
||||
description="Mark an annotation as verified by a human reviewer.",
|
||||
)
|
||||
async def verify_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
|
||||
) -> AnnotationVerifyResponse:
|
||||
"""Verify an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Verify the annotation
|
||||
annotation = ann_repo.verify(annotation_id, admin_token)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
return AnnotationVerifyResponse(
|
||||
annotation_id=annotation_id,
|
||||
is_verified=annotation.is_verified,
|
||||
verified_at=annotation.verified_at,
|
||||
verified_by=annotation.verified_by,
|
||||
message="Annotation verified successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}/override",
|
||||
response_model=AnnotationOverrideResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Override annotation",
|
||||
description="Override an auto-generated annotation with manual corrections.",
|
||||
)
|
||||
async def override_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationOverrideRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationOverrideResponse:
|
||||
"""Override an auto-generated annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Build updates dict from request
|
||||
updates = {}
|
||||
if request.text_value is not None:
|
||||
updates["text_value"] = request.text_value
|
||||
if request.class_id is not None:
|
||||
updates["class_id"] = request.class_id
|
||||
# Update class_name if class_id changed
|
||||
if request.class_id in FIELD_CLASSES:
|
||||
updates["class_name"] = FIELD_CLASSES[request.class_id]
|
||||
if request.class_name is not None:
|
||||
updates["class_name"] = request.class_name
|
||||
if request.bbox:
|
||||
# Update bbox fields
|
||||
if "x" in request.bbox:
|
||||
updates["bbox_x"] = request.bbox["x"]
|
||||
if "y" in request.bbox:
|
||||
updates["bbox_y"] = request.bbox["y"]
|
||||
if "width" in request.bbox:
|
||||
updates["bbox_width"] = request.bbox["width"]
|
||||
if "height" in request.bbox:
|
||||
updates["bbox_height"] = request.bbox["height"]
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No updates provided. Specify at least one field to update.",
|
||||
)
|
||||
|
||||
# Override the annotation
|
||||
annotation = ann_repo.override(
|
||||
annotation_id=annotation_id,
|
||||
admin_token=admin_token,
|
||||
change_reason=request.reason,
|
||||
**updates,
|
||||
)
|
||||
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Get history to return history_id
|
||||
history_records = ann_repo.get_history(UUID(annotation_id))
|
||||
latest_history = history_records[0] if history_records else None
|
||||
|
||||
return AnnotationOverrideResponse(
|
||||
annotation_id=annotation_id,
|
||||
source=annotation.source,
|
||||
override_source=annotation.override_source,
|
||||
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
|
||||
message="Annotation overridden successfully",
|
||||
history_id=str(latest_history.history_id) if latest_history else "",
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Augmentation API module."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import register_augmentation_routes
|
||||
|
||||
|
||||
def create_augmentation_router() -> APIRouter:
|
||||
"""Create and configure the augmentation router."""
|
||||
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
|
||||
register_augmentation_routes(router)
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_augmentation_router"]
|
||||
160
packages/backend/backend/web/api/v1/admin/augmentation/routes.py
Normal file
160
packages/backend/backend/web/api/v1/admin/augmentation/routes.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Augmentation API routes."""
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep, DatasetRepoDep
|
||||
from backend.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchRequest,
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewRequest,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentationTypeInfo,
|
||||
AugmentationTypesResponse,
|
||||
AugmentedDatasetListResponse,
|
||||
PresetInfo,
|
||||
PresetsResponse,
|
||||
)
|
||||
|
||||
|
||||
def register_augmentation_routes(router: APIRouter) -> None:
|
||||
"""Register augmentation endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/types",
|
||||
response_model=AugmentationTypesResponse,
|
||||
summary="List available augmentation types",
|
||||
)
|
||||
async def list_augmentation_types(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> AugmentationTypesResponse:
|
||||
"""
|
||||
List all available augmentation types with descriptions and parameters.
|
||||
"""
|
||||
from shared.augmentation.pipeline import (
|
||||
AUGMENTATION_REGISTRY,
|
||||
AugmentationPipeline,
|
||||
)
|
||||
|
||||
types = []
|
||||
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
# Create instance with empty params to get preview params
|
||||
aug = aug_class({})
|
||||
types.append(
|
||||
AugmentationTypeInfo(
|
||||
name=name,
|
||||
description=(aug_class.__doc__ or "").strip(),
|
||||
affects_geometry=aug_class.affects_geometry,
|
||||
stage=AugmentationPipeline.STAGE_MAPPING[name],
|
||||
default_params=aug.get_preview_params(),
|
||||
)
|
||||
)
|
||||
|
||||
return AugmentationTypesResponse(augmentation_types=types)
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
response_model=PresetsResponse,
|
||||
summary="Get augmentation presets",
|
||||
)
|
||||
async def get_presets(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> PresetsResponse:
|
||||
"""Get predefined augmentation presets for common use cases."""
|
||||
from shared.augmentation.presets import list_presets
|
||||
|
||||
presets = [PresetInfo(**p) for p in list_presets()]
|
||||
return PresetsResponse(presets=presets)
|
||||
|
||||
@router.post(
|
||||
"/preview/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview augmentation on document image",
|
||||
)
|
||||
async def preview_augmentation(
|
||||
document_id: str,
|
||||
request: AugmentationPreviewRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Returns URLs to original and augmented preview images.
|
||||
"""
|
||||
from backend.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(doc_repo=docs)
|
||||
return await service.preview_single(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
augmentation_type=request.augmentation_type,
|
||||
params=request.params,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/preview-config/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview full augmentation config on document",
|
||||
)
|
||||
async def preview_config(
|
||||
document_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""Preview complete augmentation pipeline on a document page."""
|
||||
from backend.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(doc_repo=docs)
|
||||
return await service.preview_config(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=AugmentationBatchResponse,
|
||||
summary="Create augmented dataset (offline preprocessing)",
|
||||
)
|
||||
async def create_augmented_dataset(
|
||||
request: AugmentationBatchRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
datasets: DatasetRepoDep,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
This runs as a background task. The augmented images are stored
|
||||
alongside the original dataset for training.
|
||||
"""
|
||||
from backend.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(doc_repo=docs, dataset_repo=datasets)
|
||||
return await service.create_augmented_dataset(
|
||||
source_dataset_id=request.dataset_id,
|
||||
config=request.config,
|
||||
output_name=request.output_name,
|
||||
multiplier=request.multiplier,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=AugmentedDatasetListResponse,
|
||||
summary="List augmented datasets",
|
||||
)
|
||||
async def list_augmented_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
datasets: DatasetRepoDep,
|
||||
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
|
||||
offset: int = Query(default=0, ge=0, description="Offset"),
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""List all augmented datasets."""
|
||||
from backend.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(dataset_repo=datasets)
|
||||
return await service.list_augmented_datasets(limit=limit, offset=offset)
|
||||
82
packages/backend/backend/web/api/v1/admin/auth.py
Normal file
82
packages/backend/backend/web/api/v1/admin/auth.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Admin Auth Routes
|
||||
|
||||
FastAPI endpoints for admin token management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, TokenRepoDep
|
||||
from backend.web.schemas.admin import (
|
||||
AdminTokenCreate,
|
||||
AdminTokenResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_auth_router() -> APIRouter:
|
||||
"""Create admin auth router."""
|
||||
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
response_model=AdminTokenResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
},
|
||||
summary="Create admin token",
|
||||
description="Create a new admin authentication token.",
|
||||
)
|
||||
async def create_token(
|
||||
request: AdminTokenCreate,
|
||||
tokens: TokenRepoDep,
|
||||
) -> AdminTokenResponse:
|
||||
"""Create a new admin token."""
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Calculate expiration (use timezone-aware datetime)
|
||||
expires_at = None
|
||||
if request.expires_in_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days)
|
||||
|
||||
# Create token in database
|
||||
tokens.create(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
return AdminTokenResponse(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
message="Admin token created successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/token",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Revoke admin token",
|
||||
description="Revoke the current admin token.",
|
||||
)
|
||||
async def revoke_token(
|
||||
admin_token: AdminTokenDep,
|
||||
tokens: TokenRepoDep,
|
||||
) -> dict:
|
||||
"""Revoke the current admin token."""
|
||||
tokens.deactivate(admin_token)
|
||||
return {
|
||||
"status": "revoked",
|
||||
"message": "Admin token has been revoked",
|
||||
}
|
||||
|
||||
return router
|
||||
135
packages/backend/backend/web/api/v1/admin/dashboard.py
Normal file
135
packages/backend/backend/web/api/v1/admin/dashboard.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Dashboard API Routes
|
||||
|
||||
FastAPI endpoints for dashboard statistics and activity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from backend.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
get_model_version_repository,
|
||||
get_training_task_repository,
|
||||
ModelVersionRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from backend.web.schemas.admin import (
|
||||
DashboardStatsResponse,
|
||||
DashboardActiveModelResponse,
|
||||
ActiveModelInfo,
|
||||
RunningTrainingInfo,
|
||||
RecentActivityResponse,
|
||||
ActivityItem,
|
||||
)
|
||||
from backend.web.services.dashboard_service import (
|
||||
DashboardStatsService,
|
||||
DashboardActivityService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_dashboard_router() -> APIRouter:
|
||||
"""Create dashboard API router."""
|
||||
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=DashboardStatsResponse,
|
||||
summary="Get dashboard statistics",
|
||||
description="Returns document counts and annotation completeness metrics.",
|
||||
)
|
||||
async def get_dashboard_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> DashboardStatsResponse:
|
||||
"""Get dashboard statistics."""
|
||||
service = DashboardStatsService()
|
||||
stats = service.get_stats()
|
||||
|
||||
return DashboardStatsResponse(
|
||||
total_documents=stats["total_documents"],
|
||||
annotation_complete=stats["annotation_complete"],
|
||||
annotation_incomplete=stats["annotation_incomplete"],
|
||||
pending=stats["pending"],
|
||||
completeness_rate=stats["completeness_rate"],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/active-model",
|
||||
response_model=DashboardActiveModelResponse,
|
||||
summary="Get active model info",
|
||||
description="Returns current active model and running training status.",
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
model_repo: ModelVersionRepoDep,
|
||||
task_repo: TrainingTaskRepoDep,
|
||||
) -> DashboardActiveModelResponse:
|
||||
"""Get active model and training status."""
|
||||
# Get active model
|
||||
active_model = model_repo.get_active()
|
||||
model_info = None
|
||||
|
||||
if active_model:
|
||||
model_info = ActiveModelInfo(
|
||||
version_id=str(active_model.version_id),
|
||||
version=active_model.version,
|
||||
name=active_model.name,
|
||||
metrics_mAP=active_model.metrics_mAP,
|
||||
metrics_precision=active_model.metrics_precision,
|
||||
metrics_recall=active_model.metrics_recall,
|
||||
document_count=active_model.document_count,
|
||||
activated_at=active_model.activated_at,
|
||||
)
|
||||
|
||||
# Get running training task
|
||||
running_task = task_repo.get_running()
|
||||
training_info = None
|
||||
|
||||
if running_task:
|
||||
training_info = RunningTrainingInfo(
|
||||
task_id=str(running_task.task_id),
|
||||
name=running_task.name,
|
||||
status=running_task.status,
|
||||
started_at=running_task.started_at,
|
||||
progress=running_task.progress or 0,
|
||||
)
|
||||
|
||||
return DashboardActiveModelResponse(
|
||||
model=model_info,
|
||||
running_training=training_info,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/activity",
|
||||
response_model=RecentActivityResponse,
|
||||
summary="Get recent activity",
|
||||
description="Returns recent system activities sorted by timestamp.",
|
||||
)
|
||||
async def get_recent_activity(
|
||||
admin_token: AdminTokenDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=50, description="Maximum number of activities"),
|
||||
] = 10,
|
||||
) -> RecentActivityResponse:
|
||||
"""Get recent system activity."""
|
||||
service = DashboardActivityService()
|
||||
activities = service.get_recent_activities(limit=limit)
|
||||
|
||||
return RecentActivityResponse(
|
||||
activities=[
|
||||
ActivityItem(
|
||||
type=act["type"],
|
||||
description=act["description"],
|
||||
timestamp=act["timestamp"],
|
||||
metadata=act["metadata"],
|
||||
)
|
||||
for act in activities
|
||||
]
|
||||
)
|
||||
|
||||
return router
|
||||
699
packages/backend/backend/web/api/v1/admin/documents.py
Normal file
699
packages/backend/backend/web/api/v1/admin/documents.py
Normal file
@@ -0,0 +1,699 @@
|
||||
"""
|
||||
Admin Document Routes
|
||||
|
||||
FastAPI endpoints for admin document management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from backend.web.config import DEFAULT_DPI, StorageConfig
|
||||
from backend.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
from backend.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
BoundingBox,
|
||||
DocumentCategoriesResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
DocumentListResponse,
|
||||
DocumentStatus,
|
||||
DocumentStatsResponse,
|
||||
DocumentUpdateRequest,
|
||||
DocumentUploadResponse,
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
# Render at configured DPI for consistency with training
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"""Create admin documents router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
|
||||
|
||||
# Directories are created by StorageConfig.__post_init__
|
||||
allowed_extensions = storage_config.allowed_extensions
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=DocumentUploadResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Upload document",
|
||||
description="Upload a PDF or image document for labeling.",
|
||||
)
|
||||
async def upload_document(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
file: UploadFile = File(..., description="PDF or image file"),
|
||||
auto_label: Annotated[
|
||||
bool,
|
||||
Query(description="Trigger auto-labeling after upload"),
|
||||
] = True,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="Optional group key for document organization", max_length=255),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
|
||||
] = "invoice",
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Get page count (for PDF)
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record (token only used for auth, not stored)
|
||||
document_id = docs.create(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database (using storage path for reference)
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
# Store the storage path (relative path within storage)
|
||||
doc.file_path = storage_path
|
||||
session.add(doc)
|
||||
|
||||
# Convert PDF to images for annotation
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling if requested
|
||||
auto_label_started = False
|
||||
if auto_label:
|
||||
# Auto-labeling will be triggered by a background task
|
||||
docs.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
auto_label_started = True
|
||||
|
||||
return DocumentUploadResponse(
|
||||
document_id=document_id,
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
category=category,
|
||||
group_key=group_key,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DocumentListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List documents",
|
||||
description="List all documents for the current admin.",
|
||||
)
|
||||
async def list_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
upload_source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by upload source (ui or api)"),
|
||||
] = None,
|
||||
has_annotations: Annotated[
|
||||
bool | None,
|
||||
Query(description="Filter by annotation presence"),
|
||||
] = None,
|
||||
auto_label_status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by auto-label status"),
|
||||
] = None,
|
||||
batch_id: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by batch ID"),
|
||||
] = None,
|
||||
category: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by document category"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> DocumentListResponse:
|
||||
"""List documents."""
|
||||
# Validate status
|
||||
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Validate upload_source
|
||||
if upload_source and upload_source not in ("ui", "api"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid upload_source: {upload_source}",
|
||||
)
|
||||
|
||||
# Validate auto_label_status
|
||||
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid auto_label_status: {auto_label_status}",
|
||||
)
|
||||
|
||||
documents, total = docs.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
upload_source=upload_source,
|
||||
has_annotations=has_annotations,
|
||||
auto_label_status=auto_label_status,
|
||||
batch_id=batch_id,
|
||||
category=category,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Get annotation counts and build items
|
||||
items = []
|
||||
for doc in documents:
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
items.append(
|
||||
DocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
file_size=doc.file_size,
|
||||
page_count=doc.page_count,
|
||||
status=DocumentStatus(doc.status),
|
||||
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
||||
annotation_count=len(doc_annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
category=doc.category if hasattr(doc, 'category') else "invoice",
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=DocumentStatsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get document statistics",
|
||||
description="Get document count by status.",
|
||||
)
|
||||
async def get_document_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> DocumentStatsResponse:
|
||||
"""Get document statistics."""
|
||||
counts = docs.count_by_status(admin_token)
|
||||
|
||||
return DocumentStatsResponse(
|
||||
total=sum(counts.values()),
|
||||
pending=counts.get("pending", 0),
|
||||
auto_labeling=counts.get("auto_labeling", 0),
|
||||
labeled=counts.get("labeled", 0),
|
||||
exported=counts.get("exported", 0),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
response_model=DocumentCategoriesResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get available categories",
|
||||
description="Get list of all available document categories.",
|
||||
)
|
||||
async def get_categories(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> DocumentCategoriesResponse:
|
||||
"""Get all available document categories."""
|
||||
categories = docs.get_categories()
|
||||
return DocumentCategoriesResponse(
|
||||
categories=categories,
|
||||
total=len(categories),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{document_id}",
|
||||
response_model=DocumentDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Get document detail",
|
||||
description="Get document details with annotations.",
|
||||
)
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> DocumentDetailResponse:
|
||||
"""Get document details."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = annotations.get_for_document(document_id)
|
||||
annotation_items = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
# Generate image URLs
|
||||
image_urls = []
|
||||
for page in range(1, document.page_count + 1):
|
||||
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
annotation_lock_until = None
|
||||
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
annotation_lock_until = document.annotation_lock_until
|
||||
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
# Get CSV field values if available
|
||||
csv_field_values = None
|
||||
if hasattr(document, 'csv_field_values') and document.csv_field_values:
|
||||
csv_field_values = document.csv_field_values
|
||||
|
||||
# Get training history (Phase 5)
|
||||
training_history = []
|
||||
training_links = tasks.get_document_training_tasks(document.document_id)
|
||||
for link in training_links:
|
||||
# Get task details
|
||||
task = tasks.get(str(link.task_id))
|
||||
if task:
|
||||
# Build metrics
|
||||
metrics = None
|
||||
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
training_history.append(
|
||||
TrainingHistoryItem(
|
||||
task_id=str(link.task_id),
|
||||
name=task.name,
|
||||
trained_at=link.created_at,
|
||||
model_metrics=metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentDetailResponse(
|
||||
document_id=str(document.document_id),
|
||||
filename=document.filename,
|
||||
file_size=document.file_size,
|
||||
content_type=document.content_type,
|
||||
page_count=document.page_count,
|
||||
status=DocumentStatus(document.status),
|
||||
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
|
||||
auto_label_error=document.auto_label_error,
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||
category=document.category if hasattr(document, 'category') else "invoice",
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
annotations=annotation_items,
|
||||
image_urls=image_urls,
|
||||
training_history=training_history,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete document",
|
||||
description="Delete a document and its annotations.",
|
||||
)
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete file using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Delete the raw PDF
|
||||
filename = Path(document.file_path).name
|
||||
if filename:
|
||||
try:
|
||||
storage._storage.delete(document.file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete PDF file: {e}")
|
||||
|
||||
# Delete admin images
|
||||
try:
|
||||
storage.delete_admin_images(document_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete admin images: {e}")
|
||||
|
||||
# Delete from database
|
||||
docs.delete(document_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"message": "Document deleted successfully",
|
||||
}
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/status",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document status",
|
||||
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
|
||||
)
|
||||
async def update_document_status(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
status: Annotated[
|
||||
str,
|
||||
Query(description="New status"),
|
||||
],
|
||||
) -> dict:
|
||||
"""Update document status.
|
||||
|
||||
When status is set to 'labeled', the annotations are automatically
|
||||
saved to PostgreSQL documents/field_results tables for consistency
|
||||
with CLI auto-label workflow.
|
||||
"""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate status
|
||||
if status not in ("pending", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
||||
db_save_result = None
|
||||
if status == "labeled":
|
||||
from backend.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
|
||||
# Get all annotations for this document
|
||||
doc_annotations = annotations.get_for_document(document_id)
|
||||
|
||||
if doc_annotations:
|
||||
db_save_result = save_manual_annotations_to_document_db(
|
||||
document=document,
|
||||
annotations=doc_annotations,
|
||||
)
|
||||
|
||||
docs.update_status(document_id, status)
|
||||
|
||||
response = {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"new_status": status,
|
||||
"message": "Document status updated",
|
||||
}
|
||||
|
||||
# Include PostgreSQL save result if applicable
|
||||
if db_save_result:
|
||||
response["document_db_saved"] = db_save_result.get("success", False)
|
||||
response["fields_saved"] = db_save_result.get("fields_saved", 0)
|
||||
|
||||
return response
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/group-key",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document group key",
|
||||
description="Update the group key for a document.",
|
||||
)
|
||||
async def update_document_group_key(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="New group key (null to clear)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Update document group key."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Verify document exists
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update group key
|
||||
docs.update_group_key(document_id, group_key)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"group_key": group_key,
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/category",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document category",
|
||||
description="Update the category for a document.",
|
||||
)
|
||||
async def update_document_category(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: DocumentUpdateRequest,
|
||||
) -> dict:
|
||||
"""Update document category."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify document exists
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update category if provided
|
||||
if request.category is not None:
|
||||
docs.update_category(document_id, request.category)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"category": request.category,
|
||||
"message": "Document category updated",
|
||||
}
|
||||
|
||||
return router
|
||||
181
packages/backend/backend/web/api/v1/admin/locks.py
Normal file
181
packages/backend/backend/web/api/v1/admin/locks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Admin Document Lock Routes
|
||||
|
||||
FastAPI endpoints for annotation lock management.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep
|
||||
from backend.web.schemas.admin import (
|
||||
AnnotationLockRequest,
|
||||
AnnotationLockResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_locks_router() -> APIRouter:
|
||||
"""Create annotation locks router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Document already locked"},
|
||||
},
|
||||
summary="Acquire annotation lock",
|
||||
description="Acquire a lock on a document to prevent concurrent annotation edits.",
|
||||
)
|
||||
async def acquire_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Acquire annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to acquire lock
|
||||
updated_doc = docs.acquire_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
duration_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Document is already locked. Please try again later.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock acquired for {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Release annotation lock",
|
||||
description="Release the annotation lock on a document.",
|
||||
)
|
||||
async def release_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(description="Force release (admin override)"),
|
||||
] = False,
|
||||
) -> AnnotationLockResponse:
|
||||
"""Release annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Release lock
|
||||
updated_doc = docs.release_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
force=force,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Failed to release lock",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=False,
|
||||
lock_expires_at=None,
|
||||
message="Lock released successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
|
||||
},
|
||||
summary="Extend annotation lock",
|
||||
description="Extend an existing annotation lock.",
|
||||
)
|
||||
async def extend_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Extend annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to extend lock
|
||||
updated_doc = docs.extend_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
additional_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock extended by {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Admin Training API Routes
|
||||
|
||||
FastAPI endpoints for training task management and scheduling.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
from .tasks import register_task_routes
|
||||
from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
from .models import register_model_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
"""Create training API router."""
|
||||
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||
|
||||
register_task_routes(router)
|
||||
register_document_routes(router)
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
register_model_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_training_router", "_validate_uuid"]
|
||||
16
packages/backend/backend/web/api/v1/admin/training/_utils.py
Normal file
16
packages/backend/backend/web/api/v1/admin/training/_utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Shared utilities for training routes."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
291
packages/backend/backend/web/api/v1/admin/training/datasets.py
Normal file
291
packages/backend/backend/web/api/v1/admin/training/datasets.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Training Dataset Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from backend.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DatasetRepoDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
ModelVersionRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from backend.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetDocumentItem,
|
||||
DatasetListItem,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingStatus,
|
||||
TrainingTaskResponse,
|
||||
)
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_dataset_routes(router: APIRouter) -> None:
|
||||
"""Register dataset endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/datasets",
|
||||
response_model=DatasetResponse,
|
||||
summary="Create training dataset",
|
||||
description="Create a dataset from selected documents with train/val/test splits.",
|
||||
)
|
||||
async def create_dataset(
|
||||
request: DatasetCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
datasets: DatasetRepoDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Validate minimum document count for proper train/val/test split
|
||||
if len(request.document_ids) < 10:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
|
||||
)
|
||||
|
||||
dataset = datasets.create(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
# Get storage paths from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
datasets_dir = storage.get_datasets_base_path()
|
||||
admin_images_dir = storage.get_admin_images_base_path()
|
||||
|
||||
if datasets_dir is None or admin_images_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=datasets,
|
||||
documents_repo=docs,
|
||||
annotations_repo=annotations,
|
||||
base_dir=datasets_dir,
|
||||
)
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=request.document_ids,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return DatasetResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
status="ready",
|
||||
message="Dataset created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets_list, total = datasets_repo.get_paginated(status=status, limit=limit, offset=offset)
|
||||
|
||||
# Get active training tasks for each dataset (graceful degradation on error)
|
||||
dataset_ids = [str(d.dataset_id) for d in datasets_list]
|
||||
try:
|
||||
active_tasks = datasets_repo.get_active_training_tasks(dataset_ids)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch active training tasks")
|
||||
active_tasks = {}
|
||||
|
||||
return DatasetListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[
|
||||
DatasetListItem(
|
||||
dataset_id=str(d.dataset_id),
|
||||
name=d.name,
|
||||
description=d.description,
|
||||
status=d.status,
|
||||
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
|
||||
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
|
||||
total_documents=d.total_documents,
|
||||
total_images=d.total_images,
|
||||
total_annotations=d.total_annotations,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in datasets_list
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets/{dataset_id}",
|
||||
response_model=DatasetDetailResponse,
|
||||
summary="Get dataset detail",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
) -> DatasetDetailResponse:
|
||||
"""Get dataset details with document list."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
docs = datasets_repo.get_documents(dataset_id)
|
||||
return DatasetDetailResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
status=dataset.status,
|
||||
training_status=dataset.training_status,
|
||||
active_training_task_id=(
|
||||
str(dataset.active_training_task_id)
|
||||
if dataset.active_training_task_id
|
||||
else None
|
||||
),
|
||||
train_ratio=dataset.train_ratio,
|
||||
val_ratio=dataset.val_ratio,
|
||||
seed=dataset.seed,
|
||||
total_documents=dataset.total_documents,
|
||||
total_images=dataset.total_images,
|
||||
total_annotations=dataset.total_annotations,
|
||||
dataset_path=dataset.dataset_path,
|
||||
error_message=dataset.error_message,
|
||||
documents=[
|
||||
DatasetDocumentItem(
|
||||
document_id=str(d.document_id),
|
||||
split=d.split,
|
||||
page_count=d.page_count,
|
||||
annotation_count=d.annotation_count,
|
||||
)
|
||||
for d in docs
|
||||
],
|
||||
created_at=dataset.created_at,
|
||||
updated_at=dataset.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/datasets/{dataset_id}",
|
||||
summary="Delete dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a dataset and its files."""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if dataset.dataset_path:
|
||||
dataset_dir = Path(dataset.dataset_path)
|
||||
if dataset_dir.exists():
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
datasets_repo.delete(dataset_id)
|
||||
return {"message": "Dataset deleted"}
|
||||
|
||||
@router.post(
|
||||
"/datasets/{dataset_id}/train",
|
||||
response_model=TrainingTaskResponse,
|
||||
summary="Start training from dataset",
|
||||
description="Create a training task. Set base_model_version_id in config for incremental training.",
|
||||
)
|
||||
async def train_from_dataset(
|
||||
dataset_id: str,
|
||||
request: DatasetTrainRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
models: ModelVersionRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset.
|
||||
|
||||
For incremental training, set config.base_model_version_id to a model version UUID.
|
||||
The training will use that model as the starting point instead of a pretrained model.
|
||||
"""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
if dataset.status != "ready":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Dataset is not ready (status: {dataset.status})",
|
||||
)
|
||||
|
||||
config_dict = request.config.model_dump()
|
||||
|
||||
# Resolve base_model_version_id to actual model path for incremental training
|
||||
base_model_version_id = config_dict.get("base_model_version_id")
|
||||
if base_model_version_id:
|
||||
_validate_uuid(base_model_version_id, "base_model_version_id")
|
||||
base_model = models.get(base_model_version_id)
|
||||
if not base_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Base model version not found: {base_model_version_id}",
|
||||
)
|
||||
# Store the resolved model path for the training worker
|
||||
config_dict["base_model_path"] = base_model.model_path
|
||||
config_dict["base_model_version"] = base_model.version
|
||||
logger.info(
|
||||
"Incremental training: using model %s (%s) as base",
|
||||
base_model.version,
|
||||
base_model.model_path,
|
||||
)
|
||||
|
||||
task_id = tasks.create(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="finetune" if base_model_version_id else "train",
|
||||
config=config_dict,
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
|
||||
if base_model_version_id
|
||||
else "Training task created from dataset"
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message=message,
|
||||
)
|
||||
218
packages/backend/backend/web/api/v1/admin/training/documents.py
Normal file
218
packages/backend/backend/web/api/v1/admin/training/documents.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Training Documents and Models Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from backend.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from backend.web.schemas.admin import (
|
||||
ModelMetrics,
|
||||
TrainingDocumentItem,
|
||||
TrainingDocumentsResponse,
|
||||
TrainingModelItem,
|
||||
TrainingModelsResponse,
|
||||
TrainingStatus,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_document_routes(router: APIRouter) -> None:
|
||||
"""Register training document and model endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=TrainingDocumentsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get documents for training",
|
||||
description="Get labeled documents available for training with filtering options.",
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
] = True,
|
||||
min_annotation_count: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Minimum annotation count"),
|
||||
] = None,
|
||||
exclude_used_in_training: Annotated[
|
||||
bool,
|
||||
Query(description="Exclude documents already used in training"),
|
||||
] = False,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
documents, total = docs.get_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
min_annotation_count=min_annotation_count,
|
||||
exclude_used_in_training=exclude_used_in_training,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for doc in documents:
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in doc_annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
training_links = tasks.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(doc_annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingDocumentsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{task_id}/download",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||
},
|
||||
summary="Download trained model",
|
||||
description="Download trained model weights file.",
|
||||
)
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if not task.model_path:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not available for this task",
|
||||
)
|
||||
|
||||
model_path = Path(task.model_path)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not found on disk",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(model_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{task.name}_model.pt",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/completed-tasks",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get completed training tasks",
|
||||
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
|
||||
)
|
||||
async def get_completed_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
tasks_repo: TrainingTaskRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
task_list, total = tasks_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for task in task_list:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
download_url = None
|
||||
if task.model_path and task.status == "completed":
|
||||
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||
|
||||
items.append(
|
||||
TrainingModelItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
status=TrainingStatus(task.status),
|
||||
document_count=task.document_count,
|
||||
created_at=task.created_at,
|
||||
completed_at=task.completed_at,
|
||||
metrics=metrics,
|
||||
model_path=task.model_path,
|
||||
download_url=download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingModelsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=items,
|
||||
)
|
||||
134
packages/backend/backend/web/api/v1/admin/training/export.py
Normal file
134
packages/backend/backend/web/api/v1/admin/training/export.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Training Export Endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, DocumentRepoDep, AnnotationRepoDep
|
||||
from backend.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_export_routes(router: APIRouter) -> None:
|
||||
"""Register export endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/export",
|
||||
response_model=ExportResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Export annotations",
|
||||
description="Export annotations in YOLO format for training.",
|
||||
)
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images and exports directory
|
||||
storage = get_storage_helper()
|
||||
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
documents = docs.get_labeled_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
# Get exports directory from StorageHelper
|
||||
exports_base = storage.get_exports_base_path()
|
||||
if exports_base is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * request.split_ratio)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
if not doc_annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in doc_annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
(export_dir / "data.yaml").write_text(yaml_content)
|
||||
|
||||
return ExportResponse(
|
||||
status="completed",
|
||||
export_path=str(export_dir),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
train_count=len(train_docs),
|
||||
val_count=len(val_docs),
|
||||
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||
)
|
||||
333
packages/backend/backend/web/api/v1/admin/training/models.py
Normal file
333
packages/backend/backend/web/api/v1/admin/training/models.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Model Version Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, ModelVersionRepoDep
|
||||
from backend.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
ModelVersionItem,
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_model_routes(router: APIRouter) -> None:
|
||||
"""Register model version endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Create model version",
|
||||
description="Register a new model version for deployment.",
|
||||
)
|
||||
async def create_model_version(
|
||||
request: ModelVersionCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Create a new model version."""
|
||||
if request.task_id:
|
||||
_validate_uuid(request.task_id, "task_id")
|
||||
if request.dataset_id:
|
||||
_validate_uuid(request.dataset_id, "dataset_id")
|
||||
|
||||
model = models.create(
|
||||
version=request.version,
|
||||
name=request.name,
|
||||
model_path=request.model_path,
|
||||
description=request.description,
|
||||
task_id=request.task_id,
|
||||
dataset_id=request.dataset_id,
|
||||
metrics_mAP=request.metrics_mAP,
|
||||
metrics_precision=request.metrics_precision,
|
||||
metrics_recall=request.metrics_recall,
|
||||
document_count=request.document_count,
|
||||
training_config=request.training_config,
|
||||
file_size=request.file_size,
|
||||
trained_at=request.trained_at,
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=ModelVersionListResponse,
|
||||
summary="List model versions",
|
||||
)
|
||||
async def list_model_versions(
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> ModelVersionListResponse:
|
||||
"""List model versions with optional status filter."""
|
||||
model_list, total = models.get_paginated(status=status, limit=limit, offset=offset)
|
||||
return ModelVersionListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=[
|
||||
ModelVersionItem(
|
||||
version_id=str(m.version_id),
|
||||
version=m.version,
|
||||
name=m.name,
|
||||
status=m.status,
|
||||
is_active=m.is_active,
|
||||
metrics_mAP=m.metrics_mAP,
|
||||
document_count=m.document_count,
|
||||
trained_at=m.trained_at,
|
||||
activated_at=m.activated_at,
|
||||
created_at=m.created_at,
|
||||
)
|
||||
for m in model_list
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/active",
|
||||
response_model=ActiveModelResponse,
|
||||
summary="Get active model",
|
||||
description="Get the currently active model for inference.",
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ActiveModelResponse:
|
||||
"""Get the currently active model version."""
|
||||
model = models.get_active()
|
||||
if not model:
|
||||
return ActiveModelResponse(has_active_model=False, model=None)
|
||||
|
||||
return ActiveModelResponse(
|
||||
has_active_model=True,
|
||||
model=ModelVersionItem(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
document_count=model.document_count,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionDetailResponse,
|
||||
summary="Get model version detail",
|
||||
)
|
||||
async def get_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionDetailResponse:
|
||||
"""Get detailed model version information."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = models.get(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionDetailResponse(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
model_path=model.model_path,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
task_id=str(model.task_id) if model.task_id else None,
|
||||
dataset_id=str(model.dataset_id) if model.dataset_id else None,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
metrics_precision=model.metrics_precision,
|
||||
metrics_recall=model.metrics_recall,
|
||||
document_count=model.document_count,
|
||||
training_config=model.training_config,
|
||||
file_size=model.file_size,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Update model version",
|
||||
)
|
||||
async def update_model_version(
|
||||
version_id: str,
|
||||
request: ModelVersionUpdateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Update model version metadata."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = models.update(
|
||||
version_id=version_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
status=request.status,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version updated successfully",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/activate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Activate model version",
|
||||
description="Activate a model version for inference (deactivates all others).",
|
||||
)
|
||||
async def activate_model_version(
|
||||
version_id: str,
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Activate a model version for inference."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = models.activate(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Trigger model reload in inference service
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
model_reloaded = False
|
||||
if inference_service:
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info(f"Inference model reloaded to version {model.version}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reload inference model: {e}")
|
||||
|
||||
message = "Model version activated for inference"
|
||||
if model_reloaded:
|
||||
message += " (model reloaded)"
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/deactivate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Deactivate model version",
|
||||
)
|
||||
async def deactivate_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Deactivate a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = models.deactivate(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version deactivated",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/archive",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Archive model version",
|
||||
)
|
||||
async def archive_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Archive a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = models.archive(version_id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot archive active model",
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version archived",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/models/{version_id}",
|
||||
summary="Delete model version",
|
||||
)
|
||||
async def delete_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
success = models.delete(version_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot delete active model",
|
||||
)
|
||||
|
||||
return {"message": "Model version deleted"}
|
||||
|
||||
@router.post(
|
||||
"/models/reload",
|
||||
summary="Reload inference model",
|
||||
description="Reload the inference model from the currently active model version.",
|
||||
)
|
||||
async def reload_inference_model(
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
) -> dict:
|
||||
"""Reload the inference model from active version."""
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
if not inference_service:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Inference service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info("Inference model manually reloaded")
|
||||
return {"message": "Model reloaded successfully", "reloaded": True}
|
||||
else:
|
||||
return {"message": "Model already up to date", "reloaded": False}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to reload model: {e}",
|
||||
)
|
||||
263
packages/backend/backend/web/api/v1/admin/training/tasks.py
Normal file
263
packages/backend/backend/web/api/v1/admin/training/tasks.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Training Task Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, TrainingTaskRepoDep
|
||||
from backend.web.schemas.admin import (
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingTaskDetailResponse,
|
||||
TrainingTaskItem,
|
||||
TrainingTaskListResponse,
|
||||
TrainingTaskResponse,
|
||||
TrainingType,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_task_routes(router: APIRouter) -> None:
|
||||
"""Register training task endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Create training task",
|
||||
description="Create a new training task.",
|
||||
)
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
task_id = tasks.create(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
description=request.description,
|
||||
config=config_dict,
|
||||
scheduled_at=request.scheduled_at,
|
||||
cron_expression=request.cron_expression,
|
||||
is_recurring=bool(request.cron_expression),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||
message="Training task created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List training tasks",
|
||||
description="List all training tasks.",
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
tasks_repo: TrainingTaskRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingTaskListResponse:
|
||||
"""List training tasks."""
|
||||
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||
if status and status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
task_list, total = tasks_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = [
|
||||
TrainingTaskItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
scheduled_at=task.scheduled_at,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in task_list
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tasks=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TrainingTaskDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training task detail",
|
||||
description="Get training task details.",
|
||||
)
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
return TrainingTaskDetailResponse(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
config=task.config,
|
||||
scheduled_at=task.scheduled_at,
|
||||
cron_expression=task.cron_expression,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
error_message=task.error_message,
|
||||
result_metrics=task.result_metrics,
|
||||
model_path=task.model_path,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/tasks/{task_id}/cancel",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||
},
|
||||
summary="Cancel training task",
|
||||
description="Cancel a pending or scheduled training task.",
|
||||
)
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if task.status not in ("pending", "scheduled"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
success = tasks.cancel(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to cancel training task",
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.CANCELLED,
|
||||
message="Training task cancelled successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/logs",
|
||||
response_model=TrainingLogsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training logs",
|
||||
description="Get training task logs.",
|
||||
)
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingLogsResponse:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
logs = tasks.get_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
level=log.level,
|
||||
message=log.message,
|
||||
details=log.details,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return TrainingLogsResponse(
|
||||
task_id=task_id,
|
||||
logs=items,
|
||||
)
|
||||
248
packages/backend/backend/web/api/v1/batch/routes.py
Normal file
248
packages/backend/backend/web/api/v1/batch/routes.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Batch Upload API Routes
|
||||
|
||||
Endpoints for batch uploading documents via ZIP files with CSV metadata.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.data.repositories import BatchUploadRepository
|
||||
from backend.web.core.auth import validate_admin_token
|
||||
from backend.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from backend.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global repository instance
|
||||
_batch_repo: BatchUploadRepository | None = None
|
||||
|
||||
|
||||
def get_batch_repository() -> BatchUploadRepository:
|
||||
"""Get the BatchUploadRepository instance."""
|
||||
global _batch_repo
|
||||
if _batch_repo is None:
|
||||
_batch_repo = BatchUploadRepository()
|
||||
return _batch_repo
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_batch(
|
||||
file: UploadFile = File(...),
|
||||
upload_source: str = Form(default="ui"),
|
||||
async_mode: bool = Form(default=True),
|
||||
auto_label: bool = Form(default=True),
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
) -> dict:
|
||||
"""Upload a batch of documents via ZIP file.
|
||||
|
||||
The ZIP file can contain:
|
||||
- Multiple PDF files
|
||||
- Optional CSV file with field values for auto-labeling
|
||||
|
||||
CSV format:
|
||||
- Required column: DocumentId (matches PDF filename without extension)
|
||||
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
|
||||
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
|
||||
|
||||
Args:
|
||||
file: ZIP file upload
|
||||
upload_source: Upload source (ui or api)
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
|
||||
Returns:
|
||||
Batch upload result with batch_id and status
|
||||
"""
|
||||
if not file.filename.lower().endswith('.zip'):
|
||||
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||
|
||||
# Check compressed size
|
||||
if file.size and file.size > MAX_COMPRESSED_SIZE:
|
||||
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
try:
|
||||
# Read file content
|
||||
zip_content = await file.read()
|
||||
|
||||
# Additional security validation before processing
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
|
||||
# Quick validation of ZIP structure
|
||||
test_zip.testzip()
|
||||
except zipfile.BadZipFile:
|
||||
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
|
||||
|
||||
if async_mode:
|
||||
# Async mode: Queue task and return immediately
|
||||
from uuid import uuid4
|
||||
|
||||
batch_id = uuid4()
|
||||
|
||||
# Create batch task for background processing
|
||||
task = BatchTask(
|
||||
batch_id=batch_id,
|
||||
admin_token=admin_token,
|
||||
zip_content=zip_content,
|
||||
zip_filename=file.filename,
|
||||
upload_source=upload_source,
|
||||
auto_label=auto_label,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Submit to queue
|
||||
queue = get_batch_queue()
|
||||
if not queue.submit(task):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Processing queue is full. Please try again later."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload queued: batch_id={batch_id}, "
|
||||
f"filename={file.filename}, async_mode=True"
|
||||
)
|
||||
|
||||
# Return 202 Accepted with batch_id and status URL
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "accepted",
|
||||
"batch_id": str(batch_id),
|
||||
"message": "Batch upload queued for processing",
|
||||
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
|
||||
"queue_depth": queue.get_queue_depth(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Sync mode: Process immediately and return results
|
||||
service = BatchUploadService(batch_repo)
|
||||
result = service.process_zip_upload(
|
||||
admin_token=admin_token,
|
||||
zip_filename=file.filename,
|
||||
zip_content=zip_content,
|
||||
upload_source=upload_source,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload completed: batch_id={result.get('batch_id')}, "
|
||||
f"status={result.get('status')}, files={result.get('successful_files')}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch upload: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to process batch upload. Please contact support."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status/{batch_id}")
|
||||
async def get_batch_status(
|
||||
batch_id: str,
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
) -> dict:
|
||||
"""Get batch upload status and file processing details.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
batch_repo: Batch upload repository
|
||||
|
||||
Returns:
|
||||
Batch status with file processing details
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
batch_uuid = UUID(batch_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid batch ID format")
|
||||
|
||||
# Check batch exists and verify ownership
|
||||
batch = batch_repo.get(batch_uuid)
|
||||
if not batch:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
# CRITICAL: Verify ownership
|
||||
if batch.admin_token != admin_token:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You do not have access to this batch"
|
||||
)
|
||||
|
||||
# Now safe to return details
|
||||
service = BatchUploadService(batch_repo)
|
||||
result = service.get_batch_status(batch_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_batch_uploads(
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""List batch uploads for the current admin token.
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
batch_repo: Batch upload repository
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of batch uploads
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if limit < 1 or limit > 100:
|
||||
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
|
||||
if offset < 0:
|
||||
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||
|
||||
# Get batch uploads filtered by admin token
|
||||
batches, total = batch_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return {
|
||||
"batches": [
|
||||
{
|
||||
"batch_id": str(b.batch_id),
|
||||
"filename": b.filename,
|
||||
"status": b.status,
|
||||
"total_files": b.total_files,
|
||||
"successful_files": b.successful_files,
|
||||
"failed_files": b.failed_files,
|
||||
"created_at": b.created_at.isoformat() if b.created_at else None,
|
||||
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
|
||||
}
|
||||
for b in batches
|
||||
],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
16
packages/backend/backend/web/api/v1/public/__init__.py
Normal file
16
packages/backend/backend/web/api/v1/public/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Public API v1
|
||||
|
||||
Customer-facing endpoints for inference, async processing, and labeling.
|
||||
"""
|
||||
|
||||
from backend.web.api.v1.public.inference import create_inference_router
|
||||
from backend.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from backend.web.api.v1.public.labeling import create_labeling_router
|
||||
|
||||
__all__ = [
|
||||
"create_inference_router",
|
||||
"create_async_router",
|
||||
"set_async_service",
|
||||
"create_labeling_router",
|
||||
]
|
||||
372
packages/backend/backend/web/api/v1/public/async_api.py
Normal file
372
packages/backend/backend/web/api/v1/public/async_api.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Async API Routes
|
||||
|
||||
FastAPI endpoints for async invoice processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from backend.web.dependencies import (
|
||||
ApiKeyDep,
|
||||
AsyncDBDep,
|
||||
PollRateLimitDep,
|
||||
SubmitRateLimitDep,
|
||||
)
|
||||
from backend.web.schemas.inference import (
|
||||
AsyncRequestItem,
|
||||
AsyncRequestsListResponse,
|
||||
AsyncResultResponse,
|
||||
AsyncStatus,
|
||||
AsyncStatusResponse,
|
||||
AsyncSubmitResponse,
|
||||
DetectionResult,
|
||||
InferenceResult,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
|
||||
def _validate_request_id(request_id: str) -> None:
|
||||
"""Validate that request_id is a valid UUID format."""
|
||||
try:
|
||||
UUID(request_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request ID format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global reference to async processing service (set during app startup)
|
||||
_async_service = None
|
||||
|
||||
|
||||
def set_async_service(service) -> None:
|
||||
"""Set the async processing service instance."""
|
||||
global _async_service
|
||||
_async_service = service
|
||||
|
||||
|
||||
def get_async_service():
|
||||
"""Get the async processing service instance."""
|
||||
if _async_service is None:
|
||||
raise RuntimeError("AsyncProcessingService not initialized")
|
||||
return _async_service
|
||||
|
||||
|
||||
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
|
||||
"""Create async API router."""
|
||||
router = APIRouter(prefix="/async", tags=["Async Processing"])
|
||||
|
||||
@router.post(
|
||||
"/submit",
|
||||
response_model=AsyncSubmitResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
|
||||
503: {"model": ErrorResponse, "description": "Queue full"},
|
||||
},
|
||||
summary="Submit PDF for async processing",
|
||||
description="Submit a PDF or image file for asynchronous processing. "
|
||||
"Returns a request_id that can be used to poll for results.",
|
||||
)
|
||||
async def submit_document(
|
||||
api_key: SubmitRateLimitDep,
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> AsyncSubmitResponse:
|
||||
"""Submit a document for async processing."""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate file extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Check file size (get from config via service)
|
||||
service = get_async_service()
|
||||
max_size = service._async_config.max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File too large. Maximum size: "
|
||||
f"{service._async_config.max_file_size_mb}MB",
|
||||
)
|
||||
|
||||
# Submit request
|
||||
result = service.submit_request(
|
||||
api_key=api_key,
|
||||
file_content=content,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
if "queue" in (result.error or "").lower():
|
||||
raise HTTPException(status_code=503, detail=result.error)
|
||||
raise HTTPException(status_code=500, detail=result.error)
|
||||
|
||||
return AsyncSubmitResponse(
|
||||
status="accepted",
|
||||
message="Request submitted for processing",
|
||||
request_id=result.request_id,
|
||||
estimated_wait_seconds=result.estimated_wait_seconds,
|
||||
poll_url=f"/api/v1/async/status/{result.request_id}",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/status/{request_id}",
|
||||
response_model=AsyncStatusResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get request status",
|
||||
description="Get the current processing status of an async request.",
|
||||
)
|
||||
async def get_status(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncStatusResponse:
|
||||
"""Get the status of an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Get queue position for pending requests
|
||||
position = None
|
||||
if request.status == "pending":
|
||||
position = db.get_queue_position(request_id)
|
||||
|
||||
# Build result URL for completed requests
|
||||
result_url = None
|
||||
if request.status == "completed":
|
||||
result_url = f"/api/v1/async/result/{request_id}"
|
||||
|
||||
return AsyncStatusResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
filename=request.filename,
|
||||
created_at=request.created_at,
|
||||
started_at=request.started_at,
|
||||
completed_at=request.completed_at,
|
||||
position_in_queue=position,
|
||||
error_message=request.error_message,
|
||||
result_url=result_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/result/{request_id}",
|
||||
response_model=AsyncResultResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Request not completed"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get extraction results",
|
||||
description="Get the extraction results for a completed async request.",
|
||||
)
|
||||
async def get_result(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncResultResponse:
|
||||
"""Get the results of a completed async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Check if completed or failed
|
||||
if request.status not in ("completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Request not yet completed. Current status: {request.status}",
|
||||
)
|
||||
|
||||
# Build inference result from stored data
|
||||
inference_result = None
|
||||
if request.result:
|
||||
# Convert detections to DetectionResult objects
|
||||
detections = []
|
||||
for d in request.result.get("detections", []):
|
||||
detections.append(DetectionResult(
|
||||
field=d.get("field", ""),
|
||||
confidence=d.get("confidence", 0.0),
|
||||
bbox=d.get("bbox", [0, 0, 0, 0]),
|
||||
))
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=request.result.get("document_id", str(request.request_id)[:8]),
|
||||
success=request.result.get("success", False),
|
||||
document_type=request.result.get("document_type", "invoice"),
|
||||
fields=request.result.get("fields", {}),
|
||||
confidence=request.result.get("confidence", {}),
|
||||
detections=detections,
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
errors=request.result.get("errors", []),
|
||||
)
|
||||
|
||||
# Build visualization URL
|
||||
viz_url = None
|
||||
if request.visualization_path:
|
||||
viz_url = f"/api/v1/results/{request.visualization_path}"
|
||||
|
||||
return AsyncResultResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
result=inference_result,
|
||||
visualization_url=viz_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/requests",
|
||||
response_model=AsyncRequestsListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
},
|
||||
summary="List requests",
|
||||
description="List all async requests for the authenticated API key.",
|
||||
)
|
||||
async def list_requests(
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (pending, processing, completed, failed)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Maximum number of results"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Pagination offset"),
|
||||
] = 0,
|
||||
) -> AsyncRequestsListResponse:
|
||||
"""List all requests for the authenticated API key."""
|
||||
# Validate status filter
|
||||
if status and status not in ("pending", "processing", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status filter: {status}. "
|
||||
"Must be one of: pending, processing, completed, failed",
|
||||
)
|
||||
|
||||
# Get requests from database
|
||||
requests, total = db.get_requests_by_api_key(
|
||||
api_key=api_key,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response items
|
||||
items = [
|
||||
AsyncRequestItem(
|
||||
request_id=str(r.request_id),
|
||||
status=AsyncStatus(r.status),
|
||||
filename=r.filename,
|
||||
file_size=r.file_size,
|
||||
created_at=r.created_at,
|
||||
completed_at=r.completed_at,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
|
||||
return AsyncRequestsListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
requests=items,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/requests/{request_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
|
||||
},
|
||||
summary="Cancel/delete request",
|
||||
description="Cancel a pending request or delete a completed/failed request.",
|
||||
)
|
||||
async def delete_request(
|
||||
request_id: str,
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
) -> dict:
|
||||
"""Delete or cancel an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Cannot delete processing requests
|
||||
if request.status == "processing":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete a request that is currently processing",
|
||||
)
|
||||
|
||||
# Delete from database (will cascade delete related records)
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DELETE FROM async_requests WHERE request_id = %s",
|
||||
(request_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"request_id": request_id,
|
||||
"message": "Request deleted successfully",
|
||||
}
|
||||
|
||||
return router
|
||||
194
packages/backend/backend/web/api/v1/public/inference.py
Normal file
194
packages/backend/backend/web/api/v1/public/inference.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Inference API Routes
|
||||
|
||||
FastAPI route definitions for the inference API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from backend.web.schemas.inference import (
|
||||
DetectionResult,
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.web.services import InferenceService
|
||||
from backend.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_inference_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with inference endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Check service health status."""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=inference_service.is_initialized,
|
||||
gpu_available=inference_service.gpu_available,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/infer",
|
||||
response_model=InferenceResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> InferenceResponse:
|
||||
"""
|
||||
Process a document and extract invoice fields.
|
||||
|
||||
Accepts PDF or image files (PNG, JPG, JPEG).
|
||||
Returns extracted field values with confidence scores.
|
||||
"""
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Get storage helper and uploads directory
|
||||
storage = get_storage_helper()
|
||||
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
|
||||
if uploads_dir is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
# Save uploaded file to temporary location for processing
|
||||
upload_path = uploads_dir / f"{doc_id}{file_ext}"
|
||||
try:
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save uploaded file",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process based on file type
|
||||
if file_ext == ".pdf":
|
||||
service_result = inference_service.process_pdf(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
else:
|
||||
service_result = inference_service.process_image(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
|
||||
# Build response
|
||||
viz_url = None
|
||||
if service_result.visualization_path:
|
||||
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=service_result.document_id,
|
||||
success=service_result.success,
|
||||
document_type=service_result.document_type,
|
||||
fields=service_result.fields,
|
||||
confidence=service_result.confidence,
|
||||
detections=[
|
||||
DetectionResult(**d) for d in service_result.detections
|
||||
],
|
||||
processing_time_ms=service_result.processing_time_ms,
|
||||
visualization_url=viz_url,
|
||||
errors=service_result.errors,
|
||||
)
|
||||
|
||||
return InferenceResponse(
|
||||
status="success" if service_result.success else "partial",
|
||||
message=f"Processed document {doc_id}",
|
||||
result=inference_result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup uploaded file
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
@router.get("/results/{filename}", response_model=None)
|
||||
async def get_result_image(filename: str) -> FileResponse:
|
||||
"""Get visualization result image."""
|
||||
storage = get_storage_helper()
|
||||
file_path = storage.get_result_local_path(filename)
|
||||
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
media_type="image/png",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@router.delete("/results/{filename}")
|
||||
async def delete_result(filename: str) -> dict:
|
||||
"""Delete a result file."""
|
||||
storage = get_storage_helper()
|
||||
|
||||
if not storage.result_exists(filename):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
storage.delete_result(filename)
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
return router
|
||||
197
packages/backend/backend/web/api/v1/public/labeling.py
Normal file
197
packages/backend/backend/web/api/v1/public/labeling.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Labeling API Routes
|
||||
|
||||
FastAPI endpoints for pre-labeling documents with expected field values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
|
||||
from backend.data.repositories import DocumentRepository
|
||||
from backend.web.schemas.labeling import PreLabelResponse
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.web.services import InferenceService
|
||||
from backend.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||
import fitz
|
||||
|
||||
storage = get_storage_helper()
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# Save to storage using StorageHelper
|
||||
image_bytes = pix.tobytes("png")
|
||||
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def get_doc_repository() -> DocumentRepository:
|
||||
"""Get document repository instance."""
|
||||
return DocumentRepository()
|
||||
|
||||
|
||||
def create_labeling_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with labeling endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||
|
||||
@router.post(
|
||||
"/pre-label",
|
||||
response_model=PreLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
summary="Pre-label document with expected values",
|
||||
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
|
||||
)
|
||||
async def pre_label(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
field_values: str = Form(
|
||||
...,
|
||||
description="JSON object with expected field values. "
|
||||
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
|
||||
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
|
||||
),
|
||||
doc_repo: DocumentRepository = Depends(get_doc_repository),
|
||||
) -> PreLabelResponse:
|
||||
"""
|
||||
Upload a document with expected field values for pre-labeling.
|
||||
|
||||
Returns document_id which can be used to retrieve results later.
|
||||
|
||||
Example field_values JSON:
|
||||
```json
|
||||
{
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "1500.00",
|
||||
"Bankgiro": "123-4567",
|
||||
"OCR": "1234567890"
|
||||
}
|
||||
```
|
||||
"""
|
||||
# Parse field_values JSON
|
||||
try:
|
||||
expected_values = json.loads(field_values)
|
||||
if not isinstance(expected_values, dict):
|
||||
raise ValueError("field_values must be a JSON object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON in field_values: {e}",
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read file",
|
||||
)
|
||||
|
||||
# Get page count for PDF
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record with field_values
|
||||
document_id = doc_repo.create(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
upload_source="api",
|
||||
csv_field_values=expected_values,
|
||||
)
|
||||
|
||||
# Save file to storage using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
filename = f"{document_id}{file_ext}"
|
||||
try:
|
||||
storage_path = storage.save_raw_pdf(content, filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save file",
|
||||
)
|
||||
|
||||
# Update file path in database (using storage path)
|
||||
doc_repo.update_file_path(document_id, storage_path)
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="pending",
|
||||
)
|
||||
|
||||
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
|
||||
|
||||
return PreLabelResponse(document_id=document_id)
|
||||
|
||||
return router
|
||||
953
packages/backend/backend/web/app.py
Normal file
953
packages/backend/backend/web/app.py
Normal file
@@ -0,0 +1,953 @@
|
||||
"""
|
||||
FastAPI Application Factory
|
||||
|
||||
Creates and configures the FastAPI application.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from backend.web.services import InferenceService
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Public API imports
|
||||
from backend.web.api.v1.public import (
|
||||
create_inference_router,
|
||||
create_async_router,
|
||||
set_async_service,
|
||||
create_labeling_router,
|
||||
)
|
||||
|
||||
# Async processing imports
|
||||
from backend.data.async_request_db import AsyncRequestDB
|
||||
from backend.web.workers.async_queue import AsyncTaskQueue
|
||||
from backend.web.services.async_processing import AsyncProcessingService
|
||||
from backend.web.dependencies import init_dependencies
|
||||
from backend.web.core.rate_limiter import RateLimiter
|
||||
|
||||
# Admin API imports
|
||||
from backend.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_augmentation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
create_training_router,
|
||||
)
|
||||
from backend.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from backend.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from backend.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
|
||||
# Batch upload imports
|
||||
from backend.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from backend.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from backend.web.services.batch_upload import BatchUploadService
|
||||
from backend.data.repositories import ModelVersionRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""
|
||||
Create and configure FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Application configuration. Uses default if not provided.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
config = config or default_config
|
||||
|
||||
# Create model path resolver that reads from database
|
||||
def get_active_model_path():
|
||||
"""Resolve active model path from database."""
|
||||
try:
|
||||
model_repo = ModelVersionRepository()
|
||||
active_model = model_repo.get_active()
|
||||
if active_model and active_model.model_path:
|
||||
return active_model.model_path
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get active model from database: {e}")
|
||||
return None
|
||||
|
||||
# Create inference service with database model resolver
|
||||
inference_service = InferenceService(
|
||||
model_config=config.model,
|
||||
storage_config=config.storage,
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
|
||||
# Create async processing components
|
||||
async_db = AsyncRequestDB()
|
||||
rate_limiter = RateLimiter(async_db)
|
||||
task_queue = AsyncTaskQueue(
|
||||
max_size=config.async_processing.queue_max_size,
|
||||
worker_count=config.async_processing.worker_count,
|
||||
)
|
||||
async_service = AsyncProcessingService(
|
||||
inference_service=inference_service,
|
||||
db=async_db,
|
||||
queue=task_queue,
|
||||
rate_limiter=rate_limiter,
|
||||
async_config=config.async_processing,
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
# Initialize dependencies for FastAPI
|
||||
init_dependencies(async_db, rate_limiter)
|
||||
set_async_service(async_service)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting Invoice Inference API...")
|
||||
|
||||
# Initialize async request database tables
|
||||
try:
|
||||
async_db.create_tables()
|
||||
logger.info("Async database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database: {e}")
|
||||
|
||||
# Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.)
|
||||
try:
|
||||
from backend.data.database import create_db_and_tables
|
||||
create_db_and_tables()
|
||||
logger.info("Admin database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize admin database: {e}")
|
||||
|
||||
# Initialize inference service on startup
|
||||
try:
|
||||
inference_service.initialize()
|
||||
logger.info("Inference service ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
# Continue anyway - service will retry on first request
|
||||
|
||||
# Start async processing service
|
||||
try:
|
||||
async_service.start()
|
||||
logger.info("Async processing service started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start async processing: {e}")
|
||||
|
||||
# Start batch upload queue
|
||||
try:
|
||||
batch_service = BatchUploadService()
|
||||
init_batch_queue(batch_service)
|
||||
logger.info("Batch upload queue started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start batch upload queue: {e}")
|
||||
|
||||
# Start training scheduler
|
||||
try:
|
||||
start_scheduler()
|
||||
logger.info("Training scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training scheduler: {e}")
|
||||
|
||||
# Start auto-label scheduler
|
||||
try:
|
||||
start_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start autolabel scheduler: {e}")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down Invoice Inference API...")
|
||||
|
||||
# Stop auto-label scheduler
|
||||
try:
|
||||
stop_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping autolabel scheduler: {e}")
|
||||
|
||||
# Stop training scheduler
|
||||
try:
|
||||
stop_scheduler()
|
||||
logger.info("Training scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping training scheduler: {e}")
|
||||
|
||||
# Stop batch upload queue
|
||||
try:
|
||||
shutdown_batch_queue()
|
||||
logger.info("Batch upload queue stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping batch upload queue: {e}")
|
||||
|
||||
# Stop async processing service
|
||||
try:
|
||||
async_service.stop(timeout=30.0)
|
||||
logger.info("Async processing service stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping async service: {e}")
|
||||
|
||||
# Close database connection
|
||||
try:
|
||||
async_db.close()
|
||||
logger.info("Database connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
# Store inference service for access by routes (e.g., model reload)
|
||||
# This will be set after app creation
|
||||
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
description="""
|
||||
REST API for extracting fields from Swedish invoices.
|
||||
|
||||
## Features
|
||||
- YOLO-based field detection
|
||||
- OCR text extraction
|
||||
- Field normalization and validation
|
||||
- Visualization of detections
|
||||
|
||||
## Supported Fields
|
||||
- InvoiceNumber
|
||||
- InvoiceDate
|
||||
- InvoiceDueDate
|
||||
- OCR (reference number)
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- Amount
|
||||
- supplier_org_number (Swedish organization number)
|
||||
- customer_number
|
||||
- payment_line (machine-readable payment code)
|
||||
""",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount static files for results using StorageHelper
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir:
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(results_dir)),
|
||||
name="results",
|
||||
)
|
||||
else:
|
||||
logger.warning("Could not mount static results directory: local storage not available")
|
||||
|
||||
# Include public API routes
|
||||
inference_router = create_inference_router(inference_service, config.storage)
|
||||
app.include_router(inference_router)
|
||||
|
||||
async_router = create_async_router(config.storage.allowed_extensions)
|
||||
app.include_router(async_router, prefix="/api/v1")
|
||||
|
||||
labeling_router = create_labeling_router(inference_service, config.storage)
|
||||
app.include_router(labeling_router)
|
||||
|
||||
# Include admin API routes
|
||||
auth_router = create_auth_router()
|
||||
app.include_router(auth_router, prefix="/api/v1")
|
||||
|
||||
documents_router = create_documents_router(config.storage)
|
||||
app.include_router(documents_router, prefix="/api/v1")
|
||||
|
||||
locks_router = create_locks_router()
|
||||
app.include_router(locks_router, prefix="/api/v1")
|
||||
|
||||
annotation_router = create_annotation_router()
|
||||
app.include_router(annotation_router, prefix="/api/v1")
|
||||
|
||||
training_router = create_training_router()
|
||||
app.include_router(training_router, prefix="/api/v1")
|
||||
|
||||
augmentation_router = create_augmentation_router()
|
||||
app.include_router(augmentation_router, prefix="/api/v1/admin")
|
||||
|
||||
# Include dashboard routes
|
||||
dashboard_router = create_dashboard_router()
|
||||
app.include_router(dashboard_router, prefix="/api/v1")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
# Store inference service in app state for access by routes
|
||||
app.state.inference_service = inference_service
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root() -> str:
|
||||
"""Serve the web UI."""
|
||||
return get_html_ui()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_html_ui() -> str:
|
||||
"""Generate HTML UI for the web application."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Invoice Field Extraction</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
color: white;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
header p {
|
||||
opacity: 0.9;
|
||||
font-size: 1.1rem;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
padding: 24px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
font-size: 1.3rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.upload-card {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.upload-card h2 {
|
||||
margin-bottom: 0;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.upload-area {
|
||||
border: 2px dashed #ddd;
|
||||
border-radius: 10px;
|
||||
padding: 15px 25px;
|
||||
text-align: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
background: #fafafa;
|
||||
flex: 1;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.upload-area:hover, .upload-area.dragover {
|
||||
border-color: #667eea;
|
||||
background: #f0f4ff;
|
||||
}
|
||||
|
||||
.upload-area.has-file {
|
||||
border-color: #10b981;
|
||||
background: #ecfdf5;
|
||||
}
|
||||
|
||||
.upload-icon {
|
||||
font-size: 24px;
|
||||
display: inline;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.upload-area p {
|
||||
color: #666;
|
||||
margin: 0;
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.upload-area small {
|
||||
color: #999;
|
||||
display: block;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
#file-input {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-name {
|
||||
margin-top: 15px;
|
||||
padding: 10px 15px;
|
||||
background: #e0f2fe;
|
||||
border-radius: 8px;
|
||||
color: #0369a1;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.btn {
|
||||
display: inline-block;
|
||||
padding: 12px 24px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-primary:disabled {
|
||||
opacity: 0.6;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.loading.active {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.results {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.results.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 15px;
|
||||
border-bottom: 2px solid #eee;
|
||||
}
|
||||
|
||||
.result-status {
|
||||
padding: 6px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.result-status.success {
|
||||
background: #dcfce7;
|
||||
color: #166534;
|
||||
}
|
||||
|
||||
.result-status.partial {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
.result-status.error {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.fields-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.field-item {
|
||||
padding: 12px;
|
||||
background: #f8fafc;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #667eea;
|
||||
}
|
||||
|
||||
.field-item label {
|
||||
display: block;
|
||||
font-size: 0.75rem;
|
||||
color: #64748b;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.field-item .value {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
.field-item .confidence {
|
||||
font-size: 0.75rem;
|
||||
color: #10b981;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.visualization {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.visualization img {
|
||||
width: 100%;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.processing-time {
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
font-size: 0.9rem;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
.cross-validation {
|
||||
background: #f8fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 10px;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.cross-validation h3 {
|
||||
margin: 0 0 10px 0;
|
||||
color: #334155;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.cv-status {
|
||||
font-weight: 600;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 10px;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.cv-status.valid {
|
||||
background: #dcfce7;
|
||||
color: #166534;
|
||||
}
|
||||
|
||||
.cv-status.invalid {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
.cv-details {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.cv-item {
|
||||
background: white;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 0.85rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.cv-item.match {
|
||||
border-color: #86efac;
|
||||
background: #f0fdf4;
|
||||
}
|
||||
|
||||
.cv-item.mismatch {
|
||||
border-color: #fca5a5;
|
||||
background: #fef2f2;
|
||||
}
|
||||
|
||||
.cv-icon {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.cv-item.match .cv-icon {
|
||||
color: #16a34a;
|
||||
}
|
||||
|
||||
.cv-item.mismatch .cv-icon {
|
||||
color: #dc2626;
|
||||
}
|
||||
|
||||
.cv-summary {
|
||||
margin-top: 10px;
|
||||
font-size: 0.8rem;
|
||||
color: #64748b;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
footer {
|
||||
text-align: center;
|
||||
color: white;
|
||||
opacity: 0.8;
|
||||
margin-top: 30px;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>📄 Invoice Field Extraction</h1>
|
||||
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
|
||||
</header>
|
||||
|
||||
<div class="main-content">
|
||||
<!-- Upload Section - Compact -->
|
||||
<div class="card upload-card">
|
||||
<h2>📤 Upload</h2>
|
||||
|
||||
<div class="upload-area" id="upload-area">
|
||||
<span class="upload-icon">📁</span>
|
||||
<p>Drag & drop or <strong>click to browse</strong></p>
|
||||
<small>PDF, PNG, JPG (max 50MB)</small>
|
||||
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
|
||||
</div>
|
||||
|
||||
<div class="file-name" id="file-name" style="display: none;"></div>
|
||||
|
||||
<button class="btn btn-primary" id="submit-btn" disabled>
|
||||
🚀 Extract
|
||||
</button>
|
||||
|
||||
<div class="loading" id="loading">
|
||||
<div class="spinner"></div>
|
||||
<p>Processing...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Results Section - Full Width -->
|
||||
<div class="card">
|
||||
<h2>📊 Extraction Results</h2>
|
||||
|
||||
<div id="placeholder" style="text-align: center; padding: 30px; color: #999;">
|
||||
<div style="font-size: 48px; margin-bottom: 10px;">🔍</div>
|
||||
<p>Upload a document to see extraction results</p>
|
||||
</div>
|
||||
|
||||
<div class="results" id="results">
|
||||
<div class="result-header">
|
||||
<span>Document: <strong id="doc-id"></strong></span>
|
||||
<span class="result-status" id="result-status"></span>
|
||||
</div>
|
||||
|
||||
<div class="fields-grid" id="fields-grid"></div>
|
||||
|
||||
<div class="processing-time" id="processing-time"></div>
|
||||
|
||||
<div class="cross-validation" id="cross-validation" style="display: none;"></div>
|
||||
|
||||
<div class="error-message" id="error-message" style="display: none;"></div>
|
||||
|
||||
<div class="visualization" id="visualization" style="display: none;">
|
||||
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
|
||||
<img id="viz-image" src="" alt="Detection visualization">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<p>Powered by ColaCoder</p>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const uploadArea = document.getElementById('upload-area');
|
||||
const fileInput = document.getElementById('file-input');
|
||||
const fileName = document.getElementById('file-name');
|
||||
const submitBtn = document.getElementById('submit-btn');
|
||||
const loading = document.getElementById('loading');
|
||||
const placeholder = document.getElementById('placeholder');
|
||||
const results = document.getElementById('results');
|
||||
|
||||
let selectedFile = null;
|
||||
|
||||
// Drag and drop handlers
|
||||
uploadArea.addEventListener('click', () => fileInput.click());
|
||||
|
||||
uploadArea.addEventListener('dragover', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.add('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('dragleave', () => {
|
||||
uploadArea.classList.remove('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('drop', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.remove('dragover');
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
handleFile(files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
fileInput.addEventListener('change', (e) => {
|
||||
if (e.target.files.length > 0) {
|
||||
handleFile(e.target.files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
function handleFile(file) {
|
||||
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
|
||||
const ext = '.' + file.name.split('.').pop().toLowerCase();
|
||||
|
||||
if (!validTypes.includes(ext)) {
|
||||
alert('Please upload a PDF, PNG, or JPG file.');
|
||||
return;
|
||||
}
|
||||
|
||||
selectedFile = file;
|
||||
fileName.textContent = `📎 ${file.name}`;
|
||||
fileName.style.display = 'block';
|
||||
uploadArea.classList.add('has-file');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
|
||||
submitBtn.addEventListener('click', async () => {
|
||||
if (!selectedFile) return;
|
||||
|
||||
// Show loading
|
||||
submitBtn.disabled = true;
|
||||
loading.classList.add('active');
|
||||
placeholder.style.display = 'none';
|
||||
results.classList.remove('active');
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', selectedFile);
|
||||
|
||||
const response = await fetch('/api/v1/infer', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.detail || 'Processing failed');
|
||||
}
|
||||
|
||||
displayResults(data);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
document.getElementById('error-message').textContent = error.message;
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
results.classList.add('active');
|
||||
} finally {
|
||||
loading.classList.remove('active');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
});
|
||||
|
||||
function displayResults(data) {
|
||||
const result = data.result;
|
||||
|
||||
// Document ID
|
||||
document.getElementById('doc-id').textContent = result.document_id;
|
||||
|
||||
// Status
|
||||
const statusEl = document.getElementById('result-status');
|
||||
statusEl.textContent = result.success ? 'Success' : 'Partial';
|
||||
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
|
||||
|
||||
// Fields
|
||||
const fieldsGrid = document.getElementById('fields-grid');
|
||||
fieldsGrid.innerHTML = '';
|
||||
|
||||
const fieldOrder = [
|
||||
'InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR',
|
||||
'Amount', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_org_number', 'customer_number', 'payment_line'
|
||||
];
|
||||
|
||||
fieldOrder.forEach(field => {
|
||||
const value = result.fields[field];
|
||||
const confidence = result.confidence[field];
|
||||
|
||||
if (value !== null && value !== undefined) {
|
||||
const fieldDiv = document.createElement('div');
|
||||
fieldDiv.className = 'field-item';
|
||||
fieldDiv.innerHTML = `
|
||||
<label>${formatFieldName(field)}</label>
|
||||
<div class="value">${value}</div>
|
||||
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
|
||||
`;
|
||||
fieldsGrid.appendChild(fieldDiv);
|
||||
}
|
||||
});
|
||||
|
||||
// Processing time
|
||||
document.getElementById('processing-time').textContent =
|
||||
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
|
||||
|
||||
// Cross-validation results
|
||||
const cvDiv = document.getElementById('cross-validation');
|
||||
if (result.cross_validation) {
|
||||
const cv = result.cross_validation;
|
||||
let cvHtml = '<h3>🔍 Cross-Validation (Payment Line)</h3>';
|
||||
cvHtml += `<div class="cv-status ${cv.is_valid ? 'valid' : 'invalid'}">`;
|
||||
cvHtml += cv.is_valid ? '✅ Valid' : '⚠️ Mismatch Detected';
|
||||
cvHtml += '</div>';
|
||||
|
||||
cvHtml += '<div class="cv-details">';
|
||||
if (cv.payment_line_ocr) {
|
||||
const matchIcon = cv.ocr_match === true ? '✓' : (cv.ocr_match === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${cv.ocr_match === true ? 'match' : (cv.ocr_match === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> OCR: ${cv.payment_line_ocr}</div>`;
|
||||
}
|
||||
if (cv.payment_line_amount) {
|
||||
const matchIcon = cv.amount_match === true ? '✓' : (cv.amount_match === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${cv.amount_match === true ? 'match' : (cv.amount_match === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> Amount: ${cv.payment_line_amount}</div>`;
|
||||
}
|
||||
if (cv.payment_line_account) {
|
||||
const accountType = cv.payment_line_account_type === 'bankgiro' ? 'Bankgiro' : 'Plusgiro';
|
||||
const matchField = cv.payment_line_account_type === 'bankgiro' ? cv.bankgiro_match : cv.plusgiro_match;
|
||||
const matchIcon = matchField === true ? '✓' : (matchField === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${matchField === true ? 'match' : (matchField === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> ${accountType}: ${cv.payment_line_account}</div>`;
|
||||
}
|
||||
cvHtml += '</div>';
|
||||
|
||||
if (cv.details && cv.details.length > 0) {
|
||||
cvHtml += '<div class="cv-summary">' + cv.details[cv.details.length - 1] + '</div>';
|
||||
}
|
||||
|
||||
cvDiv.innerHTML = cvHtml;
|
||||
cvDiv.style.display = 'block';
|
||||
} else {
|
||||
cvDiv.style.display = 'none';
|
||||
}
|
||||
|
||||
// Visualization
|
||||
if (result.visualization_url) {
|
||||
const vizDiv = document.getElementById('visualization');
|
||||
const vizImg = document.getElementById('viz-image');
|
||||
vizImg.src = result.visualization_url;
|
||||
vizDiv.style.display = 'block';
|
||||
}
|
||||
|
||||
// Errors
|
||||
if (result.errors && result.errors.length > 0) {
|
||||
document.getElementById('error-message').textContent = result.errors.join(', ');
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
} else {
|
||||
document.getElementById('error-message').style.display = 'none';
|
||||
}
|
||||
|
||||
results.classList.add('active');
|
||||
}
|
||||
|
||||
function formatFieldName(name) {
|
||||
const nameMap = {
|
||||
'InvoiceNumber': 'Invoice Number',
|
||||
'InvoiceDate': 'Invoice Date',
|
||||
'InvoiceDueDate': 'Due Date',
|
||||
'OCR': 'OCR Reference',
|
||||
'Amount': 'Amount',
|
||||
'Bankgiro': 'Bankgiro',
|
||||
'Plusgiro': 'Plusgiro',
|
||||
'supplier_org_number': 'Supplier Org Number',
|
||||
'customer_number': 'Customer Number',
|
||||
'payment_line': 'Payment Line'
|
||||
};
|
||||
return nameMap[name] || name.replace(/([A-Z])/g, ' $1').replace(/_/g, ' ').trim();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
194
packages/backend/backend/web/config.py
Normal file
194
packages/backend/backend/web/config.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Web Application Configuration
|
||||
|
||||
Centralized configuration for the web application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_storage_backend(
|
||||
config_path: Path | str | None = None,
|
||||
) -> "StorageBackend":
|
||||
"""Get storage backend for file operations.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to storage configuration file.
|
||||
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
from shared.storage import get_storage_backend as _get_storage_backend
|
||||
|
||||
# Check for config file path
|
||||
if config_path is None:
|
||||
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
|
||||
if config_path_str:
|
||||
config_path = Path(config_path_str)
|
||||
|
||||
return _get_storage_backend(config_path=config_path)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""YOLO model configuration.
|
||||
|
||||
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
|
||||
- Models need to be accessible by inference service on any platform
|
||||
- Models may be version-controlled or deployed separately
|
||||
- Models are part of the application, not user data
|
||||
"""
|
||||
|
||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||
confidence_threshold: float = 0.5
|
||||
use_gpu: bool = True
|
||||
dpi: int = DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServerConfig:
|
||||
"""Server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
reload: bool = False
|
||||
workers: int = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileConfig:
|
||||
"""File handling configuration.
|
||||
|
||||
This config holds file handling settings. For file operations,
|
||||
use the storage backend with PREFIXES from shared.storage.prefixes.
|
||||
|
||||
Example:
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
|
||||
storage = get_storage_backend()
|
||||
path = PREFIXES.document_path(document_id)
|
||||
storage.upload_bytes(content, path)
|
||||
|
||||
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
|
||||
They are kept for backward compatibility with existing code and tests.
|
||||
New code should use the storage backend with PREFIXES instead.
|
||||
"""
|
||||
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
dpi: int = DEFAULT_DPI
|
||||
presigned_url_expiry_seconds: int = 3600
|
||||
|
||||
# Deprecated path fields - kept for backward compatibility
|
||||
# New code should use storage backend with PREFIXES instead
|
||||
# All paths are now under data/ to match WSL storage layout
|
||||
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
|
||||
result_dir: Path = field(default_factory=lambda: Path("data/results"))
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
|
||||
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
|
||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
StorageConfig = FileConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncConfig:
|
||||
"""Async processing configuration.
|
||||
|
||||
Note: For file paths, use the storage backend with PREFIXES.
|
||||
Example: PREFIXES.upload_path(filename, "async")
|
||||
"""
|
||||
|
||||
# Queue settings
|
||||
queue_max_size: int = 100
|
||||
worker_count: int = 1
|
||||
task_timeout_seconds: int = 300
|
||||
|
||||
# Rate limiting defaults
|
||||
default_requests_per_minute: int = 10
|
||||
default_max_concurrent_jobs: int = 3
|
||||
default_min_poll_interval_ms: int = 1000
|
||||
|
||||
# Storage
|
||||
result_retention_days: int = 7
|
||||
max_file_size_mb: int = 50
|
||||
|
||||
# Deprecated: kept for backward compatibility
|
||||
# Path under data/ to match WSL storage layout
|
||||
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
|
||||
|
||||
# Cleanup
|
||||
cleanup_interval_hours: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist (for backward compatibility)."""
|
||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
file: FileConfig = field(default_factory=FileConfig)
|
||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||
storage_backend: "StorageBackend | None" = None
|
||||
|
||||
@property
|
||||
def storage(self) -> FileConfig:
|
||||
"""Backward compatibility alias for file config."""
|
||||
return self.file
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
"""Create config from dictionary."""
|
||||
file_config = config_dict.get("file", config_dict.get("storage", {}))
|
||||
return cls(
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
file=FileConfig(**file_config),
|
||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||
)
|
||||
|
||||
|
||||
def create_app_config(
|
||||
storage_config_path: Path | str | None = None,
|
||||
) -> AppConfig:
|
||||
"""Create application configuration with storage backend.
|
||||
|
||||
Args:
|
||||
storage_config_path: Optional path to storage configuration file.
|
||||
|
||||
Returns:
|
||||
Configured AppConfig instance with storage backend initialized.
|
||||
"""
|
||||
storage_backend = get_storage_backend(config_path=storage_config_path)
|
||||
return AppConfig(storage_backend=storage_backend)
|
||||
|
||||
|
||||
# Default configuration instance
|
||||
default_config = AppConfig()
|
||||
61
packages/backend/backend/web/core/__init__.py
Normal file
61
packages/backend/backend/web/core/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Core Components
|
||||
|
||||
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||
"""
|
||||
|
||||
from backend.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_token_repository,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_dataset_repository,
|
||||
get_training_task_repository,
|
||||
get_model_version_repository,
|
||||
get_batch_upload_repository,
|
||||
AdminTokenDep,
|
||||
TokenRepoDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
DatasetRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
ModelVersionRepoDep,
|
||||
BatchUploadRepoDep,
|
||||
)
|
||||
from backend.web.core.rate_limiter import RateLimiter
|
||||
from backend.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from backend.web.core.autolabel_scheduler import (
|
||||
start_autolabel_scheduler,
|
||||
stop_autolabel_scheduler,
|
||||
get_autolabel_scheduler,
|
||||
)
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus, TaskManager
|
||||
|
||||
__all__ = [
|
||||
"validate_admin_token",
|
||||
"get_token_repository",
|
||||
"get_document_repository",
|
||||
"get_annotation_repository",
|
||||
"get_dataset_repository",
|
||||
"get_training_task_repository",
|
||||
"get_model_version_repository",
|
||||
"get_batch_upload_repository",
|
||||
"AdminTokenDep",
|
||||
"TokenRepoDep",
|
||||
"DocumentRepoDep",
|
||||
"AnnotationRepoDep",
|
||||
"DatasetRepoDep",
|
||||
"TrainingTaskRepoDep",
|
||||
"ModelVersionRepoDep",
|
||||
"BatchUploadRepoDep",
|
||||
"RateLimiter",
|
||||
"start_scheduler",
|
||||
"stop_scheduler",
|
||||
"get_training_scheduler",
|
||||
"start_autolabel_scheduler",
|
||||
"stop_autolabel_scheduler",
|
||||
"get_autolabel_scheduler",
|
||||
"TaskRunner",
|
||||
"TaskStatus",
|
||||
"TaskManager",
|
||||
]
|
||||
115
packages/backend/backend/web/core/auth.py
Normal file
115
packages/backend/backend/web/core/auth.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Admin Authentication
|
||||
|
||||
FastAPI dependencies for admin token authentication and repository access.
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from backend.data.repositories import (
|
||||
TokenRepository,
|
||||
DocumentRepository,
|
||||
AnnotationRepository,
|
||||
DatasetRepository,
|
||||
TrainingTaskRepository,
|
||||
ModelVersionRepository,
|
||||
BatchUploadRepository,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_token_repository() -> TokenRepository:
|
||||
"""Get the TokenRepository instance (thread-safe singleton)."""
|
||||
return TokenRepository()
|
||||
|
||||
|
||||
def reset_token_repository() -> None:
|
||||
"""Reset the TokenRepository instance (for testing)."""
|
||||
get_token_repository.cache_clear()
|
||||
|
||||
|
||||
async def validate_admin_token(
|
||||
x_admin_token: Annotated[str | None, Header()] = None,
|
||||
token_repo: TokenRepository = Depends(get_token_repository),
|
||||
) -> str:
|
||||
"""Validate admin token from header."""
|
||||
if not x_admin_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Admin token required. Provide X-Admin-Token header.",
|
||||
)
|
||||
|
||||
if not token_repo.is_valid(x_admin_token):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or expired admin token.",
|
||||
)
|
||||
|
||||
# Update last used timestamp
|
||||
token_repo.update_usage(x_admin_token)
|
||||
|
||||
return x_admin_token
|
||||
|
||||
|
||||
# Type alias for dependency injection
|
||||
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
|
||||
TokenRepoDep = Annotated[TokenRepository, Depends(get_token_repository)]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_document_repository() -> DocumentRepository:
|
||||
"""Get the DocumentRepository instance (thread-safe singleton)."""
|
||||
return DocumentRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_annotation_repository() -> AnnotationRepository:
|
||||
"""Get the AnnotationRepository instance (thread-safe singleton)."""
|
||||
return AnnotationRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_dataset_repository() -> DatasetRepository:
|
||||
"""Get the DatasetRepository instance (thread-safe singleton)."""
|
||||
return DatasetRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_training_task_repository() -> TrainingTaskRepository:
|
||||
"""Get the TrainingTaskRepository instance (thread-safe singleton)."""
|
||||
return TrainingTaskRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_model_version_repository() -> ModelVersionRepository:
|
||||
"""Get the ModelVersionRepository instance (thread-safe singleton)."""
|
||||
return ModelVersionRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_batch_upload_repository() -> BatchUploadRepository:
|
||||
"""Get the BatchUploadRepository instance (thread-safe singleton)."""
|
||||
return BatchUploadRepository()
|
||||
|
||||
|
||||
def reset_all_repositories() -> None:
|
||||
"""Reset all repository instances (for testing)."""
|
||||
get_token_repository.cache_clear()
|
||||
get_document_repository.cache_clear()
|
||||
get_annotation_repository.cache_clear()
|
||||
get_dataset_repository.cache_clear()
|
||||
get_training_task_repository.cache_clear()
|
||||
get_model_version_repository.cache_clear()
|
||||
get_batch_upload_repository.cache_clear()
|
||||
|
||||
|
||||
# Repository dependency type aliases
|
||||
DocumentRepoDep = Annotated[DocumentRepository, Depends(get_document_repository)]
|
||||
AnnotationRepoDep = Annotated[AnnotationRepository, Depends(get_annotation_repository)]
|
||||
DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
|
||||
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
|
||||
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
|
||||
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]
|
||||
202
packages/backend/backend/web/core/autolabel_scheduler.py
Normal file
202
packages/backend/backend/web/core/autolabel_scheduler.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Auto-Label Scheduler
|
||||
|
||||
Background scheduler for processing documents pending auto-labeling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from backend.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelScheduler(TaskRunner):
|
||||
"""Scheduler for auto-labeling tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 10,
|
||||
batch_size: int = 5,
|
||||
output_dir: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize auto-label scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
batch_size: Number of documents to process per batch
|
||||
output_dir: Output directory for temporary files
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
if output_dir is None:
|
||||
storage = get_storage_helper()
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
self._doc_repo = DocumentRepository()
|
||||
self._ann_repo = AnnotationRepository()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "autolabel_scheduler"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the scheduler."""
|
||||
try:
|
||||
pending_docs = get_pending_autolabel_documents(limit=1000)
|
||||
pending_count = len(pending_docs)
|
||||
except Exception:
|
||||
pending_count = 0
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=pending_count,
|
||||
processing_count=1 if self._running else 0,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("AutoLabel scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the scheduler.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown.
|
||||
If None, uses default of 5 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
thread_to_join = self._thread
|
||||
|
||||
effective_timeout = timeout if timeout is not None else 5.0
|
||||
if thread_to_join:
|
||||
thread_to_join.join(timeout=effective_timeout)
|
||||
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._process_pending_documents()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _process_pending_documents(self) -> None:
|
||||
"""Check and process pending auto-label documents."""
|
||||
try:
|
||||
documents = get_pending_autolabel_documents(limit=self._batch_size)
|
||||
|
||||
if not documents:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(documents)} pending autolabel documents")
|
||||
|
||||
for doc in documents:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
output_dir=self._output_dir,
|
||||
doc_repo=self._doc_repo,
|
||||
ann_repo=self._ann_repo,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
logger.info(
|
||||
f"AutoLabel completed for document {doc.document_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"AutoLabel failed for document {doc.document_id}: "
|
||||
f"{result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing document {doc.document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_autolabel_scheduler: AutoLabelScheduler | None = None
|
||||
_autolabel_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_autolabel_scheduler() -> AutoLabelScheduler:
|
||||
"""Get the auto-label scheduler instance.
|
||||
|
||||
Uses double-checked locking pattern for thread safety.
|
||||
"""
|
||||
global _autolabel_scheduler
|
||||
|
||||
if _autolabel_scheduler is None:
|
||||
with _autolabel_lock:
|
||||
if _autolabel_scheduler is None:
|
||||
_autolabel_scheduler = AutoLabelScheduler()
|
||||
|
||||
return _autolabel_scheduler
|
||||
|
||||
|
||||
def start_autolabel_scheduler() -> None:
|
||||
"""Start the global auto-label scheduler."""
|
||||
scheduler = get_autolabel_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_autolabel_scheduler() -> None:
|
||||
"""Stop the global auto-label scheduler."""
|
||||
global _autolabel_scheduler
|
||||
if _autolabel_scheduler:
|
||||
_autolabel_scheduler.stop()
|
||||
_autolabel_scheduler = None
|
||||
211
packages/backend/backend/web/core/rate_limiter.py
Normal file
211
packages/backend/backend/web/core/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Rate Limiter Implementation
|
||||
|
||||
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Rate limit configuration for an API key."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitStatus:
|
||||
"""Current rate limit status."""
|
||||
|
||||
allowed: bool
|
||||
remaining_requests: int
|
||||
reset_at: datetime
|
||||
retry_after_seconds: int | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Thread-safe rate limiter with sliding window algorithm.
|
||||
|
||||
Tracks:
|
||||
- Requests per minute (sliding window)
|
||||
- Concurrent active jobs
|
||||
- Poll frequency per request_id
|
||||
"""
|
||||
|
||||
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||
self._db = db
|
||||
self._lock = Lock()
|
||||
# In-memory tracking for fast checks
|
||||
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||
# (api_key, request_id) -> last_poll timestamp
|
||||
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||
# Cache for API key configs (TTL 60 seconds)
|
||||
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||
self._config_cache_ttl = 60.0
|
||||
|
||||
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||
"""Check if API key can submit a new request."""
|
||||
config = self._get_config(api_key)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
|
||||
# Clean old entries
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
|
||||
current_count = len(self._request_windows[api_key])
|
||||
|
||||
if current_count >= config.requests_per_minute:
|
||||
oldest = min(self._request_windows[api_key])
|
||||
retry_after = int(oldest + 60 - now) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||
retry_after_seconds=max(1, retry_after),
|
||||
reason="Rate limit exceeded: too many requests per minute",
|
||||
)
|
||||
|
||||
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||
active_jobs = self._db.count_active_jobs(api_key)
|
||||
if active_jobs >= config.max_concurrent_jobs:
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=config.requests_per_minute - current_count,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||
)
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
def record_request(self, api_key: str) -> None:
|
||||
"""Record a successful request submission."""
|
||||
with self._lock:
|
||||
self._request_windows[api_key].append(time.time())
|
||||
|
||||
# Also record in database for persistence
|
||||
try:
|
||||
self._db.record_rate_limit_event(api_key, "request")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record rate limit event: {e}")
|
||||
|
||||
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||
"""Check if polling is allowed (prevent abuse)."""
|
||||
config = self._get_config(api_key)
|
||||
key = (api_key, request_id)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
last_poll = self._poll_timestamps.get(key, 0)
|
||||
elapsed_ms = (now - last_poll) * 1000
|
||||
|
||||
if elapsed_ms < config.min_poll_interval_ms:
|
||||
# Suggest exponential backoff
|
||||
wait_ms = min(
|
||||
config.min_poll_interval_ms * 2,
|
||||
5000, # Max 5 seconds
|
||||
)
|
||||
retry_after = int(wait_ms / 1000) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||
retry_after_seconds=retry_after,
|
||||
reason="Polling too frequently. Please wait before retrying.",
|
||||
)
|
||||
|
||||
# Update poll timestamp
|
||||
self._poll_timestamps[key] = now
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999, # No limit on poll count, just frequency
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||
"""Get rate limit config for API key with caching."""
|
||||
now = time.time()
|
||||
|
||||
# Check cache
|
||||
if api_key in self._config_cache:
|
||||
cached_config, cached_at = self._config_cache[api_key]
|
||||
if now - cached_at < self._config_cache_ttl:
|
||||
return cached_config
|
||||
|
||||
# Query database
|
||||
db_config = self._db.get_api_key_config(api_key)
|
||||
if db_config:
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=db_config.requests_per_minute,
|
||||
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||
)
|
||||
else:
|
||||
config = RateLimitConfig() # Default limits
|
||||
|
||||
# Cache result
|
||||
self._config_cache[api_key] = (config, now)
|
||||
return config
|
||||
|
||||
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||
"""Clean up old poll timestamps to prevent memory leak."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
cutoff = now - max_age_seconds
|
||||
old_keys = [
|
||||
k for k, v in self._poll_timestamps.items()
|
||||
if v < cutoff
|
||||
]
|
||||
for key in old_keys:
|
||||
del self._poll_timestamps[key]
|
||||
return len(old_keys)
|
||||
|
||||
def cleanup_request_windows(self) -> None:
|
||||
"""Clean up expired entries from request windows."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60
|
||||
|
||||
for api_key in list(self._request_windows.keys()):
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
# Remove empty entries
|
||||
if not self._request_windows[api_key]:
|
||||
del self._request_windows[api_key]
|
||||
|
||||
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||
"""Generate rate limit headers for HTTP response."""
|
||||
headers = {
|
||||
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||
}
|
||||
if status.retry_after_seconds:
|
||||
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||
return headers
|
||||
571
packages/backend/backend/web/core/scheduler.py
Normal file
571
packages/backend/backend/web/core/scheduler.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Admin Training Scheduler
|
||||
|
||||
Background scheduler for training tasks using APScheduler.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.data.repositories import (
|
||||
TrainingTaskRepository,
|
||||
DatasetRepository,
|
||||
ModelVersionRepository,
|
||||
DocumentRepository,
|
||||
AnnotationRepository,
|
||||
)
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingScheduler(TaskRunner):
|
||||
"""Scheduler for training tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 60,
|
||||
):
|
||||
"""
|
||||
Initialize training scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
# Repositories
|
||||
self._training_tasks = TrainingTaskRepository()
|
||||
self._datasets = DatasetRepository()
|
||||
self._model_versions = ModelVersionRepository()
|
||||
self._documents = DocumentRepository()
|
||||
self._annotations = AnnotationRepository()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "training_scheduler"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the scheduler is currently active."""
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the scheduler."""
|
||||
try:
|
||||
pending_tasks = self._training_tasks.get_pending()
|
||||
pending_count = len(pending_tasks)
|
||||
except Exception:
|
||||
pending_count = 0
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=pending_count,
|
||||
processing_count=1 if self._running else 0,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("Training scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Training scheduler started")
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the scheduler.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown.
|
||||
If None, uses default of 5 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
thread_to_join = self._thread
|
||||
|
||||
effective_timeout = timeout if timeout is not None else 5.0
|
||||
if thread_to_join:
|
||||
thread_to_join.join(timeout=effective_timeout)
|
||||
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
logger.info("Training scheduler stopped")
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._check_pending_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduler loop: {e}")
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _check_pending_tasks(self) -> None:
|
||||
"""Check and execute pending training tasks."""
|
||||
try:
|
||||
tasks = self._training_tasks.get_pending()
|
||||
|
||||
for task in tasks:
|
||||
task_id = str(task.task_id)
|
||||
|
||||
# Check if scheduled time has passed
|
||||
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
|
||||
continue
|
||||
|
||||
logger.info(f"Starting training task: {task_id}")
|
||||
|
||||
try:
|
||||
dataset_id = getattr(task, "dataset_id", None)
|
||||
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._training_tasks.update_status(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending tasks: {e}")
|
||||
|
||||
def _execute_task(
|
||||
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
|
||||
) -> None:
|
||||
"""Execute a training task."""
|
||||
# Update status to running
|
||||
self._training_tasks.update_status(task_id, "running")
|
||||
self._training_tasks.add_log(task_id, "INFO", "Training task started")
|
||||
|
||||
# Update dataset training status to running
|
||||
if dataset_id:
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
base_model_path = config.get("base_model_path") # For incremental training
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
image_size = config.get("image_size", 640)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Get augmentation config if present
|
||||
augmentation_config = config.get("augmentation")
|
||||
augmentation_multiplier = config.get("augmentation_multiplier", 2)
|
||||
|
||||
# Determine which model to use as base
|
||||
if base_model_path:
|
||||
# Incremental training: use existing trained model
|
||||
if not Path(base_model_path).exists():
|
||||
raise ValueError(f"Base model not found: {base_model_path}")
|
||||
effective_model = base_model_path
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Incremental training from: {base_model_path}",
|
||||
)
|
||||
else:
|
||||
# Train from pretrained model
|
||||
effective_model = model_name
|
||||
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._datasets.get(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
dataset_path = Path(dataset.dataset_path)
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
)
|
||||
else:
|
||||
export_result = self._export_training_data(task_id)
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
dataset_path = Path(data_yaml).parent
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Apply augmentation if config is provided
|
||||
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
|
||||
aug_result = self._apply_augmentation(
|
||||
task_id, dataset_path, augmentation_config, augmentation_multiplier
|
||||
)
|
||||
if aug_result:
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Augmentation complete: {aug_result['augmented_images']} new images "
|
||||
f"(total: {aug_result['total_images']})",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=effective_model, # Use base model or pretrained model
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
# Update task with results
|
||||
self._training_tasks.update_status(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
result_metrics=result.get("metrics"),
|
||||
model_path=result.get("model_path"),
|
||||
)
|
||||
self._training_tasks.add_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Update dataset training status to completed and main status to trained
|
||||
if dataset_id:
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
update_main_status=True, # Set main status to 'trained'
|
||||
)
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
config=config,
|
||||
dataset_id=dataset_id,
|
||||
result=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
# Update dataset training status to failed
|
||||
if dataset_id:
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="failed",
|
||||
active_training_task_id=None,
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_model_version_from_training(
|
||||
self,
|
||||
task_id: str,
|
||||
config: dict[str, Any],
|
||||
dataset_id: str | None,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a model version entry from completed training."""
|
||||
try:
|
||||
model_path = result.get("model_path")
|
||||
if not model_path:
|
||||
logger.warning(f"No model path in training result for task {task_id}")
|
||||
return
|
||||
|
||||
# Get task info for name
|
||||
task = self._training_tasks.get(task_id)
|
||||
task_name = task.name if task else f"Task {task_id[:8]}"
|
||||
|
||||
# Generate version number based on existing versions
|
||||
existing_versions = self._model_versions.get_paginated(limit=1, offset=0)
|
||||
version_count = existing_versions[1] if existing_versions else 0
|
||||
version = f"v{version_count + 1}.0"
|
||||
|
||||
# Extract metrics from result
|
||||
metrics = result.get("metrics", {})
|
||||
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
|
||||
metrics_precision = metrics.get("precision")
|
||||
metrics_recall = metrics.get("recall")
|
||||
|
||||
# Get file size if possible
|
||||
file_size = None
|
||||
model_file = Path(model_path)
|
||||
if model_file.exists():
|
||||
file_size = model_file.stat().st_size
|
||||
|
||||
# Get document count from dataset if available
|
||||
document_count = 0
|
||||
if dataset_id:
|
||||
dataset = self._datasets.get(dataset_id)
|
||||
if dataset:
|
||||
document_count = dataset.total_documents
|
||||
|
||||
# Create model version
|
||||
model_version = self._model_versions.create(
|
||||
version=version,
|
||||
name=task_name,
|
||||
model_path=str(model_path),
|
||||
description=f"Auto-created from training task {task_id[:8]}",
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=config,
|
||||
file_size=file_size,
|
||||
trained_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {mAP_display})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Failed to auto-create model version: {e}",
|
||||
)
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Get storage helper for reading images
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._documents.get_labeled_for_export()
|
||||
|
||||
if not documents:
|
||||
self._training_tasks.add_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory using StorageHelper
|
||||
training_base = storage.get_training_data_path()
|
||||
if training_base is None:
|
||||
self._training_tasks.add_log(task_id, "ERROR", "Storage not configured for local access")
|
||||
return None
|
||||
export_dir = training_base / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 80/20 train/val split
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * 0.8)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = self._annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
# Get image from storage
|
||||
doc_id = str(doc.document_id)
|
||||
if not storage.admin_image_exists(doc_id, page_num):
|
||||
continue
|
||||
|
||||
# Download image and save to export directory
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
image_content = storage.get_admin_image(doc_id, page_num)
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
# Create data.yaml
|
||||
yaml_path = export_dir / "data.yaml"
|
||||
yaml_content = f"""path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
yaml_path.write_text(yaml_content)
|
||||
|
||||
return {
|
||||
"data_yaml": str(yaml_path),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _run_yolo_training(
|
||||
self,
|
||||
task_id: str,
|
||||
model_name: str,
|
||||
data_yaml: str,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
image_size: int,
|
||||
learning_rate: float,
|
||||
device: str,
|
||||
project_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training using shared trainer."""
|
||||
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
|
||||
|
||||
# Create log callback that writes to DB
|
||||
def log_callback(level: str, message: str) -> None:
|
||||
self._training_tasks.add_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
|
||||
# because models need to be accessible by inference service on any platform
|
||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||
config = SharedTrainingConfig(
|
||||
model_path=model_name,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project="runs/train",
|
||||
name=f"{project_name}/task_{task_id[:8]}",
|
||||
workers=0,
|
||||
)
|
||||
|
||||
# Run training using shared trainer
|
||||
trainer = YOLOTrainer(config=config, log_callback=log_callback)
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
raise ValueError(result.error or "Training failed")
|
||||
|
||||
return {
|
||||
"model_path": result.model_path,
|
||||
"metrics": result.metrics,
|
||||
}
|
||||
|
||||
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
|
||||
"""Check if any augmentations are enabled in the config."""
|
||||
augmentation_fields = [
|
||||
"perspective_warp", "wrinkle", "edge_damage", "stain",
|
||||
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
|
||||
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
|
||||
]
|
||||
for field in augmentation_fields:
|
||||
if field in aug_config:
|
||||
field_config = aug_config[field]
|
||||
if isinstance(field_config, dict) and field_config.get("enabled", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
task_id: str,
|
||||
dataset_path: Path,
|
||||
aug_config: dict[str, Any],
|
||||
multiplier: int,
|
||||
) -> dict[str, int] | None:
|
||||
"""Apply augmentation to dataset before training."""
|
||||
try:
|
||||
from shared.augmentation import DatasetAugmenter
|
||||
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Applying augmentation with multiplier={multiplier}",
|
||||
)
|
||||
|
||||
augmenter = DatasetAugmenter(aug_config)
|
||||
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Augmentation failed for task {task_id}: {e}")
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Augmentation failed: {e}. Continuing with original dataset.",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: TrainingScheduler | None = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_training_scheduler() -> TrainingScheduler:
|
||||
"""Get the training scheduler instance.
|
||||
|
||||
Uses double-checked locking pattern for thread safety.
|
||||
"""
|
||||
global _scheduler
|
||||
|
||||
if _scheduler is None:
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = TrainingScheduler()
|
||||
|
||||
return _scheduler
|
||||
|
||||
|
||||
def start_scheduler() -> None:
|
||||
"""Start the global training scheduler."""
|
||||
scheduler = get_training_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_scheduler() -> None:
|
||||
"""Stop the global training scheduler."""
|
||||
global _scheduler
|
||||
if _scheduler:
|
||||
_scheduler.stop()
|
||||
_scheduler = None
|
||||
161
packages/backend/backend/web/core/task_interface.py
Normal file
161
packages/backend/backend/web/core/task_interface.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Unified task management interface.
|
||||
|
||||
Provides abstract base class for all task runners (schedulers and queues)
|
||||
and a TaskManager facade for unified lifecycle management.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskStatus:
|
||||
"""Status of a task runner.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the runner.
|
||||
is_running: Whether the runner is currently active.
|
||||
pending_count: Number of tasks waiting to be processed.
|
||||
processing_count: Number of tasks currently being processed.
|
||||
error: Optional error message if runner is in error state.
|
||||
"""
|
||||
|
||||
name: str
|
||||
is_running: bool
|
||||
pending_count: int
|
||||
processing_count: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TaskRunner(ABC):
|
||||
"""Abstract base class for all task runners.
|
||||
|
||||
All schedulers and task queues should implement this interface
|
||||
to enable unified lifecycle management and monitoring.
|
||||
|
||||
Note:
|
||||
Implementations may have different `start()` signatures based on
|
||||
their initialization needs (e.g., handler functions, services).
|
||||
Use the implementation-specific start methods for initialization,
|
||||
and use TaskManager for unified status monitoring.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, *args, **kwargs) -> None:
|
||||
"""Start the task runner.
|
||||
|
||||
Should be idempotent - calling start on an already running
|
||||
runner should have no effect.
|
||||
|
||||
Note:
|
||||
Implementations may require additional parameters (handlers,
|
||||
services). See implementation-specific documentation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the task runner gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown in seconds.
|
||||
If None, use implementation default.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the runner is currently active."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the runner.
|
||||
|
||||
Returns:
|
||||
TaskStatus with current state information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Unified manager for all task runners.
|
||||
|
||||
Provides centralized lifecycle management and monitoring
|
||||
for all registered task runners.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the task manager."""
|
||||
self._runners: dict[str, TaskRunner] = {}
|
||||
|
||||
def register(self, runner: TaskRunner) -> None:
|
||||
"""Register a task runner.
|
||||
|
||||
Args:
|
||||
runner: TaskRunner instance to register.
|
||||
"""
|
||||
self._runners[runner.name] = runner
|
||||
|
||||
def get_runner(self, name: str) -> TaskRunner | None:
|
||||
"""Get a specific runner by name.
|
||||
|
||||
Args:
|
||||
name: Name of the runner to retrieve.
|
||||
|
||||
Returns:
|
||||
TaskRunner if found, None otherwise.
|
||||
"""
|
||||
return self._runners.get(name)
|
||||
|
||||
@property
|
||||
def runner_names(self) -> list[str]:
|
||||
"""Get names of all registered runners.
|
||||
|
||||
Returns:
|
||||
List of runner names.
|
||||
"""
|
||||
return list(self._runners.keys())
|
||||
|
||||
def start_all(self) -> None:
|
||||
"""Start all registered runners that support no-argument start.
|
||||
|
||||
Note:
|
||||
Runners requiring initialization parameters (like AsyncTaskQueue
|
||||
or BatchTaskQueue) should be started individually before
|
||||
registering with TaskManager.
|
||||
"""
|
||||
for runner in self._runners.values():
|
||||
try:
|
||||
runner.start()
|
||||
except TypeError:
|
||||
# Runner requires arguments - skip (should be started individually)
|
||||
pass
|
||||
|
||||
def stop_all(self, timeout: float = 30.0) -> None:
|
||||
"""Stop all registered runners gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Total timeout to distribute across all runners.
|
||||
"""
|
||||
if not self._runners:
|
||||
return
|
||||
|
||||
per_runner_timeout = timeout / len(self._runners)
|
||||
for runner in self._runners.values():
|
||||
runner.stop(timeout=per_runner_timeout)
|
||||
|
||||
def get_all_status(self) -> dict[str, TaskStatus]:
|
||||
"""Get status of all registered runners.
|
||||
|
||||
Returns:
|
||||
Dict mapping runner names to their status.
|
||||
"""
|
||||
return {name: runner.get_status() for name, runner in self._runners.items()}
|
||||
133
packages/backend/backend/web/dependencies.py
Normal file
133
packages/backend/backend/web/dependencies.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
FastAPI Dependencies
|
||||
|
||||
Dependency injection for the async API endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, Request
|
||||
|
||||
from backend.data.async_request_db import AsyncRequestDB
|
||||
from backend.web.rate_limiter import RateLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances (initialized in app startup)
|
||||
_async_db: AsyncRequestDB | None = None
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def init_dependencies(db: AsyncRequestDB, rate_limiter: RateLimiter) -> None:
|
||||
"""Initialize global dependency instances."""
|
||||
global _async_db, _rate_limiter
|
||||
_async_db = db
|
||||
_rate_limiter = rate_limiter
|
||||
|
||||
|
||||
def get_async_db() -> AsyncRequestDB:
|
||||
"""Get async request database instance."""
|
||||
if _async_db is None:
|
||||
raise RuntimeError("AsyncRequestDB not initialized")
|
||||
return _async_db
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get rate limiter instance."""
|
||||
if _rate_limiter is None:
|
||||
raise RuntimeError("RateLimiter not initialized")
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Verify API key exists and is active.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if API key is missing or invalid
|
||||
"""
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="X-API-Key header is required",
|
||||
headers={"WWW-Authenticate": "API-Key"},
|
||||
)
|
||||
|
||||
db = get_async_db()
|
||||
if not db.is_valid_api_key(x_api_key):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or inactive API key",
|
||||
headers={"WWW-Authenticate": "API-Key"},
|
||||
)
|
||||
|
||||
# Update usage tracking
|
||||
try:
|
||||
db.update_api_key_usage(x_api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update API key usage: {e}")
|
||||
|
||||
return x_api_key
|
||||
|
||||
|
||||
async def check_submit_rate_limit(
|
||||
api_key: Annotated[str, Depends(verify_api_key)],
|
||||
) -> str:
|
||||
"""
|
||||
Check rate limit before processing submit request.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if rate limit exceeded
|
||||
"""
|
||||
rate_limiter = get_rate_limiter()
|
||||
status = rate_limiter.check_submit_limit(api_key)
|
||||
|
||||
if not status.allowed:
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=status.reason or "Rate limit exceeded",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
async def check_poll_rate_limit(
|
||||
request: Request,
|
||||
api_key: Annotated[str, Depends(verify_api_key)],
|
||||
) -> str:
|
||||
"""
|
||||
Check poll rate limit to prevent abuse.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if polling too frequently
|
||||
"""
|
||||
# Extract request_id from path parameters
|
||||
request_id = request.path_params.get("request_id")
|
||||
if not request_id:
|
||||
return api_key # No request_id, skip poll limit check
|
||||
|
||||
rate_limiter = get_rate_limiter()
|
||||
status = rate_limiter.check_poll_limit(api_key, request_id)
|
||||
|
||||
if not status.allowed:
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=status.reason or "Polling too frequently",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
# Type aliases for cleaner route signatures
|
||||
ApiKeyDep = Annotated[str, Depends(verify_api_key)]
|
||||
SubmitRateLimitDep = Annotated[str, Depends(check_submit_rate_limit)]
|
||||
PollRateLimitDep = Annotated[str, Depends(check_poll_rate_limit)]
|
||||
AsyncDBDep = Annotated[AsyncRequestDB, Depends(get_async_db)]
|
||||
RateLimiterDep = Annotated[RateLimiter, Depends(get_rate_limiter)]
|
||||
211
packages/backend/backend/web/rate_limiter.py
Normal file
211
packages/backend/backend/web/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Rate Limiter Implementation
|
||||
|
||||
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Rate limit configuration for an API key."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitStatus:
|
||||
"""Current rate limit status."""
|
||||
|
||||
allowed: bool
|
||||
remaining_requests: int
|
||||
reset_at: datetime
|
||||
retry_after_seconds: int | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Thread-safe rate limiter with sliding window algorithm.
|
||||
|
||||
Tracks:
|
||||
- Requests per minute (sliding window)
|
||||
- Concurrent active jobs
|
||||
- Poll frequency per request_id
|
||||
"""
|
||||
|
||||
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||
self._db = db
|
||||
self._lock = Lock()
|
||||
# In-memory tracking for fast checks
|
||||
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||
# (api_key, request_id) -> last_poll timestamp
|
||||
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||
# Cache for API key configs (TTL 60 seconds)
|
||||
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||
self._config_cache_ttl = 60.0
|
||||
|
||||
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||
"""Check if API key can submit a new request."""
|
||||
config = self._get_config(api_key)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
|
||||
# Clean old entries
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
|
||||
current_count = len(self._request_windows[api_key])
|
||||
|
||||
if current_count >= config.requests_per_minute:
|
||||
oldest = min(self._request_windows[api_key])
|
||||
retry_after = int(oldest + 60 - now) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||
retry_after_seconds=max(1, retry_after),
|
||||
reason="Rate limit exceeded: too many requests per minute",
|
||||
)
|
||||
|
||||
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||
active_jobs = self._db.count_active_jobs(api_key)
|
||||
if active_jobs >= config.max_concurrent_jobs:
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=config.requests_per_minute - current_count,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||
)
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
def record_request(self, api_key: str) -> None:
|
||||
"""Record a successful request submission."""
|
||||
with self._lock:
|
||||
self._request_windows[api_key].append(time.time())
|
||||
|
||||
# Also record in database for persistence
|
||||
try:
|
||||
self._db.record_rate_limit_event(api_key, "request")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record rate limit event: {e}")
|
||||
|
||||
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||
"""Check if polling is allowed (prevent abuse)."""
|
||||
config = self._get_config(api_key)
|
||||
key = (api_key, request_id)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
last_poll = self._poll_timestamps.get(key, 0)
|
||||
elapsed_ms = (now - last_poll) * 1000
|
||||
|
||||
if elapsed_ms < config.min_poll_interval_ms:
|
||||
# Suggest exponential backoff
|
||||
wait_ms = min(
|
||||
config.min_poll_interval_ms * 2,
|
||||
5000, # Max 5 seconds
|
||||
)
|
||||
retry_after = int(wait_ms / 1000) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||
retry_after_seconds=retry_after,
|
||||
reason="Polling too frequently. Please wait before retrying.",
|
||||
)
|
||||
|
||||
# Update poll timestamp
|
||||
self._poll_timestamps[key] = now
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999, # No limit on poll count, just frequency
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||
"""Get rate limit config for API key with caching."""
|
||||
now = time.time()
|
||||
|
||||
# Check cache
|
||||
if api_key in self._config_cache:
|
||||
cached_config, cached_at = self._config_cache[api_key]
|
||||
if now - cached_at < self._config_cache_ttl:
|
||||
return cached_config
|
||||
|
||||
# Query database
|
||||
db_config = self._db.get_api_key_config(api_key)
|
||||
if db_config:
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=db_config.requests_per_minute,
|
||||
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||
)
|
||||
else:
|
||||
config = RateLimitConfig() # Default limits
|
||||
|
||||
# Cache result
|
||||
self._config_cache[api_key] = (config, now)
|
||||
return config
|
||||
|
||||
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||
"""Clean up old poll timestamps to prevent memory leak."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
cutoff = now - max_age_seconds
|
||||
old_keys = [
|
||||
k for k, v in self._poll_timestamps.items()
|
||||
if v < cutoff
|
||||
]
|
||||
for key in old_keys:
|
||||
del self._poll_timestamps[key]
|
||||
return len(old_keys)
|
||||
|
||||
def cleanup_request_windows(self) -> None:
|
||||
"""Clean up expired entries from request windows."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60
|
||||
|
||||
for api_key in list(self._request_windows.keys()):
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
# Remove empty entries
|
||||
if not self._request_windows[api_key]:
|
||||
del self._request_windows[api_key]
|
||||
|
||||
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||
"""Generate rate limit headers for HTTP response."""
|
||||
headers = {
|
||||
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||
}
|
||||
if status.retry_after_seconds:
|
||||
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||
return headers
|
||||
11
packages/backend/backend/web/schemas/__init__.py
Normal file
11
packages/backend/backend/web/schemas/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
API Schemas
|
||||
|
||||
Pydantic models for request/response validation.
|
||||
"""
|
||||
|
||||
# Import everything from sub-modules for backward compatibility
|
||||
from backend.web.schemas.common import * # noqa: F401, F403
|
||||
from backend.web.schemas.admin import * # noqa: F401, F403
|
||||
from backend.web.schemas.inference import * # noqa: F401, F403
|
||||
from backend.web.schemas.labeling import * # noqa: F401, F403
|
||||
19
packages/backend/backend/web/schemas/admin/__init__.py
Normal file
19
packages/backend/backend/web/schemas/admin/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Admin API Request/Response Schemas
|
||||
|
||||
Pydantic models for admin API validation and serialization.
|
||||
"""
|
||||
|
||||
from .enums import * # noqa: F401, F403
|
||||
from .auth import * # noqa: F401, F403
|
||||
from .documents import * # noqa: F401, F403
|
||||
from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
from .models import * # noqa: F401, F403
|
||||
from .dashboard import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
DocumentDetailResponse.model_rebuild()
|
||||
152
packages/backend/backend/web/schemas/admin/annotations.py
Normal file
152
packages/backend/backend/web/schemas/admin/annotations.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Admin Annotation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AnnotationSource
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box coordinates."""
|
||||
|
||||
x: int = Field(..., ge=0, description="X coordinate (pixels)")
|
||||
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
|
||||
width: int = Field(..., ge=1, description="Width (pixels)")
|
||||
height: int = Field(..., ge=1, description="Height (pixels)")
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
"""Request to create an annotation."""
|
||||
|
||||
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
text_value: str | None = Field(None, description="Text value (optional)")
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
"""Request to update an annotation."""
|
||||
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
|
||||
bbox: BoundingBox | None = Field(None, description="New bounding box")
|
||||
text_value: str | None = Field(None, description="New text value")
|
||||
|
||||
|
||||
class AnnotationItem(BaseModel):
|
||||
"""Single annotation item."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
page_number: int = Field(..., ge=1, description="Page number")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID")
|
||||
class_name: str = Field(..., description="Class name")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
normalized_bbox: dict[str, float] = Field(
|
||||
..., description="Normalized bbox (x_center, y_center, width, height)"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Text value")
|
||||
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
|
||||
source: AnnotationSource = Field(..., description="Annotation source")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
"""Response for annotation operation."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
"""Response for annotation list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
page_count: int = Field(..., ge=1, description="Total pages")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
annotations: list[AnnotationItem] = Field(
|
||||
default_factory=list, description="Annotation list"
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockRequest(BaseModel):
|
||||
"""Request to acquire annotation lock."""
|
||||
|
||||
duration_seconds: int = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Lock duration in seconds (60-3600)",
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockResponse(BaseModel):
|
||||
"""Response for annotation lock operation."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
locked: bool = Field(..., description="Whether lock was acquired/released")
|
||||
lock_expires_at: datetime | None = Field(
|
||||
None, description="Lock expiration time"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AutoLabelRequest(BaseModel):
|
||||
"""Request to trigger auto-labeling."""
|
||||
|
||||
field_values: dict[str, str] = Field(
|
||||
...,
|
||||
description="Field values to match (e.g., {'invoice_number': '12345'})",
|
||||
)
|
||||
replace_existing: bool = Field(
|
||||
default=False, description="Replace existing auto annotations"
|
||||
)
|
||||
|
||||
|
||||
class AutoLabelResponse(BaseModel):
|
||||
"""Response for auto-labeling."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
status: str = Field(..., description="Auto-labeling status")
|
||||
annotations_created: int = Field(
|
||||
default=0, ge=0, description="Number of annotations created"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationVerifyRequest(BaseModel):
|
||||
"""Request to verify an annotation."""
|
||||
|
||||
pass # No body needed, just POST to verify
|
||||
|
||||
|
||||
class AnnotationVerifyResponse(BaseModel):
|
||||
"""Response for annotation verification."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
is_verified: bool = Field(..., description="Verification status")
|
||||
verified_at: datetime = Field(..., description="Verification timestamp")
|
||||
verified_by: str = Field(..., description="Admin token who verified")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationOverrideRequest(BaseModel):
|
||||
"""Request to override an annotation."""
|
||||
|
||||
bbox: dict[str, int] | None = Field(
|
||||
None, description="Updated bounding box {x, y, width, height}"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Updated text value")
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
|
||||
class_name: str | None = Field(None, description="Updated class name")
|
||||
reason: str | None = Field(None, description="Reason for override")
|
||||
|
||||
|
||||
class AnnotationOverrideResponse(BaseModel):
|
||||
"""Response for annotation override."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
source: str = Field(..., description="New source (manual)")
|
||||
override_source: str | None = Field(None, description="Original source (auto)")
|
||||
original_annotation_id: str | None = Field(None, description="Original annotation ID")
|
||||
message: str = Field(..., description="Status message")
|
||||
history_id: str = Field(..., description="History record UUID")
|
||||
187
packages/backend/backend/web/schemas/admin/augmentation.py
Normal file
187
packages/backend/backend/web/schemas/admin/augmentation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Admin Augmentation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AugmentationParamsSchema(BaseModel):
|
||||
"""Single augmentation parameters."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
|
||||
probability: float = Field(
|
||||
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Type-specific parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationConfigSchema(BaseModel):
|
||||
"""Complete augmentation configuration."""
|
||||
|
||||
# Geometric transforms
|
||||
perspective_warp: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Degradation effects
|
||||
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
edge_damage: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Lighting effects
|
||||
lighting_variation: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Blur effects
|
||||
gaussian_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
motion_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Noise effects
|
||||
gaussian_noise: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
salt_pepper: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Texture effects
|
||||
paper_texture: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
scanner_artifacts: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Global settings
|
||||
preserve_bboxes: bool = Field(
|
||||
default=True, description="Whether to adjust bboxes for geometric transforms"
|
||||
)
|
||||
seed: int | None = Field(default=None, description="Random seed for reproducibility")
|
||||
|
||||
|
||||
class AugmentationTypeInfo(BaseModel):
|
||||
"""Information about an augmentation type."""
|
||||
|
||||
name: str = Field(..., description="Augmentation name")
|
||||
description: str = Field(..., description="Augmentation description")
|
||||
affects_geometry: bool = Field(
|
||||
..., description="Whether this augmentation affects bbox coordinates"
|
||||
)
|
||||
stage: str = Field(..., description="Processing stage")
|
||||
default_params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationTypesResponse(BaseModel):
|
||||
"""Response for listing augmentation types."""
|
||||
|
||||
augmentation_types: list[AugmentationTypeInfo] = Field(
|
||||
..., description="Available augmentation types"
|
||||
)
|
||||
|
||||
|
||||
class PresetInfo(BaseModel):
|
||||
"""Information about a preset."""
|
||||
|
||||
name: str = Field(..., description="Preset name")
|
||||
description: str = Field(..., description="Preset description")
|
||||
|
||||
|
||||
class PresetsResponse(BaseModel):
|
||||
"""Response for listing presets."""
|
||||
|
||||
presets: list[PresetInfo] = Field(..., description="Available presets")
|
||||
|
||||
|
||||
class AugmentationPreviewRequest(BaseModel):
|
||||
"""Request to preview augmentation on an image."""
|
||||
|
||||
augmentation_type: str = Field(..., description="Type of augmentation to preview")
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Override parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationPreviewResponse(BaseModel):
|
||||
"""Response with preview image data."""
|
||||
|
||||
preview_url: str = Field(..., description="URL to preview image")
|
||||
original_url: str = Field(..., description="URL to original image")
|
||||
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
|
||||
|
||||
|
||||
class AugmentationBatchRequest(BaseModel):
|
||||
"""Request to augment a dataset offline."""
|
||||
|
||||
dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
|
||||
output_name: str = Field(
|
||||
..., min_length=1, max_length=255, description="Output dataset name"
|
||||
)
|
||||
multiplier: int = Field(
|
||||
default=2, ge=1, le=10, description="Augmented copies per image"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationBatchResponse(BaseModel):
|
||||
"""Response for batch augmentation."""
|
||||
|
||||
task_id: str = Field(..., description="Background task UUID")
|
||||
status: str = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
estimated_images: int = Field(..., description="Estimated total images")
|
||||
|
||||
|
||||
class AugmentedDatasetItem(BaseModel):
|
||||
"""Single augmented dataset in list."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AugmentedDatasetListResponse(BaseModel):
|
||||
"""Response for listing augmented datasets."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total datasets")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
datasets: list[AugmentedDatasetItem] = Field(
|
||||
default_factory=list, description="Dataset list"
|
||||
)
|
||||
|
||||
|
||||
class AugmentedDatasetDetailResponse(BaseModel):
|
||||
"""Detailed augmented dataset response."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
config: AugmentationConfigSchema | None = Field(
|
||||
None, description="Augmentation config used"
|
||||
)
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
dataset_path: str | None = Field(None, description="Dataset path on disk")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
23
packages/backend/backend/web/schemas/admin/auth.py
Normal file
23
packages/backend/backend/web/schemas/admin/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Admin Auth Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AdminTokenCreate(BaseModel):
|
||||
"""Request to create an admin token."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Token name")
|
||||
expires_in_days: int | None = Field(
|
||||
None, ge=1, le=365, description="Token expiration in days (optional)"
|
||||
)
|
||||
|
||||
|
||||
class AdminTokenResponse(BaseModel):
|
||||
"""Response with created admin token."""
|
||||
|
||||
token: str = Field(..., description="Admin token")
|
||||
name: str = Field(..., description="Token name")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration time")
|
||||
message: str = Field(..., description="Status message")
|
||||
92
packages/backend/backend/web/schemas/admin/dashboard.py
Normal file
92
packages/backend/backend/web/schemas/admin/dashboard.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Dashboard API Schemas
|
||||
|
||||
Pydantic models for dashboard statistics and activity endpoints.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Activity type literals for type safety
|
||||
ActivityType = Literal[
|
||||
"document_uploaded",
|
||||
"annotation_modified",
|
||||
"training_completed",
|
||||
"training_failed",
|
||||
"model_activated",
|
||||
]
|
||||
|
||||
|
||||
class DashboardStatsResponse(BaseModel):
|
||||
"""Response for dashboard statistics."""
|
||||
|
||||
total_documents: int = Field(..., description="Total number of documents")
|
||||
annotation_complete: int = Field(
|
||||
..., description="Documents with complete annotations"
|
||||
)
|
||||
annotation_incomplete: int = Field(
|
||||
..., description="Documents with incomplete annotations"
|
||||
)
|
||||
pending: int = Field(..., description="Documents pending processing")
|
||||
completeness_rate: float = Field(
|
||||
..., description="Annotation completeness percentage"
|
||||
)
|
||||
|
||||
|
||||
class ActiveModelInfo(BaseModel):
|
||||
"""Active model information."""
|
||||
|
||||
version_id: str = Field(..., description="Model version UUID")
|
||||
version: str = Field(..., description="Model version string")
|
||||
name: str = Field(..., description="Model name")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, description="Precision score")
|
||||
metrics_recall: float | None = Field(None, description="Recall score")
|
||||
document_count: int = Field(0, description="Number of training documents")
|
||||
activated_at: datetime | None = Field(None, description="Activation timestamp")
|
||||
|
||||
|
||||
class RunningTrainingInfo(BaseModel):
|
||||
"""Running training task information."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
status: str = Field(..., description="Training status")
|
||||
started_at: datetime | None = Field(None, description="Start timestamp")
|
||||
progress: int = Field(0, description="Training progress percentage")
|
||||
|
||||
|
||||
class DashboardActiveModelResponse(BaseModel):
|
||||
"""Response for dashboard active model endpoint."""
|
||||
|
||||
model: ActiveModelInfo | None = Field(
|
||||
None, description="Active model info, null if none"
|
||||
)
|
||||
running_training: RunningTrainingInfo | None = Field(
|
||||
None, description="Running training task, null if none"
|
||||
)
|
||||
|
||||
|
||||
class ActivityItem(BaseModel):
|
||||
"""Single activity item."""
|
||||
|
||||
type: ActivityType = Field(
|
||||
...,
|
||||
description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated",
|
||||
)
|
||||
description: str = Field(..., description="Human-readable description")
|
||||
timestamp: datetime = Field(..., description="Activity timestamp")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class RecentActivityResponse(BaseModel):
|
||||
"""Response for recent activity endpoint."""
|
||||
|
||||
activities: list[ActivityItem] = Field(
|
||||
default_factory=list, description="List of recent activities"
|
||||
)
|
||||
90
packages/backend/backend/web/schemas/admin/datasets.py
Normal file
90
packages/backend/backend/web/schemas/admin/datasets.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Admin Dataset Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .training import TrainingConfig
|
||||
|
||||
|
||||
class DatasetCreateRequest(BaseModel):
|
||||
"""Request to create a training dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||
category: str | None = Field(None, description="Filter documents by category (optional)")
|
||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||
seed: int = Field(42, description="Random seed for split")
|
||||
|
||||
|
||||
class DatasetDocumentItem(BaseModel):
|
||||
"""Document within a dataset."""
|
||||
|
||||
document_id: str
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
|
||||
|
||||
class DatasetResponse(BaseModel):
|
||||
"""Response after creating a dataset."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class DatasetDetailResponse(BaseModel):
|
||||
"""Detailed dataset info with documents."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
documents: list[DatasetDocumentItem]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetListItem(BaseModel):
|
||||
"""Dataset in list view."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
"""Paginated dataset list."""
|
||||
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
datasets: list[DatasetListItem]
|
||||
|
||||
|
||||
class DatasetTrainRequest(BaseModel):
|
||||
"""Request to start training from a dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
|
||||
config: TrainingConfig = Field(..., description="Training configuration")
|
||||
123
packages/backend/backend/web/schemas/admin/documents.py
Normal file
123
packages/backend/backend/web/schemas/admin/documents.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Admin Document Schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AutoLabelStatus, DocumentStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .annotations import AnnotationItem
|
||||
from .training import TrainingHistoryItem
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Response for document upload."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class DocumentItem(BaseModel):
|
||||
"""Single document in list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""Response for document list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
documents: list[DocumentItem] = Field(
|
||||
default_factory=list, description="Document list"
|
||||
)
|
||||
|
||||
|
||||
class DocumentDetailResponse(BaseModel):
|
||||
"""Response for document detail."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
content_type: str = Field(..., description="MIME type")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
annotation_lock_until: datetime | None = Field(
|
||||
None, description="Lock expiration time if document is locked"
|
||||
)
|
||||
annotations: list["AnnotationItem"] = Field(
|
||||
default_factory=list, description="Document annotations"
|
||||
)
|
||||
image_urls: list[str] = Field(
|
||||
default_factory=list, description="URLs to page images"
|
||||
)
|
||||
training_history: list["TrainingHistoryItem"] = Field(
|
||||
default_factory=list, description="Training tasks that used this document"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentStatsResponse(BaseModel):
|
||||
"""Document statistics response."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
pending: int = Field(default=0, ge=0, description="Pending documents")
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
|
||||
|
||||
class DocumentUpdateRequest(BaseModel):
|
||||
"""Request for updating document metadata."""
|
||||
|
||||
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
|
||||
|
||||
class DocumentCategoriesResponse(BaseModel):
|
||||
"""Response for available document categories."""
|
||||
|
||||
categories: list[str] = Field(..., description="List of available categories")
|
||||
total: int = Field(..., ge=0, description="Total number of categories")
|
||||
46
packages/backend/backend/web/schemas/admin/enums.py
Normal file
46
packages/backend/backend/web/schemas/admin/enums.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Admin API Enums."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Document status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
AUTO_LABELING = "auto_labeling"
|
||||
LABELED = "labeled"
|
||||
EXPORTED = "exported"
|
||||
|
||||
|
||||
class AutoLabelStatus(str, Enum):
|
||||
"""Auto-labeling status enum."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training task status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
SCHEDULED = "scheduled"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingType(str, Enum):
|
||||
"""Training task type enum."""
|
||||
|
||||
TRAIN = "train"
|
||||
FINETUNE = "finetune"
|
||||
|
||||
|
||||
class AnnotationSource(str, Enum):
|
||||
"""Annotation source enum."""
|
||||
|
||||
MANUAL = "manual"
|
||||
AUTO = "auto"
|
||||
IMPORTED = "imported"
|
||||
95
packages/backend/backend/web/schemas/admin/models.py
Normal file
95
packages/backend/backend/web/schemas/admin/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Admin Model Version Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelVersionCreateRequest(BaseModel):
|
||||
"""Request to create a model version."""
|
||||
|
||||
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Model name")
|
||||
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
document_count: int = Field(0, ge=0, description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
|
||||
|
||||
class ModelVersionUpdateRequest(BaseModel):
|
||||
"""Request to update a model version."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
status: str | None = Field(None, description="Status (inactive, archived)")
|
||||
|
||||
|
||||
class ModelVersionItem(BaseModel):
|
||||
"""Model version in list view."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class ModelVersionListResponse(BaseModel):
|
||||
"""Paginated model version list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model versions")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
|
||||
|
||||
|
||||
class ModelVersionDetailResponse(BaseModel):
|
||||
"""Detailed model version info."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, description="Precision")
|
||||
metrics_recall: float | None = Field(None, description="Recall")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ModelVersionResponse(BaseModel):
|
||||
"""Response for model version operation."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
status: str = Field(..., description="Model status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class ActiveModelResponse(BaseModel):
|
||||
"""Response for active model query."""
|
||||
|
||||
has_active_model: bool = Field(..., description="Whether an active model exists")
|
||||
model: ModelVersionItem | None = Field(None, description="Active model if exists")
|
||||
219
packages/backend/backend/web/schemas/admin/training.py
Normal file
219
packages/backend/backend/web/schemas/admin/training.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Admin Training Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .augmentation import AugmentationConfigSchema
|
||||
from .enums import TrainingStatus, TrainingType
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
|
||||
base_model_version_id: str | None = Field(
|
||||
default=None,
|
||||
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
|
||||
)
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
|
||||
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
|
||||
project_name: str = Field(
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
# Data augmentation settings
|
||||
augmentation: AugmentationConfigSchema | None = Field(
|
||||
default=None,
|
||||
description="Augmentation configuration. If provided, augments dataset before training.",
|
||||
)
|
||||
augmentation_multiplier: int = Field(
|
||||
default=2,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Number of augmented copies per original image",
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Task name")
|
||||
description: str | None = Field(None, max_length=1000, description="Description")
|
||||
task_type: TrainingType = Field(
|
||||
default=TrainingType.TRAIN, description="Task type"
|
||||
)
|
||||
config: TrainingConfig = Field(
|
||||
default_factory=TrainingConfig, description="Training configuration"
|
||||
)
|
||||
scheduled_at: datetime | None = Field(
|
||||
None, description="Scheduled execution time"
|
||||
)
|
||||
cron_expression: str | None = Field(
|
||||
None, max_length=50, description="Cron expression for recurring tasks"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskItem(BaseModel):
|
||||
"""Single training task in list."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskListResponse(BaseModel):
|
||||
"""Response for training task list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total tasks")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
|
||||
|
||||
|
||||
class TrainingTaskDetailResponse(BaseModel):
|
||||
"""Response for training task detail."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
cron_expression: str | None = Field(None, description="Cron expression")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
error_message: str | None = Field(None, description="Error message")
|
||||
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
|
||||
model_path: str | None = Field(None, description="Trained model path")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskResponse(BaseModel):
|
||||
"""Response for training task operation."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingLogItem(BaseModel):
|
||||
"""Single training log entry."""
|
||||
|
||||
level: str = Field(..., description="Log level")
|
||||
message: str = Field(..., description="Log message")
|
||||
details: dict[str, Any] | None = Field(None, description="Additional details")
|
||||
created_at: datetime = Field(..., description="Timestamp")
|
||||
|
||||
|
||||
class TrainingLogsResponse(BaseModel):
|
||||
"""Response for training logs."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""Request to export annotations."""
|
||||
|
||||
format: str = Field(
|
||||
default="yolo", description="Export format (yolo, coco, voc)"
|
||||
)
|
||||
include_images: bool = Field(
|
||||
default=True, description="Include images in export"
|
||||
)
|
||||
split_ratio: float = Field(
|
||||
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
|
||||
)
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""Response for export operation."""
|
||||
|
||||
status: str = Field(..., description="Export status")
|
||||
export_path: str = Field(..., description="Path to exported dataset")
|
||||
total_images: int = Field(..., ge=0, description="Total images exported")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
train_count: int = Field(..., ge=0, description="Training set count")
|
||||
val_count: int = Field(..., ge=0, description="Validation set count")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingDocumentItem(BaseModel):
|
||||
"""Document item for training page."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Filename")
|
||||
annotation_count: int = Field(..., ge=0, description="Total annotations")
|
||||
annotation_sources: dict[str, int] = Field(
|
||||
..., description="Annotation counts by source (manual, auto)"
|
||||
)
|
||||
used_in_training: list[str] = Field(
|
||||
default_factory=list, description="List of training task IDs that used this document"
|
||||
)
|
||||
last_modified: datetime = Field(..., description="Last modification time")
|
||||
|
||||
|
||||
class TrainingDocumentsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/documents."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total document count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
documents: list[TrainingDocumentItem] = Field(
|
||||
default_factory=list, description="Documents available for training"
|
||||
)
|
||||
|
||||
|
||||
class ModelMetrics(BaseModel):
|
||||
"""Training model metrics."""
|
||||
|
||||
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
|
||||
|
||||
class TrainingModelItem(BaseModel):
|
||||
"""Trained model item for model list."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: TrainingStatus = Field(..., description="Training status")
|
||||
document_count: int = Field(..., ge=0, description="Documents used in training")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
metrics: ModelMetrics = Field(..., description="Model metrics")
|
||||
model_path: str | None = Field(None, description="Path to model weights")
|
||||
download_url: str | None = Field(None, description="Download URL for model")
|
||||
|
||||
|
||||
class TrainingModelsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/models."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
models: list[TrainingModelItem] = Field(
|
||||
default_factory=list, description="Trained models"
|
||||
)
|
||||
|
||||
|
||||
class TrainingHistoryItem(BaseModel):
|
||||
"""Training history for a document."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
trained_at: datetime = Field(..., description="Training timestamp")
|
||||
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")
|
||||
15
packages/backend/backend/web/schemas/common.py
Normal file
15
packages/backend/backend/web/schemas/common.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Common Schemas
|
||||
|
||||
Shared Pydantic models used across multiple API modules.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
status: str = Field(default="error", description="Error status")
|
||||
message: str = Field(..., description="Error message")
|
||||
detail: str | None = Field(None, description="Detailed error information")
|
||||
196
packages/backend/backend/web/schemas/inference.py
Normal file
196
packages/backend/backend/web/schemas/inference.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
API Request/Response Schemas
|
||||
|
||||
Pydantic models for API validation and serialization.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enums
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AsyncStatus(str, Enum):
|
||||
"""Async request status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sync API Schemas (existing)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Single detection result."""
|
||||
|
||||
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
|
||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||
|
||||
|
||||
class ExtractedField(BaseModel):
|
||||
"""Extracted and normalized field value."""
|
||||
|
||||
field_name: str = Field(..., description="Field name")
|
||||
value: str | None = Field(None, description="Extracted value")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
|
||||
is_valid: bool = Field(True, description="Whether the value passed validation")
|
||||
|
||||
|
||||
class InferenceResult(BaseModel):
|
||||
"""Complete inference result for a document."""
|
||||
|
||||
document_id: str = Field(..., description="Document identifier")
|
||||
success: bool = Field(..., description="Whether inference succeeded")
|
||||
document_type: str = Field(
|
||||
default="invoice", description="Document type: 'invoice' or 'letter'"
|
||||
)
|
||||
fields: dict[str, str | None] = Field(
|
||||
default_factory=dict, description="Extracted field values"
|
||||
)
|
||||
confidence: dict[str, float] = Field(
|
||||
default_factory=dict, description="Confidence scores per field"
|
||||
)
|
||||
detections: list[DetectionResult] = Field(
|
||||
default_factory=list, description="Raw YOLO detections"
|
||||
)
|
||||
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
|
||||
visualization_url: str | None = Field(
|
||||
None, description="URL to visualization image"
|
||||
)
|
||||
errors: list[str] = Field(default_factory=list, description="Error messages")
|
||||
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
"""API response for inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status: success or error")
|
||||
message: str = Field(..., description="Response message")
|
||||
result: InferenceResult | None = Field(None, description="Inference result")
|
||||
|
||||
|
||||
class BatchInferenceResponse(BaseModel):
|
||||
"""API response for batch inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status")
|
||||
message: str = Field(..., description="Response message")
|
||||
total: int = Field(..., description="Total documents processed")
|
||||
successful: int = Field(..., description="Number of successful extractions")
|
||||
results: list[InferenceResult] = Field(
|
||||
default_factory=list, description="Individual results"
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Service status")
|
||||
model_loaded: bool = Field(..., description="Whether model is loaded")
|
||||
gpu_available: bool = Field(..., description="Whether GPU is available")
|
||||
version: str = Field(..., description="API version")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
status: str = Field(default="error", description="Error status")
|
||||
message: str = Field(..., description="Error message")
|
||||
detail: str | None = Field(None, description="Detailed error information")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Async API Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AsyncSubmitResponse(BaseModel):
|
||||
"""Response for async submit endpoint."""
|
||||
|
||||
status: str = Field(default="accepted", description="Response status")
|
||||
message: str = Field(..., description="Response message")
|
||||
request_id: str = Field(..., description="Unique request identifier (UUID)")
|
||||
estimated_wait_seconds: int = Field(
|
||||
..., ge=0, description="Estimated wait time in seconds"
|
||||
)
|
||||
poll_url: str = Field(..., description="URL to poll for status updates")
|
||||
|
||||
|
||||
class AsyncStatusResponse(BaseModel):
|
||||
"""Response for async status endpoint."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Current processing status")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
created_at: datetime = Field(..., description="Request creation timestamp")
|
||||
started_at: datetime | None = Field(
|
||||
None, description="Processing start timestamp"
|
||||
)
|
||||
completed_at: datetime | None = Field(
|
||||
None, description="Processing completion timestamp"
|
||||
)
|
||||
position_in_queue: int | None = Field(
|
||||
None, description="Position in queue (for pending status)"
|
||||
)
|
||||
error_message: str | None = Field(
|
||||
None, description="Error message (for failed status)"
|
||||
)
|
||||
result_url: str | None = Field(
|
||||
None, description="URL to fetch results (for completed status)"
|
||||
)
|
||||
|
||||
|
||||
class AsyncResultResponse(BaseModel):
|
||||
"""Response for async result endpoint."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Processing status")
|
||||
processing_time_ms: float = Field(
|
||||
..., ge=0, description="Total processing time in milliseconds"
|
||||
)
|
||||
result: InferenceResult | None = Field(
|
||||
None, description="Extraction result (when completed)"
|
||||
)
|
||||
visualization_url: str | None = Field(
|
||||
None, description="URL to visualization image"
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestItem(BaseModel):
|
||||
"""Single item in async requests list."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Current processing status")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
created_at: datetime = Field(..., description="Request creation timestamp")
|
||||
completed_at: datetime | None = Field(
|
||||
None, description="Processing completion timestamp"
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestsListResponse(BaseModel):
|
||||
"""Response for async requests list endpoint."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total number of requests")
|
||||
limit: int = Field(..., ge=1, description="Maximum items per page")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
requests: list[AsyncRequestItem] = Field(
|
||||
default_factory=list, description="List of requests"
|
||||
)
|
||||
|
||||
|
||||
class RateLimitInfo(BaseModel):
|
||||
"""Rate limit information (included in headers)."""
|
||||
|
||||
limit: int = Field(..., description="Maximum requests per minute")
|
||||
remaining: int = Field(..., description="Remaining requests in current window")
|
||||
reset_at: datetime = Field(..., description="Time when limit resets")
|
||||
13
packages/backend/backend/web/schemas/labeling.py
Normal file
13
packages/backend/backend/web/schemas/labeling.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Labeling API Schemas
|
||||
|
||||
Pydantic models for pre-labeling and label validation endpoints.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreLabelResponse(BaseModel):
|
||||
"""API response for pre-label endpoint."""
|
||||
|
||||
document_id: str = Field(..., description="Document identifier for retrieving results")
|
||||
18
packages/backend/backend/web/services/__init__.py
Normal file
18
packages/backend/backend/web/services/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Business Logic Services
|
||||
|
||||
Service layer for processing requests and orchestrating data operations.
|
||||
"""
|
||||
|
||||
from backend.web.services.autolabel import AutoLabelService, get_auto_label_service
|
||||
from backend.web.services.inference import InferenceService
|
||||
from backend.web.services.async_processing import AsyncProcessingService
|
||||
from backend.web.services.batch_upload import BatchUploadService
|
||||
|
||||
__all__ = [
|
||||
"AutoLabelService",
|
||||
"get_auto_label_service",
|
||||
"InferenceService",
|
||||
"AsyncProcessingService",
|
||||
"BatchUploadService",
|
||||
]
|
||||
386
packages/backend/backend/web/services/async_processing.py
Normal file
386
packages/backend/backend/web/services/async_processing.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Async Processing Service
|
||||
|
||||
Manages async request lifecycle and background processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from threading import Event, Thread
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.data.async_request_db import AsyncRequestDB
|
||||
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from backend.web.core.rate_limiter import RateLimiter
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.web.config import AsyncConfig, StorageConfig
|
||||
from backend.web.services.inference import InferenceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncSubmitResult:
|
||||
"""Result from async submit operation."""
|
||||
|
||||
success: bool
|
||||
request_id: str | None = None
|
||||
estimated_wait_seconds: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AsyncProcessingService:
|
||||
"""
|
||||
Manages async request lifecycle and processing.
|
||||
|
||||
Coordinates between:
|
||||
- HTTP endpoints (submit/status/result)
|
||||
- Background task queue
|
||||
- Database storage
|
||||
- Rate limiting
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inference_service: "InferenceService",
|
||||
db: AsyncRequestDB,
|
||||
queue: AsyncTaskQueue,
|
||||
rate_limiter: RateLimiter,
|
||||
async_config: "AsyncConfig",
|
||||
storage_config: "StorageConfig",
|
||||
) -> None:
|
||||
self._inference = inference_service
|
||||
self._db = db
|
||||
self._queue = queue
|
||||
self._rate_limiter = rate_limiter
|
||||
self._async_config = async_config
|
||||
self._storage_config = storage_config
|
||||
|
||||
# Cleanup thread
|
||||
self._cleanup_stop_event = Event()
|
||||
self._cleanup_thread: Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the async processing service."""
|
||||
# Start the task queue with our handler
|
||||
self._queue.start(self._process_task)
|
||||
|
||||
# Start cleanup thread
|
||||
self._cleanup_stop_event.clear()
|
||||
self._cleanup_thread = Thread(
|
||||
target=self._cleanup_loop,
|
||||
name="async-cleanup",
|
||||
daemon=True,
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
logger.info("AsyncProcessingService started")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
"""Stop the async processing service."""
|
||||
# Stop cleanup thread
|
||||
self._cleanup_stop_event.set()
|
||||
if self._cleanup_thread and self._cleanup_thread.is_alive():
|
||||
self._cleanup_thread.join(timeout=5.0)
|
||||
|
||||
# Stop task queue
|
||||
self._queue.stop(timeout=timeout)
|
||||
logger.info("AsyncProcessingService stopped")
|
||||
|
||||
def submit_request(
|
||||
self,
|
||||
api_key: str,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
) -> AsyncSubmitResult:
|
||||
"""
|
||||
Submit a new async processing request.
|
||||
|
||||
Args:
|
||||
api_key: API key for the request
|
||||
file_content: File content as bytes
|
||||
filename: Original filename
|
||||
content_type: File content type
|
||||
|
||||
Returns:
|
||||
AsyncSubmitResult with request_id and status
|
||||
"""
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Save file to temp storage
|
||||
file_path = self._save_upload(request_id, filename, file_content)
|
||||
file_size = len(file_content)
|
||||
|
||||
try:
|
||||
# Calculate expiration
|
||||
expires_at = datetime.utcnow() + timedelta(
|
||||
days=self._async_config.result_retention_days
|
||||
)
|
||||
|
||||
# Create database record
|
||||
self._db.create_request(
|
||||
api_key=api_key,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
expires_at=expires_at,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Record rate limit event
|
||||
self._rate_limiter.record_request(api_key)
|
||||
|
||||
# Create and queue task
|
||||
task = AsyncTask(
|
||||
request_id=request_id,
|
||||
api_key=api_key,
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
if not self._queue.submit(task):
|
||||
# Queue is full
|
||||
self._db.update_status(
|
||||
request_id,
|
||||
"failed",
|
||||
error_message="Processing queue is full",
|
||||
)
|
||||
# Cleanup file
|
||||
file_path.unlink(missing_ok=True)
|
||||
return AsyncSubmitResult(
|
||||
success=False,
|
||||
request_id=request_id,
|
||||
error="Processing queue is full. Please try again later.",
|
||||
)
|
||||
|
||||
# Estimate wait time
|
||||
estimated_wait = self._estimate_wait()
|
||||
|
||||
return AsyncSubmitResult(
|
||||
success=True,
|
||||
request_id=request_id,
|
||||
estimated_wait_seconds=estimated_wait,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit request: {e}", exc_info=True)
|
||||
# Cleanup file on error
|
||||
file_path.unlink(missing_ok=True)
|
||||
return AsyncSubmitResult(
|
||||
success=False,
|
||||
# Return generic error message to avoid leaking implementation details
|
||||
error="Failed to process request. Please try again later.",
|
||||
)
|
||||
|
||||
# Allowed file extensions whitelist
|
||||
ALLOWED_EXTENSIONS = frozenset({".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"})
|
||||
|
||||
def _save_upload(
|
||||
self,
|
||||
request_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> Path:
|
||||
"""Save uploaded file to temp storage using StorageHelper."""
|
||||
# Extract extension from filename
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
# Validate extension: must be alphanumeric only (e.g., .pdf, .png)
|
||||
if not ext or not re.match(r'^\.[a-z0-9]+$', ext):
|
||||
ext = ".pdf"
|
||||
|
||||
# Validate against whitelist
|
||||
if ext not in self.ALLOWED_EXTENSIONS:
|
||||
ext = ".pdf"
|
||||
|
||||
# Get upload directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
|
||||
# Build file path - request_id is a UUID so it's safe
|
||||
file_path = upload_dir / f"{request_id}{ext}"
|
||||
|
||||
# Defense in depth: ensure path is within upload_dir
|
||||
if not file_path.resolve().is_relative_to(upload_dir.resolve()):
|
||||
raise ValueError("Invalid file path detected")
|
||||
|
||||
file_path.write_bytes(content)
|
||||
|
||||
return file_path
|
||||
|
||||
def _process_task(self, task: AsyncTask) -> None:
|
||||
"""
|
||||
Process a single task (called by worker thread).
|
||||
|
||||
This method is called by the AsyncTaskQueue worker threads.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
self._db.update_status(task.request_id, "processing")
|
||||
|
||||
# Ensure file exists
|
||||
if not task.file_path.exists():
|
||||
raise FileNotFoundError(f"Upload file not found: {task.file_path}")
|
||||
|
||||
# Run inference based on file type
|
||||
file_ext = task.file_path.suffix.lower()
|
||||
if file_ext == ".pdf":
|
||||
result = self._inference.process_pdf(
|
||||
task.file_path,
|
||||
document_id=task.request_id[:8],
|
||||
)
|
||||
else:
|
||||
result = self._inference.process_image(
|
||||
task.file_path,
|
||||
document_id=task.request_id[:8],
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Prepare result for storage
|
||||
result_data = {
|
||||
"document_id": result.document_id,
|
||||
"success": result.success,
|
||||
"document_type": result.document_type,
|
||||
"fields": result.fields,
|
||||
"confidence": result.confidence,
|
||||
"detections": result.detections,
|
||||
"errors": result.errors,
|
||||
}
|
||||
|
||||
# Get visualization path as string
|
||||
viz_path = None
|
||||
if result.visualization_path:
|
||||
viz_path = str(result.visualization_path.name)
|
||||
|
||||
# Store result in database
|
||||
self._db.complete_request(
|
||||
request_id=task.request_id,
|
||||
document_id=result.document_id,
|
||||
result=result_data,
|
||||
processing_time_ms=processing_time_ms,
|
||||
visualization_path=viz_path,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Task {task.request_id} completed successfully "
|
||||
f"in {processing_time_ms:.0f}ms"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task {task.request_id} failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
self._db.update_status(
|
||||
task.request_id,
|
||||
"failed",
|
||||
error_message=str(e),
|
||||
increment_retry=True,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup uploaded file
|
||||
if task.file_path.exists():
|
||||
task.file_path.unlink(missing_ok=True)
|
||||
|
||||
def _estimate_wait(self) -> int:
|
||||
"""Estimate wait time based on queue depth."""
|
||||
queue_depth = self._queue.get_queue_depth()
|
||||
processing_count = self._queue.get_processing_count()
|
||||
total_pending = queue_depth + processing_count
|
||||
|
||||
# Estimate ~5 seconds per document
|
||||
avg_processing_time = 5
|
||||
return total_pending * avg_processing_time
|
||||
|
||||
def _cleanup_loop(self) -> None:
|
||||
"""Background cleanup loop."""
|
||||
logger.info("Cleanup thread started")
|
||||
cleanup_interval = self._async_config.cleanup_interval_hours * 3600
|
||||
|
||||
while not self._cleanup_stop_event.wait(timeout=cleanup_interval):
|
||||
try:
|
||||
self._run_cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Cleanup thread stopped")
|
||||
|
||||
def _run_cleanup(self) -> None:
|
||||
"""Run cleanup operations."""
|
||||
logger.info("Running cleanup...")
|
||||
|
||||
# Delete expired requests
|
||||
deleted_requests = self._db.delete_expired_requests()
|
||||
|
||||
# Reset stale processing requests
|
||||
reset_count = self._db.reset_stale_processing_requests(
|
||||
stale_minutes=self._async_config.task_timeout_seconds // 60,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Cleanup old rate limit events
|
||||
deleted_events = self._db.cleanup_old_rate_limit_events(hours=1)
|
||||
|
||||
# Cleanup old poll timestamps
|
||||
cleaned_polls = self._rate_limiter.cleanup_poll_timestamps()
|
||||
|
||||
# Cleanup rate limiter request windows
|
||||
self._rate_limiter.cleanup_request_windows()
|
||||
|
||||
# Cleanup orphaned upload files
|
||||
orphan_count = self._cleanup_orphan_files()
|
||||
|
||||
logger.info(
|
||||
f"Cleanup complete: {deleted_requests} expired requests, "
|
||||
f"{reset_count} stale requests reset, "
|
||||
f"{deleted_events} rate limit events, "
|
||||
f"{cleaned_polls} poll timestamps, "
|
||||
f"{orphan_count} orphan files"
|
||||
)
|
||||
|
||||
def _cleanup_orphan_files(self) -> int:
|
||||
"""Clean up upload files that don't have matching requests."""
|
||||
storage = get_storage_helper()
|
||||
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||
if upload_dir is None or not upload_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
# Files older than 1 hour without matching request are considered orphans
|
||||
cutoff = time.time() - 3600
|
||||
|
||||
for file_path in upload_dir.iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Check if file is old enough
|
||||
if file_path.stat().st_mtime > cutoff:
|
||||
continue
|
||||
|
||||
# Extract request_id from filename
|
||||
request_id = file_path.stem
|
||||
|
||||
# Check if request exists in database
|
||||
request = self._db.get_request(request_id)
|
||||
if request is None:
|
||||
file_path.unlink(missing_ok=True)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
322
packages/backend/backend/web/services/augmentation_service.py
Normal file
322
packages/backend/backend/web/services/augmentation_service.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Augmentation service for handling augmentation operations."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from backend.data.repositories import DocumentRepository, DatasetRepository
|
||||
from backend.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
)
|
||||
|
||||
# Constants
|
||||
PREVIEW_MAX_SIZE = 800
|
||||
PREVIEW_SEED = 42
|
||||
UUID_PATTERN = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AugmentationService:
|
||||
"""Service for augmentation operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
dataset_repo: DatasetRepository | None = None,
|
||||
) -> None:
|
||||
"""Initialize service with repository connections."""
|
||||
self.doc_repo = doc_repo or DocumentRepository()
|
||||
self.dataset_repo = dataset_repo or DatasetRepository()
|
||||
|
||||
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
|
||||
"""
|
||||
Validate UUID format to prevent path traversal.
|
||||
|
||||
Args:
|
||||
value: Value to validate.
|
||||
field_name: Field name for error message.
|
||||
|
||||
Raises:
|
||||
HTTPException: If value is not a valid UUID.
|
||||
"""
|
||||
if not UUID_PATTERN.match(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {field_name} format: {value}",
|
||||
)
|
||||
|
||||
async def preview_single(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
augmentation_type: str,
|
||||
params: dict[str, Any],
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
augmentation_type: Name of augmentation to apply.
|
||||
params: Override parameters.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document not found or augmentation invalid.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
|
||||
|
||||
# Validate augmentation type
|
||||
if augmentation_type not in AUGMENTATION_REGISTRY:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown augmentation type: {augmentation_type}. "
|
||||
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
|
||||
)
|
||||
|
||||
# Get document and load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Create config with only this augmentation enabled
|
||||
config_kwargs = {
|
||||
augmentation_type: AugmentationParams(
|
||||
enabled=True,
|
||||
probability=1.0, # Always apply for preview
|
||||
params=params,
|
||||
),
|
||||
"seed": PREVIEW_SEED, # Deterministic preview
|
||||
}
|
||||
config = AugmentationConfig(**config_kwargs)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=params,
|
||||
)
|
||||
|
||||
async def preview_config(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
config: AugmentationConfigSchema,
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview full augmentation config on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
config: Full augmentation configuration.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Convert Pydantic model to internal config
|
||||
config_dict = config.model_dump()
|
||||
internal_config = AugmentationConfig.from_dict(config_dict)
|
||||
pipeline = AugmentationPipeline(internal_config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=config_dict,
|
||||
)
|
||||
|
||||
async def create_augmented_dataset(
|
||||
self,
|
||||
source_dataset_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
output_name: str,
|
||||
multiplier: int,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
Args:
|
||||
source_dataset_id: Source dataset UUID.
|
||||
config: Augmentation configuration.
|
||||
output_name: Name for the new dataset.
|
||||
multiplier: Number of augmented copies per image.
|
||||
|
||||
Returns:
|
||||
Batch response with task ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If source dataset not found.
|
||||
"""
|
||||
# Validate source dataset exists
|
||||
try:
|
||||
source_dataset = self.dataset_repo.get(source_dataset_id)
|
||||
if source_dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
) from e
|
||||
|
||||
# Create task ID for background processing
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Estimate total images
|
||||
estimated_images = (
|
||||
source_dataset.total_images * multiplier
|
||||
if hasattr(source_dataset, "total_images")
|
||||
else 0
|
||||
)
|
||||
|
||||
# TODO: Queue background task for actual augmentation
|
||||
# For now, return pending status
|
||||
|
||||
return AugmentationBatchResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message=f"Augmentation task queued for dataset '{output_name}'",
|
||||
estimated_images=estimated_images,
|
||||
)
|
||||
|
||||
async def list_augmented_datasets(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""
|
||||
List augmented datasets.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of datasets to return.
|
||||
offset: Number of datasets to skip.
|
||||
|
||||
Returns:
|
||||
List response with datasets.
|
||||
"""
|
||||
# TODO: Implement actual database query for augmented datasets
|
||||
# For now, return empty list
|
||||
|
||||
return AugmentedDatasetListResponse(
|
||||
total=0,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[],
|
||||
)
|
||||
|
||||
async def _load_document_page(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load a document page as numpy array.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image as numpy array (H, W, C) with dtype uint8.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document or page not found.
|
||||
"""
|
||||
# Validate document_id format to prevent path traversal
|
||||
self._validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document from database
|
||||
try:
|
||||
document = self.doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
) from e
|
||||
|
||||
# Get image path for page
|
||||
if hasattr(document, "images_dir"):
|
||||
images_dir = Path(document.images_dir)
|
||||
else:
|
||||
# Fallback to constructed path
|
||||
from backend.web.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
|
||||
|
||||
# Find image for page
|
||||
page_idx = page - 1 # Convert to 0-indexed
|
||||
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
||||
|
||||
if page_idx >= len(image_files):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page} not found for document {document_id}",
|
||||
)
|
||||
|
||||
# Load image
|
||||
image_path = image_files[page_idx]
|
||||
pil_image = Image.open(image_path).convert("RGB")
|
||||
return np.array(pil_image)
|
||||
|
||||
def _image_to_data_url(self, image: np.ndarray) -> str:
|
||||
"""Convert numpy image to base64 data URL."""
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
# Resize for preview if too large
|
||||
max_size = PREVIEW_MAX_SIZE
|
||||
if max(pil_image.size) > max_size:
|
||||
ratio = max_size / max(pil_image.size)
|
||||
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return f"data:image/png;base64,{base64_data}"
|
||||
343
packages/backend/backend/web/services/autolabel.py
Normal file
343
packages/backend/backend/web/services/autolabel.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Admin Auto-Labeling Service
|
||||
|
||||
Uses FieldMatcher to automatically create annotations from field values.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from backend.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelService:
|
||||
"""Service for automatic document labeling using field matching."""
|
||||
|
||||
def __init__(self, ocr_engine: OCREngine | None = None):
|
||||
"""
|
||||
Initialize auto-label service.
|
||||
|
||||
Args:
|
||||
ocr_engine: OCR engine instance (creates one if not provided)
|
||||
"""
|
||||
self._ocr_engine = ocr_engine
|
||||
self._field_matcher = FieldMatcher()
|
||||
|
||||
@property
|
||||
def ocr_engine(self) -> OCREngine:
|
||||
"""Lazy initialization of OCR engine."""
|
||||
if self._ocr_engine is None:
|
||||
self._ocr_engine = OCREngine(lang="en")
|
||||
return self._ocr_engine
|
||||
|
||||
def auto_label_document(
|
||||
self,
|
||||
document_id: str,
|
||||
file_path: str,
|
||||
field_values: dict[str, str],
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
replace_existing: bool = False,
|
||||
skip_lock_check: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Auto-label a document using field matching.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID
|
||||
file_path: Path to document file
|
||||
field_values: Dict of field_name -> value to match
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
replace_existing: Whether to replace existing auto annotations
|
||||
skip_lock_check: Skip annotation lock check (for batch processing)
|
||||
|
||||
Returns:
|
||||
Dict with status and annotation count
|
||||
"""
|
||||
# Initialize repositories if not provided
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
try:
|
||||
# Get document info first
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise ValueError(f"Document not found: {document_id}")
|
||||
|
||||
# Check annotation lock unless explicitly skipped
|
||||
if not skip_lock_check:
|
||||
from datetime import datetime, timezone
|
||||
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||
if document.annotation_lock_until > datetime.now(timezone.utc):
|
||||
raise ValueError(
|
||||
f"Document is locked for annotation until {document.annotation_lock_until}. "
|
||||
"Auto-labeling skipped."
|
||||
)
|
||||
|
||||
# Update status to running
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
|
||||
# Delete existing auto annotations if requested
|
||||
if replace_existing:
|
||||
deleted = ann_repo.delete_for_document(
|
||||
document_id=document_id,
|
||||
source="auto",
|
||||
)
|
||||
logger.info(f"Deleted {deleted} existing auto annotations")
|
||||
|
||||
# Process document
|
||||
path = Path(file_path)
|
||||
annotations_created = 0
|
||||
|
||||
if path.suffix.lower() == ".pdf":
|
||||
# Process PDF (all pages)
|
||||
annotations_created = self._process_pdf(
|
||||
document_id, path, field_values, ann_repo
|
||||
)
|
||||
else:
|
||||
# Process single image
|
||||
annotations_created = self._process_image(
|
||||
document_id, path, field_values, ann_repo, page_number=1
|
||||
)
|
||||
|
||||
# Update document status
|
||||
status = "labeled" if annotations_created > 0 else "pending"
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status=status,
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"annotations_created": annotations_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-labeling failed for {document_id}: {e}")
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error=str(e),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"annotations_created": 0,
|
||||
}
|
||||
|
||||
def _process_pdf(
|
||||
self,
|
||||
document_id: str,
|
||||
pdf_path: Path,
|
||||
field_values: dict[str, str],
|
||||
ann_repo: AnnotationRepository,
|
||||
) -> int:
|
||||
"""Process PDF document and create annotations."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
import io
|
||||
|
||||
total_annotations = 0
|
||||
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
|
||||
# Convert to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Extract tokens
|
||||
tokens = self.ocr_engine.extract_from_image(
|
||||
image_array,
|
||||
page_no=page_no,
|
||||
)
|
||||
|
||||
# Find matches
|
||||
annotations = self._find_annotations(
|
||||
document_id,
|
||||
tokens,
|
||||
field_values,
|
||||
page_number=page_no + 1, # 1-indexed
|
||||
image_width=image_array.shape[1],
|
||||
image_height=image_array.shape[0],
|
||||
)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
ann_repo.create_batch(annotations)
|
||||
total_annotations += len(annotations)
|
||||
|
||||
return total_annotations
|
||||
|
||||
def _process_image(
|
||||
self,
|
||||
document_id: str,
|
||||
image_path: Path,
|
||||
field_values: dict[str, str],
|
||||
ann_repo: AnnotationRepository,
|
||||
page_number: int = 1,
|
||||
) -> int:
|
||||
"""Process single image and create annotations."""
|
||||
# Load image
|
||||
image = Image.open(image_path)
|
||||
image_array = np.array(image)
|
||||
|
||||
# Extract tokens
|
||||
tokens = self.ocr_engine.extract_from_image(
|
||||
image_array,
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
# Find matches
|
||||
annotations = self._find_annotations(
|
||||
document_id,
|
||||
tokens,
|
||||
field_values,
|
||||
page_number=page_number,
|
||||
image_width=image_array.shape[1],
|
||||
image_height=image_array.shape[0],
|
||||
)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
ann_repo.create_batch(annotations)
|
||||
|
||||
return len(annotations)
|
||||
|
||||
def _find_annotations(
|
||||
self,
|
||||
document_id: str,
|
||||
tokens: list[OCRToken],
|
||||
field_values: dict[str, str],
|
||||
page_number: int,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Find annotations for field values using token matching."""
|
||||
from shared.normalize import normalize_field
|
||||
|
||||
annotations = []
|
||||
|
||||
for field_name, value in field_values.items():
|
||||
if not value or not value.strip():
|
||||
continue
|
||||
|
||||
# Map field name to class ID
|
||||
class_id = self._get_class_id(field_name)
|
||||
if class_id is None:
|
||||
logger.warning(f"Unknown field name: {field_name}")
|
||||
continue
|
||||
|
||||
class_name = FIELD_CLASSES[class_id]
|
||||
|
||||
# Normalize value
|
||||
try:
|
||||
normalized_values = normalize_field(field_name, value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
|
||||
normalized_values = [value]
|
||||
|
||||
# Find matches
|
||||
matches = self._field_matcher.find_matches(
|
||||
tokens=tokens,
|
||||
field_name=field_name,
|
||||
normalized_values=normalized_values,
|
||||
page_no=page_number - 1, # 0-indexed for matcher
|
||||
)
|
||||
|
||||
# Take best match
|
||||
if matches:
|
||||
best_match = matches[0]
|
||||
bbox = best_match.bbox # (x0, y0, x1, y1)
|
||||
|
||||
# Calculate normalized coordinates (YOLO format)
|
||||
x_center = (bbox[0] + bbox[2]) / 2 / image_width
|
||||
y_center = (bbox[1] + bbox[3]) / 2 / image_height
|
||||
width = (bbox[2] - bbox[0]) / image_width
|
||||
height = (bbox[3] - bbox[1]) / image_height
|
||||
|
||||
# Pixel coordinates
|
||||
bbox_x = int(bbox[0])
|
||||
bbox_y = int(bbox[1])
|
||||
bbox_width = int(bbox[2] - bbox[0])
|
||||
bbox_height = int(bbox[3] - bbox[1])
|
||||
|
||||
annotations.append({
|
||||
"document_id": document_id,
|
||||
"page_number": page_number,
|
||||
"class_id": class_id,
|
||||
"class_name": class_name,
|
||||
"x_center": x_center,
|
||||
"y_center": y_center,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"bbox_x": bbox_x,
|
||||
"bbox_y": bbox_y,
|
||||
"bbox_width": bbox_width,
|
||||
"bbox_height": bbox_height,
|
||||
"text_value": best_match.matched_text,
|
||||
"confidence": best_match.score,
|
||||
"source": "auto",
|
||||
})
|
||||
|
||||
return annotations
|
||||
|
||||
def _get_class_id(self, field_name: str) -> int | None:
|
||||
"""Map field name to class ID."""
|
||||
# Direct match
|
||||
if field_name in FIELD_CLASS_IDS:
|
||||
return FIELD_CLASS_IDS[field_name]
|
||||
|
||||
# Handle alternative names
|
||||
name_mapping = {
|
||||
"InvoiceNumber": "invoice_number",
|
||||
"InvoiceDate": "invoice_date",
|
||||
"InvoiceDueDate": "invoice_due_date",
|
||||
"OCR": "ocr_number",
|
||||
"Bankgiro": "bankgiro",
|
||||
"Plusgiro": "plusgiro",
|
||||
"Amount": "amount",
|
||||
"supplier_organisation_number": "supplier_organisation_number",
|
||||
"PaymentLine": "payment_line",
|
||||
"customer_number": "customer_number",
|
||||
}
|
||||
|
||||
mapped_name = name_mapping.get(field_name)
|
||||
if mapped_name and mapped_name in FIELD_CLASS_IDS:
|
||||
return FIELD_CLASS_IDS[mapped_name]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global service instance
|
||||
_auto_label_service: AutoLabelService | None = None
|
||||
|
||||
|
||||
def get_auto_label_service() -> AutoLabelService:
|
||||
"""Get the auto-label service instance."""
|
||||
global _auto_label_service
|
||||
if _auto_label_service is None:
|
||||
_auto_label_service = AutoLabelService()
|
||||
return _auto_label_service
|
||||
|
||||
|
||||
def reset_auto_label_service() -> None:
|
||||
"""Reset the auto-label service (for testing)."""
|
||||
global _auto_label_service
|
||||
_auto_label_service = None
|
||||
548
packages/backend/backend/web/services/batch_upload.py
Normal file
548
packages/backend/backend/web/services/batch_upload.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Batch Upload Service
|
||||
|
||||
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.data.repositories import BatchUploadRepository
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Security limits
|
||||
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
|
||||
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
MAX_FILES_IN_ZIP = 1000
|
||||
|
||||
|
||||
class CSVRowData(BaseModel):
|
||||
"""Validated CSV row data with security checks."""
|
||||
|
||||
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
|
||||
invoice_number: str | None = Field(None, max_length=255)
|
||||
invoice_date: str | None = Field(None, max_length=50)
|
||||
invoice_due_date: str | None = Field(None, max_length=50)
|
||||
amount: str | None = Field(None, max_length=100)
|
||||
ocr: str | None = Field(None, max_length=100)
|
||||
bankgiro: str | None = Field(None, max_length=50)
|
||||
plusgiro: str | None = Field(None, max_length=50)
|
||||
customer_number: str | None = Field(None, max_length=255)
|
||||
supplier_organisation_number: str | None = Field(None, max_length=50)
|
||||
|
||||
@field_validator('*', mode='before')
|
||||
@classmethod
|
||||
def strip_whitespace(cls, v):
|
||||
"""Strip whitespace from all string fields."""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v
|
||||
|
||||
@field_validator('*', mode='before')
|
||||
@classmethod
|
||||
def reject_suspicious_patterns(cls, v):
|
||||
"""Reject values with suspicious characters."""
|
||||
if isinstance(v, str):
|
||||
# Reject SQL/shell metacharacters and newlines
|
||||
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
|
||||
if any(char in v for char in dangerous_chars):
|
||||
raise ValueError(f"Suspicious characters detected in value")
|
||||
return v
|
||||
|
||||
|
||||
class BatchUploadService:
|
||||
"""Service for handling batch uploads of documents via ZIP files."""
|
||||
|
||||
def __init__(self, batch_repo: BatchUploadRepository | None = None):
|
||||
"""Initialize the batch upload service.
|
||||
|
||||
Args:
|
||||
batch_repo: Batch upload repository (created if None)
|
||||
"""
|
||||
self.batch_repo = batch_repo or BatchUploadRepository()
|
||||
|
||||
def _safe_extract_filename(self, zip_path: str) -> str:
|
||||
"""Safely extract filename from ZIP path, preventing path traversal.
|
||||
|
||||
Args:
|
||||
zip_path: Path from ZIP file entry
|
||||
|
||||
Returns:
|
||||
Safe filename
|
||||
|
||||
Raises:
|
||||
ValueError: If path contains traversal attempts or is invalid
|
||||
"""
|
||||
# Reject absolute paths
|
||||
if zip_path.startswith('/') or zip_path.startswith('\\'):
|
||||
raise ValueError(f"Absolute path rejected: {zip_path}")
|
||||
|
||||
# Reject path traversal attempts
|
||||
if '..' in zip_path:
|
||||
raise ValueError(f"Path traversal rejected: {zip_path}")
|
||||
|
||||
# Reject Windows drive letters
|
||||
if len(zip_path) >= 2 and zip_path[1] == ':':
|
||||
raise ValueError(f"Windows path rejected: {zip_path}")
|
||||
|
||||
# Get only the basename
|
||||
safe_name = Path(zip_path).name
|
||||
if not safe_name or safe_name in ['.', '..']:
|
||||
raise ValueError(f"Invalid filename: {zip_path}")
|
||||
|
||||
# Validate filename doesn't contain suspicious characters
|
||||
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
|
||||
raise ValueError(f"Invalid characters in filename: {safe_name}")
|
||||
|
||||
return safe_name
|
||||
|
||||
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
|
||||
"""Validate ZIP file against Zip bomb and other attacks.
|
||||
|
||||
Args:
|
||||
zip_file: Opened ZIP file
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP file is unsafe
|
||||
"""
|
||||
total_uncompressed = 0
|
||||
file_count = 0
|
||||
|
||||
for zip_info in zip_file.infolist():
|
||||
file_count += 1
|
||||
|
||||
# Check file count limit
|
||||
if file_count > MAX_FILES_IN_ZIP:
|
||||
raise ValueError(
|
||||
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
|
||||
)
|
||||
|
||||
# Check individual file size
|
||||
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
|
||||
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
|
||||
raise ValueError(
|
||||
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
# Accumulate uncompressed size
|
||||
total_uncompressed += zip_info.file_size
|
||||
|
||||
# Check total uncompressed size (Zip bomb protection)
|
||||
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
|
||||
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
|
||||
raise ValueError(
|
||||
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
# Validate filename safety
|
||||
try:
|
||||
self._safe_extract_filename(zip_info.filename)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Rejecting malicious ZIP entry: {e}")
|
||||
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
|
||||
|
||||
def process_zip_upload(
|
||||
self,
|
||||
admin_token: str,
|
||||
zip_filename: str,
|
||||
zip_content: bytes,
|
||||
upload_source: str = "ui",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a ZIP file containing PDFs and optional CSV.
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
zip_filename: Name of the ZIP file
|
||||
zip_content: ZIP file content as bytes
|
||||
upload_source: Upload source (ui or api)
|
||||
|
||||
Returns:
|
||||
Dictionary with batch upload results
|
||||
"""
|
||||
batch = self.batch_repo.create(
|
||||
admin_token=admin_token,
|
||||
filename=zip_filename,
|
||||
file_size=len(zip_content),
|
||||
upload_source=upload_source,
|
||||
)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
|
||||
# Validate ZIP safety first
|
||||
self._validate_zip_safety(zip_file)
|
||||
|
||||
result = self._process_zip_contents(
|
||||
batch_id=batch.batch_id,
|
||||
admin_token=admin_token,
|
||||
zip_file=zip_file,
|
||||
)
|
||||
|
||||
# Update batch upload status
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status=result["status"],
|
||||
total_files=result["total_files"],
|
||||
processed_files=result["processed_files"],
|
||||
successful_files=result["successful_files"],
|
||||
failed_files=result["failed_files"],
|
||||
csv_filename=result.get("csv_filename"),
|
||||
csv_row_count=result.get("csv_row_count"),
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
**result,
|
||||
}
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Invalid ZIP file format",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "Invalid ZIP file format",
|
||||
}
|
||||
except ValueError as e:
|
||||
# Security validation errors
|
||||
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP file validation failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "ZIP file validation failed",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "Failed to process batch upload",
|
||||
}
|
||||
|
||||
def _process_zip_contents(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
admin_token: str,
|
||||
zip_file: zipfile.ZipFile,
|
||||
) -> dict[str, Any]:
|
||||
"""Process contents of ZIP file.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
zip_file: Opened ZIP file
|
||||
|
||||
Returns:
|
||||
Processing results dictionary
|
||||
"""
|
||||
# Extract file lists
|
||||
pdf_files = []
|
||||
csv_file = None
|
||||
csv_data = {}
|
||||
|
||||
for file_info in zip_file.filelist:
|
||||
if file_info.is_dir():
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use safe filename extraction
|
||||
filename = self._safe_extract_filename(file_info.filename)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Skipping invalid file: {e}")
|
||||
continue
|
||||
|
||||
if filename.lower().endswith('.pdf'):
|
||||
pdf_files.append(file_info)
|
||||
elif filename.lower().endswith('.csv'):
|
||||
if csv_file is None:
|
||||
csv_file = file_info
|
||||
# Parse CSV
|
||||
csv_data = self._parse_csv_file(zip_file, file_info)
|
||||
else:
|
||||
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
|
||||
|
||||
if not pdf_files:
|
||||
return {
|
||||
"status": "failed",
|
||||
"total_files": 0,
|
||||
"processed_files": 0,
|
||||
"successful_files": 0,
|
||||
"failed_files": 0,
|
||||
"error": "No PDF files found in ZIP",
|
||||
}
|
||||
|
||||
# Process each PDF file
|
||||
total_files = len(pdf_files)
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
for pdf_info in pdf_files:
|
||||
file_record = None
|
||||
|
||||
try:
|
||||
# Use safe filename extraction
|
||||
filename = self._safe_extract_filename(pdf_info.filename)
|
||||
|
||||
# Create batch upload file record
|
||||
file_record = self.batch_repo.create_file(
|
||||
batch_id=batch_id,
|
||||
filename=filename,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
# Get CSV data for this file if available
|
||||
document_id_base = Path(filename).stem
|
||||
csv_row_data = csv_data.get(document_id_base)
|
||||
|
||||
# Extract PDF content
|
||||
pdf_content = zip_file.read(pdf_info.filename)
|
||||
|
||||
# TODO: Save PDF file and create document
|
||||
# For now, just mark as completed
|
||||
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="completed",
|
||||
csv_row_data=csv_row_data,
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
successful_files += 1
|
||||
|
||||
except ValueError as e:
|
||||
# Path validation error
|
||||
logger.warning(f"Skipping invalid file: {e}")
|
||||
if file_record:
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid filename",
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
failed_files += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF: {e}", exc_info=True)
|
||||
if file_record:
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
failed_files += 1
|
||||
|
||||
# Determine overall status
|
||||
if failed_files == 0:
|
||||
status = "completed"
|
||||
elif successful_files == 0:
|
||||
status = "failed"
|
||||
else:
|
||||
status = "partial"
|
||||
|
||||
result = {
|
||||
"status": status,
|
||||
"total_files": total_files,
|
||||
"processed_files": total_files,
|
||||
"successful_files": successful_files,
|
||||
"failed_files": failed_files,
|
||||
}
|
||||
|
||||
if csv_file:
|
||||
result["csv_filename"] = Path(csv_file.filename).name
|
||||
result["csv_row_count"] = len(csv_data)
|
||||
|
||||
return result
|
||||
|
||||
def _parse_csv_file(
|
||||
self,
|
||||
zip_file: zipfile.ZipFile,
|
||||
csv_file_info: zipfile.ZipInfo,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Parse CSV file and extract field values with validation.
|
||||
|
||||
Args:
|
||||
zip_file: Opened ZIP file
|
||||
csv_file_info: CSV file info
|
||||
|
||||
Returns:
|
||||
Dictionary mapping DocumentId to validated field values
|
||||
"""
|
||||
# Try multiple encodings
|
||||
csv_bytes = zip_file.read(csv_file_info.filename)
|
||||
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
|
||||
csv_content = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
csv_content = csv_bytes.decode(encoding)
|
||||
logger.info(f"CSV decoded with {encoding}")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if csv_content is None:
|
||||
logger.error("Failed to decode CSV with any encoding")
|
||||
raise ValueError("Unable to decode CSV file")
|
||||
|
||||
csv_reader = csv.DictReader(io.StringIO(csv_content))
|
||||
result = {}
|
||||
|
||||
# Case-insensitive column mapping
|
||||
field_name_map = {
|
||||
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
|
||||
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
|
||||
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
|
||||
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
|
||||
'Amount': ['Amount', 'amount'],
|
||||
'OCR': ['OCR', 'ocr'],
|
||||
'Bankgiro': ['Bankgiro', 'bankgiro'],
|
||||
'Plusgiro': ['Plusgiro', 'plusgiro'],
|
||||
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
|
||||
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
|
||||
}
|
||||
|
||||
for row_num, row in enumerate(csv_reader, start=2):
|
||||
try:
|
||||
# Create case-insensitive lookup
|
||||
row_lower = {k.lower(): v for k, v in row.items()}
|
||||
|
||||
# Find DocumentId with case-insensitive matching
|
||||
document_id = None
|
||||
for variant in field_name_map['DocumentId']:
|
||||
if variant.lower() in row_lower:
|
||||
document_id = row_lower[variant.lower()]
|
||||
break
|
||||
|
||||
if not document_id:
|
||||
logger.warning(f"Row {row_num}: No DocumentId found")
|
||||
continue
|
||||
|
||||
# Validate using Pydantic model
|
||||
csv_row_dict = {'document_id': document_id}
|
||||
|
||||
# Map CSV field names to model attribute names
|
||||
csv_to_model_attr = {
|
||||
'InvoiceNumber': 'invoice_number',
|
||||
'InvoiceDate': 'invoice_date',
|
||||
'InvoiceDueDate': 'invoice_due_date',
|
||||
'Amount': 'amount',
|
||||
'OCR': 'ocr',
|
||||
'Bankgiro': 'bankgiro',
|
||||
'Plusgiro': 'plusgiro',
|
||||
'customer_number': 'customer_number',
|
||||
'supplier_organisation_number': 'supplier_organisation_number',
|
||||
}
|
||||
|
||||
for csv_field in field_name_map.keys():
|
||||
if csv_field == 'DocumentId':
|
||||
continue
|
||||
|
||||
model_attr = csv_to_model_attr.get(csv_field)
|
||||
if not model_attr:
|
||||
continue
|
||||
|
||||
for variant in field_name_map[csv_field]:
|
||||
if variant.lower() in row_lower and row_lower[variant.lower()]:
|
||||
csv_row_dict[model_attr] = row_lower[variant.lower()]
|
||||
break
|
||||
|
||||
# Validate
|
||||
validated_row = CSVRowData(**csv_row_dict)
|
||||
|
||||
# Extract only the fields we care about (map back to CSV field names)
|
||||
field_values = {}
|
||||
model_attr_to_csv = {
|
||||
'invoice_number': 'InvoiceNumber',
|
||||
'invoice_date': 'InvoiceDate',
|
||||
'invoice_due_date': 'InvoiceDueDate',
|
||||
'amount': 'Amount',
|
||||
'ocr': 'OCR',
|
||||
'bankgiro': 'Bankgiro',
|
||||
'plusgiro': 'Plusgiro',
|
||||
'customer_number': 'customer_number',
|
||||
'supplier_organisation_number': 'supplier_organisation_number',
|
||||
}
|
||||
|
||||
for model_attr, csv_field in model_attr_to_csv.items():
|
||||
value = getattr(validated_row, model_attr, None)
|
||||
if value and csv_field in CSV_TO_CLASS_MAPPING:
|
||||
field_values[csv_field] = value
|
||||
|
||||
if field_values:
|
||||
result[document_id] = field_values
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Row {row_num}: Validation error - {e}")
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
|
||||
"""Get batch upload status.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
|
||||
Returns:
|
||||
Batch status dictionary
|
||||
"""
|
||||
batch = self.batch_repo.get(UUID(batch_id))
|
||||
if not batch:
|
||||
return {
|
||||
"error": "Batch upload not found",
|
||||
}
|
||||
|
||||
files = self.batch_repo.get_files(batch.batch_id)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"filename": batch.filename,
|
||||
"status": batch.status,
|
||||
"total_files": batch.total_files,
|
||||
"processed_files": batch.processed_files,
|
||||
"successful_files": batch.successful_files,
|
||||
"failed_files": batch.failed_files,
|
||||
"csv_filename": batch.csv_filename,
|
||||
"csv_row_count": batch.csv_row_count,
|
||||
"error_message": batch.error_message,
|
||||
"created_at": batch.created_at.isoformat() if batch.created_at else None,
|
||||
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
||||
"files": [
|
||||
{
|
||||
"filename": f.filename,
|
||||
"status": f.status,
|
||||
"error_message": f.error_message,
|
||||
"annotation_count": f.annotation_count,
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
276
packages/backend/backend/web/services/dashboard_service.py
Normal file
276
packages/backend/backend/web/services/dashboard_service.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Dashboard Service
|
||||
|
||||
Business logic for dashboard statistics and activity aggregation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, exists, and_, or_
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import (
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
AnnotationHistory,
|
||||
TrainingTask,
|
||||
ModelVersion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Field class IDs for completeness calculation
|
||||
# Identifiers: invoice_number (0) or ocr_number (3)
|
||||
IDENTIFIER_CLASS_IDS = {0, 3}
|
||||
# Payment accounts: bankgiro (4) or plusgiro (5)
|
||||
PAYMENT_CLASS_IDS = {4, 5}
|
||||
|
||||
|
||||
def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool:
|
||||
"""Check if a document's annotations are complete.
|
||||
|
||||
A document is complete if it has:
|
||||
- At least one identifier field (invoice_number OR ocr_number)
|
||||
- At least one payment field (bankgiro OR plusgiro)
|
||||
|
||||
Args:
|
||||
annotations: List of annotation dicts with class_id
|
||||
|
||||
Returns:
|
||||
True if document has required fields
|
||||
"""
|
||||
class_ids = {ann.get("class_id") for ann in annotations}
|
||||
|
||||
has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS)
|
||||
has_payment = bool(class_ids & PAYMENT_CLASS_IDS)
|
||||
|
||||
return has_identifier and has_payment
|
||||
|
||||
|
||||
class DashboardStatsService:
|
||||
"""Service for computing dashboard statistics."""
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get dashboard statistics.
|
||||
|
||||
Returns:
|
||||
Dict with total_documents, annotation_complete, annotation_incomplete,
|
||||
pending, and completeness_rate
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
# Total documents
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(AdminDocument)
|
||||
).one()
|
||||
|
||||
# Pending documents (status in ['pending', 'auto_labeling'])
|
||||
pending = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(AdminDocument.status.in_(["pending", "auto_labeling"]))
|
||||
).one()
|
||||
|
||||
# Complete annotations: labeled + has identifier + has payment
|
||||
complete = self._count_complete(session)
|
||||
|
||||
# Incomplete: labeled but not complete
|
||||
labeled_count = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(AdminDocument.status == "labeled")
|
||||
).one()
|
||||
incomplete = labeled_count - complete
|
||||
|
||||
# Calculate completeness rate
|
||||
total_assessed = complete + incomplete
|
||||
completeness_rate = (
|
||||
round(complete / total_assessed * 100, 2)
|
||||
if total_assessed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_documents": total,
|
||||
"annotation_complete": complete,
|
||||
"annotation_incomplete": incomplete,
|
||||
"pending": pending,
|
||||
"completeness_rate": completeness_rate,
|
||||
}
|
||||
|
||||
def _count_complete(self, session) -> int:
|
||||
"""Count documents with complete annotations.
|
||||
|
||||
A document is complete if it:
|
||||
1. Has status = 'labeled'
|
||||
2. Has at least one identifier annotation (class_id 0 or 3)
|
||||
3. Has at least one payment annotation (class_id 4 or 5)
|
||||
"""
|
||||
# Subquery for documents with identifier
|
||||
has_identifier = exists(
|
||||
select(1)
|
||||
.select_from(AdminAnnotation)
|
||||
.where(
|
||||
and_(
|
||||
AdminAnnotation.document_id == AdminDocument.document_id,
|
||||
AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Subquery for documents with payment
|
||||
has_payment = exists(
|
||||
select(1)
|
||||
.select_from(AdminAnnotation)
|
||||
.where(
|
||||
and_(
|
||||
AdminAnnotation.document_id == AdminDocument.document_id,
|
||||
AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
count = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(
|
||||
and_(
|
||||
AdminDocument.status == "labeled",
|
||||
has_identifier,
|
||||
has_payment,
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class DashboardActivityService:
|
||||
"""Service for aggregating recent activities."""
|
||||
|
||||
def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""Get recent system activities.
|
||||
|
||||
Aggregates from:
|
||||
- Document uploads
|
||||
- Annotation modifications
|
||||
- Training completions/failures
|
||||
- Model activations
|
||||
|
||||
Args:
|
||||
limit: Maximum number of activities to return
|
||||
|
||||
Returns:
|
||||
List of activity dicts sorted by timestamp DESC
|
||||
"""
|
||||
activities = []
|
||||
|
||||
with get_session_context() as session:
|
||||
# Document uploads (recent 10)
|
||||
uploads = session.exec(
|
||||
select(AdminDocument)
|
||||
.order_by(AdminDocument.created_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for doc in uploads:
|
||||
activities.append({
|
||||
"type": "document_uploaded",
|
||||
"description": f"Uploaded {doc.filename}",
|
||||
"timestamp": doc.created_at,
|
||||
"metadata": {
|
||||
"document_id": str(doc.document_id),
|
||||
"filename": doc.filename,
|
||||
},
|
||||
})
|
||||
|
||||
# Annotation modifications (from history)
|
||||
modifications = session.exec(
|
||||
select(AnnotationHistory)
|
||||
.where(AnnotationHistory.action == "override")
|
||||
.order_by(AnnotationHistory.created_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for mod in modifications:
|
||||
# Get document filename
|
||||
doc = session.get(AdminDocument, mod.document_id)
|
||||
filename = doc.filename if doc else "Unknown"
|
||||
field_name = ""
|
||||
if mod.new_value and isinstance(mod.new_value, dict):
|
||||
field_name = mod.new_value.get("class_name", "")
|
||||
|
||||
activities.append({
|
||||
"type": "annotation_modified",
|
||||
"description": f"Modified {filename} {field_name}".strip(),
|
||||
"timestamp": mod.created_at,
|
||||
"metadata": {
|
||||
"annotation_id": str(mod.annotation_id),
|
||||
"document_id": str(mod.document_id),
|
||||
"field_name": field_name,
|
||||
},
|
||||
})
|
||||
|
||||
# Training completions and failures
|
||||
training_tasks = session.exec(
|
||||
select(TrainingTask)
|
||||
.where(TrainingTask.status.in_(["completed", "failed"]))
|
||||
.order_by(TrainingTask.updated_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for task in training_tasks:
|
||||
if task.updated_at is None:
|
||||
continue
|
||||
if task.status == "completed":
|
||||
# Use metrics_mAP field directly
|
||||
mAP = task.metrics_mAP or 0.0
|
||||
activities.append({
|
||||
"type": "training_completed",
|
||||
"description": f"Training complete: {task.name}, mAP {mAP:.1%}",
|
||||
"timestamp": task.updated_at,
|
||||
"metadata": {
|
||||
"task_id": str(task.task_id),
|
||||
"task_name": task.name,
|
||||
"mAP": mAP,
|
||||
},
|
||||
})
|
||||
else:
|
||||
activities.append({
|
||||
"type": "training_failed",
|
||||
"description": f"Training failed: {task.name}",
|
||||
"timestamp": task.updated_at,
|
||||
"metadata": {
|
||||
"task_id": str(task.task_id),
|
||||
"task_name": task.name,
|
||||
"error": task.error_message or "",
|
||||
},
|
||||
})
|
||||
|
||||
# Model activations
|
||||
model_versions = session.exec(
|
||||
select(ModelVersion)
|
||||
.where(ModelVersion.activated_at.is_not(None))
|
||||
.order_by(ModelVersion.activated_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for model in model_versions:
|
||||
if model.activated_at is None:
|
||||
continue
|
||||
activities.append({
|
||||
"type": "model_activated",
|
||||
"description": f"Activated model {model.version}",
|
||||
"timestamp": model.activated_at,
|
||||
"metadata": {
|
||||
"version_id": str(model.version_id),
|
||||
"version": model.version,
|
||||
},
|
||||
})
|
||||
|
||||
# Sort all activities by timestamp DESC and return top N
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
return activities[:limit]
|
||||
265
packages/backend/backend/web/services/dataset_builder.py
Normal file
265
packages/backend/backend/web/services/dataset_builder.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Dataset Builder Service
|
||||
|
||||
Creates training datasets by copying images from admin storage,
|
||||
generating YOLO label files, and splitting into train/val/test sets.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training datasets from admin documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets_repo,
|
||||
documents_repo,
|
||||
annotations_repo,
|
||||
base_dir: Path,
|
||||
):
|
||||
self._datasets_repo = datasets_repo
|
||||
self._documents_repo = documents_repo
|
||||
self._annotations_repo = annotations_repo
|
||||
self._base_dir = Path(base_dir)
|
||||
|
||||
def build_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
"""Build a complete YOLO dataset from document IDs.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset record.
|
||||
document_ids: List of document UUIDs to include.
|
||||
train_ratio: Fraction for training set.
|
||||
val_ratio: Fraction for validation set.
|
||||
seed: Random seed for reproducible splits.
|
||||
admin_images_dir: Root directory of admin images.
|
||||
|
||||
Returns:
|
||||
Summary dict with total_documents, total_images, total_annotations.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid documents found.
|
||||
"""
|
||||
try:
|
||||
return self._do_build(
|
||||
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
||||
)
|
||||
except Exception as e:
|
||||
self._datasets_repo.update_status(
|
||||
dataset_id=dataset_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
def _do_build(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
# 1. Fetch documents
|
||||
documents = self._documents_repo.get_by_ids(document_ids)
|
||||
if not documents:
|
||||
raise ValueError("No valid documents found for the given IDs")
|
||||
|
||||
# 2. Create directory structure
|
||||
dataset_dir = self._base_dir / dataset_id
|
||||
for split in ["train", "val", "test"]:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Group documents by group_key and assign splits
|
||||
doc_list = list(documents)
|
||||
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
|
||||
|
||||
# 4. Process each document
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
dataset_docs = []
|
||||
|
||||
for doc in doc_list:
|
||||
doc_id = str(doc.document_id)
|
||||
split = doc_splits[doc_id]
|
||||
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
|
||||
|
||||
# Group annotations by page
|
||||
page_annotations: dict[int, list] = {}
|
||||
for ann in annotations:
|
||||
page_annotations.setdefault(ann.page_number, []).append(ann)
|
||||
|
||||
doc_image_count = 0
|
||||
doc_ann_count = 0
|
||||
|
||||
# Copy images and write labels for each page
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
logger.warning("Image not found: %s", src_image)
|
||||
continue
|
||||
|
||||
dst_name = f"{doc_id}_page{page_num}"
|
||||
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
|
||||
shutil.copy2(src_image, dst_image)
|
||||
doc_image_count += 1
|
||||
|
||||
# Write YOLO label file
|
||||
page_anns = page_annotations.get(page_num, [])
|
||||
label_lines = []
|
||||
for ann in page_anns:
|
||||
label_lines.append(
|
||||
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
|
||||
f"{ann.width:.6f} {ann.height:.6f}"
|
||||
)
|
||||
doc_ann_count += 1
|
||||
|
||||
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
|
||||
label_path.write_text("\n".join(label_lines))
|
||||
|
||||
total_images += doc_image_count
|
||||
total_annotations += doc_ann_count
|
||||
|
||||
dataset_docs.append({
|
||||
"document_id": doc_id,
|
||||
"split": split,
|
||||
"page_count": doc_image_count,
|
||||
"annotation_count": doc_ann_count,
|
||||
})
|
||||
|
||||
# 5. Record document-split assignments in DB
|
||||
self._datasets_repo.add_documents(
|
||||
dataset_id=dataset_id,
|
||||
documents=dataset_docs,
|
||||
)
|
||||
|
||||
# 6. Generate data.yaml
|
||||
self._generate_data_yaml(dataset_dir)
|
||||
|
||||
# 7. Update dataset status
|
||||
self._datasets_repo.update_status(
|
||||
dataset_id=dataset_id,
|
||||
status="ready",
|
||||
total_documents=len(doc_list),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
dataset_path=str(dataset_dir),
|
||||
)
|
||||
|
||||
return {
|
||||
"total_documents": len(doc_list),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _assign_splits_by_group(
|
||||
self,
|
||||
documents: list,
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
) -> dict[str, str]:
|
||||
"""Assign splits based on group_key.
|
||||
|
||||
Logic:
|
||||
- Documents with same group_key stay together in the same split
|
||||
- Groups with only 1 document go directly to train
|
||||
- Groups with 2+ documents participate in shuffle & split
|
||||
|
||||
Args:
|
||||
documents: List of AdminDocument objects
|
||||
train_ratio: Fraction for training set
|
||||
val_ratio: Fraction for validation set
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Dict mapping document_id (str) -> split ("train"/"val"/"test")
|
||||
"""
|
||||
# Group documents by group_key
|
||||
# None/empty group_key treated as unique (each doc is its own group)
|
||||
groups: dict[str | None, list] = {}
|
||||
for doc in documents:
|
||||
key = doc.group_key if doc.group_key else None
|
||||
if key is None:
|
||||
# Treat each ungrouped doc as its own unique group
|
||||
# Use document_id as pseudo-key
|
||||
key = f"__ungrouped_{doc.document_id}"
|
||||
groups.setdefault(key, []).append(doc)
|
||||
|
||||
# Separate single-doc groups from multi-doc groups
|
||||
single_doc_groups: list[tuple[str | None, list]] = []
|
||||
multi_doc_groups: list[tuple[str | None, list]] = []
|
||||
|
||||
for key, docs in groups.items():
|
||||
if len(docs) == 1:
|
||||
single_doc_groups.append((key, docs))
|
||||
else:
|
||||
multi_doc_groups.append((key, docs))
|
||||
|
||||
# Initialize result mapping
|
||||
doc_splits: dict[str, str] = {}
|
||||
|
||||
# Combine all groups for splitting
|
||||
all_groups = single_doc_groups + multi_doc_groups
|
||||
|
||||
# Shuffle all groups and assign splits
|
||||
if all_groups:
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(all_groups)
|
||||
|
||||
n_groups = len(all_groups)
|
||||
n_train = max(1, round(n_groups * train_ratio))
|
||||
# Ensure at least 1 in val if we have more than 1 group
|
||||
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
|
||||
|
||||
for i, (_key, docs) in enumerate(all_groups):
|
||||
if i < n_train:
|
||||
split = "train"
|
||||
elif i < n_train + n_val:
|
||||
split = "val"
|
||||
else:
|
||||
split = "test"
|
||||
|
||||
for doc in docs:
|
||||
doc_splits[str(doc.document_id)] = split
|
||||
|
||||
logger.info(
|
||||
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
|
||||
len(all_groups),
|
||||
sum(1 for s in doc_splits.values() if s == "train"),
|
||||
sum(1 for s in doc_splits.values() if s == "val"),
|
||||
)
|
||||
|
||||
return doc_splits
|
||||
|
||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||
"""Generate YOLO data.yaml configuration file."""
|
||||
data = {
|
||||
"path": str(dataset_dir.absolute()),
|
||||
"train": "images/train",
|
||||
"val": "images/val",
|
||||
"test": "images/test",
|
||||
"nc": len(FIELD_CLASSES),
|
||||
"names": FIELD_CLASSES,
|
||||
}
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))
|
||||
550
packages/backend/backend/web/services/db_autolabel.py
Normal file
550
packages/backend/backend/web/services/db_autolabel.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""
|
||||
Database-based Auto-labeling Service
|
||||
|
||||
Processes documents with field values stored in the database (csv_field_values).
|
||||
Used by the pre-label API to create annotations from expected values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
from backend.data.admin_models import AdminDocument
|
||||
from backend.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from shared.data.db import DocumentDB
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize DocumentDB for saving match reports
|
||||
_document_db: DocumentDB | None = None
|
||||
|
||||
|
||||
def get_document_db() -> DocumentDB:
|
||||
"""Get or create DocumentDB instance with connection and tables initialized.
|
||||
|
||||
Follows the same pattern as CLI autolabel (src/cli/autolabel.py lines 370-373).
|
||||
"""
|
||||
global _document_db
|
||||
if _document_db is None:
|
||||
_document_db = DocumentDB()
|
||||
_document_db.connect()
|
||||
_document_db.create_tables() # Ensure tables exist
|
||||
logger.info("Connected to PostgreSQL DocumentDB for match reports")
|
||||
return _document_db
|
||||
|
||||
|
||||
def convert_csv_field_values_to_row_dict(
|
||||
document: AdminDocument,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert AdminDocument.csv_field_values to row_dict format for autolabel.
|
||||
|
||||
Args:
|
||||
document: AdminDocument with csv_field_values
|
||||
|
||||
Returns:
|
||||
Dictionary in row_dict format compatible with autolabel_tasks
|
||||
"""
|
||||
csv_values = document.csv_field_values or {}
|
||||
|
||||
# Build row_dict with DocumentId
|
||||
row_dict = {
|
||||
"DocumentId": str(document.document_id),
|
||||
}
|
||||
|
||||
# Map csv_field_values to row_dict format
|
||||
# csv_field_values uses keys like: InvoiceNumber, InvoiceDate, Amount, OCR, Bankgiro, etc.
|
||||
# row_dict uses same keys
|
||||
for key, value in csv_values.items():
|
||||
if value is not None and value != "":
|
||||
row_dict[key] = str(value)
|
||||
|
||||
return row_dict
|
||||
|
||||
|
||||
def get_pending_autolabel_documents(
|
||||
limit: int = 10,
|
||||
) -> list[AdminDocument]:
|
||||
"""
|
||||
Get documents pending auto-labeling.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import AdminDocument
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == "auto_labeling",
|
||||
AdminDocument.auto_label_status == "pending",
|
||||
).order_by(AdminDocument.created_at).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
|
||||
def process_document_autolabel(
|
||||
document: AdminDocument,
|
||||
output_dir: Path | None = None,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
min_confidence: float = 0.5,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a single document for auto-labeling using csv_field_values.
|
||||
|
||||
Args:
|
||||
document: AdminDocument with csv_field_values and file_path
|
||||
output_dir: Output directory for temp files
|
||||
dpi: Rendering DPI
|
||||
min_confidence: Minimum match confidence
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status and annotations
|
||||
"""
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
from shared.pdf import PDFDocument
|
||||
|
||||
# Initialize repositories if not provided
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
|
||||
# Get output directory from StorageHelper
|
||||
storage = get_storage_helper()
|
||||
if output_dir is None:
|
||||
output_dir = storage.get_autolabel_output_path()
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mark as processing
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if file exists
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# Convert csv_field_values to row_dict
|
||||
row_dict = convert_csv_field_values_to_row_dict(document)
|
||||
|
||||
if len(row_dict) <= 1: # Only has DocumentId
|
||||
raise ValueError("No field values to match")
|
||||
|
||||
# Determine PDF type (text or scanned)
|
||||
is_scanned = False
|
||||
with PDFDocument(file_path) as pdf_doc:
|
||||
# Check if first page has extractable text
|
||||
tokens = list(pdf_doc.extract_text_tokens(0))
|
||||
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
||||
|
||||
# Build task data
|
||||
# Use raw_pdfs base path for pdf_path
|
||||
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
raise ValueError("Storage not configured for local access")
|
||||
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
task_data = {
|
||||
"row_dict": row_dict,
|
||||
"pdf_path": str(pdf_path_for_report),
|
||||
"output_dir": str(output_dir),
|
||||
"dpi": dpi,
|
||||
"min_confidence": min_confidence,
|
||||
}
|
||||
|
||||
# Process based on PDF type
|
||||
if is_scanned:
|
||||
result = process_scanned_pdf(task_data)
|
||||
else:
|
||||
result = process_text_pdf(task_data)
|
||||
|
||||
# Save report to DocumentDB (same as CLI autolabel)
|
||||
if result.get("report"):
|
||||
try:
|
||||
doc_db = get_document_db()
|
||||
doc_db.save_document(result["report"])
|
||||
logger.info(f"Saved match report to DocumentDB for {document_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save report to DocumentDB: {e}")
|
||||
|
||||
# Save annotations to database
|
||||
if result.get("success") and result.get("report"):
|
||||
_save_annotations_to_db(
|
||||
ann_repo=ann_repo,
|
||||
document_id=document_id,
|
||||
report=result["report"],
|
||||
page_annotations=result.get("pages", []),
|
||||
dpi=dpi,
|
||||
)
|
||||
|
||||
# Mark as completed
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
else:
|
||||
# Mark as failed
|
||||
errors = result.get("report", {}).get("errors", ["Unknown error"])
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="; ".join(errors) if errors else "No annotations generated",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
|
||||
|
||||
# Mark as failed
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error=str(e),
|
||||
)
|
||||
|
||||
return {
|
||||
"doc_id": document_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def _save_annotations_to_db(
|
||||
ann_repo: AnnotationRepository,
|
||||
document_id: str,
|
||||
report: dict[str, Any],
|
||||
page_annotations: list[dict[str, Any]],
|
||||
dpi: int = 200,
|
||||
) -> int:
|
||||
"""
|
||||
Save generated annotations to database.
|
||||
|
||||
Args:
|
||||
ann_repo: Annotation repository instance
|
||||
document_id: Document ID
|
||||
report: AutoLabelReport as dict
|
||||
page_annotations: List of page annotation data
|
||||
dpi: DPI used for rendering images (for coordinate conversion)
|
||||
|
||||
Returns:
|
||||
Number of annotations saved
|
||||
"""
|
||||
from shared.fields import FIELD_CLASS_IDS
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
# Mapping from CSV field names to internal field names
|
||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||
"InvoiceNumber": "invoice_number",
|
||||
"InvoiceDate": "invoice_date",
|
||||
"InvoiceDueDate": "invoice_due_date",
|
||||
"OCR": "ocr_number",
|
||||
"Bankgiro": "bankgiro",
|
||||
"Plusgiro": "plusgiro",
|
||||
"Amount": "amount",
|
||||
"supplier_organisation_number": "supplier_organisation_number",
|
||||
"customer_number": "customer_number",
|
||||
"payment_line": "payment_line",
|
||||
}
|
||||
|
||||
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
||||
scale = dpi / 72.0
|
||||
|
||||
# Get storage helper for image dimensions
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Cache for image dimensions per page
|
||||
image_dimensions: dict[int, tuple[int, int]] = {}
|
||||
|
||||
def get_image_dimensions(page_no: int) -> tuple[int, int] | None:
|
||||
"""Get image dimensions for a page (1-indexed)."""
|
||||
if page_no in image_dimensions:
|
||||
return image_dimensions[page_no]
|
||||
|
||||
# Get dimensions from storage helper
|
||||
dims = storage.get_admin_image_dimensions(document_id, page_no)
|
||||
if dims:
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
|
||||
return None
|
||||
|
||||
annotation_count = 0
|
||||
|
||||
# Get field results from report (list of dicts)
|
||||
field_results = report.get("field_results", [])
|
||||
|
||||
for field_info in field_results:
|
||||
if not field_info.get("matched"):
|
||||
continue
|
||||
|
||||
csv_field_name = field_info.get("field_name", "")
|
||||
|
||||
# Map CSV field name to internal field name
|
||||
field_name = CSV_TO_INTERNAL_FIELD.get(csv_field_name, csv_field_name)
|
||||
|
||||
# Get class_id from field name
|
||||
class_id = FIELD_CLASS_IDS.get(field_name)
|
||||
if class_id is None:
|
||||
logger.warning(f"Unknown field name: {csv_field_name} -> {field_name}")
|
||||
continue
|
||||
|
||||
# Get bbox info (list: [x, y, x2, y2] in PDF points - 72 DPI)
|
||||
bbox = field_info.get("bbox", [])
|
||||
if not bbox or len(bbox) < 4:
|
||||
continue
|
||||
|
||||
# Convert PDF points (72 DPI) to pixel coordinates (at configured DPI)
|
||||
pdf_x1, pdf_y1, pdf_x2, pdf_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x1 = pdf_x1 * scale
|
||||
y1 = pdf_y1 * scale
|
||||
x2 = pdf_x2 * scale
|
||||
y2 = pdf_y2 * scale
|
||||
|
||||
bbox_width = x2 - x1
|
||||
bbox_height = y2 - y1
|
||||
|
||||
# Get page number (convert to 1-indexed)
|
||||
page_no = field_info.get("page_no", 0) + 1
|
||||
|
||||
# Get image dimensions for normalization
|
||||
dims = get_image_dimensions(page_no)
|
||||
if dims:
|
||||
img_width, img_height = dims
|
||||
# Calculate normalized coordinates
|
||||
x_center = (x1 + x2) / 2 / img_width
|
||||
y_center = (y1 + y2) / 2 / img_height
|
||||
width = bbox_width / img_width
|
||||
height = bbox_height / img_height
|
||||
else:
|
||||
# Fallback: use pixel coordinates as-is for normalization
|
||||
# (will be slightly off but better than /1000)
|
||||
logger.warning(f"Could not get image dimensions for page {page_no}, using estimates")
|
||||
# Estimate A4 at configured DPI: 595 x 842 points * scale
|
||||
estimated_width = 595 * scale
|
||||
estimated_height = 842 * scale
|
||||
x_center = (x1 + x2) / 2 / estimated_width
|
||||
y_center = (y1 + y2) / 2 / estimated_height
|
||||
width = bbox_width / estimated_width
|
||||
height = bbox_height / estimated_height
|
||||
|
||||
# Create annotation
|
||||
try:
|
||||
ann_repo.create(
|
||||
document_id=document_id,
|
||||
page_number=page_no,
|
||||
class_id=class_id,
|
||||
class_name=field_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=int(x1),
|
||||
bbox_y=int(y1),
|
||||
bbox_width=int(bbox_width),
|
||||
bbox_height=int(bbox_height),
|
||||
text_value=field_info.get("matched_text"),
|
||||
confidence=field_info.get("score"),
|
||||
source="auto",
|
||||
)
|
||||
annotation_count += 1
|
||||
logger.info(f"Saved annotation for {field_name}: bbox=({int(x1)}, {int(y1)}, {int(bbox_width)}, {int(bbox_height)})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save annotation for {field_name}: {e}")
|
||||
|
||||
return annotation_count
|
||||
|
||||
|
||||
def run_pending_autolabel_batch(
|
||||
batch_size: int = 10,
|
||||
output_dir: Path | None = None,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a batch of pending auto-label documents.
|
||||
|
||||
Args:
|
||||
batch_size: Number of documents to process
|
||||
output_dir: Output directory for temp files
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
documents = get_pending_autolabel_documents(limit=batch_size)
|
||||
|
||||
results = {
|
||||
"total": len(documents),
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"documents": [],
|
||||
}
|
||||
|
||||
for doc in documents:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
output_dir=output_dir,
|
||||
doc_repo=doc_repo,
|
||||
ann_repo=ann_repo,
|
||||
)
|
||||
|
||||
doc_result = {
|
||||
"document_id": str(doc.document_id),
|
||||
"success": result.get("success", False),
|
||||
}
|
||||
|
||||
if result.get("success"):
|
||||
results["successful"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
doc_result["error"] = result.get("error") or "Unknown error"
|
||||
|
||||
results["documents"].append(doc_result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_manual_annotations_to_document_db(
|
||||
document: AdminDocument,
|
||||
annotations: list,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save manual annotations to PostgreSQL documents and field_results tables.
|
||||
|
||||
Called when user marks a document as 'labeled' from the web UI.
|
||||
This ensures manually labeled documents are also tracked in the same
|
||||
database as auto-labeled documents for consistency.
|
||||
|
||||
Args:
|
||||
document: AdminDocument instance
|
||||
annotations: List of AdminAnnotation instances
|
||||
|
||||
Returns:
|
||||
Dict with success status and details
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
document_id = str(document.document_id)
|
||||
|
||||
# Build pdf_path using raw_pdfs base path (same as auto-label)
|
||||
storage = get_storage_helper()
|
||||
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||
if raw_pdfs_dir is None:
|
||||
return {
|
||||
"success": False,
|
||||
"document_id": document_id,
|
||||
"error": "Storage not configured for local access",
|
||||
}
|
||||
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
|
||||
|
||||
# Build report dict compatible with DocumentDB.save_document()
|
||||
field_results = []
|
||||
fields_total = len(annotations)
|
||||
fields_matched = 0
|
||||
|
||||
for ann in annotations:
|
||||
# All manual annotations are considered "matched" since user verified them
|
||||
field_result = {
|
||||
"field_name": ann.class_name,
|
||||
"csv_value": ann.text_value or "", # Manual annotations may not have CSV value
|
||||
"matched": True,
|
||||
"score": ann.confidence or 1.0, # Manual = high confidence
|
||||
"matched_text": ann.text_value,
|
||||
"candidate_used": "manual",
|
||||
"bbox": [ann.bbox_x, ann.bbox_y, ann.bbox_x + ann.bbox_width, ann.bbox_y + ann.bbox_height],
|
||||
"page_no": ann.page_number - 1, # Convert to 0-indexed
|
||||
"context_keywords": [],
|
||||
"error": None,
|
||||
}
|
||||
field_results.append(field_result)
|
||||
fields_matched += 1
|
||||
|
||||
# Determine PDF type
|
||||
pdf_type = "unknown"
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
from shared.pdf import PDFDocument
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
tokens = list(pdf_doc.extract_text_tokens(0))
|
||||
pdf_type = "scanned" if len(tokens) < 10 else "text"
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not determine PDF type: {e}")
|
||||
|
||||
# Build report
|
||||
report = {
|
||||
"document_id": document_id,
|
||||
"pdf_path": str(pdf_path),
|
||||
"pdf_type": pdf_type,
|
||||
"success": fields_matched > 0,
|
||||
"total_pages": document.page_count,
|
||||
"fields_matched": fields_matched,
|
||||
"fields_total": fields_total,
|
||||
"annotations_generated": fields_matched,
|
||||
"processing_time_ms": 0, # Manual labeling - no processing time
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"errors": [],
|
||||
"field_results": field_results,
|
||||
# Extended fields (from CSV if available)
|
||||
"split": None,
|
||||
"customer_number": document.csv_field_values.get("customer_number") if document.csv_field_values else None,
|
||||
"supplier_name": document.csv_field_values.get("supplier_name") if document.csv_field_values else None,
|
||||
"supplier_organisation_number": document.csv_field_values.get("supplier_organisation_number") if document.csv_field_values else None,
|
||||
"supplier_accounts": document.csv_field_values.get("supplier_accounts") if document.csv_field_values else None,
|
||||
}
|
||||
|
||||
# Save to PostgreSQL DocumentDB
|
||||
try:
|
||||
doc_db = get_document_db()
|
||||
doc_db.save_document(report)
|
||||
logger.info(f"Saved manual annotations to DocumentDB for {document_id}: {fields_matched} fields")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document_id": document_id,
|
||||
"fields_saved": fields_matched,
|
||||
"message": f"Saved {fields_matched} annotations to DocumentDB",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save manual annotations to DocumentDB: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"document_id": document_id,
|
||||
"error": str(e),
|
||||
}
|
||||
217
packages/backend/backend/web/services/document_service.py
Normal file
217
packages/backend/backend/web/services/document_service.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Document Service for storage-backed file operations.
|
||||
|
||||
Provides a unified interface for document upload, download, and serving
|
||||
using the storage abstraction layer.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentResult:
|
||||
"""Result of document operation."""
|
||||
|
||||
id: str
|
||||
file_path: str
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document file operations using storage backend.
|
||||
|
||||
Provides upload, download, and URL generation for documents and images.
|
||||
"""
|
||||
|
||||
# Storage path prefixes
|
||||
DOCUMENTS_PREFIX = "documents"
|
||||
IMAGES_PREFIX = "images"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_backend: "StorageBackend",
|
||||
admin_db: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize document service.
|
||||
|
||||
Args:
|
||||
storage_backend: Storage backend for file operations.
|
||||
admin_db: Optional AdminDB instance for database operations.
|
||||
"""
|
||||
self._storage = storage_backend
|
||||
self._admin_db = admin_db
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
dataset_id: str | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> DocumentResult:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename.
|
||||
dataset_id: Optional dataset ID for organization.
|
||||
document_id: Optional document ID (generated if not provided).
|
||||
|
||||
Returns:
|
||||
DocumentResult with ID and storage path.
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
# Extract extension from filename
|
||||
ext = ""
|
||||
if "." in filename:
|
||||
ext = "." + filename.rsplit(".", 1)[-1].lower()
|
||||
|
||||
# Build logical path
|
||||
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
|
||||
|
||||
# Upload via storage backend
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
|
||||
return DocumentResult(
|
||||
id=document_id,
|
||||
file_path=remote_path,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
def download_document(self, remote_path: str) -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
remote_path: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a document.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for document access.
|
||||
"""
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, remote_path: str) -> bool:
|
||||
"""Check if a document exists in storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
return self._storage.exists(remote_path)
|
||||
|
||||
def delete_document_files(self, remote_path: str) -> bool:
|
||||
"""Delete a document from storage.
|
||||
|
||||
Args:
|
||||
remote_path: Logical path to the document.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
return self._storage.delete(remote_path)
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Logical path where image was stored.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||
return remote_path
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get a URL for accessing a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL validity duration.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL for image access.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||
return self._storage.download_bytes(remote_path)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
image_paths = self._storage.list_files(prefix)
|
||||
|
||||
deleted_count = 0
|
||||
for path in image_paths:
|
||||
if self._storage.delete(path):
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
360
packages/backend/backend/web/services/inference.py
Normal file
360
packages/backend/backend/web/services/inference.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Inference Service
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Type alias for model path resolver function
|
||||
ModelPathResolver = Callable[[], Path | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
|
||||
document_id: str
|
||||
success: bool = False
|
||||
document_type: str = "invoice" # "invoice" or "letter"
|
||||
fields: dict[str, str | None] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
detections: list[dict] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration (default model settings)
|
||||
storage_config: Storage configuration
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
"""Resolve the model path to use for inference.
|
||||
|
||||
Priority:
|
||||
1. Active model from database (via resolver)
|
||||
2. Default model from config
|
||||
"""
|
||||
if self._model_path_resolver:
|
||||
try:
|
||||
db_model_path = self._model_path_resolver()
|
||||
if db_model_path and Path(db_model_path).exists():
|
||||
logger.info(f"Using active model from database: {db_model_path}")
|
||||
return Path(db_model_path)
|
||||
elif db_model_path:
|
||||
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
||||
|
||||
return self.model_config.model_path
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing inference service...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from backend.pipeline.pipeline import InferencePipeline
|
||||
from backend.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Resolve model path (from DB or config)
|
||||
model_path = self._resolve_model_path()
|
||||
self._current_model_path = model_path
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload the model if active model has changed.
|
||||
|
||||
Returns:
|
||||
True if model was reloaded, False if no change needed.
|
||||
"""
|
||||
new_model_path = self._resolve_model_path()
|
||||
|
||||
if self._current_model_path == new_model_path:
|
||||
logger.debug("Model unchanged, no reload needed")
|
||||
return False
|
||||
|
||||
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
||||
self._is_initialized = False
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self.initialize()
|
||||
return True
|
||||
|
||||
@property
|
||||
def current_model_path(self) -> Path | None:
|
||||
"""Get the currently loaded model path."""
|
||||
return self._current_model_path
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
return self._is_initialized
|
||||
|
||||
@property
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process an image file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_visualization(image_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process a PDF file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization (render first page)
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF {pdf_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
830
packages/backend/backend/web/services/storage_helpers.py
Normal file
830
packages/backend/backend/web/services/storage_helpers.py
Normal file
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
Storage helpers for web services.
|
||||
|
||||
Provides convenience functions for common storage operations,
|
||||
wrapping the storage backend with proper path handling using prefixes.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.storage import PREFIXES, get_storage_backend
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
|
||||
def get_default_storage() -> "StorageBackend":
|
||||
"""Get the default storage backend.
|
||||
|
||||
Returns:
|
||||
Configured StorageBackend instance.
|
||||
"""
|
||||
return get_storage_backend()
|
||||
|
||||
|
||||
class StorageHelper:
|
||||
"""Helper class for storage operations with prefixes.
|
||||
|
||||
Provides high-level operations for document storage, including
|
||||
upload, download, and URL generation with proper path prefixes.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: "StorageBackend | None" = None) -> None:
|
||||
"""Initialize storage helper.
|
||||
|
||||
Args:
|
||||
storage: Storage backend to use. If None, creates default.
|
||||
"""
|
||||
self._storage = storage or get_default_storage()
|
||||
|
||||
@property
|
||||
def storage(self) -> "StorageBackend":
|
||||
"""Get the underlying storage backend."""
|
||||
return self._storage
|
||||
|
||||
# Document operations
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
document_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Upload a document to storage.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes.
|
||||
filename: Original filename (used for extension).
|
||||
document_id: Optional document ID. Generated if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple of (document_id, storage_path).
|
||||
"""
|
||||
if document_id is None:
|
||||
document_id = str(uuid4())
|
||||
|
||||
ext = Path(filename).suffix.lower() or ".pdf"
|
||||
path = PREFIXES.document_path(document_id, ext)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
|
||||
return document_id, path
|
||||
|
||||
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
|
||||
"""Download a document from storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
Document content as bytes.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_document_url(
|
||||
self,
|
||||
document_id: str,
|
||||
extension: str = ".pdf",
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Check if a document exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document exists.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
extension: File extension.
|
||||
|
||||
Returns:
|
||||
True if document was deleted.
|
||||
"""
|
||||
path = PREFIXES.document_path(document_id, extension)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Image operations
|
||||
|
||||
def save_page_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save a page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_page_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for a page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def delete_document_images(self, document_id: str) -> int:
|
||||
"""Delete all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def list_document_images(self, document_id: str) -> list[str]:
|
||||
"""List all images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
# Upload staging operations
|
||||
|
||||
def save_upload(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
subfolder: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file to upload staging area.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
|
||||
"""Get a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Delete a file from upload staging area.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Result operations
|
||||
|
||||
def save_result(self, content: bytes, filename: str) -> str:
|
||||
"""Save a result file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_result(self, filename: str) -> bytes:
|
||||
"""Get a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
File content as bytes.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
|
||||
"""Get presigned URL for a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def result_exists(self, filename: str) -> bool:
|
||||
"""Check if a result file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def delete_result(self, filename: str) -> bool:
|
||||
"""Delete a result file.
|
||||
|
||||
Args:
|
||||
filename: Filename to delete.
|
||||
|
||||
Returns:
|
||||
True if file was deleted.
|
||||
"""
|
||||
path = PREFIXES.result_path(filename)
|
||||
return self._storage.delete(path)
|
||||
|
||||
# Export operations
|
||||
|
||||
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
|
||||
"""Save an export file.
|
||||
|
||||
Args:
|
||||
content: File content as bytes.
|
||||
export_id: Export identifier.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_export_url(
|
||||
self,
|
||||
export_id: str,
|
||||
filename: str,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an export file.
|
||||
|
||||
Args:
|
||||
export_id: Export identifier.
|
||||
filename: Filename.
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = PREFIXES.export_path(export_id, filename)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
# Admin image operations
|
||||
|
||||
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
|
||||
"""Get the storage path for an admin image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Storage path like "admin_images/doc123/page_1.png"
|
||||
"""
|
||||
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
|
||||
|
||||
def save_admin_image(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Save an admin page image to storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
content: Image content as bytes.
|
||||
|
||||
Returns:
|
||||
Storage path where image was saved.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
|
||||
"""Download an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image content as bytes.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def get_admin_image_url(
|
||||
self,
|
||||
document_id: str,
|
||||
page_num: int,
|
||||
expires_in_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""Get presigned URL for an admin page image.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
expires_in_seconds: URL expiration time.
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||
|
||||
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
|
||||
"""Check if an admin page image exists.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
True if image exists.
|
||||
"""
|
||||
path = self.get_admin_image_path(document_id, page_num)
|
||||
return self._storage.exists(path)
|
||||
|
||||
def list_admin_images(self, document_id: str) -> list[str]:
|
||||
"""List all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
List of image paths.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
return self._storage.list_files(prefix)
|
||||
|
||||
def delete_admin_images(self, document_id: str) -> int:
|
||||
"""Delete all admin images for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
|
||||
Returns:
|
||||
Number of images deleted.
|
||||
"""
|
||||
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||
images = self._storage.list_files(prefix)
|
||||
deleted = 0
|
||||
for img_path in images:
|
||||
if self._storage.delete(img_path):
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
def get_admin_image_local_path(
|
||||
self, document_id: str, page_num: int
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an admin image.
|
||||
|
||||
This method is useful for serving files via FileResponse.
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
# Cloud storage - cannot get local path
|
||||
return None
|
||||
|
||||
remote_path = self.get_admin_image_path(document_id, page_num)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(remote_path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_image_dimensions(
|
||||
self, document_id: str, page_num: int
|
||||
) -> tuple[int, int] | None:
|
||||
"""Get the dimensions (width, height) of an admin image.
|
||||
|
||||
This method is useful for normalizing bounding box coordinates.
|
||||
|
||||
Args:
|
||||
document_id: Document identifier.
|
||||
page_num: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if image exists, None otherwise.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
# Try local path first for efficiency
|
||||
local_path = self.get_admin_image_local_path(document_id, page_num)
|
||||
if local_path is not None:
|
||||
with Image.open(local_path) as img:
|
||||
return img.size
|
||||
|
||||
# Fall back to downloading for cloud storage
|
||||
if not self.admin_image_exists(document_id, page_num):
|
||||
return None
|
||||
|
||||
try:
|
||||
import io
|
||||
image_bytes = self.get_admin_image(document_id, page_num)
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return img.size
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Raw PDF operations (legacy compatibility)
|
||||
|
||||
def save_raw_pdf(self, content: bytes, filename: str) -> str:
|
||||
"""Save a raw PDF for auto-labeling pipeline.
|
||||
|
||||
Args:
|
||||
content: PDF content as bytes.
|
||||
filename: Filename to save as.
|
||||
|
||||
Returns:
|
||||
Storage path where file was saved.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
self._storage.upload_bytes(content, path, overwrite=True)
|
||||
return path
|
||||
|
||||
def get_raw_pdf(self, filename: str) -> bytes:
|
||||
"""Get a raw PDF from storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
PDF content as bytes.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.download_bytes(path)
|
||||
|
||||
def raw_pdf_exists(self, filename: str) -> bool:
|
||||
"""Check if a raw PDF exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
return self._storage.exists(path)
|
||||
|
||||
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a raw PDF.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdf_path(self, filename: str) -> str:
|
||||
"""Get the storage path for a raw PDF (not the local filesystem path).
|
||||
|
||||
Args:
|
||||
filename: Filename.
|
||||
|
||||
Returns:
|
||||
Storage path like "raw_pdfs/filename.pdf"
|
||||
"""
|
||||
return f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||
|
||||
# Result local path operations
|
||||
|
||||
def get_result_local_path(self, filename: str) -> Path | None:
|
||||
"""Get the local filesystem path for a result file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.result_path(filename)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_results_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for results (local storage only).
|
||||
|
||||
Used for mounting static file directories.
|
||||
|
||||
Returns:
|
||||
Path to results directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Upload local path operations
|
||||
|
||||
def get_upload_local_path(
|
||||
self, filename: str, subfolder: str | None = None
|
||||
) -> Path | None:
|
||||
"""Get the local filesystem path for an upload file.
|
||||
|
||||
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||
|
||||
Args:
|
||||
filename: Filename to retrieve.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
Path object if using local storage and file exists, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
try:
|
||||
full_path = self._storage._get_full_path(path)
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
|
||||
"""Get the base directory path for uploads (local storage only).
|
||||
|
||||
Args:
|
||||
subfolder: Optional subfolder (e.g., "async").
|
||||
|
||||
Returns:
|
||||
Path to uploads directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
if subfolder:
|
||||
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
|
||||
else:
|
||||
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
|
||||
"""Check if an upload file exists.
|
||||
|
||||
Args:
|
||||
filename: Filename to check.
|
||||
subfolder: Optional subfolder.
|
||||
|
||||
Returns:
|
||||
True if file exists.
|
||||
"""
|
||||
path = PREFIXES.upload_path(filename, subfolder)
|
||||
return self._storage.exists(path)
|
||||
|
||||
# Dataset operations
|
||||
|
||||
def get_datasets_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for datasets (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to datasets directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_admin_images_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for admin images (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to admin_images directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_raw_pdfs_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for raw PDFs (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to raw_pdfs directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_autolabel_output_path(self) -> Path | None:
|
||||
"""Get the directory path for autolabel output (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to autolabel_output directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use a subfolder under results for autolabel output
|
||||
base_path = self._storage._get_full_path("autolabel_output")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_training_data_path(self) -> Path | None:
|
||||
"""Get the directory path for training data exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to training directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path("training")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_exports_base_path(self) -> Path | None:
|
||||
"""Get the base directory path for exports (local storage only).
|
||||
|
||||
Returns:
|
||||
Path to exports directory if using local storage, None otherwise.
|
||||
"""
|
||||
if not isinstance(self._storage, LocalStorageBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
return base_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Default instance for convenience
|
||||
_default_helper: StorageHelper | None = None
|
||||
|
||||
|
||||
def get_storage_helper() -> StorageHelper:
|
||||
"""Get the default storage helper instance.
|
||||
|
||||
Creates the helper on first call with default storage backend.
|
||||
|
||||
Returns:
|
||||
Default StorageHelper instance.
|
||||
"""
|
||||
global _default_helper
|
||||
if _default_helper is None:
|
||||
_default_helper = StorageHelper()
|
||||
return _default_helper
|
||||
24
packages/backend/backend/web/workers/__init__.py
Normal file
24
packages/backend/backend/web/workers/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Background Task Queues
|
||||
|
||||
Worker queues for asynchronous and batch processing.
|
||||
"""
|
||||
|
||||
from backend.web.workers.async_queue import AsyncTaskQueue, AsyncTask
|
||||
from backend.web.workers.batch_queue import (
|
||||
BatchTaskQueue,
|
||||
BatchTask,
|
||||
init_batch_queue,
|
||||
shutdown_batch_queue,
|
||||
get_batch_queue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncTaskQueue",
|
||||
"AsyncTask",
|
||||
"BatchTaskQueue",
|
||||
"BatchTask",
|
||||
"init_batch_queue",
|
||||
"shutdown_batch_queue",
|
||||
"get_batch_queue",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user