restructure project
This commit is contained in:
25
packages/inference/Dockerfile
Normal file
25
packages/inference/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install shared package
|
||||
COPY packages/shared /app/packages/shared
|
||||
RUN pip install --no-cache-dir -e /app/packages/shared
|
||||
|
||||
# Install inference package
|
||||
COPY packages/inference /app/packages/inference
|
||||
RUN pip install --no-cache-dir -e /app/packages/inference
|
||||
|
||||
# Copy frontend (if needed)
|
||||
COPY frontend /app/frontend
|
||||
|
||||
WORKDIR /app/packages/inference
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]
|
||||
0
packages/inference/inference/__init__.py
Normal file
0
packages/inference/inference/__init__.py
Normal file
0
packages/inference/inference/azure/__init__.py
Normal file
0
packages/inference/inference/azure/__init__.py
Normal file
105
packages/inference/inference/azure/aci_trigger.py
Normal file
105
packages/inference/inference/azure/aci_trigger.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Trigger training jobs on Azure Container Instances."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Azure SDK is optional; only needed if using ACI trigger
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
|
||||
from azure.mgmt.containerinstance.models import (
|
||||
Container,
|
||||
ContainerGroup,
|
||||
EnvironmentVariable,
|
||||
GpuResource,
|
||||
ResourceRequests,
|
||||
ResourceRequirements,
|
||||
)
|
||||
|
||||
_AZURE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AZURE_SDK_AVAILABLE = False
|
||||
|
||||
|
||||
def start_training_container(task_id: str) -> str | None:
|
||||
"""
|
||||
Start an Azure Container Instance for a training task.
|
||||
|
||||
Returns the container group name if successful, None otherwise.
|
||||
Requires environment variables:
|
||||
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
|
||||
"""
|
||||
if not _AZURE_SDK_AVAILABLE:
|
||||
logger.warning(
|
||||
"Azure SDK not installed. Install azure-mgmt-containerinstance "
|
||||
"and azure-identity to use ACI trigger."
|
||||
)
|
||||
return None
|
||||
|
||||
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
|
||||
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
|
||||
image = os.environ.get(
|
||||
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
|
||||
)
|
||||
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
|
||||
location = os.environ.get("AZURE_LOCATION", "eastus")
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
|
||||
return None
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
client = ContainerInstanceManagementClient(credential, subscription_id)
|
||||
|
||||
container_name = f"training-{task_id[:8]}"
|
||||
|
||||
env_vars = [
|
||||
EnvironmentVariable(name="TASK_ID", value=task_id),
|
||||
]
|
||||
|
||||
# Pass DB connection securely
|
||||
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
|
||||
val = os.environ.get(var, "")
|
||||
if val:
|
||||
env_vars.append(EnvironmentVariable(name=var, value=val))
|
||||
|
||||
db_password = os.environ.get("DB_PASSWORD", "")
|
||||
if db_password:
|
||||
env_vars.append(
|
||||
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
|
||||
)
|
||||
|
||||
container = Container(
|
||||
name=container_name,
|
||||
image=image,
|
||||
resources=ResourceRequirements(
|
||||
requests=ResourceRequests(
|
||||
cpu=4,
|
||||
memory_in_gb=16,
|
||||
gpu=GpuResource(count=1, sku=gpu_sku),
|
||||
)
|
||||
),
|
||||
environment_variables=env_vars,
|
||||
command=[
|
||||
"python",
|
||||
"run_training.py",
|
||||
"--task-id",
|
||||
task_id,
|
||||
],
|
||||
)
|
||||
|
||||
group = ContainerGroup(
|
||||
location=location,
|
||||
containers=[container],
|
||||
os_type="Linux",
|
||||
restart_policy="Never",
|
||||
)
|
||||
|
||||
logger.info("Creating ACI container group: %s", container_name)
|
||||
client.container_groups.begin_create_or_update(
|
||||
resource_group, container_name, group
|
||||
)
|
||||
|
||||
return container_name
|
||||
0
packages/inference/inference/cli/__init__.py
Normal file
0
packages/inference/inference/cli/__init__.py
Normal file
141
packages/inference/inference/cli/infer.py
Normal file
141
packages/inference/inference/cli/infer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Inference CLI
|
||||
|
||||
Runs inference on new PDFs to extract invoice data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract invoice data from PDFs using trained model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
required=True,
|
||||
help='Path to trained YOLO model (.pt file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input', '-i',
|
||||
required=True,
|
||||
help='Input PDF file or directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
help='Output JSON file (default: stdout)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--confidence',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Detection confidence threshold (default: 0.5)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-fallback',
|
||||
action='store_true',
|
||||
help='Disable fallback OCR'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lang',
|
||||
default='en',
|
||||
help='OCR language (default: en)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu',
|
||||
action='store_true',
|
||||
help='Use GPU'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
pdf_files = [input_path]
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
|
||||
from inference.pipeline import InferencePipeline
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = InferencePipeline(
|
||||
model_path=model_path,
|
||||
confidence_threshold=args.confidence,
|
||||
ocr_lang=args.lang,
|
||||
use_gpu=args.gpu,
|
||||
dpi=args.dpi,
|
||||
enable_fallback=not args.no_fallback
|
||||
)
|
||||
|
||||
# Process files
|
||||
results = []
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
output = results[0]
|
||||
else:
|
||||
output = results
|
||||
|
||||
json_output = json.dumps(output, indent=2, ensure_ascii=False)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
159
packages/inference/inference/cli/serve.py
Normal file
159
packages/inference/inference/cli/serve.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Web Server CLI
|
||||
|
||||
Command-line interface for starting the web server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Start the Invoice Field Extraction web server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to listen on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=Path,
|
||||
default=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
help="Path to YOLO model weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Detection confidence threshold",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-gpu",
|
||||
action="store_true",
|
||||
help="Disable GPU acceleration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
setup_logging(debug=args.debug)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validate model path
|
||||
if not args.model.exists():
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Invoice Field Extraction Web Server")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"Confidence threshold: {args.confidence}")
|
||||
logger.info(f"GPU enabled: {not args.no_gpu}")
|
||||
logger.info(f"Server: http://{args.host}:{args.port}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
use_gpu=not args.no_gpu,
|
||||
dpi=args.dpi,
|
||||
),
|
||||
server=ServerConfig(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
storage=StorageConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
import uvicorn
|
||||
from inference.web.app import create_app
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
reload=config.server.reload,
|
||||
workers=config.server.workers if not config.server.reload else 1,
|
||||
log_level="debug" if config.server.debug else "info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
packages/inference/inference/data/__init__.py
Normal file
0
packages/inference/inference/data/__init__.py
Normal file
1316
packages/inference/inference/data/admin_db.py
Normal file
1316
packages/inference/inference/data/admin_db.py
Normal file
File diff suppressed because it is too large
Load Diff
407
packages/inference/inference/data/admin_models.py
Normal file
407
packages/inference/inference/data/admin_models.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Admin API SQLModel Database Models
|
||||
|
||||
Defines the database schema for admin document management, annotations, and training tasks.
|
||||
Includes batch upload support, training document links, and annotation history.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CSV to Field Class Mapping
|
||||
# =============================================================================
|
||||
|
||||
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
||||
"InvoiceNumber": 0, # invoice_number
|
||||
"InvoiceDate": 1, # invoice_date
|
||||
"InvoiceDueDate": 2, # invoice_due_date
|
||||
"OCR": 3, # ocr_number
|
||||
"Bankgiro": 4, # bankgiro
|
||||
"Plusgiro": 5, # plusgiro
|
||||
"Amount": 6, # amount
|
||||
"supplier_organisation_number": 7, # supplier_organisation_number
|
||||
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
||||
"customer_number": 9, # customer_number
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AdminToken(SQLModel, table=True):
|
||||
"""Admin authentication token."""
|
||||
|
||||
__tablename__ = "admin_tokens"
|
||||
|
||||
token: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AdminDocument(SQLModel, table=True):
|
||||
"""Document uploaded for labeling/annotation."""
|
||||
|
||||
__tablename__ = "admin_documents"
|
||||
|
||||
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
file_path: str = Field(max_length=512) # Path to stored file
|
||||
page_count: int = Field(default=1)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, auto_labeling, labeled, exported
|
||||
auto_label_status: str | None = Field(default=None, max_length=20)
|
||||
# Auto-label status: running, completed, failed
|
||||
auto_label_error: str | None = Field(default=None)
|
||||
# v2: Upload source tracking
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
batch_id: UUID | None = Field(default=None, index=True)
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
# When auto-label was queued
|
||||
annotation_lock_until: datetime | None = Field(default=None)
|
||||
# Lock for manual annotation while auto-label runs
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class AdminAnnotation(SQLModel, table=True):
|
||||
"""Annotation for a document (bounding box + label)."""
|
||||
|
||||
__tablename__ = "admin_annotations"
|
||||
|
||||
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
page_number: int = Field(default=1) # 1-indexed
|
||||
class_id: int # 0-9 for invoice fields
|
||||
class_name: str = Field(max_length=50) # e.g., "invoice_number"
|
||||
# Bounding box (normalized 0-1 coordinates)
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
# Original pixel coordinates (for display)
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
# OCR extracted text (if available)
|
||||
text_value: str | None = Field(default=None)
|
||||
confidence: float | None = Field(default=None)
|
||||
# Source: manual, auto, imported
|
||||
source: str = Field(default="manual", max_length=20, index=True)
|
||||
# v2: Verification fields
|
||||
is_verified: bool = Field(default=False, index=True)
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
verified_by: str | None = Field(default=None, max_length=255)
|
||||
# v2: Override tracking
|
||||
override_source: str | None = Field(default=None, max_length=20)
|
||||
# If this annotation overrides another: 'auto' or 'imported'
|
||||
original_annotation_id: UUID | None = Field(default=None)
|
||||
# Reference to the annotation this overrides
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingTask(SQLModel, table=True):
|
||||
"""Training/fine-tuning task."""
|
||||
|
||||
__tablename__ = "training_tasks"
|
||||
|
||||
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, scheduled, running, completed, failed, cancelled
|
||||
task_type: str = Field(default="train", max_length=20)
|
||||
# Task type: train, finetune
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
# Training configuration
|
||||
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Schedule settings
|
||||
scheduled_at: datetime | None = Field(default=None)
|
||||
cron_expression: str | None = Field(default=None, max_length=50)
|
||||
is_recurring: bool = Field(default=False)
|
||||
# Execution details
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
# Result metrics
|
||||
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
model_path: str | None = Field(default=None, max_length=512)
|
||||
# v2: Document count and extracted metrics
|
||||
document_count: int = Field(default=0)
|
||||
# Count of documents used in training
|
||||
metrics_mAP: float | None = Field(default=None, index=True)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
# Extracted metrics for easy querying
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingLog(SQLModel, table=True):
|
||||
"""Training log entry."""
|
||||
|
||||
__tablename__ = "training_logs"
|
||||
|
||||
log_id: int | None = Field(default=None, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
level: str = Field(max_length=20) # INFO, WARNING, ERROR
|
||||
message: str
|
||||
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Batch Upload Models (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BatchUpload(SQLModel, table=True):
|
||||
"""Batch upload of multiple documents via ZIP file."""
|
||||
|
||||
__tablename__ = "batch_uploads"
|
||||
|
||||
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255) # ZIP filename
|
||||
file_size: int
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
status: str = Field(default="processing", max_length=20, index=True)
|
||||
# Status: processing, completed, partial, failed
|
||||
total_files: int = Field(default=0)
|
||||
processed_files: int = Field(default=0)
|
||||
# Number of files processed so far
|
||||
successful_files: int = Field(default=0)
|
||||
failed_files: int = Field(default=0)
|
||||
csv_filename: str | None = Field(default=None, max_length=255)
|
||||
# CSV file used for auto-labeling
|
||||
csv_row_count: int | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class BatchUploadFile(SQLModel, table=True):
|
||||
"""Individual file within a batch upload."""
|
||||
|
||||
__tablename__ = "batch_upload_files"
|
||||
|
||||
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
|
||||
filename: str = Field(max_length=255) # PDF filename within ZIP
|
||||
document_id: UUID | None = Field(default=None)
|
||||
# Link to created AdminDocument (if successful)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, processing, completed, failed, skipped
|
||||
error_message: str | None = Field(default=None)
|
||||
annotation_count: int = Field(default=0)
|
||||
# Number of annotations created for this file
|
||||
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# CSV row data for this file (if available)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
processed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Training Document Link (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingDataset(SQLModel, table=True):
|
||||
"""Training dataset containing selected documents with train/val/test splits."""
|
||||
|
||||
__tablename__ = "training_datasets"
|
||||
|
||||
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="building", max_length=20, index=True)
|
||||
# Status: building, ready, training, archived, failed
|
||||
train_ratio: float = Field(default=0.8)
|
||||
val_ratio: float = Field(default=0.1)
|
||||
seed: int = Field(default=42)
|
||||
total_documents: int = Field(default=0)
|
||||
total_images: int = Field(default=0)
|
||||
total_annotations: int = Field(default=0)
|
||||
dataset_path: str | None = Field(default=None, max_length=512)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class DatasetDocument(SQLModel, table=True):
|
||||
"""Junction table linking datasets to documents with split assignment."""
|
||||
|
||||
__tablename__ = "dataset_documents"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
split: str = Field(max_length=10) # train, val, test
|
||||
page_count: int = Field(default=0)
|
||||
annotation_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingDocumentLink(SQLModel, table=True):
|
||||
"""Junction table linking training tasks to documents."""
|
||||
|
||||
__tablename__ = "training_document_links"
|
||||
|
||||
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of annotations at training time (includes count, verified count, etc.)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AnnotationHistory(SQLModel, table=True):
|
||||
"""History of annotation changes (for override tracking)."""
|
||||
|
||||
__tablename__ = "annotation_history"
|
||||
|
||||
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
# Change action: created, updated, deleted, override
|
||||
action: str = Field(max_length=20, index=True)
|
||||
# Previous value (for updates/deletes)
|
||||
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# New value (for creates/updates)
|
||||
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Change metadata
|
||||
changed_by: str | None = Field(default=None, max_length=255)
|
||||
# User/token who made the change
|
||||
change_reason: str | None = Field(default=None)
|
||||
# Optional reason for change
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Field class mapping (same as src/cli/train.py)
|
||||
FIELD_CLASSES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_organisation_number",
|
||||
8: "payment_line",
|
||||
9: "customer_number",
|
||||
}
|
||||
|
||||
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
||||
|
||||
|
||||
# Read-only models for API responses
|
||||
class AdminDocumentRead(SQLModel):
|
||||
"""Admin document response model."""
|
||||
|
||||
document_id: UUID
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
page_count: int
|
||||
status: str
|
||||
auto_label_status: str | None
|
||||
auto_label_error: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AdminAnnotationRead(SQLModel):
|
||||
"""Admin annotation response model."""
|
||||
|
||||
annotation_id: UUID
|
||||
document_id: UUID
|
||||
page_number: int
|
||||
class_id: int
|
||||
class_name: str
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
text_value: str | None
|
||||
confidence: float | None
|
||||
source: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingTaskRead(SQLModel):
|
||||
"""Training task response model."""
|
||||
|
||||
task_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
task_type: str
|
||||
config: dict[str, Any] | None
|
||||
scheduled_at: datetime | None
|
||||
is_recurring: bool
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
error_message: str | None
|
||||
result_metrics: dict[str, Any] | None
|
||||
model_path: str | None
|
||||
dataset_id: UUID | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingDatasetRead(SQLModel):
|
||||
"""Training dataset response model."""
|
||||
|
||||
dataset_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetDocumentRead(SQLModel):
|
||||
"""Dataset document response model."""
|
||||
|
||||
id: UUID
|
||||
dataset_id: UUID
|
||||
document_id: UUID
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
374
packages/inference/inference/data/async_request_db.py
Normal file
374
packages/inference/inference/data/async_request_db.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Async Request Database Operations
|
||||
|
||||
Database interface for async invoice processing requests using SQLModel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, text
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from inference.data.database import get_session_context, create_db_and_tables, close_engine
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Legacy dataclasses for backward compatibility
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiKeyConfig:
|
||||
"""API key configuration and limits (legacy compatibility)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestDB:
|
||||
"""Database interface for async processing requests using SQLModel."""
|
||||
|
||||
def __init__(self, connection_string: str | None = None) -> None:
|
||||
# connection_string is kept for backward compatibility but ignored
|
||||
# SQLModel uses the global engine from database.py
|
||||
self._initialized = False
|
||||
|
||||
def connect(self):
|
||||
"""Legacy method - returns self for compatibility."""
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connections."""
|
||||
close_engine()
|
||||
|
||||
def __enter__(self) -> "AsyncRequestDB":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass # Sessions are managed per-operation
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create async processing tables if they don't exist."""
|
||||
create_db_and_tables()
|
||||
self._initialized = True
|
||||
|
||||
# ==========================================================================
|
||||
# API Key Operations
|
||||
# ==========================================================================
|
||||
|
||||
def is_valid_api_key(self, api_key: str) -> bool:
|
||||
"""Check if API key exists and is active."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
return result is not None and result.is_active is True
|
||||
|
||||
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
|
||||
"""Get API key configuration and limits."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
if result is None:
|
||||
return None
|
||||
return ApiKeyConfig(
|
||||
api_key=result.api_key,
|
||||
name=result.name,
|
||||
is_active=result.is_active,
|
||||
requests_per_minute=result.requests_per_minute,
|
||||
max_concurrent_jobs=result.max_concurrent_jobs,
|
||||
max_file_size_mb=result.max_file_size_mb,
|
||||
)
|
||||
|
||||
def create_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
name: str,
|
||||
requests_per_minute: int = 10,
|
||||
max_concurrent_jobs: int = 3,
|
||||
max_file_size_mb: int = 50,
|
||||
) -> None:
|
||||
"""Create a new API key."""
|
||||
with get_session_context() as session:
|
||||
existing = session.get(ApiKey, api_key)
|
||||
if existing:
|
||||
existing.name = name
|
||||
existing.requests_per_minute = requests_per_minute
|
||||
existing.max_concurrent_jobs = max_concurrent_jobs
|
||||
existing.max_file_size_mb = max_file_size_mb
|
||||
session.add(existing)
|
||||
else:
|
||||
new_key = ApiKey(
|
||||
api_key=api_key,
|
||||
name=name,
|
||||
requests_per_minute=requests_per_minute,
|
||||
max_concurrent_jobs=max_concurrent_jobs,
|
||||
max_file_size_mb=max_file_size_mb,
|
||||
)
|
||||
session.add(new_key)
|
||||
|
||||
def update_api_key_usage(self, api_key: str) -> None:
|
||||
"""Update API key last used timestamp and increment total requests."""
|
||||
with get_session_context() as session:
|
||||
key = session.get(ApiKey, api_key)
|
||||
if key:
|
||||
key.last_used_at = datetime.utcnow()
|
||||
key.total_requests += 1
|
||||
session.add(key)
|
||||
|
||||
# ==========================================================================
|
||||
# Async Request Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
api_key: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
expires_at: datetime,
|
||||
request_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new async request."""
|
||||
with get_session_context() as session:
|
||||
request = AsyncRequest(
|
||||
api_key=api_key,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
if request_id:
|
||||
request.request_id = UUID(request_id)
|
||||
session.add(request)
|
||||
session.flush() # To get the generated ID
|
||||
return str(request.request_id)
|
||||
|
||||
def get_request(self, request_id: str) -> AsyncRequest | None:
|
||||
"""Get a single async request by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AsyncRequest, UUID(request_id))
|
||||
if result:
|
||||
# Detach from session for use outside context
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_request_by_api_key(
|
||||
self,
|
||||
request_id: str,
|
||||
api_key: str,
|
||||
) -> AsyncRequest | None:
|
||||
"""Get a request only if it belongs to the given API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.request_id == UUID(request_id),
|
||||
AsyncRequest.api_key == api_key,
|
||||
)
|
||||
result = session.exec(statement).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
increment_retry: bool = False,
|
||||
) -> None:
|
||||
"""Update request status."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = status
|
||||
if status == "processing":
|
||||
request.started_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
request.error_message = error_message
|
||||
if increment_retry:
|
||||
request.retry_count += 1
|
||||
session.add(request)
|
||||
|
||||
def complete_request(
|
||||
self,
|
||||
request_id: str,
|
||||
document_id: str,
|
||||
result: dict[str, Any],
|
||||
processing_time_ms: float,
|
||||
visualization_path: str | None = None,
|
||||
) -> None:
|
||||
"""Mark request as completed with result."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = "completed"
|
||||
request.document_id = document_id
|
||||
request.result = result
|
||||
request.processing_time_ms = processing_time_ms
|
||||
request.visualization_path = visualization_path
|
||||
request.completed_at = datetime.utcnow()
|
||||
session.add(request)
|
||||
|
||||
def get_requests_by_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AsyncRequest], int]:
|
||||
"""Get paginated requests for an API key."""
|
||||
with get_session_context() as session:
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(AsyncRequest.status == status)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Fetch query
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
statement = statement.where(AsyncRequest.status == status)
|
||||
statement = statement.order_by(AsyncRequest.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
# Detach results from session
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def count_active_jobs(self, api_key: str) -> int:
|
||||
"""Count active (pending + processing) jobs for an API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key,
|
||||
AsyncRequest.status.in_(["pending", "processing"]),
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
|
||||
"""Get pending requests ordered by creation time."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending"
|
||||
).order_by(AsyncRequest.created_at).limit(limit)
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_queue_position(self, request_id: str) -> int | None:
|
||||
"""Get position of a request in the pending queue."""
|
||||
with get_session_context() as session:
|
||||
# Get the request's created_at
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if not request:
|
||||
return None
|
||||
|
||||
# Count pending requests created before this one
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending",
|
||||
AsyncRequest.created_at < request.created_at,
|
||||
)
|
||||
count = session.exec(statement).one()
|
||||
return count + 1 # 1-based position
|
||||
|
||||
# ==========================================================================
|
||||
# Rate Limit Operations
|
||||
# ==========================================================================
|
||||
|
||||
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
|
||||
"""Record a rate limit event."""
|
||||
with get_session_context() as session:
|
||||
event = RateLimitEvent(
|
||||
api_key=api_key,
|
||||
event_type=event_type,
|
||||
)
|
||||
session.add(event)
|
||||
|
||||
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
|
||||
"""Count requests in the last N seconds."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
|
||||
statement = select(func.count()).select_from(RateLimitEvent).where(
|
||||
RateLimitEvent.api_key == api_key,
|
||||
RateLimitEvent.event_type == "request",
|
||||
RateLimitEvent.created_at > cutoff,
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
# ==========================================================================
|
||||
# Cleanup Operations
|
||||
# ==========================================================================
|
||||
|
||||
def delete_expired_requests(self) -> int:
|
||||
"""Delete requests that have expired. Returns count of deleted rows."""
|
||||
with get_session_context() as session:
|
||||
now = datetime.utcnow()
|
||||
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
|
||||
expired = session.exec(statement).all()
|
||||
count = len(expired)
|
||||
for request in expired:
|
||||
session.delete(request)
|
||||
logger.info(f"Deleted {count} expired async requests")
|
||||
return count
|
||||
|
||||
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
|
||||
"""Delete rate limit events older than N hours."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
statement = select(RateLimitEvent).where(
|
||||
RateLimitEvent.created_at < cutoff
|
||||
)
|
||||
old_events = session.exec(statement).all()
|
||||
count = len(old_events)
|
||||
for event in old_events:
|
||||
session.delete(event)
|
||||
return count
|
||||
|
||||
def reset_stale_processing_requests(
|
||||
self,
|
||||
stale_minutes: int = 10,
|
||||
max_retries: int = 3,
|
||||
) -> int:
|
||||
"""
|
||||
Reset requests stuck in 'processing' status.
|
||||
|
||||
Requests that have been processing for more than stale_minutes
|
||||
are considered stale. They are either reset to 'pending' (if under
|
||||
max_retries) or set to 'failed'.
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
|
||||
reset_count = 0
|
||||
|
||||
# Find stale processing requests
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "processing",
|
||||
AsyncRequest.started_at < cutoff,
|
||||
)
|
||||
stale_requests = session.exec(statement).all()
|
||||
|
||||
for request in stale_requests:
|
||||
if request.retry_count < max_retries:
|
||||
request.status = "pending"
|
||||
request.started_at = None
|
||||
else:
|
||||
request.status = "failed"
|
||||
request.error_message = "Processing timeout after max retries"
|
||||
session.add(request)
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.warning(f"Reset {reset_count} stale processing requests")
|
||||
return reset_count
|
||||
102
packages/inference/inference/data/database.py
Normal file
102
packages/inference/inference/data/database.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Database Engine and Session Management
|
||||
|
||||
Provides SQLModel database engine and session handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
import sys
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global engine instance
|
||||
_engine = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Get or create the database engine."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
connection_string = get_db_connection_string()
|
||||
# Convert psycopg2 format to SQLAlchemy format
|
||||
if connection_string.startswith("postgresql://"):
|
||||
# Already in correct format
|
||||
pass
|
||||
elif "host=" in connection_string:
|
||||
# Convert DSN format to URL format
|
||||
parts = dict(item.split("=") for item in connection_string.split())
|
||||
connection_string = (
|
||||
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
|
||||
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
|
||||
f"/{parts.get('dbname', 'docmaster')}"
|
||||
)
|
||||
|
||||
_engine = create_engine(
|
||||
connection_string,
|
||||
echo=False, # Set to True for SQL debugging
|
||||
pool_pre_ping=True, # Verify connections before use
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
from inference.data.admin_models import ( # noqa: F401
|
||||
AdminToken,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
TrainingTask,
|
||||
TrainingLog,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""Get a new database session."""
|
||||
engine = get_engine()
|
||||
return Session(engine)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_context() -> Generator[Session, None, None]:
|
||||
"""Context manager for database sessions with auto-commit/rollback."""
|
||||
session = get_session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def close_engine() -> None:
|
||||
"""Close the database engine and release connections."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
logger.info("Database engine closed")
|
||||
|
||||
|
||||
def execute_raw_sql(sql: str) -> None:
|
||||
"""Execute raw SQL (for migrations)."""
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
95
packages/inference/inference/data/models.py
Normal file
95
packages/inference/inference/data/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
SQLModel Database Models
|
||||
|
||||
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
class ApiKey(SQLModel, table=True):
|
||||
"""API key configuration and limits."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
api_key: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
requests_per_minute: int = Field(default=10)
|
||||
max_concurrent_jobs: int = Field(default=3)
|
||||
max_file_size_mb: int = Field(default=50)
|
||||
total_requests: int = Field(default=0)
|
||||
total_processed: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AsyncRequest(SQLModel, table=True):
|
||||
"""Async request record."""
|
||||
|
||||
__tablename__ = "async_requests"
|
||||
|
||||
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
document_id: str | None = Field(default=None, max_length=100)
|
||||
error_message: str | None = Field(default=None)
|
||||
retry_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field(index=True)
|
||||
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
processing_time_ms: float | None = Field(default=None)
|
||||
visualization_path: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class RateLimitEvent(SQLModel, table=True):
|
||||
"""Rate limit event record."""
|
||||
|
||||
__tablename__ = "rate_limit_events"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
event_type: str = Field(max_length=50)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Read-only models for responses (without table=True)
|
||||
class ApiKeyRead(SQLModel):
|
||||
"""API key response model (read-only)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestRead(SQLModel):
|
||||
"""Async request response model (read-only)."""
|
||||
|
||||
request_id: UUID
|
||||
api_key: str
|
||||
status: str
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
document_id: str | None
|
||||
error_message: str | None
|
||||
retry_count: int
|
||||
created_at: datetime
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
expires_at: datetime
|
||||
result: dict[str, Any] | None
|
||||
processing_time_ms: float | None
|
||||
visualization_path: str | None
|
||||
5
packages/inference/inference/pipeline/__init__.py
Normal file
5
packages/inference/inference/pipeline/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .pipeline import InferencePipeline, InferenceResult
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor
|
||||
|
||||
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
|
||||
101
packages/inference/inference/pipeline/constants.py
Normal file
101
packages/inference/inference/pipeline/constants.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Inference Configuration Constants
|
||||
|
||||
Centralized configuration values for the inference pipeline.
|
||||
Extracted from hardcoded values across multiple modules for easier maintenance.
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# Detection & Model Configuration
|
||||
# ============================================================================
|
||||
|
||||
# YOLO Detection
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
|
||||
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
|
||||
|
||||
# ============================================================================
|
||||
# Image Processing Configuration
|
||||
# ============================================================================
|
||||
|
||||
# DPI (Dots Per Inch) for PDF rendering
|
||||
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
|
||||
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
|
||||
|
||||
# ============================================================================
|
||||
# Customer Number Parser Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Pattern confidence scores (higher = more confident)
|
||||
CUSTOMER_NUMBER_CONFIDENCE = {
|
||||
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
|
||||
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
|
||||
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
|
||||
'compact': 0.75, # Compact format (e.g., "JTY5763")
|
||||
'generic_base': 0.5, # Base score for generic alphanumeric pattern
|
||||
}
|
||||
|
||||
# Bonus scores for generic pattern matching
|
||||
CUSTOMER_NUMBER_BONUS = {
|
||||
'has_dash': 0.2, # Bonus if contains dash
|
||||
'typical_format': 0.25, # Bonus for format XXX NNN-X
|
||||
'medium_length': 0.1, # Bonus for length 6-12 characters
|
||||
}
|
||||
|
||||
# Customer number length constraints
|
||||
CUSTOMER_NUMBER_LENGTH = {
|
||||
'min': 6, # Minimum length for medium length bonus
|
||||
'max': 12, # Maximum length for medium length bonus
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Field Extraction Confidence Scores
|
||||
# ============================================================================
|
||||
|
||||
# Confidence multipliers and base scores
|
||||
FIELD_CONFIDENCE = {
|
||||
'pdf_text': 1.0, # PDF text extraction (always accurate)
|
||||
'payment_line_high': 0.95, # Payment line parsed successfully
|
||||
'regex_fallback': 0.5, # Regex-based fallback extraction
|
||||
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Payment Line Validation
|
||||
# ============================================================================
|
||||
|
||||
# Account number length thresholds for type detection
|
||||
ACCOUNT_TYPE_THRESHOLD = {
|
||||
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
|
||||
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# OCR Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Minimum OCR reference number length
|
||||
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
|
||||
|
||||
# ============================================================================
|
||||
# Pattern Matching
|
||||
# ============================================================================
|
||||
|
||||
# Swedish postal code pattern (to exclude from customer numbers)
|
||||
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
|
||||
|
||||
# ============================================================================
|
||||
# Usage Notes
|
||||
# ============================================================================
|
||||
"""
|
||||
These constants can be overridden at runtime by passing parameters to
|
||||
constructors or methods. The values here serve as sensible defaults
|
||||
based on Swedish invoice processing requirements.
|
||||
|
||||
Example:
|
||||
from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
|
||||
|
||||
detector = YOLODetector(
|
||||
model_path="model.pt",
|
||||
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
|
||||
)
|
||||
"""
|
||||
390
packages/inference/inference/pipeline/customer_number_parser.py
Normal file
390
packages/inference/inference/pipeline/customer_number_parser.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Swedish Customer Number Parser
|
||||
|
||||
Handles extraction and normalization of Swedish customer numbers.
|
||||
Uses Strategy Pattern with multiple matching patterns.
|
||||
|
||||
Common Swedish customer number formats:
|
||||
- JTY 576-3
|
||||
- EMM 256-6
|
||||
- DWQ 211-X
|
||||
- FFL 019N
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
from shared.exceptions import CustomerNumberParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomerNumberMatch:
|
||||
"""Customer number match result."""
|
||||
|
||||
value: str
|
||||
"""The normalized customer number"""
|
||||
|
||||
pattern_name: str
|
||||
"""Name of the pattern that matched"""
|
||||
|
||||
confidence: float
|
||||
"""Confidence score (0.0 to 1.0)"""
|
||||
|
||||
raw_text: str
|
||||
"""Original text that was matched"""
|
||||
|
||||
position: int = 0
|
||||
"""Position in text where match was found"""
|
||||
|
||||
|
||||
class CustomerNumberPattern(ABC):
|
||||
"""Abstract base for customer number patterns."""
|
||||
|
||||
@abstractmethod
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""
|
||||
Try to match pattern in text.
|
||||
|
||||
Args:
|
||||
text: Text to search for customer number
|
||||
|
||||
Returns:
|
||||
CustomerNumberMatch if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""
|
||||
Format matched groups to standard format.
|
||||
|
||||
Args:
|
||||
match: Regex match object
|
||||
|
||||
Returns:
|
||||
Formatted customer number string
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DashFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC 123-X (with dash)
|
||||
|
||||
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number with dash format."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Check if it's not a postal code
|
||||
full_match = match.group(0)
|
||||
if self._is_postal_code(full_match):
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="DashFormat",
|
||||
confidence=0.95,
|
||||
raw_text=full_match,
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to standard ABC 123-X format."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
|
||||
def _is_postal_code(self, text: str) -> bool:
|
||||
"""Check if text looks like Swedish postal code."""
|
||||
# SE 106 43, SE10643, etc.
|
||||
return bool(
|
||||
text.upper().startswith('SE ') and
|
||||
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
|
||||
)
|
||||
|
||||
|
||||
class NoDashFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC 123X (no dash)
|
||||
|
||||
Examples: Dwq 211X, FFL 019N
|
||||
Converts to: DWQ 211-X, FFL 019-N
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number without dash."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Exclude postal codes
|
||||
full_match = match.group(0)
|
||||
if self._is_postal_code(full_match):
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="NoDashFormat",
|
||||
confidence=0.90,
|
||||
raw_text=full_match,
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to standard ABC 123-X format (add dash)."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
|
||||
def _is_postal_code(self, text: str) -> bool:
|
||||
"""Check if text looks like Swedish postal code."""
|
||||
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
|
||||
|
||||
|
||||
class CompactFormatPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: ABC123X (compact, no spaces)
|
||||
|
||||
Examples: JTY5763, FFL019N
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match compact customer number format."""
|
||||
upper_text = text.upper()
|
||||
match = self.PATTERN.search(upper_text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
# Filter out SE postal codes
|
||||
if match.group(1) == 'SE':
|
||||
return None
|
||||
|
||||
formatted = self.format(match)
|
||||
return CustomerNumberMatch(
|
||||
value=formatted,
|
||||
pattern_name="CompactFormat",
|
||||
confidence=0.75,
|
||||
raw_text=match.group(0),
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Format to ABC123X or ABC123-X format."""
|
||||
prefix = match.group(1).upper()
|
||||
number = match.group(2)
|
||||
suffix = match.group(3).upper()
|
||||
|
||||
if suffix:
|
||||
return f"{prefix} {number}-{suffix}"
|
||||
else:
|
||||
return f"{prefix}{number}"
|
||||
|
||||
|
||||
class GenericAlphanumericPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Generic pattern: Letters + numbers + optional dash/letter
|
||||
|
||||
Examples: EMM 256-6, ABC 123, FFL 019
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match generic alphanumeric pattern."""
|
||||
upper_text = text.upper()
|
||||
|
||||
all_matches = []
|
||||
for match in self.PATTERN.finditer(upper_text):
|
||||
matched_text = match.group(1)
|
||||
|
||||
# Filter out pure numbers
|
||||
if re.match(r'^\d+$', matched_text):
|
||||
continue
|
||||
|
||||
# Filter out Swedish postal codes
|
||||
if re.match(r'^SE[\s\-]*\d', matched_text):
|
||||
continue
|
||||
|
||||
# Filter out single letter + digit + space + digit (V4 2)
|
||||
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
|
||||
continue
|
||||
|
||||
# Calculate confidence based on characteristics
|
||||
confidence = self._calculate_confidence(matched_text)
|
||||
|
||||
all_matches.append((confidence, matched_text, match.start()))
|
||||
|
||||
if all_matches:
|
||||
# Return highest confidence match
|
||||
best = max(all_matches, key=lambda x: x[0])
|
||||
return CustomerNumberMatch(
|
||||
value=best[1].strip(),
|
||||
pattern_name="GenericAlphanumeric",
|
||||
confidence=best[0],
|
||||
raw_text=best[1],
|
||||
position=best[2]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Return matched text as-is (already uppercase)."""
|
||||
return match.group(1).strip()
|
||||
|
||||
def _calculate_confidence(self, text: str) -> float:
|
||||
"""Calculate confidence score based on text characteristics."""
|
||||
# Require letters AND digits
|
||||
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
|
||||
has_digits = bool(re.search(r'\d', text))
|
||||
|
||||
if not (has_letters and has_digits):
|
||||
return 0.0 # Not a valid customer number
|
||||
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Bonus for containing dash
|
||||
if '-' in text:
|
||||
score += 0.2
|
||||
|
||||
# Bonus for typical format XXX NNN-X
|
||||
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
|
||||
score += 0.25
|
||||
|
||||
# Bonus for medium length
|
||||
if 6 <= len(text) <= 12:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
|
||||
class LabeledPattern(CustomerNumberPattern):
|
||||
"""
|
||||
Pattern: Explicit label + customer number
|
||||
|
||||
Examples:
|
||||
- "Kundnummer: JTY 576-3"
|
||||
- "Customer No: EMM 256-6"
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
|
||||
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
def match(self, text: str) -> Optional[CustomerNumberMatch]:
|
||||
"""Match customer number with explicit label."""
|
||||
match = self.PATTERN.search(text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
extracted = match.group(1).strip()
|
||||
# Remove trailing punctuation
|
||||
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
|
||||
|
||||
if extracted and len(extracted) >= 2:
|
||||
return CustomerNumberMatch(
|
||||
value=extracted.upper(),
|
||||
pattern_name="Labeled",
|
||||
confidence=0.98, # Very high confidence when labeled
|
||||
raw_text=match.group(0),
|
||||
position=match.start()
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def format(self, match: re.Match) -> str:
|
||||
"""Return matched customer number."""
|
||||
extracted = match.group(1).strip()
|
||||
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
|
||||
|
||||
|
||||
class CustomerNumberParser:
|
||||
"""Parser for Swedish customer numbers."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize parser with patterns ordered by specificity."""
|
||||
self.patterns: List[CustomerNumberPattern] = [
|
||||
LabeledPattern(), # Highest priority - explicit label
|
||||
DashFormatPattern(), # Standard format with dash
|
||||
NoDashFormatPattern(), # Standard format without dash
|
||||
CompactFormatPattern(), # Compact format
|
||||
GenericAlphanumericPattern(), # Fallback generic pattern
|
||||
]
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
|
||||
"""
|
||||
Parse customer number from text.
|
||||
|
||||
Args:
|
||||
text: Text to search for customer number
|
||||
|
||||
Returns:
|
||||
Tuple of (customer_number, is_valid, error_message)
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return None, False, "Empty text"
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Try each pattern
|
||||
all_matches: List[CustomerNumberMatch] = []
|
||||
for pattern in self.patterns:
|
||||
match = pattern.match(text)
|
||||
if match:
|
||||
all_matches.append(match)
|
||||
|
||||
# No matches
|
||||
if not all_matches:
|
||||
return None, False, "No customer number found"
|
||||
|
||||
# Return highest confidence match
|
||||
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
|
||||
self.logger.debug(
|
||||
f"Customer number matched: {best_match.value} "
|
||||
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
|
||||
)
|
||||
return best_match.value, True, None
|
||||
|
||||
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
|
||||
"""
|
||||
Find all customer numbers in text.
|
||||
|
||||
Useful for cases with multiple potential matches.
|
||||
|
||||
Args:
|
||||
text: Text to search
|
||||
|
||||
Returns:
|
||||
List of CustomerNumberMatch sorted by confidence (descending)
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
all_matches: List[CustomerNumberMatch] = []
|
||||
for pattern in self.patterns:
|
||||
match = pattern.match(text)
|
||||
if match:
|
||||
all_matches.append(match)
|
||||
|
||||
# Sort by confidence (highest first), then by position (later first)
|
||||
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)
|
||||
1183
packages/inference/inference/pipeline/field_extractor.py
Normal file
1183
packages/inference/inference/pipeline/field_extractor.py
Normal file
File diff suppressed because it is too large
Load Diff
261
packages/inference/inference/pipeline/payment_line_parser.py
Normal file
261
packages/inference/inference/pipeline/payment_line_parser.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Swedish Payment Line Parser
|
||||
|
||||
Handles parsing and validation of Swedish machine-readable payment lines.
|
||||
Unifies payment line parsing logic that was previously duplicated across multiple modules.
|
||||
|
||||
Standard Swedish payment line format:
|
||||
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
|
||||
Example:
|
||||
# 94228110015950070 # 15658 00 8 > 48666036#14#
|
||||
|
||||
This parser handles common OCR errors:
|
||||
- Spaces in numbers: "12 0 0" → "1200"
|
||||
- Missing symbols: Missing ">"
|
||||
- Spaces in check digits: "#41 #" → "#41#"
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from shared.exceptions import PaymentLineParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaymentLineData:
|
||||
"""Parsed payment line data."""
|
||||
|
||||
ocr_number: str
|
||||
"""OCR reference number (payment reference)"""
|
||||
|
||||
amount: Optional[str] = None
|
||||
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
|
||||
|
||||
account_number: Optional[str] = None
|
||||
"""Bankgiro or Plusgiro account number"""
|
||||
|
||||
record_type: Optional[str] = None
|
||||
"""Record type digit (usually '5' or '8' or '9')"""
|
||||
|
||||
check_digits: Optional[str] = None
|
||||
"""Check digits for account validation"""
|
||||
|
||||
raw_text: str = ""
|
||||
"""Original raw text that was parsed"""
|
||||
|
||||
is_valid: bool = True
|
||||
"""Whether parsing was successful"""
|
||||
|
||||
error: Optional[str] = None
|
||||
"""Error message if parsing failed"""
|
||||
|
||||
parse_method: str = "unknown"
|
||||
"""Which parsing pattern was used (for debugging)"""
|
||||
|
||||
|
||||
class PaymentLineParser:
|
||||
"""Parser for Swedish payment lines with OCR error handling."""
|
||||
|
||||
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
|
||||
FULL_PATTERN = re.compile(
|
||||
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
|
||||
NO_AMOUNT_PATTERN = re.compile(
|
||||
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Alternative pattern: look for OCR > ACCOUNT# pattern
|
||||
ALT_PATTERN = re.compile(
|
||||
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
# Account only pattern: > ACCOUNT#CHECK#
|
||||
ACCOUNT_ONLY_PATTERN = re.compile(
|
||||
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize parser with logger."""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse(self, text: str) -> PaymentLineData:
|
||||
"""
|
||||
Parse payment line text.
|
||||
|
||||
Handles common OCR errors:
|
||||
- Spaces in numbers: "12 0 0" → "1200"
|
||||
- Missing symbols: Missing ">"
|
||||
- Spaces in check digits: "#41 #" → "#41#"
|
||||
|
||||
Args:
|
||||
text: Raw payment line text from OCR
|
||||
|
||||
Returns:
|
||||
PaymentLineData with parsed fields or error information
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
raw_text=text,
|
||||
is_valid=False,
|
||||
error="Empty payment line text",
|
||||
parse_method="none"
|
||||
)
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Try full pattern with amount
|
||||
match = self.FULL_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_full_match(match, text)
|
||||
|
||||
# Try pattern without amount
|
||||
match = self.NO_AMOUNT_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_no_amount_match(match, text)
|
||||
|
||||
# Try alternative pattern
|
||||
match = self.ALT_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_alt_match(match, text)
|
||||
|
||||
# Try account only pattern
|
||||
match = self.ACCOUNT_ONLY_PATTERN.search(text)
|
||||
if match:
|
||||
return self._parse_account_only_match(match, text)
|
||||
|
||||
# No match - return error
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
raw_text=text,
|
||||
is_valid=False,
|
||||
error="No valid payment line format found",
|
||||
parse_method="none"
|
||||
)
|
||||
|
||||
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse full pattern match (with amount)."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
kronor = self._clean_digits(match.group(2))
|
||||
ore = match.group(3)
|
||||
record_type = match.group(4)
|
||||
account = self._clean_digits(match.group(5))
|
||||
check_digits = match.group(6)
|
||||
|
||||
amount = f"{kronor}.{ore}"
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=amount,
|
||||
account_number=account,
|
||||
record_type=record_type,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="full"
|
||||
)
|
||||
|
||||
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse pattern match without amount."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
account = self._clean_digits(match.group(2))
|
||||
check_digits = match.group(3)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="no_amount"
|
||||
)
|
||||
|
||||
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse alternative pattern match."""
|
||||
ocr = self._clean_digits(match.group(1))
|
||||
account = self._clean_digits(match.group(2))
|
||||
check_digits = match.group(3)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number=ocr,
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error=None,
|
||||
parse_method="alternative"
|
||||
)
|
||||
|
||||
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
|
||||
"""Parse account-only pattern match."""
|
||||
account = self._clean_digits(match.group(1))
|
||||
check_digits = match.group(2)
|
||||
|
||||
return PaymentLineData(
|
||||
ocr_number="",
|
||||
amount=None,
|
||||
account_number=account,
|
||||
record_type=None,
|
||||
check_digits=check_digits,
|
||||
raw_text=raw_text,
|
||||
is_valid=True,
|
||||
error="Partial payment line (account only)",
|
||||
parse_method="account_only"
|
||||
)
|
||||
|
||||
def _clean_digits(self, text: str) -> str:
|
||||
"""Remove spaces from digit string (OCR error correction)."""
|
||||
return text.replace(' ', '')
|
||||
|
||||
def format_machine_readable(self, data: PaymentLineData) -> str:
|
||||
"""
|
||||
Format parsed data back to machine-readable format.
|
||||
|
||||
Returns:
|
||||
Formatted string in standard Swedish payment line format
|
||||
"""
|
||||
if not data.is_valid:
|
||||
return data.raw_text
|
||||
|
||||
# Full format with amount
|
||||
if data.amount and data.record_type:
|
||||
kronor, ore = data.amount.split('.')
|
||||
return (
|
||||
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
|
||||
f"{data.account_number}#{data.check_digits}#"
|
||||
)
|
||||
|
||||
# Format without amount
|
||||
if data.ocr_number and data.account_number:
|
||||
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
|
||||
|
||||
# Account only
|
||||
if data.account_number:
|
||||
return f"> {data.account_number}#{data.check_digits}#"
|
||||
|
||||
# Fallback
|
||||
return data.raw_text
|
||||
|
||||
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
|
||||
"""
|
||||
Format parsed data for FieldExtractor compatibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
|
||||
"""
|
||||
if not data.is_valid:
|
||||
return None, False, data.error
|
||||
|
||||
formatted = self.format_machine_readable(data)
|
||||
return formatted, True, data.error
|
||||
498
packages/inference/inference/pipeline/pipeline.py
Normal file
498
packages/inference/inference/pipeline/pipeline.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
Inference Pipeline
|
||||
|
||||
Complete pipeline for extracting invoice data from PDFs.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import time
|
||||
import re
|
||||
|
||||
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
|
||||
from .field_extractor import FieldExtractor, ExtractedField
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossValidationResult:
|
||||
"""Result of cross-validation between payment_line and other fields."""
|
||||
is_valid: bool = False
|
||||
ocr_match: bool | None = None # None if not comparable
|
||||
amount_match: bool | None = None
|
||||
bankgiro_match: bool | None = None
|
||||
plusgiro_match: bool | None = None
|
||||
payment_line_ocr: str | None = None
|
||||
payment_line_amount: str | None = None
|
||||
payment_line_account: str | None = None
|
||||
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
|
||||
details: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""Result of invoice processing."""
|
||||
document_id: str | None = None
|
||||
success: bool = False
|
||||
fields: dict[str, Any] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
|
||||
raw_detections: list[Detection] = field(default_factory=list)
|
||||
extracted_fields: list[ExtractedField] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
fallback_used: bool = False
|
||||
cross_validation: CrossValidationResult | None = None
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""Convert to JSON-serializable dictionary."""
|
||||
result = {
|
||||
'DocumentId': self.document_id,
|
||||
'InvoiceNumber': self.fields.get('InvoiceNumber'),
|
||||
'InvoiceDate': self.fields.get('InvoiceDate'),
|
||||
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
|
||||
'OCR': self.fields.get('OCR'),
|
||||
'Bankgiro': self.fields.get('Bankgiro'),
|
||||
'Plusgiro': self.fields.get('Plusgiro'),
|
||||
'Amount': self.fields.get('Amount'),
|
||||
'supplier_org_number': self.fields.get('supplier_org_number'),
|
||||
'customer_number': self.fields.get('customer_number'),
|
||||
'payment_line': self.fields.get('payment_line'),
|
||||
'confidence': self.confidence,
|
||||
'success': self.success,
|
||||
'fallback_used': self.fallback_used
|
||||
}
|
||||
# Add bboxes if present
|
||||
if self.bboxes:
|
||||
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
|
||||
# Add cross-validation results if present
|
||||
if self.cross_validation:
|
||||
result['cross_validation'] = {
|
||||
'is_valid': self.cross_validation.is_valid,
|
||||
'ocr_match': self.cross_validation.ocr_match,
|
||||
'amount_match': self.cross_validation.amount_match,
|
||||
'bankgiro_match': self.cross_validation.bankgiro_match,
|
||||
'plusgiro_match': self.cross_validation.plusgiro_match,
|
||||
'payment_line_ocr': self.cross_validation.payment_line_ocr,
|
||||
'payment_line_amount': self.cross_validation.payment_line_amount,
|
||||
'payment_line_account': self.cross_validation.payment_line_account,
|
||||
'payment_line_account_type': self.cross_validation.payment_line_account_type,
|
||||
'details': self.cross_validation.details,
|
||||
}
|
||||
return result
|
||||
|
||||
def get_field(self, field_name: str) -> tuple[Any, float]:
|
||||
"""Get field value and confidence."""
|
||||
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
||||
|
||||
|
||||
class InferencePipeline:
|
||||
"""
|
||||
Complete inference pipeline for invoice data extraction.
|
||||
|
||||
Pipeline flow:
|
||||
1. PDF -> Image rendering
|
||||
2. YOLO detection of field regions
|
||||
3. OCR extraction from detected regions
|
||||
4. Field normalization and validation
|
||||
5. Fallback to full-page OCR if YOLO fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
dpi: int = 300,
|
||||
enable_fallback: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize inference pipeline.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained YOLO model
|
||||
confidence_threshold: Detection confidence threshold
|
||||
ocr_lang: Language for OCR
|
||||
use_gpu: Whether to use GPU
|
||||
dpi: Resolution for PDF rendering
|
||||
enable_fallback: Enable fallback to full-page OCR
|
||||
"""
|
||||
self.detector = YOLODetector(
|
||||
model_path,
|
||||
confidence_threshold=confidence_threshold,
|
||||
device='cuda' if use_gpu else 'cpu'
|
||||
)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.dpi = dpi
|
||||
self.enable_fallback = enable_fallback
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
document_id: str | None = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Process a PDF and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
"""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = InferenceResult(
|
||||
document_id=document_id or Path(pdf_path).stem
|
||||
)
|
||||
|
||||
try:
|
||||
all_detections = []
|
||||
all_extracted = []
|
||||
|
||||
# Process each page
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
# Convert to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Run YOLO detection
|
||||
detections = self.detector.detect(image_array, page_no=page_no)
|
||||
all_detections.extend(detections)
|
||||
|
||||
# Extract fields from detections
|
||||
for detection in detections:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
all_extracted.append(extracted)
|
||||
|
||||
result.raw_detections = all_detections
|
||||
result.extracted_fields = all_extracted
|
||||
|
||||
# Merge extracted fields (prefer highest confidence)
|
||||
self._merge_fields(result)
|
||||
|
||||
# Fallback if key fields are missing
|
||||
if self.enable_fallback and self._needs_fallback(result):
|
||||
self._run_fallback(pdf_path, result)
|
||||
|
||||
result.success = len(result.fields) > 0
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
for extracted in result.extracted_fields:
|
||||
if not extracted.is_valid or not extracted.normalized_value:
|
||||
continue
|
||||
|
||||
if extracted.field_name not in field_candidates:
|
||||
field_candidates[extracted.field_name] = []
|
||||
field_candidates[extracted.field_name].append(extracted)
|
||||
|
||||
# Select best candidate for each field
|
||||
for field_name, candidates in field_candidates.items():
|
||||
best = max(candidates, key=lambda x: x.confidence)
|
||||
result.fields[field_name] = best.normalized_value
|
||||
result.confidence[field_name] = best.confidence
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
result.bboxes[field_name] = best.bbox
|
||||
|
||||
# Perform cross-validation if payment_line is detected
|
||||
self._cross_validate_payment_line(result)
|
||||
|
||||
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
||||
|
||||
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
|
||||
|
||||
Returns: (ocr, amount, account) tuple
|
||||
"""
|
||||
parsed = self.payment_line_parser.parse(payment_line)
|
||||
|
||||
if not parsed.is_valid:
|
||||
return None, None, None
|
||||
|
||||
return parsed.ocr_number, parsed.amount, parsed.account_number
|
||||
|
||||
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
|
||||
"""
|
||||
Cross-validate payment_line data against other detected fields.
|
||||
Payment line values take PRIORITY over individually detected fields.
|
||||
|
||||
Swedish payment line (Betalningsrad) contains:
|
||||
- OCR reference number
|
||||
- Amount (kronor and öre)
|
||||
- Bankgiro or Plusgiro account number
|
||||
|
||||
This method:
|
||||
1. Parses payment_line to extract OCR, Amount, Account
|
||||
2. Compares with separately detected fields for validation
|
||||
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
|
||||
"""
|
||||
payment_line = result.fields.get('payment_line')
|
||||
if not payment_line:
|
||||
return
|
||||
|
||||
cv = CrossValidationResult()
|
||||
cv.details = []
|
||||
|
||||
# Parse machine-readable payment line format
|
||||
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
|
||||
|
||||
cv.payment_line_ocr = ocr
|
||||
cv.payment_line_amount = amount
|
||||
|
||||
# Determine account type based on digit count
|
||||
if account:
|
||||
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
|
||||
if len(account) >= 7:
|
||||
cv.payment_line_account_type = 'bankgiro'
|
||||
# Format: XXX-XXXX or XXXX-XXXX
|
||||
if len(account) == 7:
|
||||
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
|
||||
else:
|
||||
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
|
||||
else:
|
||||
cv.payment_line_account_type = 'plusgiro'
|
||||
# Format: XXXXXXX-X
|
||||
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
|
||||
|
||||
# Cross-validate and OVERRIDE with payment_line values
|
||||
|
||||
# OCR: payment_line takes priority
|
||||
detected_ocr = result.fields.get('OCR')
|
||||
if cv.payment_line_ocr:
|
||||
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
|
||||
if detected_ocr:
|
||||
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
|
||||
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
|
||||
if cv.ocr_match:
|
||||
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
|
||||
else:
|
||||
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
|
||||
else:
|
||||
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
|
||||
# OVERRIDE: use payment_line OCR
|
||||
result.fields['OCR'] = cv.payment_line_ocr
|
||||
result.confidence['OCR'] = 0.95 # High confidence for payment_line
|
||||
|
||||
# Amount: payment_line takes priority
|
||||
detected_amount = result.fields.get('Amount')
|
||||
if cv.payment_line_amount:
|
||||
if detected_amount:
|
||||
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
|
||||
det_amount = self._normalize_amount_for_compare(str(detected_amount))
|
||||
cv.amount_match = pl_amount == det_amount
|
||||
if cv.amount_match:
|
||||
cv.details.append(f"Amount match: {cv.payment_line_amount}")
|
||||
else:
|
||||
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
|
||||
else:
|
||||
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
|
||||
# OVERRIDE: use payment_line Amount
|
||||
result.fields['Amount'] = cv.payment_line_amount
|
||||
result.confidence['Amount'] = 0.95
|
||||
|
||||
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
||||
detected_bankgiro = result.fields.get('Bankgiro')
|
||||
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
|
||||
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
||||
if detected_bankgiro:
|
||||
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
|
||||
cv.bankgiro_match = pl_bg_digits == det_bg_digits
|
||||
if cv.bankgiro_match:
|
||||
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
|
||||
else:
|
||||
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
|
||||
# Do NOT override - keep detected value
|
||||
|
||||
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
||||
detected_plusgiro = result.fields.get('Plusgiro')
|
||||
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
|
||||
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
||||
if detected_plusgiro:
|
||||
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
|
||||
cv.plusgiro_match = pl_pg_digits == det_pg_digits
|
||||
if cv.plusgiro_match:
|
||||
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
|
||||
else:
|
||||
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
|
||||
# Do NOT override - keep detected value
|
||||
|
||||
# Determine overall validity
|
||||
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
|
||||
# has both accounts, the other one cannot be matched - this is expected and OK.
|
||||
# Only count the account type that payment_line actually has.
|
||||
matches = [cv.ocr_match, cv.amount_match]
|
||||
|
||||
# Only include account match if payment_line has that account type
|
||||
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
|
||||
matches.append(cv.bankgiro_match)
|
||||
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
|
||||
matches.append(cv.plusgiro_match)
|
||||
|
||||
valid_matches = [m for m in matches if m is not None]
|
||||
if valid_matches:
|
||||
match_count = sum(1 for m in valid_matches if m)
|
||||
cv.is_valid = match_count >= min(2, len(valid_matches))
|
||||
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
|
||||
else:
|
||||
# No comparison possible
|
||||
cv.is_valid = True
|
||||
cv.details.append("No comparison available from payment_line")
|
||||
|
||||
result.cross_validation = cv
|
||||
|
||||
def _normalize_amount_for_compare(self, amount: str) -> float | None:
|
||||
"""Normalize amount string to float for comparison."""
|
||||
try:
|
||||
# Remove spaces, convert comma to dot
|
||||
cleaned = amount.replace(' ', '').replace(',', '.')
|
||||
# Handle Swedish format with space as thousands separator
|
||||
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
|
||||
return round(float(cleaned), 2)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _needs_fallback(self, result: InferenceResult) -> bool:
|
||||
"""Check if fallback OCR is needed."""
|
||||
# Check for key fields
|
||||
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
||||
missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
return missing >= 2 # Fallback if 2+ key fields missing
|
||||
|
||||
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||
"""Run full-page OCR fallback."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from shared.ocr import OCREngine
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
result.fallback_used = True
|
||||
ocr_engine = OCREngine()
|
||||
|
||||
try:
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Full page OCR
|
||||
tokens = ocr_engine.extract_from_image(image_array, page_no)
|
||||
full_text = ' '.join(t.text for t in tokens)
|
||||
|
||||
# Try to extract missing fields with regex patterns
|
||||
self._extract_with_patterns(full_text, result)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Fallback OCR error: {e}")
|
||||
|
||||
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
|
||||
"""Extract fields using regex patterns (fallback)."""
|
||||
patterns = {
|
||||
'Amount': [
|
||||
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
||||
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
||||
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
||||
],
|
||||
'InvoiceNumber': [
|
||||
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
||||
],
|
||||
}
|
||||
|
||||
for field_name, field_patterns in patterns.items():
|
||||
if field_name in result.fields:
|
||||
continue
|
||||
|
||||
for pattern in field_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
value = match.group(1).strip()
|
||||
|
||||
# Normalize the value
|
||||
if field_name == 'Amount':
|
||||
value = value.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
value = f"{float(value):.2f}"
|
||||
except ValueError:
|
||||
continue
|
||||
elif field_name == 'Bankgiro':
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 8:
|
||||
value = f"{digits[:4]}-{digits[4:]}"
|
||||
|
||||
result.fields[field_name] = value
|
||||
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
||||
break
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: str | Path,
|
||||
document_id: str | None = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Process a single image (for pre-rendered pages).
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
"""
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = InferenceResult(
|
||||
document_id=document_id or Path(image_path).stem
|
||||
)
|
||||
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
image_array = np.array(image)
|
||||
|
||||
# Run detection
|
||||
detections = self.detector.detect(image_array, page_no=0)
|
||||
result.raw_detections = detections
|
||||
|
||||
# Extract fields
|
||||
for detection in detections:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
result.extracted_fields.append(extracted)
|
||||
|
||||
# Merge fields
|
||||
self._merge_fields(result)
|
||||
result.success = len(result.fields) > 0
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
210
packages/inference/inference/pipeline/yolo_detector.py
Normal file
210
packages/inference/inference/pipeline/yolo_detector.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
YOLO Detection Module
|
||||
|
||||
Runs YOLO model inference for field detection.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""Represents a single YOLO detection."""
|
||||
class_id: int
|
||||
class_name: str
|
||||
confidence: float
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
|
||||
page_no: int = 0
|
||||
|
||||
@property
|
||||
def x0(self) -> float:
|
||||
return self.bbox[0]
|
||||
|
||||
@property
|
||||
def y0(self) -> float:
|
||||
return self.bbox[1]
|
||||
|
||||
@property
|
||||
def x1(self) -> float:
|
||||
return self.bbox[2]
|
||||
|
||||
@property
|
||||
def y1(self) -> float:
|
||||
return self.bbox[3]
|
||||
|
||||
@property
|
||||
def center(self) -> tuple[float, float]:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
@property
|
||||
def width(self) -> float:
|
||||
return self.x1 - self.x0
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
return self.y1 - self.y0
|
||||
|
||||
def get_padded_bbox(
|
||||
self,
|
||||
padding: float = 0.1,
|
||||
image_width: float | None = None,
|
||||
image_height: float | None = None
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""Get bbox with padding for OCR extraction."""
|
||||
pad_x = self.width * padding
|
||||
pad_y = self.height * padding
|
||||
|
||||
x0 = self.x0 - pad_x
|
||||
y0 = self.y0 - pad_y
|
||||
x1 = self.x1 + pad_x
|
||||
y1 = self.y1 + pad_y
|
||||
|
||||
if image_width:
|
||||
x0 = max(0, x0)
|
||||
x1 = min(image_width, x1)
|
||||
if image_height:
|
||||
y0 = max(0, y0)
|
||||
y1 = min(image_height, y1)
|
||||
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
# Class names (must match training configuration)
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number', # Matches training class name
|
||||
'customer_number',
|
||||
'payment_line', # Machine code payment line at bottom of invoice
|
||||
]
|
||||
|
||||
# Mapping from class name to field name
|
||||
CLASS_TO_FIELD = {
|
||||
'invoice_number': 'InvoiceNumber',
|
||||
'invoice_date': 'InvoiceDate',
|
||||
'invoice_due_date': 'InvoiceDueDate',
|
||||
'ocr_number': 'OCR',
|
||||
'bankgiro': 'Bankgiro',
|
||||
'plusgiro': 'Plusgiro',
|
||||
'amount': 'Amount',
|
||||
'supplier_org_number': 'supplier_org_number',
|
||||
'customer_number': 'customer_number',
|
||||
'payment_line': 'payment_line',
|
||||
}
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLO model wrapper for field detection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
iou_threshold: float = 0.45,
|
||||
device: str = 'auto'
|
||||
):
|
||||
"""
|
||||
Initialize YOLO detector.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained YOLO model (.pt file)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
iou_threshold: IOU threshold for NMS
|
||||
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(model_path)
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.device = device
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0
|
||||
) -> list[Detection]:
|
||||
"""
|
||||
Run detection on an image.
|
||||
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
|
||||
Returns:
|
||||
List of Detection objects
|
||||
"""
|
||||
results = self.model.predict(
|
||||
source=image,
|
||||
conf=self.confidence_threshold,
|
||||
iou=self.iou_threshold,
|
||||
device=self.device,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is None:
|
||||
continue
|
||||
|
||||
for i in range(len(boxes)):
|
||||
class_id = int(boxes.cls[i])
|
||||
confidence = float(boxes.conf[i])
|
||||
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
|
||||
|
||||
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
|
||||
|
||||
detections.append(Detection(
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
confidence=confidence,
|
||||
bbox=tuple(bbox),
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
return detections
|
||||
|
||||
def detect_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
dpi: int = 300
|
||||
) -> dict[int, list[Detection]]:
|
||||
"""
|
||||
Run detection on all pages of a PDF.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Returns:
|
||||
Dict mapping page number to list of detections
|
||||
"""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
results = {}
|
||||
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
|
||||
# Convert bytes to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
detections = self.detect(image_array, page_no=page_no)
|
||||
results[page_no] = detections
|
||||
|
||||
return results
|
||||
|
||||
def get_field_name(self, class_name: str) -> str:
|
||||
"""Convert class name to field name."""
|
||||
return CLASS_TO_FIELD.get(class_name, class_name)
|
||||
7
packages/inference/inference/validation/__init__.py
Normal file
7
packages/inference/inference/validation/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Cross-validation module for verifying field extraction using LLM.
|
||||
"""
|
||||
|
||||
from .llm_validator import LLMValidator
|
||||
|
||||
__all__ = ['LLMValidator']
|
||||
748
packages/inference/inference/validation/llm_validator.py
Normal file
748
packages/inference/inference/validation/llm_validator.py
Normal file
@@ -0,0 +1,748 @@
|
||||
"""
|
||||
LLM-based cross-validation for invoice field extraction.
|
||||
|
||||
Uses a vision LLM to extract fields from invoice PDFs and compare with
|
||||
the autolabel results to identify potential errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMExtractionResult:
|
||||
"""Result of LLM field extraction."""
|
||||
document_id: str
|
||||
invoice_number: Optional[str] = None
|
||||
invoice_date: Optional[str] = None
|
||||
invoice_due_date: Optional[str] = None
|
||||
ocr_number: Optional[str] = None
|
||||
bankgiro: Optional[str] = None
|
||||
plusgiro: Optional[str] = None
|
||||
amount: Optional[str] = None
|
||||
supplier_organisation_number: Optional[str] = None
|
||||
raw_response: Optional[str] = None
|
||||
model_used: Optional[str] = None
|
||||
processing_time_ms: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class LLMValidator:
|
||||
"""
|
||||
Cross-validates invoice field extraction using LLM.
|
||||
|
||||
Queries documents with failed field matches from the database,
|
||||
sends the PDF images to an LLM for extraction, and stores
|
||||
the results for comparison.
|
||||
"""
|
||||
|
||||
# Fields to extract (excluding customer_number as requested)
|
||||
FIELDS_TO_EXTRACT = [
|
||||
'InvoiceNumber',
|
||||
'InvoiceDate',
|
||||
'InvoiceDueDate',
|
||||
'OCR',
|
||||
'Bankgiro',
|
||||
'Plusgiro',
|
||||
'Amount',
|
||||
'supplier_organisation_number',
|
||||
]
|
||||
|
||||
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
|
||||
|
||||
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
|
||||
|
||||
{
|
||||
"invoice_number": "the invoice number/fakturanummer",
|
||||
"invoice_date": "the invoice date in YYYY-MM-DD format",
|
||||
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
|
||||
"ocr_number": "the OCR payment reference number",
|
||||
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
|
||||
"plusgiro": "the plusgiro number",
|
||||
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
|
||||
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
|
||||
}
|
||||
|
||||
Rules:
|
||||
- If a field is not found or not visible, use null
|
||||
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
|
||||
- For amounts, extract just the numeric value without currency symbols
|
||||
- The OCR number is typically a long number used for payment reference
|
||||
- Look for "Att betala" or "Summa att betala" for the amount
|
||||
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
|
||||
|
||||
Return ONLY the JSON object, no other text."""
|
||||
|
||||
def __init__(self, connection_string: str = None, pdf_dir: str = None):
|
||||
"""
|
||||
Initialize the validator.
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection string
|
||||
pdf_dir: Directory containing PDF files
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS
|
||||
|
||||
self.connection_string = connection_string or get_db_connection_string()
|
||||
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
|
||||
self.conn = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to database."""
|
||||
if self.conn is None:
|
||||
self.conn = psycopg2.connect(self.connection_string)
|
||||
return self.conn
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def create_validation_table(self):
|
||||
"""Create the llm_validation table if it doesn't exist."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS llm_validations (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
-- Extracted fields
|
||||
invoice_number TEXT,
|
||||
invoice_date TEXT,
|
||||
invoice_due_date TEXT,
|
||||
ocr_number TEXT,
|
||||
bankgiro TEXT,
|
||||
plusgiro TEXT,
|
||||
amount TEXT,
|
||||
supplier_organisation_number TEXT,
|
||||
-- Metadata
|
||||
raw_response TEXT,
|
||||
model_used TEXT,
|
||||
processing_time_ms REAL,
|
||||
error TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
-- Comparison results (populated later)
|
||||
comparison_results JSONB,
|
||||
UNIQUE(document_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
|
||||
ON llm_validations(document_id);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def get_documents_with_failed_matches(
|
||||
self,
|
||||
exclude_customer_number: bool = True,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get documents that have at least one failed field match.
|
||||
|
||||
Args:
|
||||
exclude_customer_number: If True, ignore customer_number failures
|
||||
limit: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of document info with failed fields
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Find documents with failed matches (excluding customer_number if requested)
|
||||
exclude_clause = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause = "AND fr.field_name != 'customer_number'"
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
|
||||
d.supplier_name, d.split
|
||||
FROM documents d
|
||||
JOIN field_results fr ON d.document_id = fr.document_id
|
||||
WHERE fr.matched = false
|
||||
AND fr.field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
AND d.document_id NOT IN (
|
||||
SELECT document_id FROM llm_validations WHERE error IS NULL
|
||||
)
|
||||
ORDER BY d.document_id
|
||||
"""
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query)
|
||||
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
doc_id = row[0]
|
||||
|
||||
# Get failed fields for this document
|
||||
exclude_clause_inner = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause_inner = "AND field_name != 'customer_number'"
|
||||
cursor.execute(f"""
|
||||
SELECT field_name, csv_value, score
|
||||
FROM field_results
|
||||
WHERE document_id = %s
|
||||
AND matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause_inner}
|
||||
""", (doc_id,))
|
||||
|
||||
failed_fields = [
|
||||
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
|
||||
for r in cursor.fetchall()
|
||||
]
|
||||
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'supplier_name': row[3],
|
||||
'split': row[4],
|
||||
'failed_fields': failed_fields,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
|
||||
"""Get statistics about failed matches."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
exclude_clause = ""
|
||||
if exclude_customer_number:
|
||||
exclude_clause = "AND field_name != 'customer_number'"
|
||||
|
||||
# Count by field
|
||||
cursor.execute(f"""
|
||||
SELECT field_name, COUNT(*) as cnt
|
||||
FROM field_results
|
||||
WHERE matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
GROUP BY field_name
|
||||
ORDER BY cnt DESC
|
||||
""")
|
||||
by_field = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Count documents with failures
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(DISTINCT document_id)
|
||||
FROM field_results
|
||||
WHERE matched = false
|
||||
AND field_name NOT LIKE 'supplier_accounts%%'
|
||||
{exclude_clause}
|
||||
""")
|
||||
doc_count = cursor.fetchone()[0]
|
||||
|
||||
# Already validated count
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
|
||||
""")
|
||||
validated_count = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
'documents_with_failures': doc_count,
|
||||
'already_validated': validated_count,
|
||||
'remaining': doc_count - validated_count,
|
||||
'failures_by_field': by_field,
|
||||
}
|
||||
|
||||
def render_pdf_to_image(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
page_no: int = 0,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
max_size_mb: float = 18.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Render a PDF page to PNG image bytes.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
page_no: Page number to render (0-indexed)
|
||||
dpi: Resolution for rendering
|
||||
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
|
||||
|
||||
Returns:
|
||||
PNG image bytes
|
||||
"""
|
||||
import fitz # PyMuPDF
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[page_no]
|
||||
|
||||
# Try different DPI values until we get a small enough image
|
||||
dpi_values = [dpi, 120, 100, 72, 50]
|
||||
|
||||
for current_dpi in dpi_values:
|
||||
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
png_bytes = pix.tobytes("png")
|
||||
|
||||
size_mb = len(png_bytes) / (1024 * 1024)
|
||||
if size_mb <= max_size_mb:
|
||||
doc.close()
|
||||
return png_bytes
|
||||
|
||||
# If still too large, use JPEG compression
|
||||
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# Convert to PIL Image and compress as JPEG
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
# Try different JPEG quality levels
|
||||
for quality in [85, 70, 50, 30]:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=quality)
|
||||
jpeg_bytes = buffer.getvalue()
|
||||
|
||||
size_mb = len(jpeg_bytes) / (1024 * 1024)
|
||||
if size_mb <= max_size_mb:
|
||||
doc.close()
|
||||
return jpeg_bytes
|
||||
|
||||
doc.close()
|
||||
# Return whatever we have, let the API handle the error
|
||||
return jpeg_bytes
|
||||
|
||||
def extract_with_openai(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
model: str = "gpt-4o"
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Extract fields using OpenAI's vision API (supports Azure OpenAI).
|
||||
|
||||
Args:
|
||||
image_bytes: PNG image bytes
|
||||
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
import openai
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Encode image to base64 and detect format
|
||||
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
|
||||
if image_bytes[:4] == b'\x89PNG':
|
||||
media_type = "image/png"
|
||||
else:
|
||||
media_type = "image/jpeg"
|
||||
|
||||
# Check for Azure OpenAI configuration
|
||||
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
||||
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
|
||||
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
|
||||
|
||||
if azure_endpoint and azure_api_key:
|
||||
# Use Azure OpenAI
|
||||
client = openai.AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=azure_api_key,
|
||||
api_version="2024-02-15-preview"
|
||||
)
|
||||
model = azure_deployment # Use deployment name for Azure
|
||||
else:
|
||||
# Use standard OpenAI
|
||||
client = openai.OpenAI()
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.EXTRACTION_PROMPT},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{media_type};base64,{image_b64}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
max_tokens=1000,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
raw_response = response.choices[0].message.content
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Parse JSON response
|
||||
# Try to extract JSON from response (may have markdown code blocks)
|
||||
json_str = raw_response
|
||||
if "```json" in json_str:
|
||||
json_str = json_str.split("```json")[1].split("```")[0]
|
||||
elif "```" in json_str:
|
||||
json_str = json_str.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(json_str.strip())
|
||||
|
||||
return LLMExtractionResult(
|
||||
document_id="", # Will be set by caller
|
||||
invoice_number=data.get('invoice_number'),
|
||||
invoice_date=data.get('invoice_date'),
|
||||
invoice_due_date=data.get('invoice_due_date'),
|
||||
ocr_number=data.get('ocr_number'),
|
||||
bankgiro=data.get('bankgiro'),
|
||||
plusgiro=data.get('plusgiro'),
|
||||
amount=data.get('amount'),
|
||||
supplier_organisation_number=data.get('supplier_organisation_number'),
|
||||
raw_response=raw_response,
|
||||
model_used=model,
|
||||
processing_time_ms=processing_time,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
raw_response=raw_response if 'raw_response' in dir() else None,
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=f"JSON parse error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def extract_with_anthropic(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
model: str = "claude-sonnet-4-20250514"
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Extract fields using Anthropic's Claude API.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG image bytes
|
||||
model: Model to use
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
import anthropic
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Encode image to base64
|
||||
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": image_b64,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.EXTRACTION_PROMPT
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
raw_response = response.content[0].text
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Parse JSON response
|
||||
json_str = raw_response
|
||||
if "```json" in json_str:
|
||||
json_str = json_str.split("```json")[1].split("```")[0]
|
||||
elif "```" in json_str:
|
||||
json_str = json_str.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(json_str.strip())
|
||||
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
invoice_number=data.get('invoice_number'),
|
||||
invoice_date=data.get('invoice_date'),
|
||||
invoice_due_date=data.get('invoice_due_date'),
|
||||
ocr_number=data.get('ocr_number'),
|
||||
bankgiro=data.get('bankgiro'),
|
||||
plusgiro=data.get('plusgiro'),
|
||||
amount=data.get('amount'),
|
||||
supplier_organisation_number=data.get('supplier_organisation_number'),
|
||||
raw_response=raw_response,
|
||||
model_used=model,
|
||||
processing_time_ms=processing_time,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
raw_response=raw_response if 'raw_response' in dir() else None,
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=f"JSON parse error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id="",
|
||||
model_used=model,
|
||||
processing_time_ms=(time.time() - start_time) * 1000,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def save_validation_result(self, result: LLMExtractionResult):
|
||||
"""Save extraction result to database."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO llm_validations (
|
||||
document_id, invoice_number, invoice_date, invoice_due_date,
|
||||
ocr_number, bankgiro, plusgiro, amount,
|
||||
supplier_organisation_number, raw_response, model_used,
|
||||
processing_time_ms, error
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (document_id) DO UPDATE SET
|
||||
invoice_number = EXCLUDED.invoice_number,
|
||||
invoice_date = EXCLUDED.invoice_date,
|
||||
invoice_due_date = EXCLUDED.invoice_due_date,
|
||||
ocr_number = EXCLUDED.ocr_number,
|
||||
bankgiro = EXCLUDED.bankgiro,
|
||||
plusgiro = EXCLUDED.plusgiro,
|
||||
amount = EXCLUDED.amount,
|
||||
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
|
||||
raw_response = EXCLUDED.raw_response,
|
||||
model_used = EXCLUDED.model_used,
|
||||
processing_time_ms = EXCLUDED.processing_time_ms,
|
||||
error = EXCLUDED.error,
|
||||
created_at = NOW()
|
||||
""", (
|
||||
result.document_id,
|
||||
result.invoice_number,
|
||||
result.invoice_date,
|
||||
result.invoice_due_date,
|
||||
result.ocr_number,
|
||||
result.bankgiro,
|
||||
result.plusgiro,
|
||||
result.amount,
|
||||
result.supplier_organisation_number,
|
||||
result.raw_response,
|
||||
result.model_used,
|
||||
result.processing_time_ms,
|
||||
result.error,
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
def validate_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
provider: str = "openai",
|
||||
model: str = None
|
||||
) -> LLMExtractionResult:
|
||||
"""
|
||||
Validate a single document using LLM.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
provider: LLM provider ("openai" or "anthropic")
|
||||
model: Model to use (defaults based on provider)
|
||||
|
||||
Returns:
|
||||
Extraction result
|
||||
"""
|
||||
# Get PDF path
|
||||
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
||||
if not pdf_path.exists():
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"PDF not found: {pdf_path}"
|
||||
)
|
||||
|
||||
# Render first page
|
||||
try:
|
||||
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
|
||||
except Exception as e:
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"Failed to render PDF: {str(e)}"
|
||||
)
|
||||
|
||||
# Extract with LLM
|
||||
if provider == "openai":
|
||||
model = model or "gpt-4o"
|
||||
result = self.extract_with_openai(image_bytes, model)
|
||||
elif provider == "anthropic":
|
||||
model = model or "claude-sonnet-4-20250514"
|
||||
result = self.extract_with_anthropic(image_bytes, model)
|
||||
else:
|
||||
return LLMExtractionResult(
|
||||
document_id=doc_id,
|
||||
error=f"Unknown provider: {provider}"
|
||||
)
|
||||
|
||||
result.document_id = doc_id
|
||||
|
||||
# Save to database
|
||||
self.save_validation_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def validate_batch(
|
||||
self,
|
||||
limit: int = 10,
|
||||
provider: str = "openai",
|
||||
model: str = None,
|
||||
verbose: bool = True
|
||||
) -> List[LLMExtractionResult]:
|
||||
"""
|
||||
Validate a batch of documents with failed matches.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of documents to validate
|
||||
provider: LLM provider
|
||||
model: Model to use
|
||||
verbose: Print progress
|
||||
|
||||
Returns:
|
||||
List of extraction results
|
||||
"""
|
||||
# Get documents to validate
|
||||
docs = self.get_documents_with_failed_matches(limit=limit)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(docs)} documents with failed matches to validate")
|
||||
|
||||
results = []
|
||||
for i, doc in enumerate(docs):
|
||||
doc_id = doc['document_id']
|
||||
|
||||
if verbose:
|
||||
failed_fields = [f['field'] for f in doc['failed_fields']]
|
||||
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
||||
|
||||
result = self.validate_document(doc_id, provider, model)
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
if result.error:
|
||||
print(f" ERROR: {result.error}")
|
||||
else:
|
||||
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
||||
|
||||
return results
|
||||
|
||||
def compare_results(self, doc_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare LLM extraction with autolabel results.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Comparison results
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Get autolabel results
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, matched, matched_text
|
||||
FROM field_results
|
||||
WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
autolabel = {}
|
||||
for row in cursor.fetchall():
|
||||
autolabel[row[0]] = {
|
||||
'csv_value': row[1],
|
||||
'matched': row[2],
|
||||
'matched_text': row[3],
|
||||
}
|
||||
|
||||
# Get LLM results
|
||||
cursor.execute("""
|
||||
SELECT invoice_number, invoice_date, invoice_due_date,
|
||||
ocr_number, bankgiro, plusgiro, amount,
|
||||
supplier_organisation_number
|
||||
FROM llm_validations
|
||||
WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return {'error': 'No LLM validation found'}
|
||||
|
||||
llm = {
|
||||
'InvoiceNumber': row[0],
|
||||
'InvoiceDate': row[1],
|
||||
'InvoiceDueDate': row[2],
|
||||
'OCR': row[3],
|
||||
'Bankgiro': row[4],
|
||||
'Plusgiro': row[5],
|
||||
'Amount': row[6],
|
||||
'supplier_organisation_number': row[7],
|
||||
}
|
||||
|
||||
# Compare
|
||||
comparison = {}
|
||||
for field in self.FIELDS_TO_EXTRACT:
|
||||
auto = autolabel.get(field, {})
|
||||
llm_value = llm.get(field)
|
||||
|
||||
comparison[field] = {
|
||||
'csv_value': auto.get('csv_value'),
|
||||
'autolabel_matched': auto.get('matched'),
|
||||
'autolabel_text': auto.get('matched_text'),
|
||||
'llm_value': llm_value,
|
||||
'agreement': self._values_match(auto.get('csv_value'), llm_value),
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
def _values_match(self, csv_value: str, llm_value: str) -> bool:
|
||||
"""Check if CSV value matches LLM extracted value."""
|
||||
if csv_value is None or llm_value is None:
|
||||
return csv_value == llm_value
|
||||
|
||||
# Normalize for comparison
|
||||
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
|
||||
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
|
||||
|
||||
return csv_norm == llm_norm
|
||||
9
packages/inference/inference/web/__init__.py
Normal file
9
packages/inference/inference/web/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Web Application Module
|
||||
|
||||
Provides REST API and web interface for invoice field extraction.
|
||||
"""
|
||||
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
8
packages/inference/inference/web/admin_routes_new.py
Normal file
8
packages/inference/inference/web/admin_routes_new.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Backward compatibility shim for admin_routes.py
|
||||
|
||||
DEPRECATED: Import from inference.web.api.v1.admin.documents instead.
|
||||
"""
|
||||
from inference.web.api.v1.admin.documents import *
|
||||
|
||||
__all__ = ["create_admin_router"]
|
||||
0
packages/inference/inference/web/api/__init__.py
Normal file
0
packages/inference/inference/web/api/__init__.py
Normal file
0
packages/inference/inference/web/api/v1/__init__.py
Normal file
0
packages/inference/inference/web/api/v1/__init__.py
Normal file
19
packages/inference/inference/web/api/v1/admin/__init__.py
Normal file
19
packages/inference/inference/web/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Admin API v1
|
||||
|
||||
Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.api.v1.admin.auth import create_auth_router
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
"create_training_router",
|
||||
]
|
||||
644
packages/inference/inference/web/api/v1/admin/annotations.py
Normal file
644
packages/inference/inference/web/api/v1/admin/annotations.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Admin Annotation API Routes
|
||||
|
||||
FastAPI endpoints for annotation management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.autolabel import get_auto_label_service
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
AnnotationListResponse,
|
||||
AnnotationOverrideRequest,
|
||||
AnnotationOverrideResponse,
|
||||
AnnotationResponse,
|
||||
AnnotationSource,
|
||||
AnnotationUpdate,
|
||||
AnnotationVerifyRequest,
|
||||
AnnotationVerifyResponse,
|
||||
AutoLabelRequest,
|
||||
AutoLabelResponse,
|
||||
BoundingBox,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Image storage directory
|
||||
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_annotation_router() -> APIRouter:
|
||||
"""Create annotation API router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
|
||||
|
||||
# =========================================================================
|
||||
# Image Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/images/{page_number}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Get page image",
|
||||
description="Get the image for a specific page.",
|
||||
)
|
||||
async def get_page_image(
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> FileResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if page_number < 1 or page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||
)
|
||||
|
||||
# Find image file
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
||||
if not image_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Image for page {page_number} not found",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Annotation Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="List annotations",
|
||||
description="Get all annotations for a document.",
|
||||
)
|
||||
async def list_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page_number: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Filter by page number"),
|
||||
] = None,
|
||||
) -> AnnotationListResponse:
|
||||
"""List annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id, page_number)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
return AnnotationListResponse(
|
||||
document_id=document_id,
|
||||
page_count=document.page_count,
|
||||
total_annotations=len(annotations),
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Create annotation",
|
||||
description="Create a new annotation for a document.",
|
||||
)
|
||||
async def create_annotation(
|
||||
document_id: str,
|
||||
request: AnnotationCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Create a new annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if request.page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
|
||||
)
|
||||
|
||||
# Get image dimensions for normalization
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
||||
if not image_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {request.page_number} not available",
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
|
||||
# Calculate normalized coordinates
|
||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
width = request.bbox.width / image_width
|
||||
height = request.bbox.height / image_height
|
||||
|
||||
# Get class name
|
||||
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
|
||||
|
||||
# Create annotation
|
||||
annotation_id = db.create_annotation(
|
||||
document_id=document_id,
|
||||
page_number=request.page_number,
|
||||
class_id=request.class_id,
|
||||
class_name=class_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=request.bbox.x,
|
||||
bbox_y=request.bbox.y,
|
||||
bbox_width=request.bbox.width,
|
||||
bbox_height=request.bbox.height,
|
||||
text_value=request.text_value,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Keep status as pending - user must click "Mark Complete" to finalize
|
||||
# This allows user to add multiple annotations before saving to PostgreSQL
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation created successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Update annotation",
|
||||
description="Update an existing annotation.",
|
||||
)
|
||||
async def update_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationUpdate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Update an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Prepare update data
|
||||
update_kwargs = {}
|
||||
|
||||
if request.class_id is not None:
|
||||
update_kwargs["class_id"] = request.class_id
|
||||
update_kwargs["class_name"] = FIELD_CLASSES.get(
|
||||
request.class_id, f"class_{request.class_id}"
|
||||
)
|
||||
|
||||
if request.text_value is not None:
|
||||
update_kwargs["text_value"] = request.text_value
|
||||
|
||||
if request.bbox is not None:
|
||||
# Get image dimensions
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
|
||||
# Calculate normalized coordinates
|
||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
update_kwargs["width"] = request.bbox.width / image_width
|
||||
update_kwargs["height"] = request.bbox.height / image_height
|
||||
update_kwargs["bbox_x"] = request.bbox.x
|
||||
update_kwargs["bbox_y"] = request.bbox.y
|
||||
update_kwargs["bbox_width"] = request.bbox.width
|
||||
update_kwargs["bbox_height"] = request.bbox.height
|
||||
|
||||
# Update annotation
|
||||
if update_kwargs:
|
||||
success = db.update_annotation(annotation_id, **update_kwargs)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update annotation",
|
||||
)
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation updated successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Delete annotation",
|
||||
description="Delete an annotation.",
|
||||
)
|
||||
async def delete_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Delete annotation
|
||||
db.delete_annotation(annotation_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"annotation_id": annotation_id,
|
||||
"message": "Annotation deleted successfully",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Auto-Labeling Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/auto-label",
|
||||
response_model=AutoLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Trigger auto-labeling",
|
||||
description="Trigger auto-labeling for a document using field values.",
|
||||
)
|
||||
async def trigger_auto_label(
|
||||
document_id: str,
|
||||
request: AutoLabelRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AutoLabelResponse:
|
||||
"""Trigger auto-labeling for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate field values
|
||||
if not request.field_values:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one field value is required",
|
||||
)
|
||||
|
||||
# Run auto-labeling
|
||||
service = get_auto_label_service()
|
||||
result = service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=document.file_path,
|
||||
field_values=request.field_values,
|
||||
db=db,
|
||||
replace_existing=request.replace_existing,
|
||||
)
|
||||
|
||||
if result["status"] == "failed":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return AutoLabelResponse(
|
||||
document_id=document_id,
|
||||
status=result["status"],
|
||||
annotations_created=result["annotations_created"],
|
||||
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete all annotations",
|
||||
description="Delete all annotations for a document (optionally filter by source).",
|
||||
)
|
||||
async def delete_all_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by source (manual, auto, imported)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Delete all annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate source
|
||||
if source and source not in ("manual", "auto", "imported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid source: {source}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete annotations
|
||||
deleted_count = db.delete_annotations_for_document(document_id, source)
|
||||
|
||||
# Update document status if all annotations deleted
|
||||
remaining = db.get_annotations_for_document(document_id)
|
||||
if not remaining:
|
||||
db.update_document_status(document_id, "pending")
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"Deleted {deleted_count} annotations",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Phase 5: Annotation Enhancement
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations/{annotation_id}/verify",
|
||||
response_model=AnnotationVerifyResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Verify annotation",
|
||||
description="Mark an annotation as verified by a human reviewer.",
|
||||
)
|
||||
async def verify_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
|
||||
) -> AnnotationVerifyResponse:
|
||||
"""Verify an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Verify the annotation
|
||||
annotation = db.verify_annotation(annotation_id, admin_token)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
return AnnotationVerifyResponse(
|
||||
annotation_id=annotation_id,
|
||||
is_verified=annotation.is_verified,
|
||||
verified_at=annotation.verified_at,
|
||||
verified_by=annotation.verified_by,
|
||||
message="Annotation verified successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}/override",
|
||||
response_model=AnnotationOverrideResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Override annotation",
|
||||
description="Override an auto-generated annotation with manual corrections.",
|
||||
)
|
||||
async def override_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationOverrideRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationOverrideResponse:
|
||||
"""Override an auto-generated annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Build updates dict from request
|
||||
updates = {}
|
||||
if request.text_value is not None:
|
||||
updates["text_value"] = request.text_value
|
||||
if request.class_id is not None:
|
||||
updates["class_id"] = request.class_id
|
||||
# Update class_name if class_id changed
|
||||
if request.class_id in FIELD_CLASSES:
|
||||
updates["class_name"] = FIELD_CLASSES[request.class_id]
|
||||
if request.class_name is not None:
|
||||
updates["class_name"] = request.class_name
|
||||
if request.bbox:
|
||||
# Update bbox fields
|
||||
if "x" in request.bbox:
|
||||
updates["bbox_x"] = request.bbox["x"]
|
||||
if "y" in request.bbox:
|
||||
updates["bbox_y"] = request.bbox["y"]
|
||||
if "width" in request.bbox:
|
||||
updates["bbox_width"] = request.bbox["width"]
|
||||
if "height" in request.bbox:
|
||||
updates["bbox_height"] = request.bbox["height"]
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No updates provided. Specify at least one field to update.",
|
||||
)
|
||||
|
||||
# Override the annotation
|
||||
annotation = db.override_annotation(
|
||||
annotation_id=annotation_id,
|
||||
admin_token=admin_token,
|
||||
change_reason=request.reason,
|
||||
**updates,
|
||||
)
|
||||
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Get history to return history_id
|
||||
history_records = db.get_annotation_history(UUID(annotation_id))
|
||||
latest_history = history_records[0] if history_records else None
|
||||
|
||||
return AnnotationOverrideResponse(
|
||||
annotation_id=annotation_id,
|
||||
source=annotation.source,
|
||||
override_source=annotation.override_source,
|
||||
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
|
||||
message="Annotation overridden successfully",
|
||||
history_id=str(latest_history.history_id) if latest_history else "",
|
||||
)
|
||||
|
||||
return router
|
||||
82
packages/inference/inference/web/api/v1/admin/auth.py
Normal file
82
packages/inference/inference/web/api/v1/admin/auth.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Admin Auth Routes
|
||||
|
||||
FastAPI endpoints for admin token management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AdminTokenCreate,
|
||||
AdminTokenResponse,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_auth_router() -> APIRouter:
|
||||
"""Create admin auth router."""
|
||||
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
response_model=AdminTokenResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
},
|
||||
summary="Create admin token",
|
||||
description="Create a new admin authentication token.",
|
||||
)
|
||||
async def create_token(
|
||||
request: AdminTokenCreate,
|
||||
db: AdminDBDep,
|
||||
) -> AdminTokenResponse:
|
||||
"""Create a new admin token."""
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if request.expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
|
||||
|
||||
# Create token in database
|
||||
db.create_admin_token(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
return AdminTokenResponse(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
message="Admin token created successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/token",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Revoke admin token",
|
||||
description="Revoke the current admin token.",
|
||||
)
|
||||
async def revoke_token(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Revoke the current admin token."""
|
||||
db.deactivate_admin_token(admin_token)
|
||||
return {
|
||||
"status": "revoked",
|
||||
"message": "Admin token has been revoked",
|
||||
}
|
||||
|
||||
return router
|
||||
551
packages/inference/inference/web/api/v1/admin/documents.py
Normal file
551
packages/inference/inference/web/api/v1/admin/documents.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""
|
||||
Admin Document Routes
|
||||
|
||||
FastAPI endpoints for admin document management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
BoundingBox,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
DocumentListResponse,
|
||||
DocumentStatus,
|
||||
DocumentStatsResponse,
|
||||
DocumentUploadResponse,
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
# Render at configured DPI for consistency with training
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"""Create admin documents router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
|
||||
|
||||
# Directories are created by StorageConfig.__post_init__
|
||||
allowed_extensions = storage_config.allowed_extensions
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=DocumentUploadResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Upload document",
|
||||
description="Upload a PDF or image document for labeling.",
|
||||
)
|
||||
async def upload_document(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
file: UploadFile = File(..., description="PDF or image file"),
|
||||
auto_label: Annotated[
|
||||
bool,
|
||||
Query(description="Trigger auto-labeling after upload"),
|
||||
] = True,
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Get page count (for PDF)
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record (token only used for auth, not stored)
|
||||
document_id = db.create_document(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
doc.file_path = str(file_path)
|
||||
session.add(doc)
|
||||
|
||||
# Convert PDF to images for annotation
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling if requested
|
||||
auto_label_started = False
|
||||
if auto_label:
|
||||
# Auto-labeling will be triggered by a background task
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
auto_label_started = True
|
||||
|
||||
return DocumentUploadResponse(
|
||||
document_id=document_id,
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DocumentListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List documents",
|
||||
description="List all documents for the current admin.",
|
||||
)
|
||||
async def list_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
upload_source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by upload source (ui or api)"),
|
||||
] = None,
|
||||
has_annotations: Annotated[
|
||||
bool | None,
|
||||
Query(description="Filter by annotation presence"),
|
||||
] = None,
|
||||
auto_label_status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by auto-label status"),
|
||||
] = None,
|
||||
batch_id: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by batch ID"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> DocumentListResponse:
|
||||
"""List documents."""
|
||||
# Validate status
|
||||
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Validate upload_source
|
||||
if upload_source and upload_source not in ("ui", "api"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid upload_source: {upload_source}",
|
||||
)
|
||||
|
||||
# Validate auto_label_status
|
||||
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid auto_label_status: {auto_label_status}",
|
||||
)
|
||||
|
||||
documents, total = db.get_documents_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
upload_source=upload_source,
|
||||
has_annotations=has_annotations,
|
||||
auto_label_status=auto_label_status,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Get annotation counts and build items
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
items.append(
|
||||
DocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
file_size=doc.file_size,
|
||||
page_count=doc.page_count,
|
||||
status=DocumentStatus(doc.status),
|
||||
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
||||
annotation_count=len(annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=DocumentStatsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get document statistics",
|
||||
description="Get document count by status.",
|
||||
)
|
||||
async def get_document_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentStatsResponse:
|
||||
"""Get document statistics."""
|
||||
counts = db.count_documents_by_status(admin_token)
|
||||
|
||||
return DocumentStatsResponse(
|
||||
total=sum(counts.values()),
|
||||
pending=counts.get("pending", 0),
|
||||
auto_labeling=counts.get("auto_labeling", 0),
|
||||
labeled=counts.get("labeled", 0),
|
||||
exported=counts.get("exported", 0),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{document_id}",
|
||||
response_model=DocumentDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Get document detail",
|
||||
description="Get document details with annotations.",
|
||||
)
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentDetailResponse:
|
||||
"""Get document details."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
# Generate image URLs
|
||||
image_urls = []
|
||||
for page in range(1, document.page_count + 1):
|
||||
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
annotation_lock_until = None
|
||||
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
annotation_lock_until = document.annotation_lock_until
|
||||
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
# Get CSV field values if available
|
||||
csv_field_values = None
|
||||
if hasattr(document, 'csv_field_values') and document.csv_field_values:
|
||||
csv_field_values = document.csv_field_values
|
||||
|
||||
# Get training history (Phase 5)
|
||||
training_history = []
|
||||
training_links = db.get_document_training_tasks(document.document_id)
|
||||
for link in training_links:
|
||||
# Get task details
|
||||
task = db.get_training_task(str(link.task_id))
|
||||
if task:
|
||||
# Build metrics
|
||||
metrics = None
|
||||
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
training_history.append(
|
||||
TrainingHistoryItem(
|
||||
task_id=str(link.task_id),
|
||||
name=task.name,
|
||||
trained_at=link.created_at,
|
||||
model_metrics=metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentDetailResponse(
|
||||
document_id=str(document.document_id),
|
||||
filename=document.filename,
|
||||
file_size=document.file_size,
|
||||
content_type=document.content_type,
|
||||
page_count=document.page_count,
|
||||
status=DocumentStatus(document.status),
|
||||
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
|
||||
auto_label_error=document.auto_label_error,
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
annotations=annotations,
|
||||
image_urls=image_urls,
|
||||
training_history=training_history,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete document",
|
||||
description="Delete a document and its annotations.",
|
||||
)
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete file
|
||||
file_path = Path(document.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# Delete images
|
||||
images_dir = ADMIN_IMAGES_DIR / document_id
|
||||
if images_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(images_dir)
|
||||
|
||||
# Delete from database
|
||||
db.delete_document(document_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"message": "Document deleted successfully",
|
||||
}
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/status",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document status",
|
||||
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
|
||||
)
|
||||
async def update_document_status(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str,
|
||||
Query(description="New status"),
|
||||
],
|
||||
) -> dict:
|
||||
"""Update document status.
|
||||
|
||||
When status is set to 'labeled', the annotations are automatically
|
||||
saved to PostgreSQL documents/field_results tables for consistency
|
||||
with CLI auto-label workflow.
|
||||
"""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate status
|
||||
if status not in ("pending", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
||||
db_save_result = None
|
||||
if status == "labeled":
|
||||
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
|
||||
# Get all annotations for this document
|
||||
annotations = db.get_annotations_for_document(document_id)
|
||||
|
||||
if annotations:
|
||||
db_save_result = save_manual_annotations_to_document_db(
|
||||
document=document,
|
||||
annotations=annotations,
|
||||
db=db,
|
||||
)
|
||||
|
||||
db.update_document_status(document_id, status)
|
||||
|
||||
response = {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"new_status": status,
|
||||
"message": "Document status updated",
|
||||
}
|
||||
|
||||
# Include PostgreSQL save result if applicable
|
||||
if db_save_result:
|
||||
response["document_db_saved"] = db_save_result.get("success", False)
|
||||
response["fields_saved"] = db_save_result.get("fields_saved", 0)
|
||||
|
||||
return response
|
||||
|
||||
return router
|
||||
184
packages/inference/inference/web/api/v1/admin/locks.py
Normal file
184
packages/inference/inference/web/api/v1/admin/locks.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Admin Document Lock Routes
|
||||
|
||||
FastAPI endpoints for annotation lock management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationLockRequest,
|
||||
AnnotationLockResponse,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_locks_router() -> APIRouter:
|
||||
"""Create annotation locks router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Document already locked"},
|
||||
},
|
||||
summary="Acquire annotation lock",
|
||||
description="Acquire a lock on a document to prevent concurrent annotation edits.",
|
||||
)
|
||||
async def acquire_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Acquire annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to acquire lock
|
||||
updated_doc = db.acquire_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
duration_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Document is already locked. Please try again later.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock acquired for {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Release annotation lock",
|
||||
description="Release the annotation lock on a document.",
|
||||
)
|
||||
async def release_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(description="Force release (admin override)"),
|
||||
] = False,
|
||||
) -> AnnotationLockResponse:
|
||||
"""Release annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Release lock
|
||||
updated_doc = db.release_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
force=force,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Failed to release lock",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=False,
|
||||
lock_expires_at=None,
|
||||
message="Lock released successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
|
||||
},
|
||||
summary="Extend annotation lock",
|
||||
description="Extend an existing annotation lock.",
|
||||
)
|
||||
async def extend_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Extend annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to extend lock
|
||||
updated_doc = db.extend_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
additional_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock extended by {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Admin Training API Routes
|
||||
|
||||
FastAPI endpoints for training task management and scheduling.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
from .tasks import register_task_routes
|
||||
from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
"""Create training API router."""
|
||||
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||
|
||||
register_task_routes(router)
|
||||
register_document_routes(router)
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_training_router", "_validate_uuid"]
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Shared utilities for training routes."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Training Dataset Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetDocumentItem,
|
||||
DatasetListItem,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingStatus,
|
||||
TrainingTaskResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_dataset_routes(router: APIRouter) -> None:
|
||||
"""Register dataset endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/datasets",
|
||||
response_model=DatasetResponse,
|
||||
summary="Create training dataset",
|
||||
description="Create a dataset from selected documents with train/val/test splits.",
|
||||
)
|
||||
async def create_dataset(
|
||||
request: DatasetCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset = db.create_dataset(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=request.document_ids,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
admin_images_dir=Path("data/admin_images"),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return DatasetResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
status="ready",
|
||||
message="Dataset created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||
return DatasetListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[
|
||||
DatasetListItem(
|
||||
dataset_id=str(d.dataset_id),
|
||||
name=d.name,
|
||||
description=d.description,
|
||||
status=d.status,
|
||||
total_documents=d.total_documents,
|
||||
total_images=d.total_images,
|
||||
total_annotations=d.total_annotations,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in datasets
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets/{dataset_id}",
|
||||
response_model=DatasetDetailResponse,
|
||||
summary="Get dataset detail",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DatasetDetailResponse:
|
||||
"""Get dataset details with document list."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
docs = db.get_dataset_documents(dataset_id)
|
||||
return DatasetDetailResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
status=dataset.status,
|
||||
train_ratio=dataset.train_ratio,
|
||||
val_ratio=dataset.val_ratio,
|
||||
seed=dataset.seed,
|
||||
total_documents=dataset.total_documents,
|
||||
total_images=dataset.total_images,
|
||||
total_annotations=dataset.total_annotations,
|
||||
dataset_path=dataset.dataset_path,
|
||||
error_message=dataset.error_message,
|
||||
documents=[
|
||||
DatasetDocumentItem(
|
||||
document_id=str(d.document_id),
|
||||
split=d.split,
|
||||
page_count=d.page_count,
|
||||
annotation_count=d.annotation_count,
|
||||
)
|
||||
for d in docs
|
||||
],
|
||||
created_at=dataset.created_at,
|
||||
updated_at=dataset.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/datasets/{dataset_id}",
|
||||
summary="Delete dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a dataset and its files."""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if dataset.dataset_path:
|
||||
dataset_dir = Path(dataset.dataset_path)
|
||||
if dataset_dir.exists():
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
db.delete_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted"}
|
||||
|
||||
@router.post(
|
||||
"/datasets/{dataset_id}/train",
|
||||
response_model=TrainingTaskResponse,
|
||||
summary="Start training from dataset",
|
||||
)
|
||||
async def train_from_dataset(
|
||||
dataset_id: str,
|
||||
request: DatasetTrainRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
if dataset.status != "ready":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Dataset is not ready (status: {dataset.status})",
|
||||
)
|
||||
|
||||
config_dict = request.config.model_dump()
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="train",
|
||||
config=config_dict,
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message="Training task created from dataset",
|
||||
)
|
||||
@@ -0,0 +1,211 @@
|
||||
"""Training Documents and Models Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ModelMetrics,
|
||||
TrainingDocumentItem,
|
||||
TrainingDocumentsResponse,
|
||||
TrainingModelItem,
|
||||
TrainingModelsResponse,
|
||||
TrainingStatus,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_document_routes(router: APIRouter) -> None:
|
||||
"""Register training document and model endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=TrainingDocumentsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get documents for training",
|
||||
description="Get labeled documents available for training with filtering options.",
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
] = True,
|
||||
min_annotation_count: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Minimum annotation count"),
|
||||
] = None,
|
||||
exclude_used_in_training: Annotated[
|
||||
bool,
|
||||
Query(description="Exclude documents already used in training"),
|
||||
] = False,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
documents, total = db.get_documents_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
min_annotation_count=min_annotation_count,
|
||||
exclude_used_in_training=exclude_used_in_training,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
training_links = db.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingDocumentsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{task_id}/download",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||
},
|
||||
summary="Download trained model",
|
||||
description="Download trained model weights file.",
|
||||
)
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if not task.model_path:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not available for this task",
|
||||
)
|
||||
|
||||
model_path = Path(task.model_path)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not found on disk",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(model_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{task.name}_model.pt",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
)
|
||||
async def get_training_models(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for task in tasks:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
download_url = None
|
||||
if task.model_path and task.status == "completed":
|
||||
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||
|
||||
items.append(
|
||||
TrainingModelItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
status=TrainingStatus(task.status),
|
||||
document_count=task.document_count,
|
||||
created_at=task.created_at,
|
||||
completed_at=task.completed_at,
|
||||
metrics=metrics,
|
||||
model_path=task.model_path,
|
||||
download_url=download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingModelsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=items,
|
||||
)
|
||||
121
packages/inference/inference/web/api/v1/admin/training/export.py
Normal file
121
packages/inference/inference/web/api/v1/admin/training/export.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Training Export Endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_export_routes(router: APIRouter) -> None:
|
||||
"""Register export endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/export",
|
||||
response_model=ExportResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Export annotations",
|
||||
description="Export annotations in YOLO format for training.",
|
||||
)
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
documents = db.get_labeled_documents_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * request.split_ratio)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
(export_dir / "data.yaml").write_text(yaml_content)
|
||||
|
||||
return ExportResponse(
|
||||
status="completed",
|
||||
export_path=str(export_dir),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
train_count=len(train_docs),
|
||||
val_count=len(val_docs),
|
||||
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||
)
|
||||
263
packages/inference/inference/web/api/v1/admin/training/tasks.py
Normal file
263
packages/inference/inference/web/api/v1/admin/training/tasks.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Training Task Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingTaskDetailResponse,
|
||||
TrainingTaskItem,
|
||||
TrainingTaskListResponse,
|
||||
TrainingTaskResponse,
|
||||
TrainingType,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_task_routes(router: APIRouter) -> None:
|
||||
"""Register training task endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Create training task",
|
||||
description="Create a new training task.",
|
||||
)
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
description=request.description,
|
||||
config=config_dict,
|
||||
scheduled_at=request.scheduled_at,
|
||||
cron_expression=request.cron_expression,
|
||||
is_recurring=bool(request.cron_expression),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||
message="Training task created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List training tasks",
|
||||
description="List all training tasks.",
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingTaskListResponse:
|
||||
"""List training tasks."""
|
||||
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||
if status and status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = [
|
||||
TrainingTaskItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
scheduled_at=task.scheduled_at,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tasks=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TrainingTaskDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training task detail",
|
||||
description="Get training task details.",
|
||||
)
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
return TrainingTaskDetailResponse(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
config=task.config,
|
||||
scheduled_at=task.scheduled_at,
|
||||
cron_expression=task.cron_expression,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
error_message=task.error_message,
|
||||
result_metrics=task.result_metrics,
|
||||
model_path=task.model_path,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/tasks/{task_id}/cancel",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||
},
|
||||
summary="Cancel training task",
|
||||
description="Cancel a pending or scheduled training task.",
|
||||
)
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if task.status not in ("pending", "scheduled"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
success = db.cancel_training_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to cancel training task",
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.CANCELLED,
|
||||
message="Training task cancelled successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/logs",
|
||||
response_model=TrainingLogsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training logs",
|
||||
description="Get training task logs.",
|
||||
)
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingLogsResponse:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
logs = db.get_training_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
level=log.level,
|
||||
message=log.message,
|
||||
details=log.details,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return TrainingLogsResponse(
|
||||
task_id=task_id,
|
||||
logs=items,
|
||||
)
|
||||
236
packages/inference/inference/web/api/v1/batch/routes.py
Normal file
236
packages/inference/inference/web/api/v1/batch/routes.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Batch Upload API Routes
|
||||
|
||||
Endpoints for batch uploading documents via ZIP files with CSV metadata.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_batch(
|
||||
file: UploadFile = File(...),
|
||||
upload_source: str = Form(default="ui"),
|
||||
async_mode: bool = Form(default=True),
|
||||
auto_label: bool = Form(default=True),
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
) -> dict:
|
||||
"""Upload a batch of documents via ZIP file.
|
||||
|
||||
The ZIP file can contain:
|
||||
- Multiple PDF files
|
||||
- Optional CSV file with field values for auto-labeling
|
||||
|
||||
CSV format:
|
||||
- Required column: DocumentId (matches PDF filename without extension)
|
||||
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
|
||||
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
|
||||
|
||||
Args:
|
||||
file: ZIP file upload
|
||||
upload_source: Upload source (ui or api)
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
|
||||
Returns:
|
||||
Batch upload result with batch_id and status
|
||||
"""
|
||||
if not file.filename.lower().endswith('.zip'):
|
||||
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||
|
||||
# Check compressed size
|
||||
if file.size and file.size > MAX_COMPRESSED_SIZE:
|
||||
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
try:
|
||||
# Read file content
|
||||
zip_content = await file.read()
|
||||
|
||||
# Additional security validation before processing
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
|
||||
# Quick validation of ZIP structure
|
||||
test_zip.testzip()
|
||||
except zipfile.BadZipFile:
|
||||
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
|
||||
|
||||
if async_mode:
|
||||
# Async mode: Queue task and return immediately
|
||||
from uuid import uuid4
|
||||
|
||||
batch_id = uuid4()
|
||||
|
||||
# Create batch task for background processing
|
||||
task = BatchTask(
|
||||
batch_id=batch_id,
|
||||
admin_token=admin_token,
|
||||
zip_content=zip_content,
|
||||
zip_filename=file.filename,
|
||||
upload_source=upload_source,
|
||||
auto_label=auto_label,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Submit to queue
|
||||
queue = get_batch_queue()
|
||||
if not queue.submit(task):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Processing queue is full. Please try again later."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload queued: batch_id={batch_id}, "
|
||||
f"filename={file.filename}, async_mode=True"
|
||||
)
|
||||
|
||||
# Return 202 Accepted with batch_id and status URL
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "accepted",
|
||||
"batch_id": str(batch_id),
|
||||
"message": "Batch upload queued for processing",
|
||||
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
|
||||
"queue_depth": queue.get_queue_depth(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Sync mode: Process immediately and return results
|
||||
service = BatchUploadService(admin_db)
|
||||
result = service.process_zip_upload(
|
||||
admin_token=admin_token,
|
||||
zip_filename=file.filename,
|
||||
zip_content=zip_content,
|
||||
upload_source=upload_source,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload completed: batch_id={result.get('batch_id')}, "
|
||||
f"status={result.get('status')}, files={result.get('successful_files')}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch upload: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to process batch upload. Please contact support."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status/{batch_id}")
|
||||
async def get_batch_status(
|
||||
batch_id: str,
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
) -> dict:
|
||||
"""Get batch upload status and file processing details.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
|
||||
Returns:
|
||||
Batch status with file processing details
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
batch_uuid = UUID(batch_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid batch ID format")
|
||||
|
||||
# Check batch exists and verify ownership
|
||||
batch = admin_db.get_batch_upload(batch_uuid)
|
||||
if not batch:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
# CRITICAL: Verify ownership
|
||||
if batch.admin_token != admin_token:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You do not have access to this batch"
|
||||
)
|
||||
|
||||
# Now safe to return details
|
||||
service = BatchUploadService(admin_db)
|
||||
result = service.get_batch_status(batch_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_batch_uploads(
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""List batch uploads for the current admin token.
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of batch uploads
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if limit < 1 or limit > 100:
|
||||
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
|
||||
if offset < 0:
|
||||
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||
|
||||
# Get batch uploads filtered by admin token
|
||||
batches, total = admin_db.get_batch_uploads_by_token(
|
||||
admin_token=admin_token,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return {
|
||||
"batches": [
|
||||
{
|
||||
"batch_id": str(b.batch_id),
|
||||
"filename": b.filename,
|
||||
"status": b.status,
|
||||
"total_files": b.total_files,
|
||||
"successful_files": b.successful_files,
|
||||
"failed_files": b.failed_files,
|
||||
"created_at": b.created_at.isoformat() if b.created_at else None,
|
||||
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
|
||||
}
|
||||
for b in batches
|
||||
],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
16
packages/inference/inference/web/api/v1/public/__init__.py
Normal file
16
packages/inference/inference/web/api/v1/public/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Public API v1
|
||||
|
||||
Customer-facing endpoints for inference, async processing, and labeling.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.public.inference import create_inference_router
|
||||
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from inference.web.api.v1.public.labeling import create_labeling_router
|
||||
|
||||
__all__ = [
|
||||
"create_inference_router",
|
||||
"create_async_router",
|
||||
"set_async_service",
|
||||
"create_labeling_router",
|
||||
]
|
||||
372
packages/inference/inference/web/api/v1/public/async_api.py
Normal file
372
packages/inference/inference/web/api/v1/public/async_api.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Async API Routes
|
||||
|
||||
FastAPI endpoints for async invoice processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from inference.web.dependencies import (
|
||||
ApiKeyDep,
|
||||
AsyncDBDep,
|
||||
PollRateLimitDep,
|
||||
SubmitRateLimitDep,
|
||||
)
|
||||
from inference.web.schemas.inference import (
|
||||
AsyncRequestItem,
|
||||
AsyncRequestsListResponse,
|
||||
AsyncResultResponse,
|
||||
AsyncStatus,
|
||||
AsyncStatusResponse,
|
||||
AsyncSubmitResponse,
|
||||
DetectionResult,
|
||||
InferenceResult,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
|
||||
def _validate_request_id(request_id: str) -> None:
|
||||
"""Validate that request_id is a valid UUID format."""
|
||||
try:
|
||||
UUID(request_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request ID format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global reference to async processing service (set during app startup)
|
||||
_async_service = None
|
||||
|
||||
|
||||
def set_async_service(service) -> None:
|
||||
"""Set the async processing service instance."""
|
||||
global _async_service
|
||||
_async_service = service
|
||||
|
||||
|
||||
def get_async_service():
|
||||
"""Get the async processing service instance."""
|
||||
if _async_service is None:
|
||||
raise RuntimeError("AsyncProcessingService not initialized")
|
||||
return _async_service
|
||||
|
||||
|
||||
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
|
||||
"""Create async API router."""
|
||||
router = APIRouter(prefix="/async", tags=["Async Processing"])
|
||||
|
||||
@router.post(
|
||||
"/submit",
|
||||
response_model=AsyncSubmitResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
|
||||
503: {"model": ErrorResponse, "description": "Queue full"},
|
||||
},
|
||||
summary="Submit PDF for async processing",
|
||||
description="Submit a PDF or image file for asynchronous processing. "
|
||||
"Returns a request_id that can be used to poll for results.",
|
||||
)
|
||||
async def submit_document(
|
||||
api_key: SubmitRateLimitDep,
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> AsyncSubmitResponse:
|
||||
"""Submit a document for async processing."""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate file extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Check file size (get from config via service)
|
||||
service = get_async_service()
|
||||
max_size = service._async_config.max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File too large. Maximum size: "
|
||||
f"{service._async_config.max_file_size_mb}MB",
|
||||
)
|
||||
|
||||
# Submit request
|
||||
result = service.submit_request(
|
||||
api_key=api_key,
|
||||
file_content=content,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
if "queue" in (result.error or "").lower():
|
||||
raise HTTPException(status_code=503, detail=result.error)
|
||||
raise HTTPException(status_code=500, detail=result.error)
|
||||
|
||||
return AsyncSubmitResponse(
|
||||
status="accepted",
|
||||
message="Request submitted for processing",
|
||||
request_id=result.request_id,
|
||||
estimated_wait_seconds=result.estimated_wait_seconds,
|
||||
poll_url=f"/api/v1/async/status/{result.request_id}",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/status/{request_id}",
|
||||
response_model=AsyncStatusResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get request status",
|
||||
description="Get the current processing status of an async request.",
|
||||
)
|
||||
async def get_status(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncStatusResponse:
|
||||
"""Get the status of an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Get queue position for pending requests
|
||||
position = None
|
||||
if request.status == "pending":
|
||||
position = db.get_queue_position(request_id)
|
||||
|
||||
# Build result URL for completed requests
|
||||
result_url = None
|
||||
if request.status == "completed":
|
||||
result_url = f"/api/v1/async/result/{request_id}"
|
||||
|
||||
return AsyncStatusResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
filename=request.filename,
|
||||
created_at=request.created_at,
|
||||
started_at=request.started_at,
|
||||
completed_at=request.completed_at,
|
||||
position_in_queue=position,
|
||||
error_message=request.error_message,
|
||||
result_url=result_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/result/{request_id}",
|
||||
response_model=AsyncResultResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Request not completed"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get extraction results",
|
||||
description="Get the extraction results for a completed async request.",
|
||||
)
|
||||
async def get_result(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncResultResponse:
|
||||
"""Get the results of a completed async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Check if completed or failed
|
||||
if request.status not in ("completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Request not yet completed. Current status: {request.status}",
|
||||
)
|
||||
|
||||
# Build inference result from stored data
|
||||
inference_result = None
|
||||
if request.result:
|
||||
# Convert detections to DetectionResult objects
|
||||
detections = []
|
||||
for d in request.result.get("detections", []):
|
||||
detections.append(DetectionResult(
|
||||
field=d.get("field", ""),
|
||||
confidence=d.get("confidence", 0.0),
|
||||
bbox=d.get("bbox", [0, 0, 0, 0]),
|
||||
))
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=request.result.get("document_id", str(request.request_id)[:8]),
|
||||
success=request.result.get("success", False),
|
||||
document_type=request.result.get("document_type", "invoice"),
|
||||
fields=request.result.get("fields", {}),
|
||||
confidence=request.result.get("confidence", {}),
|
||||
detections=detections,
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
errors=request.result.get("errors", []),
|
||||
)
|
||||
|
||||
# Build visualization URL
|
||||
viz_url = None
|
||||
if request.visualization_path:
|
||||
viz_url = f"/api/v1/results/{request.visualization_path}"
|
||||
|
||||
return AsyncResultResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
result=inference_result,
|
||||
visualization_url=viz_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/requests",
|
||||
response_model=AsyncRequestsListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
},
|
||||
summary="List requests",
|
||||
description="List all async requests for the authenticated API key.",
|
||||
)
|
||||
async def list_requests(
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (pending, processing, completed, failed)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Maximum number of results"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Pagination offset"),
|
||||
] = 0,
|
||||
) -> AsyncRequestsListResponse:
|
||||
"""List all requests for the authenticated API key."""
|
||||
# Validate status filter
|
||||
if status and status not in ("pending", "processing", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status filter: {status}. "
|
||||
"Must be one of: pending, processing, completed, failed",
|
||||
)
|
||||
|
||||
# Get requests from database
|
||||
requests, total = db.get_requests_by_api_key(
|
||||
api_key=api_key,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response items
|
||||
items = [
|
||||
AsyncRequestItem(
|
||||
request_id=str(r.request_id),
|
||||
status=AsyncStatus(r.status),
|
||||
filename=r.filename,
|
||||
file_size=r.file_size,
|
||||
created_at=r.created_at,
|
||||
completed_at=r.completed_at,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
|
||||
return AsyncRequestsListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
requests=items,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/requests/{request_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
|
||||
},
|
||||
summary="Cancel/delete request",
|
||||
description="Cancel a pending request or delete a completed/failed request.",
|
||||
)
|
||||
async def delete_request(
|
||||
request_id: str,
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
) -> dict:
|
||||
"""Delete or cancel an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Cannot delete processing requests
|
||||
if request.status == "processing":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete a request that is currently processing",
|
||||
)
|
||||
|
||||
# Delete from database (will cascade delete related records)
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DELETE FROM async_requests WHERE request_id = %s",
|
||||
(request_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"request_id": request_id,
|
||||
"message": "Request deleted successfully",
|
||||
}
|
||||
|
||||
return router
|
||||
183
packages/inference/inference/web/api/v1/public/inference.py
Normal file
183
packages/inference/inference/web/api/v1/public/inference.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Inference API Routes
|
||||
|
||||
FastAPI route definitions for the inference API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from inference.web.schemas.inference import (
|
||||
DetectionResult,
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_inference_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with inference endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Check service health status."""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=inference_service.is_initialized,
|
||||
gpu_available=inference_service.gpu_available,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/infer",
|
||||
response_model=InferenceResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> InferenceResponse:
|
||||
"""
|
||||
Process a document and extract invoice fields.
|
||||
|
||||
Accepts PDF or image files (PNG, JPG, JPEG).
|
||||
Returns extracted field values with confidence scores.
|
||||
"""
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Save uploaded file
|
||||
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
||||
try:
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save uploaded file",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process based on file type
|
||||
if file_ext == ".pdf":
|
||||
service_result = inference_service.process_pdf(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
else:
|
||||
service_result = inference_service.process_image(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
|
||||
# Build response
|
||||
viz_url = None
|
||||
if service_result.visualization_path:
|
||||
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=service_result.document_id,
|
||||
success=service_result.success,
|
||||
document_type=service_result.document_type,
|
||||
fields=service_result.fields,
|
||||
confidence=service_result.confidence,
|
||||
detections=[
|
||||
DetectionResult(**d) for d in service_result.detections
|
||||
],
|
||||
processing_time_ms=service_result.processing_time_ms,
|
||||
visualization_url=viz_url,
|
||||
errors=service_result.errors,
|
||||
)
|
||||
|
||||
return InferenceResponse(
|
||||
status="success" if service_result.success else "partial",
|
||||
message=f"Processed document {doc_id}",
|
||||
result=inference_result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup uploaded file
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
async def get_result_image(filename: str) -> FileResponse:
|
||||
"""Get visualization result image."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
media_type="image/png",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@router.delete("/results/{filename}")
|
||||
async def delete_result(filename: str) -> dict:
|
||||
"""Delete a result file."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
file_path.unlink()
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
return router
|
||||
203
packages/inference/inference/web/api/v1/public/labeling.py
Normal file
203
packages/inference/inference/web/api/v1/public/labeling.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Labeling API Routes
|
||||
|
||||
FastAPI endpoints for pre-labeling documents with expected field values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.labeling import PreLabelResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
||||
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get admin database instance."""
|
||||
return AdminDB()
|
||||
|
||||
|
||||
def create_labeling_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with labeling endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||
|
||||
# Ensure upload directory exists
|
||||
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@router.post(
|
||||
"/pre-label",
|
||||
response_model=PreLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
summary="Pre-label document with expected values",
|
||||
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
|
||||
)
|
||||
async def pre_label(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
field_values: str = Form(
|
||||
...,
|
||||
description="JSON object with expected field values. "
|
||||
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
|
||||
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
|
||||
),
|
||||
db: AdminDB = Depends(get_admin_db),
|
||||
) -> PreLabelResponse:
|
||||
"""
|
||||
Upload a document with expected field values for pre-labeling.
|
||||
|
||||
Returns document_id which can be used to retrieve results later.
|
||||
|
||||
Example field_values JSON:
|
||||
```json
|
||||
{
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "1500.00",
|
||||
"Bankgiro": "123-4567",
|
||||
"OCR": "1234567890"
|
||||
}
|
||||
```
|
||||
"""
|
||||
# Parse field_values JSON
|
||||
try:
|
||||
expected_values = json.loads(field_values)
|
||||
if not isinstance(expected_values, dict):
|
||||
raise ValueError("field_values must be a JSON object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON in field_values: {e}",
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read file",
|
||||
)
|
||||
|
||||
# Get page count for PDF
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record with field_values
|
||||
document_id = db.create_document(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
upload_source="api",
|
||||
csv_field_values=expected_values,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save file",
|
||||
)
|
||||
|
||||
# Update file path in database
|
||||
db.update_document_file_path(document_id, str(file_path))
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="pending",
|
||||
)
|
||||
|
||||
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
|
||||
|
||||
return PreLabelResponse(document_id=document_id)
|
||||
|
||||
return router
|
||||
913
packages/inference/inference/web/app.py
Normal file
913
packages/inference/inference/web/app.py
Normal file
@@ -0,0 +1,913 @@
|
||||
"""
|
||||
FastAPI Application Factory
|
||||
|
||||
Creates and configures the FastAPI application.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from inference.web.services import InferenceService
|
||||
|
||||
# Public API imports
|
||||
from inference.web.api.v1.public import (
|
||||
create_inference_router,
|
||||
create_async_router,
|
||||
set_async_service,
|
||||
create_labeling_router,
|
||||
)
|
||||
|
||||
# Async processing imports
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.dependencies import init_dependencies
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
|
||||
# Admin API imports
|
||||
from inference.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
create_training_router,
|
||||
)
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
|
||||
# Batch upload imports
|
||||
from inference.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""
|
||||
Create and configure FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Application configuration. Uses default if not provided.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
config = config or default_config
|
||||
|
||||
# Create inference service
|
||||
inference_service = InferenceService(
|
||||
model_config=config.model,
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
# Create async processing components
|
||||
async_db = AsyncRequestDB()
|
||||
rate_limiter = RateLimiter(async_db)
|
||||
task_queue = AsyncTaskQueue(
|
||||
max_size=config.async_processing.queue_max_size,
|
||||
worker_count=config.async_processing.worker_count,
|
||||
)
|
||||
async_service = AsyncProcessingService(
|
||||
inference_service=inference_service,
|
||||
db=async_db,
|
||||
queue=task_queue,
|
||||
rate_limiter=rate_limiter,
|
||||
async_config=config.async_processing,
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
# Initialize dependencies for FastAPI
|
||||
init_dependencies(async_db, rate_limiter)
|
||||
set_async_service(async_service)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting Invoice Inference API...")
|
||||
|
||||
# Initialize database tables
|
||||
try:
|
||||
async_db.create_tables()
|
||||
logger.info("Async database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database: {e}")
|
||||
|
||||
# Initialize inference service on startup
|
||||
try:
|
||||
inference_service.initialize()
|
||||
logger.info("Inference service ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
# Continue anyway - service will retry on first request
|
||||
|
||||
# Start async processing service
|
||||
try:
|
||||
async_service.start()
|
||||
logger.info("Async processing service started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start async processing: {e}")
|
||||
|
||||
# Start batch upload queue
|
||||
try:
|
||||
admin_db = AdminDB()
|
||||
batch_service = BatchUploadService(admin_db)
|
||||
init_batch_queue(batch_service)
|
||||
logger.info("Batch upload queue started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start batch upload queue: {e}")
|
||||
|
||||
# Start training scheduler
|
||||
try:
|
||||
start_scheduler()
|
||||
logger.info("Training scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training scheduler: {e}")
|
||||
|
||||
# Start auto-label scheduler
|
||||
try:
|
||||
start_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start autolabel scheduler: {e}")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down Invoice Inference API...")
|
||||
|
||||
# Stop auto-label scheduler
|
||||
try:
|
||||
stop_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping autolabel scheduler: {e}")
|
||||
|
||||
# Stop training scheduler
|
||||
try:
|
||||
stop_scheduler()
|
||||
logger.info("Training scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping training scheduler: {e}")
|
||||
|
||||
# Stop batch upload queue
|
||||
try:
|
||||
shutdown_batch_queue()
|
||||
logger.info("Batch upload queue stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping batch upload queue: {e}")
|
||||
|
||||
# Stop async processing service
|
||||
try:
|
||||
async_service.stop(timeout=30.0)
|
||||
logger.info("Async processing service stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping async service: {e}")
|
||||
|
||||
# Close database connection
|
||||
try:
|
||||
async_db.close()
|
||||
logger.info("Database connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
description="""
|
||||
REST API for extracting fields from Swedish invoices.
|
||||
|
||||
## Features
|
||||
- YOLO-based field detection
|
||||
- OCR text extraction
|
||||
- Field normalization and validation
|
||||
- Visualization of detections
|
||||
|
||||
## Supported Fields
|
||||
- InvoiceNumber
|
||||
- InvoiceDate
|
||||
- InvoiceDueDate
|
||||
- OCR (reference number)
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- Amount
|
||||
- supplier_org_number (Swedish organization number)
|
||||
- customer_number
|
||||
- payment_line (machine-readable payment code)
|
||||
""",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount static files for results
|
||||
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(config.storage.result_dir)),
|
||||
name="results",
|
||||
)
|
||||
|
||||
# Include public API routes
|
||||
inference_router = create_inference_router(inference_service, config.storage)
|
||||
app.include_router(inference_router)
|
||||
|
||||
async_router = create_async_router(config.storage.allowed_extensions)
|
||||
app.include_router(async_router, prefix="/api/v1")
|
||||
|
||||
labeling_router = create_labeling_router(inference_service, config.storage)
|
||||
app.include_router(labeling_router)
|
||||
|
||||
# Include admin API routes
|
||||
auth_router = create_auth_router()
|
||||
app.include_router(auth_router, prefix="/api/v1")
|
||||
|
||||
documents_router = create_documents_router(config.storage)
|
||||
app.include_router(documents_router, prefix="/api/v1")
|
||||
|
||||
locks_router = create_locks_router()
|
||||
app.include_router(locks_router, prefix="/api/v1")
|
||||
|
||||
annotation_router = create_annotation_router()
|
||||
app.include_router(annotation_router, prefix="/api/v1")
|
||||
|
||||
training_router = create_training_router()
|
||||
app.include_router(training_router, prefix="/api/v1")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root() -> str:
|
||||
"""Serve the web UI."""
|
||||
return get_html_ui()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_html_ui() -> str:
|
||||
"""Generate HTML UI for the web application."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Invoice Field Extraction</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
color: white;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
header p {
|
||||
opacity: 0.9;
|
||||
font-size: 1.1rem;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
padding: 24px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
font-size: 1.3rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.upload-card {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.upload-card h2 {
|
||||
margin-bottom: 0;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.upload-area {
|
||||
border: 2px dashed #ddd;
|
||||
border-radius: 10px;
|
||||
padding: 15px 25px;
|
||||
text-align: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
background: #fafafa;
|
||||
flex: 1;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.upload-area:hover, .upload-area.dragover {
|
||||
border-color: #667eea;
|
||||
background: #f0f4ff;
|
||||
}
|
||||
|
||||
.upload-area.has-file {
|
||||
border-color: #10b981;
|
||||
background: #ecfdf5;
|
||||
}
|
||||
|
||||
.upload-icon {
|
||||
font-size: 24px;
|
||||
display: inline;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.upload-area p {
|
||||
color: #666;
|
||||
margin: 0;
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.upload-area small {
|
||||
color: #999;
|
||||
display: block;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
#file-input {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-name {
|
||||
margin-top: 15px;
|
||||
padding: 10px 15px;
|
||||
background: #e0f2fe;
|
||||
border-radius: 8px;
|
||||
color: #0369a1;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.btn {
|
||||
display: inline-block;
|
||||
padding: 12px 24px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-primary:disabled {
|
||||
opacity: 0.6;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.loading.active {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.results {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.results.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 15px;
|
||||
border-bottom: 2px solid #eee;
|
||||
}
|
||||
|
||||
.result-status {
|
||||
padding: 6px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.result-status.success {
|
||||
background: #dcfce7;
|
||||
color: #166534;
|
||||
}
|
||||
|
||||
.result-status.partial {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
.result-status.error {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.fields-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.field-item {
|
||||
padding: 12px;
|
||||
background: #f8fafc;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #667eea;
|
||||
}
|
||||
|
||||
.field-item label {
|
||||
display: block;
|
||||
font-size: 0.75rem;
|
||||
color: #64748b;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.field-item .value {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
.field-item .confidence {
|
||||
font-size: 0.75rem;
|
||||
color: #10b981;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.visualization {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.visualization img {
|
||||
width: 100%;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.processing-time {
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
font-size: 0.9rem;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
.cross-validation {
|
||||
background: #f8fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 10px;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.cross-validation h3 {
|
||||
margin: 0 0 10px 0;
|
||||
color: #334155;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.cv-status {
|
||||
font-weight: 600;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 10px;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.cv-status.valid {
|
||||
background: #dcfce7;
|
||||
color: #166534;
|
||||
}
|
||||
|
||||
.cv-status.invalid {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
.cv-details {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.cv-item {
|
||||
background: white;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 0.85rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.cv-item.match {
|
||||
border-color: #86efac;
|
||||
background: #f0fdf4;
|
||||
}
|
||||
|
||||
.cv-item.mismatch {
|
||||
border-color: #fca5a5;
|
||||
background: #fef2f2;
|
||||
}
|
||||
|
||||
.cv-icon {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.cv-item.match .cv-icon {
|
||||
color: #16a34a;
|
||||
}
|
||||
|
||||
.cv-item.mismatch .cv-icon {
|
||||
color: #dc2626;
|
||||
}
|
||||
|
||||
.cv-summary {
|
||||
margin-top: 10px;
|
||||
font-size: 0.8rem;
|
||||
color: #64748b;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
footer {
|
||||
text-align: center;
|
||||
color: white;
|
||||
opacity: 0.8;
|
||||
margin-top: 30px;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>📄 Invoice Field Extraction</h1>
|
||||
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
|
||||
</header>
|
||||
|
||||
<div class="main-content">
|
||||
<!-- Upload Section - Compact -->
|
||||
<div class="card upload-card">
|
||||
<h2>📤 Upload</h2>
|
||||
|
||||
<div class="upload-area" id="upload-area">
|
||||
<span class="upload-icon">📁</span>
|
||||
<p>Drag & drop or <strong>click to browse</strong></p>
|
||||
<small>PDF, PNG, JPG (max 50MB)</small>
|
||||
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
|
||||
</div>
|
||||
|
||||
<div class="file-name" id="file-name" style="display: none;"></div>
|
||||
|
||||
<button class="btn btn-primary" id="submit-btn" disabled>
|
||||
🚀 Extract
|
||||
</button>
|
||||
|
||||
<div class="loading" id="loading">
|
||||
<div class="spinner"></div>
|
||||
<p>Processing...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Results Section - Full Width -->
|
||||
<div class="card">
|
||||
<h2>📊 Extraction Results</h2>
|
||||
|
||||
<div id="placeholder" style="text-align: center; padding: 30px; color: #999;">
|
||||
<div style="font-size: 48px; margin-bottom: 10px;">🔍</div>
|
||||
<p>Upload a document to see extraction results</p>
|
||||
</div>
|
||||
|
||||
<div class="results" id="results">
|
||||
<div class="result-header">
|
||||
<span>Document: <strong id="doc-id"></strong></span>
|
||||
<span class="result-status" id="result-status"></span>
|
||||
</div>
|
||||
|
||||
<div class="fields-grid" id="fields-grid"></div>
|
||||
|
||||
<div class="processing-time" id="processing-time"></div>
|
||||
|
||||
<div class="cross-validation" id="cross-validation" style="display: none;"></div>
|
||||
|
||||
<div class="error-message" id="error-message" style="display: none;"></div>
|
||||
|
||||
<div class="visualization" id="visualization" style="display: none;">
|
||||
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
|
||||
<img id="viz-image" src="" alt="Detection visualization">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<p>Powered by ColaCoder</p>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const uploadArea = document.getElementById('upload-area');
|
||||
const fileInput = document.getElementById('file-input');
|
||||
const fileName = document.getElementById('file-name');
|
||||
const submitBtn = document.getElementById('submit-btn');
|
||||
const loading = document.getElementById('loading');
|
||||
const placeholder = document.getElementById('placeholder');
|
||||
const results = document.getElementById('results');
|
||||
|
||||
let selectedFile = null;
|
||||
|
||||
// Drag and drop handlers
|
||||
uploadArea.addEventListener('click', () => fileInput.click());
|
||||
|
||||
uploadArea.addEventListener('dragover', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.add('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('dragleave', () => {
|
||||
uploadArea.classList.remove('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('drop', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.remove('dragover');
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
handleFile(files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
fileInput.addEventListener('change', (e) => {
|
||||
if (e.target.files.length > 0) {
|
||||
handleFile(e.target.files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
function handleFile(file) {
|
||||
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
|
||||
const ext = '.' + file.name.split('.').pop().toLowerCase();
|
||||
|
||||
if (!validTypes.includes(ext)) {
|
||||
alert('Please upload a PDF, PNG, or JPG file.');
|
||||
return;
|
||||
}
|
||||
|
||||
selectedFile = file;
|
||||
fileName.textContent = `📎 ${file.name}`;
|
||||
fileName.style.display = 'block';
|
||||
uploadArea.classList.add('has-file');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
|
||||
submitBtn.addEventListener('click', async () => {
|
||||
if (!selectedFile) return;
|
||||
|
||||
// Show loading
|
||||
submitBtn.disabled = true;
|
||||
loading.classList.add('active');
|
||||
placeholder.style.display = 'none';
|
||||
results.classList.remove('active');
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', selectedFile);
|
||||
|
||||
const response = await fetch('/api/v1/infer', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.detail || 'Processing failed');
|
||||
}
|
||||
|
||||
displayResults(data);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
document.getElementById('error-message').textContent = error.message;
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
results.classList.add('active');
|
||||
} finally {
|
||||
loading.classList.remove('active');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
});
|
||||
|
||||
function displayResults(data) {
|
||||
const result = data.result;
|
||||
|
||||
// Document ID
|
||||
document.getElementById('doc-id').textContent = result.document_id;
|
||||
|
||||
// Status
|
||||
const statusEl = document.getElementById('result-status');
|
||||
statusEl.textContent = result.success ? 'Success' : 'Partial';
|
||||
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
|
||||
|
||||
// Fields
|
||||
const fieldsGrid = document.getElementById('fields-grid');
|
||||
fieldsGrid.innerHTML = '';
|
||||
|
||||
const fieldOrder = [
|
||||
'InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR',
|
||||
'Amount', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_org_number', 'customer_number', 'payment_line'
|
||||
];
|
||||
|
||||
fieldOrder.forEach(field => {
|
||||
const value = result.fields[field];
|
||||
const confidence = result.confidence[field];
|
||||
|
||||
if (value !== null && value !== undefined) {
|
||||
const fieldDiv = document.createElement('div');
|
||||
fieldDiv.className = 'field-item';
|
||||
fieldDiv.innerHTML = `
|
||||
<label>${formatFieldName(field)}</label>
|
||||
<div class="value">${value}</div>
|
||||
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
|
||||
`;
|
||||
fieldsGrid.appendChild(fieldDiv);
|
||||
}
|
||||
});
|
||||
|
||||
// Processing time
|
||||
document.getElementById('processing-time').textContent =
|
||||
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
|
||||
|
||||
// Cross-validation results
|
||||
const cvDiv = document.getElementById('cross-validation');
|
||||
if (result.cross_validation) {
|
||||
const cv = result.cross_validation;
|
||||
let cvHtml = '<h3>🔍 Cross-Validation (Payment Line)</h3>';
|
||||
cvHtml += `<div class="cv-status ${cv.is_valid ? 'valid' : 'invalid'}">`;
|
||||
cvHtml += cv.is_valid ? '✅ Valid' : '⚠️ Mismatch Detected';
|
||||
cvHtml += '</div>';
|
||||
|
||||
cvHtml += '<div class="cv-details">';
|
||||
if (cv.payment_line_ocr) {
|
||||
const matchIcon = cv.ocr_match === true ? '✓' : (cv.ocr_match === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${cv.ocr_match === true ? 'match' : (cv.ocr_match === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> OCR: ${cv.payment_line_ocr}</div>`;
|
||||
}
|
||||
if (cv.payment_line_amount) {
|
||||
const matchIcon = cv.amount_match === true ? '✓' : (cv.amount_match === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${cv.amount_match === true ? 'match' : (cv.amount_match === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> Amount: ${cv.payment_line_amount}</div>`;
|
||||
}
|
||||
if (cv.payment_line_account) {
|
||||
const accountType = cv.payment_line_account_type === 'bankgiro' ? 'Bankgiro' : 'Plusgiro';
|
||||
const matchField = cv.payment_line_account_type === 'bankgiro' ? cv.bankgiro_match : cv.plusgiro_match;
|
||||
const matchIcon = matchField === true ? '✓' : (matchField === false ? '✗' : '—');
|
||||
cvHtml += `<div class="cv-item ${matchField === true ? 'match' : (matchField === false ? 'mismatch' : '')}">`;
|
||||
cvHtml += `<span class="cv-icon">${matchIcon}</span> ${accountType}: ${cv.payment_line_account}</div>`;
|
||||
}
|
||||
cvHtml += '</div>';
|
||||
|
||||
if (cv.details && cv.details.length > 0) {
|
||||
cvHtml += '<div class="cv-summary">' + cv.details[cv.details.length - 1] + '</div>';
|
||||
}
|
||||
|
||||
cvDiv.innerHTML = cvHtml;
|
||||
cvDiv.style.display = 'block';
|
||||
} else {
|
||||
cvDiv.style.display = 'none';
|
||||
}
|
||||
|
||||
// Visualization
|
||||
if (result.visualization_url) {
|
||||
const vizDiv = document.getElementById('visualization');
|
||||
const vizImg = document.getElementById('viz-image');
|
||||
vizImg.src = result.visualization_url;
|
||||
vizDiv.style.display = 'block';
|
||||
}
|
||||
|
||||
// Errors
|
||||
if (result.errors && result.errors.length > 0) {
|
||||
document.getElementById('error-message').textContent = result.errors.join(', ');
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
} else {
|
||||
document.getElementById('error-message').style.display = 'none';
|
||||
}
|
||||
|
||||
results.classList.add('active');
|
||||
}
|
||||
|
||||
function formatFieldName(name) {
|
||||
const nameMap = {
|
||||
'InvoiceNumber': 'Invoice Number',
|
||||
'InvoiceDate': 'Invoice Date',
|
||||
'InvoiceDueDate': 'Due Date',
|
||||
'OCR': 'OCR Reference',
|
||||
'Amount': 'Amount',
|
||||
'Bankgiro': 'Bankgiro',
|
||||
'Plusgiro': 'Plusgiro',
|
||||
'supplier_org_number': 'Supplier Org Number',
|
||||
'customer_number': 'Customer Number',
|
||||
'payment_line': 'Payment Line'
|
||||
};
|
||||
return nameMap[name] || name.replace(/([A-Z])/g, ' $1').replace(/_/g, ' ').trim();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
113
packages/inference/inference/web/config.py
Normal file
113
packages/inference/inference/web/config.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Web Application Configuration
|
||||
|
||||
Centralized configuration for the web application.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""YOLO model configuration."""
|
||||
|
||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||
confidence_threshold: float = 0.5
|
||||
use_gpu: bool = True
|
||||
dpi: int = DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServerConfig:
|
||||
"""Server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
reload: bool = False
|
||||
workers: int = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""File storage configuration.
|
||||
|
||||
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
||||
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
||||
and avoids storing duplicate files.
|
||||
"""
|
||||
|
||||
upload_dir: Path = Path("uploads")
|
||||
result_dir: Path = Path("results")
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
||||
admin_images_dir: Path = Path("data/admin_images")
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
dpi: int = DEFAULT_DPI
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
|
||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncConfig:
|
||||
"""Async processing configuration."""
|
||||
|
||||
# Queue settings
|
||||
queue_max_size: int = 100
|
||||
worker_count: int = 1
|
||||
task_timeout_seconds: int = 300
|
||||
|
||||
# Rate limiting defaults
|
||||
default_requests_per_minute: int = 10
|
||||
default_max_concurrent_jobs: int = 3
|
||||
default_min_poll_interval_ms: int = 1000
|
||||
|
||||
# Storage
|
||||
result_retention_days: int = 7
|
||||
temp_upload_dir: Path = Path("uploads/async")
|
||||
max_file_size_mb: int = 50
|
||||
|
||||
# Cleanup
|
||||
cleanup_interval_hours: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
"""Create config from dictionary."""
|
||||
return cls(
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||
)
|
||||
|
||||
|
||||
# Default configuration instance
|
||||
default_config = AppConfig()
|
||||
28
packages/inference/inference/web/core/__init__.py
Normal file
28
packages/inference/inference/web/core/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Core Components
|
||||
|
||||
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||
"""
|
||||
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from inference.web.core.autolabel_scheduler import (
|
||||
start_autolabel_scheduler,
|
||||
stop_autolabel_scheduler,
|
||||
get_autolabel_scheduler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"validate_admin_token",
|
||||
"get_admin_db",
|
||||
"AdminTokenDep",
|
||||
"AdminDBDep",
|
||||
"RateLimiter",
|
||||
"start_scheduler",
|
||||
"stop_scheduler",
|
||||
"get_training_scheduler",
|
||||
"start_autolabel_scheduler",
|
||||
"stop_autolabel_scheduler",
|
||||
"get_autolabel_scheduler",
|
||||
]
|
||||
60
packages/inference/inference/web/core/auth.py
Normal file
60
packages/inference/inference/web/core/auth.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Admin Authentication
|
||||
|
||||
FastAPI dependencies for admin token authentication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global AdminDB instance
|
||||
_admin_db: AdminDB | None = None
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get the AdminDB instance."""
|
||||
global _admin_db
|
||||
if _admin_db is None:
|
||||
_admin_db = AdminDB()
|
||||
return _admin_db
|
||||
|
||||
|
||||
def reset_admin_db() -> None:
|
||||
"""Reset the AdminDB instance (for testing)."""
|
||||
global _admin_db
|
||||
_admin_db = None
|
||||
|
||||
|
||||
async def validate_admin_token(
|
||||
x_admin_token: Annotated[str | None, Header()] = None,
|
||||
admin_db: AdminDB = Depends(get_admin_db),
|
||||
) -> str:
|
||||
"""Validate admin token from header."""
|
||||
if not x_admin_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Admin token required. Provide X-Admin-Token header.",
|
||||
)
|
||||
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or expired admin token.",
|
||||
)
|
||||
|
||||
# Update last used timestamp
|
||||
admin_db.update_admin_token_usage(x_admin_token)
|
||||
|
||||
return x_admin_token
|
||||
|
||||
|
||||
# Type alias for dependency injection
|
||||
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
|
||||
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]
|
||||
153
packages/inference/inference/web/core/autolabel_scheduler.py
Normal file
153
packages/inference/inference/web/core/autolabel_scheduler.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Auto-Label Scheduler
|
||||
|
||||
Background scheduler for processing documents pending auto-labeling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelScheduler:
|
||||
"""Scheduler for auto-labeling tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 10,
|
||||
batch_size: int = 5,
|
||||
output_dir: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize auto-label scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
batch_size: Number of documents to process per batch
|
||||
output_dir: Output directory for temporary files
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._batch_size = batch_size
|
||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("AutoLabel scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._process_pending_documents()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _process_pending_documents(self) -> None:
|
||||
"""Check and process pending auto-label documents."""
|
||||
try:
|
||||
documents = get_pending_autolabel_documents(
|
||||
self._db, limit=self._batch_size
|
||||
)
|
||||
|
||||
if not documents:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(documents)} pending autolabel documents")
|
||||
|
||||
for doc in documents:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
db=self._db,
|
||||
output_dir=self._output_dir,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
logger.info(
|
||||
f"AutoLabel completed for document {doc.document_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"AutoLabel failed for document {doc.document_id}: "
|
||||
f"{result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing document {doc.document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_autolabel_scheduler: AutoLabelScheduler | None = None
|
||||
|
||||
|
||||
def get_autolabel_scheduler() -> AutoLabelScheduler:
|
||||
"""Get the auto-label scheduler instance."""
|
||||
global _autolabel_scheduler
|
||||
if _autolabel_scheduler is None:
|
||||
_autolabel_scheduler = AutoLabelScheduler()
|
||||
return _autolabel_scheduler
|
||||
|
||||
|
||||
def start_autolabel_scheduler() -> None:
|
||||
"""Start the global auto-label scheduler."""
|
||||
scheduler = get_autolabel_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_autolabel_scheduler() -> None:
|
||||
"""Stop the global auto-label scheduler."""
|
||||
global _autolabel_scheduler
|
||||
if _autolabel_scheduler:
|
||||
_autolabel_scheduler.stop()
|
||||
_autolabel_scheduler = None
|
||||
211
packages/inference/inference/web/core/rate_limiter.py
Normal file
211
packages/inference/inference/web/core/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Rate Limiter Implementation
|
||||
|
||||
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Rate limit configuration for an API key."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitStatus:
|
||||
"""Current rate limit status."""
|
||||
|
||||
allowed: bool
|
||||
remaining_requests: int
|
||||
reset_at: datetime
|
||||
retry_after_seconds: int | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Thread-safe rate limiter with sliding window algorithm.
|
||||
|
||||
Tracks:
|
||||
- Requests per minute (sliding window)
|
||||
- Concurrent active jobs
|
||||
- Poll frequency per request_id
|
||||
"""
|
||||
|
||||
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||
self._db = db
|
||||
self._lock = Lock()
|
||||
# In-memory tracking for fast checks
|
||||
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||
# (api_key, request_id) -> last_poll timestamp
|
||||
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||
# Cache for API key configs (TTL 60 seconds)
|
||||
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||
self._config_cache_ttl = 60.0
|
||||
|
||||
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||
"""Check if API key can submit a new request."""
|
||||
config = self._get_config(api_key)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
|
||||
# Clean old entries
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
|
||||
current_count = len(self._request_windows[api_key])
|
||||
|
||||
if current_count >= config.requests_per_minute:
|
||||
oldest = min(self._request_windows[api_key])
|
||||
retry_after = int(oldest + 60 - now) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||
retry_after_seconds=max(1, retry_after),
|
||||
reason="Rate limit exceeded: too many requests per minute",
|
||||
)
|
||||
|
||||
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||
active_jobs = self._db.count_active_jobs(api_key)
|
||||
if active_jobs >= config.max_concurrent_jobs:
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=config.requests_per_minute - current_count,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||
)
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
def record_request(self, api_key: str) -> None:
|
||||
"""Record a successful request submission."""
|
||||
with self._lock:
|
||||
self._request_windows[api_key].append(time.time())
|
||||
|
||||
# Also record in database for persistence
|
||||
try:
|
||||
self._db.record_rate_limit_event(api_key, "request")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record rate limit event: {e}")
|
||||
|
||||
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||
"""Check if polling is allowed (prevent abuse)."""
|
||||
config = self._get_config(api_key)
|
||||
key = (api_key, request_id)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
last_poll = self._poll_timestamps.get(key, 0)
|
||||
elapsed_ms = (now - last_poll) * 1000
|
||||
|
||||
if elapsed_ms < config.min_poll_interval_ms:
|
||||
# Suggest exponential backoff
|
||||
wait_ms = min(
|
||||
config.min_poll_interval_ms * 2,
|
||||
5000, # Max 5 seconds
|
||||
)
|
||||
retry_after = int(wait_ms / 1000) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||
retry_after_seconds=retry_after,
|
||||
reason="Polling too frequently. Please wait before retrying.",
|
||||
)
|
||||
|
||||
# Update poll timestamp
|
||||
self._poll_timestamps[key] = now
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999, # No limit on poll count, just frequency
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||
"""Get rate limit config for API key with caching."""
|
||||
now = time.time()
|
||||
|
||||
# Check cache
|
||||
if api_key in self._config_cache:
|
||||
cached_config, cached_at = self._config_cache[api_key]
|
||||
if now - cached_at < self._config_cache_ttl:
|
||||
return cached_config
|
||||
|
||||
# Query database
|
||||
db_config = self._db.get_api_key_config(api_key)
|
||||
if db_config:
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=db_config.requests_per_minute,
|
||||
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||
)
|
||||
else:
|
||||
config = RateLimitConfig() # Default limits
|
||||
|
||||
# Cache result
|
||||
self._config_cache[api_key] = (config, now)
|
||||
return config
|
||||
|
||||
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||
"""Clean up old poll timestamps to prevent memory leak."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
cutoff = now - max_age_seconds
|
||||
old_keys = [
|
||||
k for k, v in self._poll_timestamps.items()
|
||||
if v < cutoff
|
||||
]
|
||||
for key in old_keys:
|
||||
del self._poll_timestamps[key]
|
||||
return len(old_keys)
|
||||
|
||||
def cleanup_request_windows(self) -> None:
|
||||
"""Clean up expired entries from request windows."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60
|
||||
|
||||
for api_key in list(self._request_windows.keys()):
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
# Remove empty entries
|
||||
if not self._request_windows[api_key]:
|
||||
del self._request_windows[api_key]
|
||||
|
||||
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||
"""Generate rate limit headers for HTTP response."""
|
||||
headers = {
|
||||
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||
}
|
||||
if status.retry_after_seconds:
|
||||
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||
return headers
|
||||
340
packages/inference/inference/web/core/scheduler.py
Normal file
340
packages/inference/inference/web/core/scheduler.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Admin Training Scheduler
|
||||
|
||||
Background scheduler for training tasks using APScheduler.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingScheduler:
|
||||
"""Scheduler for training tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 60,
|
||||
):
|
||||
"""
|
||||
Initialize training scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("Training scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Training scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
logger.info("Training scheduler stopped")
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._check_pending_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduler loop: {e}")
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _check_pending_tasks(self) -> None:
|
||||
"""Check and execute pending training tasks."""
|
||||
try:
|
||||
tasks = self._db.get_pending_training_tasks()
|
||||
|
||||
for task in tasks:
|
||||
task_id = str(task.task_id)
|
||||
|
||||
# Check if scheduled time has passed
|
||||
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
|
||||
continue
|
||||
|
||||
logger.info(f"Starting training task: {task_id}")
|
||||
|
||||
try:
|
||||
dataset_id = getattr(task, "dataset_id", None)
|
||||
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.update_training_task_status(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending tasks: {e}")
|
||||
|
||||
def _execute_task(
|
||||
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
|
||||
) -> None:
|
||||
"""Execute a training task."""
|
||||
# Update status to running
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
image_size = config.get("image_size", 640)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
)
|
||||
else:
|
||||
export_result = self._export_training_data(task_id)
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
# Update task with results
|
||||
self._db.update_training_task_status(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
result_metrics=result.get("metrics"),
|
||||
model_path=result.get("model_path"),
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
|
||||
if not documents:
|
||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/training") / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 80/20 train/val split
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * 0.8)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = self._db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
# Create data.yaml
|
||||
yaml_path = export_dir / "data.yaml"
|
||||
yaml_content = f"""path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
yaml_path.write_text(yaml_content)
|
||||
|
||||
return {
|
||||
"data_yaml": str(yaml_path),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _run_yolo_training(
|
||||
self,
|
||||
task_id: str,
|
||||
model_name: str,
|
||||
data_yaml: str,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
image_size: int,
|
||||
learning_rate: float,
|
||||
device: str,
|
||||
project_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Log training start
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Train
|
||||
results = model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
batch=batch_size,
|
||||
imgsz=image_size,
|
||||
lr0=learning_rate,
|
||||
device=device,
|
||||
project=f"runs/train/{project_name}",
|
||||
name=f"task_{task_id[:8]}",
|
||||
exist_ok=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
||||
)
|
||||
|
||||
return {
|
||||
"model_path": str(best_model) if best_model.exists() else None,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
||||
raise ValueError("Ultralytics (YOLO) not installed")
|
||||
except Exception as e:
|
||||
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: TrainingScheduler | None = None
|
||||
|
||||
|
||||
def get_training_scheduler() -> TrainingScheduler:
|
||||
"""Get the training scheduler instance."""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = TrainingScheduler()
|
||||
return _scheduler
|
||||
|
||||
|
||||
def start_scheduler() -> None:
|
||||
"""Start the global training scheduler."""
|
||||
scheduler = get_training_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_scheduler() -> None:
|
||||
"""Stop the global training scheduler."""
|
||||
global _scheduler
|
||||
if _scheduler:
|
||||
_scheduler.stop()
|
||||
_scheduler = None
|
||||
133
packages/inference/inference/web/dependencies.py
Normal file
133
packages/inference/inference/web/dependencies.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
FastAPI Dependencies
|
||||
|
||||
Dependency injection for the async API endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, Request
|
||||
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.rate_limiter import RateLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances (initialized in app startup)
|
||||
_async_db: AsyncRequestDB | None = None
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def init_dependencies(db: AsyncRequestDB, rate_limiter: RateLimiter) -> None:
|
||||
"""Initialize global dependency instances."""
|
||||
global _async_db, _rate_limiter
|
||||
_async_db = db
|
||||
_rate_limiter = rate_limiter
|
||||
|
||||
|
||||
def get_async_db() -> AsyncRequestDB:
|
||||
"""Get async request database instance."""
|
||||
if _async_db is None:
|
||||
raise RuntimeError("AsyncRequestDB not initialized")
|
||||
return _async_db
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get rate limiter instance."""
|
||||
if _rate_limiter is None:
|
||||
raise RuntimeError("RateLimiter not initialized")
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Verify API key exists and is active.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if API key is missing or invalid
|
||||
"""
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="X-API-Key header is required",
|
||||
headers={"WWW-Authenticate": "API-Key"},
|
||||
)
|
||||
|
||||
db = get_async_db()
|
||||
if not db.is_valid_api_key(x_api_key):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or inactive API key",
|
||||
headers={"WWW-Authenticate": "API-Key"},
|
||||
)
|
||||
|
||||
# Update usage tracking
|
||||
try:
|
||||
db.update_api_key_usage(x_api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update API key usage: {e}")
|
||||
|
||||
return x_api_key
|
||||
|
||||
|
||||
async def check_submit_rate_limit(
|
||||
api_key: Annotated[str, Depends(verify_api_key)],
|
||||
) -> str:
|
||||
"""
|
||||
Check rate limit before processing submit request.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if rate limit exceeded
|
||||
"""
|
||||
rate_limiter = get_rate_limiter()
|
||||
status = rate_limiter.check_submit_limit(api_key)
|
||||
|
||||
if not status.allowed:
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=status.reason or "Rate limit exceeded",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
async def check_poll_rate_limit(
|
||||
request: Request,
|
||||
api_key: Annotated[str, Depends(verify_api_key)],
|
||||
) -> str:
|
||||
"""
|
||||
Check poll rate limit to prevent abuse.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if polling too frequently
|
||||
"""
|
||||
# Extract request_id from path parameters
|
||||
request_id = request.path_params.get("request_id")
|
||||
if not request_id:
|
||||
return api_key # No request_id, skip poll limit check
|
||||
|
||||
rate_limiter = get_rate_limiter()
|
||||
status = rate_limiter.check_poll_limit(api_key, request_id)
|
||||
|
||||
if not status.allowed:
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=status.reason or "Polling too frequently",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
# Type aliases for cleaner route signatures
|
||||
ApiKeyDep = Annotated[str, Depends(verify_api_key)]
|
||||
SubmitRateLimitDep = Annotated[str, Depends(check_submit_rate_limit)]
|
||||
PollRateLimitDep = Annotated[str, Depends(check_poll_rate_limit)]
|
||||
AsyncDBDep = Annotated[AsyncRequestDB, Depends(get_async_db)]
|
||||
RateLimiterDep = Annotated[RateLimiter, Depends(get_rate_limiter)]
|
||||
211
packages/inference/inference/web/rate_limiter.py
Normal file
211
packages/inference/inference/web/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Rate Limiter Implementation
|
||||
|
||||
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Rate limit configuration for an API key."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitStatus:
|
||||
"""Current rate limit status."""
|
||||
|
||||
allowed: bool
|
||||
remaining_requests: int
|
||||
reset_at: datetime
|
||||
retry_after_seconds: int | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Thread-safe rate limiter with sliding window algorithm.
|
||||
|
||||
Tracks:
|
||||
- Requests per minute (sliding window)
|
||||
- Concurrent active jobs
|
||||
- Poll frequency per request_id
|
||||
"""
|
||||
|
||||
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||
self._db = db
|
||||
self._lock = Lock()
|
||||
# In-memory tracking for fast checks
|
||||
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||
# (api_key, request_id) -> last_poll timestamp
|
||||
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||
# Cache for API key configs (TTL 60 seconds)
|
||||
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||
self._config_cache_ttl = 60.0
|
||||
|
||||
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||
"""Check if API key can submit a new request."""
|
||||
config = self._get_config(api_key)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
|
||||
# Clean old entries
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
|
||||
current_count = len(self._request_windows[api_key])
|
||||
|
||||
if current_count >= config.requests_per_minute:
|
||||
oldest = min(self._request_windows[api_key])
|
||||
retry_after = int(oldest + 60 - now) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||
retry_after_seconds=max(1, retry_after),
|
||||
reason="Rate limit exceeded: too many requests per minute",
|
||||
)
|
||||
|
||||
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||
active_jobs = self._db.count_active_jobs(api_key)
|
||||
if active_jobs >= config.max_concurrent_jobs:
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=config.requests_per_minute - current_count,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||
)
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
def record_request(self, api_key: str) -> None:
|
||||
"""Record a successful request submission."""
|
||||
with self._lock:
|
||||
self._request_windows[api_key].append(time.time())
|
||||
|
||||
# Also record in database for persistence
|
||||
try:
|
||||
self._db.record_rate_limit_event(api_key, "request")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record rate limit event: {e}")
|
||||
|
||||
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||
"""Check if polling is allowed (prevent abuse)."""
|
||||
config = self._get_config(api_key)
|
||||
key = (api_key, request_id)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
last_poll = self._poll_timestamps.get(key, 0)
|
||||
elapsed_ms = (now - last_poll) * 1000
|
||||
|
||||
if elapsed_ms < config.min_poll_interval_ms:
|
||||
# Suggest exponential backoff
|
||||
wait_ms = min(
|
||||
config.min_poll_interval_ms * 2,
|
||||
5000, # Max 5 seconds
|
||||
)
|
||||
retry_after = int(wait_ms / 1000) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||
retry_after_seconds=retry_after,
|
||||
reason="Polling too frequently. Please wait before retrying.",
|
||||
)
|
||||
|
||||
# Update poll timestamp
|
||||
self._poll_timestamps[key] = now
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999, # No limit on poll count, just frequency
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||
"""Get rate limit config for API key with caching."""
|
||||
now = time.time()
|
||||
|
||||
# Check cache
|
||||
if api_key in self._config_cache:
|
||||
cached_config, cached_at = self._config_cache[api_key]
|
||||
if now - cached_at < self._config_cache_ttl:
|
||||
return cached_config
|
||||
|
||||
# Query database
|
||||
db_config = self._db.get_api_key_config(api_key)
|
||||
if db_config:
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=db_config.requests_per_minute,
|
||||
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||
)
|
||||
else:
|
||||
config = RateLimitConfig() # Default limits
|
||||
|
||||
# Cache result
|
||||
self._config_cache[api_key] = (config, now)
|
||||
return config
|
||||
|
||||
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||
"""Clean up old poll timestamps to prevent memory leak."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
cutoff = now - max_age_seconds
|
||||
old_keys = [
|
||||
k for k, v in self._poll_timestamps.items()
|
||||
if v < cutoff
|
||||
]
|
||||
for key in old_keys:
|
||||
del self._poll_timestamps[key]
|
||||
return len(old_keys)
|
||||
|
||||
def cleanup_request_windows(self) -> None:
|
||||
"""Clean up expired entries from request windows."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60
|
||||
|
||||
for api_key in list(self._request_windows.keys()):
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
# Remove empty entries
|
||||
if not self._request_windows[api_key]:
|
||||
del self._request_windows[api_key]
|
||||
|
||||
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||
"""Generate rate limit headers for HTTP response."""
|
||||
headers = {
|
||||
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||
}
|
||||
if status.retry_after_seconds:
|
||||
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||
return headers
|
||||
11
packages/inference/inference/web/schemas/__init__.py
Normal file
11
packages/inference/inference/web/schemas/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
API Schemas
|
||||
|
||||
Pydantic models for request/response validation.
|
||||
"""
|
||||
|
||||
# Import everything from sub-modules for backward compatibility
|
||||
from inference.web.schemas.common import * # noqa: F401, F403
|
||||
from inference.web.schemas.admin import * # noqa: F401, F403
|
||||
from inference.web.schemas.inference import * # noqa: F401, F403
|
||||
from inference.web.schemas.labeling import * # noqa: F401, F403
|
||||
17
packages/inference/inference/web/schemas/admin/__init__.py
Normal file
17
packages/inference/inference/web/schemas/admin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Admin API Request/Response Schemas
|
||||
|
||||
Pydantic models for admin API validation and serialization.
|
||||
"""
|
||||
|
||||
from .enums import * # noqa: F401, F403
|
||||
from .auth import * # noqa: F401, F403
|
||||
from .documents import * # noqa: F401, F403
|
||||
from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
DocumentDetailResponse.model_rebuild()
|
||||
152
packages/inference/inference/web/schemas/admin/annotations.py
Normal file
152
packages/inference/inference/web/schemas/admin/annotations.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Admin Annotation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AnnotationSource
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box coordinates."""
|
||||
|
||||
x: int = Field(..., ge=0, description="X coordinate (pixels)")
|
||||
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
|
||||
width: int = Field(..., ge=1, description="Width (pixels)")
|
||||
height: int = Field(..., ge=1, description="Height (pixels)")
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
"""Request to create an annotation."""
|
||||
|
||||
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
text_value: str | None = Field(None, description="Text value (optional)")
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
"""Request to update an annotation."""
|
||||
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
|
||||
bbox: BoundingBox | None = Field(None, description="New bounding box")
|
||||
text_value: str | None = Field(None, description="New text value")
|
||||
|
||||
|
||||
class AnnotationItem(BaseModel):
|
||||
"""Single annotation item."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
page_number: int = Field(..., ge=1, description="Page number")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID")
|
||||
class_name: str = Field(..., description="Class name")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
normalized_bbox: dict[str, float] = Field(
|
||||
..., description="Normalized bbox (x_center, y_center, width, height)"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Text value")
|
||||
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
|
||||
source: AnnotationSource = Field(..., description="Annotation source")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
"""Response for annotation operation."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
"""Response for annotation list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
page_count: int = Field(..., ge=1, description="Total pages")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
annotations: list[AnnotationItem] = Field(
|
||||
default_factory=list, description="Annotation list"
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockRequest(BaseModel):
|
||||
"""Request to acquire annotation lock."""
|
||||
|
||||
duration_seconds: int = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Lock duration in seconds (60-3600)",
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockResponse(BaseModel):
|
||||
"""Response for annotation lock operation."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
locked: bool = Field(..., description="Whether lock was acquired/released")
|
||||
lock_expires_at: datetime | None = Field(
|
||||
None, description="Lock expiration time"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AutoLabelRequest(BaseModel):
|
||||
"""Request to trigger auto-labeling."""
|
||||
|
||||
field_values: dict[str, str] = Field(
|
||||
...,
|
||||
description="Field values to match (e.g., {'invoice_number': '12345'})",
|
||||
)
|
||||
replace_existing: bool = Field(
|
||||
default=False, description="Replace existing auto annotations"
|
||||
)
|
||||
|
||||
|
||||
class AutoLabelResponse(BaseModel):
|
||||
"""Response for auto-labeling."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
status: str = Field(..., description="Auto-labeling status")
|
||||
annotations_created: int = Field(
|
||||
default=0, ge=0, description="Number of annotations created"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationVerifyRequest(BaseModel):
|
||||
"""Request to verify an annotation."""
|
||||
|
||||
pass # No body needed, just POST to verify
|
||||
|
||||
|
||||
class AnnotationVerifyResponse(BaseModel):
|
||||
"""Response for annotation verification."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
is_verified: bool = Field(..., description="Verification status")
|
||||
verified_at: datetime = Field(..., description="Verification timestamp")
|
||||
verified_by: str = Field(..., description="Admin token who verified")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationOverrideRequest(BaseModel):
|
||||
"""Request to override an annotation."""
|
||||
|
||||
bbox: dict[str, int] | None = Field(
|
||||
None, description="Updated bounding box {x, y, width, height}"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Updated text value")
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
|
||||
class_name: str | None = Field(None, description="Updated class name")
|
||||
reason: str | None = Field(None, description="Reason for override")
|
||||
|
||||
|
||||
class AnnotationOverrideResponse(BaseModel):
|
||||
"""Response for annotation override."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
source: str = Field(..., description="New source (manual)")
|
||||
override_source: str | None = Field(None, description="Original source (auto)")
|
||||
original_annotation_id: str | None = Field(None, description="Original annotation ID")
|
||||
message: str = Field(..., description="Status message")
|
||||
history_id: str = Field(..., description="History record UUID")
|
||||
23
packages/inference/inference/web/schemas/admin/auth.py
Normal file
23
packages/inference/inference/web/schemas/admin/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Admin Auth Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AdminTokenCreate(BaseModel):
|
||||
"""Request to create an admin token."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Token name")
|
||||
expires_in_days: int | None = Field(
|
||||
None, ge=1, le=365, description="Token expiration in days (optional)"
|
||||
)
|
||||
|
||||
|
||||
class AdminTokenResponse(BaseModel):
|
||||
"""Response with created admin token."""
|
||||
|
||||
token: str = Field(..., description="Admin token")
|
||||
name: str = Field(..., description="Token name")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration time")
|
||||
message: str = Field(..., description="Status message")
|
||||
85
packages/inference/inference/web/schemas/admin/datasets.py
Normal file
85
packages/inference/inference/web/schemas/admin/datasets.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Admin Dataset Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .training import TrainingConfig
|
||||
|
||||
|
||||
class DatasetCreateRequest(BaseModel):
|
||||
"""Request to create a training dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||
seed: int = Field(42, description="Random seed for split")
|
||||
|
||||
|
||||
class DatasetDocumentItem(BaseModel):
|
||||
"""Document within a dataset."""
|
||||
|
||||
document_id: str
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
|
||||
|
||||
class DatasetResponse(BaseModel):
|
||||
"""Response after creating a dataset."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class DatasetDetailResponse(BaseModel):
|
||||
"""Detailed dataset info with documents."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
documents: list[DatasetDocumentItem]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetListItem(BaseModel):
|
||||
"""Dataset in list view."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
"""Paginated dataset list."""
|
||||
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
datasets: list[DatasetListItem]
|
||||
|
||||
|
||||
class DatasetTrainRequest(BaseModel):
|
||||
"""Request to start training from a dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
|
||||
config: TrainingConfig = Field(..., description="Training configuration")
|
||||
103
packages/inference/inference/web/schemas/admin/documents.py
Normal file
103
packages/inference/inference/web/schemas/admin/documents.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Admin Document Schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AutoLabelStatus, DocumentStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .annotations import AnnotationItem
|
||||
from .training import TrainingHistoryItem
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Response for document upload."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class DocumentItem(BaseModel):
|
||||
"""Single document in list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""Response for document list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
documents: list[DocumentItem] = Field(
|
||||
default_factory=list, description="Document list"
|
||||
)
|
||||
|
||||
|
||||
class DocumentDetailResponse(BaseModel):
|
||||
"""Response for document detail."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
content_type: str = Field(..., description="MIME type")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
annotation_lock_until: datetime | None = Field(
|
||||
None, description="Lock expiration time if document is locked"
|
||||
)
|
||||
annotations: list["AnnotationItem"] = Field(
|
||||
default_factory=list, description="Document annotations"
|
||||
)
|
||||
image_urls: list[str] = Field(
|
||||
default_factory=list, description="URLs to page images"
|
||||
)
|
||||
training_history: list["TrainingHistoryItem"] = Field(
|
||||
default_factory=list, description="Training tasks that used this document"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentStatsResponse(BaseModel):
|
||||
"""Document statistics response."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
pending: int = Field(default=0, ge=0, description="Pending documents")
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
46
packages/inference/inference/web/schemas/admin/enums.py
Normal file
46
packages/inference/inference/web/schemas/admin/enums.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Admin API Enums."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Document status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
AUTO_LABELING = "auto_labeling"
|
||||
LABELED = "labeled"
|
||||
EXPORTED = "exported"
|
||||
|
||||
|
||||
class AutoLabelStatus(str, Enum):
|
||||
"""Auto-labeling status enum."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training task status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
SCHEDULED = "scheduled"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingType(str, Enum):
|
||||
"""Training task type enum."""
|
||||
|
||||
TRAIN = "train"
|
||||
FINETUNE = "finetune"
|
||||
|
||||
|
||||
class AnnotationSource(str, Enum):
|
||||
"""Annotation source enum."""
|
||||
|
||||
MANUAL = "manual"
|
||||
AUTO = "auto"
|
||||
IMPORTED = "imported"
|
||||
202
packages/inference/inference/web/schemas/admin/training.py
Normal file
202
packages/inference/inference/web/schemas/admin/training.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Admin Training Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import TrainingStatus, TrainingType
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
|
||||
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
|
||||
project_name: str = Field(
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Task name")
|
||||
description: str | None = Field(None, max_length=1000, description="Description")
|
||||
task_type: TrainingType = Field(
|
||||
default=TrainingType.TRAIN, description="Task type"
|
||||
)
|
||||
config: TrainingConfig = Field(
|
||||
default_factory=TrainingConfig, description="Training configuration"
|
||||
)
|
||||
scheduled_at: datetime | None = Field(
|
||||
None, description="Scheduled execution time"
|
||||
)
|
||||
cron_expression: str | None = Field(
|
||||
None, max_length=50, description="Cron expression for recurring tasks"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskItem(BaseModel):
|
||||
"""Single training task in list."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskListResponse(BaseModel):
|
||||
"""Response for training task list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total tasks")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
|
||||
|
||||
|
||||
class TrainingTaskDetailResponse(BaseModel):
|
||||
"""Response for training task detail."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
cron_expression: str | None = Field(None, description="Cron expression")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
error_message: str | None = Field(None, description="Error message")
|
||||
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
|
||||
model_path: str | None = Field(None, description="Trained model path")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskResponse(BaseModel):
|
||||
"""Response for training task operation."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingLogItem(BaseModel):
|
||||
"""Single training log entry."""
|
||||
|
||||
level: str = Field(..., description="Log level")
|
||||
message: str = Field(..., description="Log message")
|
||||
details: dict[str, Any] | None = Field(None, description="Additional details")
|
||||
created_at: datetime = Field(..., description="Timestamp")
|
||||
|
||||
|
||||
class TrainingLogsResponse(BaseModel):
|
||||
"""Response for training logs."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""Request to export annotations."""
|
||||
|
||||
format: str = Field(
|
||||
default="yolo", description="Export format (yolo, coco, voc)"
|
||||
)
|
||||
include_images: bool = Field(
|
||||
default=True, description="Include images in export"
|
||||
)
|
||||
split_ratio: float = Field(
|
||||
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
|
||||
)
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""Response for export operation."""
|
||||
|
||||
status: str = Field(..., description="Export status")
|
||||
export_path: str = Field(..., description="Path to exported dataset")
|
||||
total_images: int = Field(..., ge=0, description="Total images exported")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
train_count: int = Field(..., ge=0, description="Training set count")
|
||||
val_count: int = Field(..., ge=0, description="Validation set count")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingDocumentItem(BaseModel):
|
||||
"""Document item for training page."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Filename")
|
||||
annotation_count: int = Field(..., ge=0, description="Total annotations")
|
||||
annotation_sources: dict[str, int] = Field(
|
||||
..., description="Annotation counts by source (manual, auto)"
|
||||
)
|
||||
used_in_training: list[str] = Field(
|
||||
default_factory=list, description="List of training task IDs that used this document"
|
||||
)
|
||||
last_modified: datetime = Field(..., description="Last modification time")
|
||||
|
||||
|
||||
class TrainingDocumentsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/documents."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total document count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
documents: list[TrainingDocumentItem] = Field(
|
||||
default_factory=list, description="Documents available for training"
|
||||
)
|
||||
|
||||
|
||||
class ModelMetrics(BaseModel):
|
||||
"""Training model metrics."""
|
||||
|
||||
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
|
||||
|
||||
class TrainingModelItem(BaseModel):
|
||||
"""Trained model item for model list."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: TrainingStatus = Field(..., description="Training status")
|
||||
document_count: int = Field(..., ge=0, description="Documents used in training")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
metrics: ModelMetrics = Field(..., description="Model metrics")
|
||||
model_path: str | None = Field(None, description="Path to model weights")
|
||||
download_url: str | None = Field(None, description="Download URL for model")
|
||||
|
||||
|
||||
class TrainingModelsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/models."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
models: list[TrainingModelItem] = Field(
|
||||
default_factory=list, description="Trained models"
|
||||
)
|
||||
|
||||
|
||||
class TrainingHistoryItem(BaseModel):
|
||||
"""Training history for a document."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
trained_at: datetime = Field(..., description="Training timestamp")
|
||||
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")
|
||||
15
packages/inference/inference/web/schemas/common.py
Normal file
15
packages/inference/inference/web/schemas/common.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Common Schemas
|
||||
|
||||
Shared Pydantic models used across multiple API modules.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
status: str = Field(default="error", description="Error status")
|
||||
message: str = Field(..., description="Error message")
|
||||
detail: str | None = Field(None, description="Detailed error information")
|
||||
196
packages/inference/inference/web/schemas/inference.py
Normal file
196
packages/inference/inference/web/schemas/inference.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
API Request/Response Schemas
|
||||
|
||||
Pydantic models for API validation and serialization.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enums
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AsyncStatus(str, Enum):
|
||||
"""Async request status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sync API Schemas (existing)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Single detection result."""
|
||||
|
||||
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
|
||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||
|
||||
|
||||
class ExtractedField(BaseModel):
|
||||
"""Extracted and normalized field value."""
|
||||
|
||||
field_name: str = Field(..., description="Field name")
|
||||
value: str | None = Field(None, description="Extracted value")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
|
||||
is_valid: bool = Field(True, description="Whether the value passed validation")
|
||||
|
||||
|
||||
class InferenceResult(BaseModel):
|
||||
"""Complete inference result for a document."""
|
||||
|
||||
document_id: str = Field(..., description="Document identifier")
|
||||
success: bool = Field(..., description="Whether inference succeeded")
|
||||
document_type: str = Field(
|
||||
default="invoice", description="Document type: 'invoice' or 'letter'"
|
||||
)
|
||||
fields: dict[str, str | None] = Field(
|
||||
default_factory=dict, description="Extracted field values"
|
||||
)
|
||||
confidence: dict[str, float] = Field(
|
||||
default_factory=dict, description="Confidence scores per field"
|
||||
)
|
||||
detections: list[DetectionResult] = Field(
|
||||
default_factory=list, description="Raw YOLO detections"
|
||||
)
|
||||
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
|
||||
visualization_url: str | None = Field(
|
||||
None, description="URL to visualization image"
|
||||
)
|
||||
errors: list[str] = Field(default_factory=list, description="Error messages")
|
||||
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
"""API response for inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status: success or error")
|
||||
message: str = Field(..., description="Response message")
|
||||
result: InferenceResult | None = Field(None, description="Inference result")
|
||||
|
||||
|
||||
class BatchInferenceResponse(BaseModel):
|
||||
"""API response for batch inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status")
|
||||
message: str = Field(..., description="Response message")
|
||||
total: int = Field(..., description="Total documents processed")
|
||||
successful: int = Field(..., description="Number of successful extractions")
|
||||
results: list[InferenceResult] = Field(
|
||||
default_factory=list, description="Individual results"
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Service status")
|
||||
model_loaded: bool = Field(..., description="Whether model is loaded")
|
||||
gpu_available: bool = Field(..., description="Whether GPU is available")
|
||||
version: str = Field(..., description="API version")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
status: str = Field(default="error", description="Error status")
|
||||
message: str = Field(..., description="Error message")
|
||||
detail: str | None = Field(None, description="Detailed error information")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Async API Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AsyncSubmitResponse(BaseModel):
|
||||
"""Response for async submit endpoint."""
|
||||
|
||||
status: str = Field(default="accepted", description="Response status")
|
||||
message: str = Field(..., description="Response message")
|
||||
request_id: str = Field(..., description="Unique request identifier (UUID)")
|
||||
estimated_wait_seconds: int = Field(
|
||||
..., ge=0, description="Estimated wait time in seconds"
|
||||
)
|
||||
poll_url: str = Field(..., description="URL to poll for status updates")
|
||||
|
||||
|
||||
class AsyncStatusResponse(BaseModel):
|
||||
"""Response for async status endpoint."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Current processing status")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
created_at: datetime = Field(..., description="Request creation timestamp")
|
||||
started_at: datetime | None = Field(
|
||||
None, description="Processing start timestamp"
|
||||
)
|
||||
completed_at: datetime | None = Field(
|
||||
None, description="Processing completion timestamp"
|
||||
)
|
||||
position_in_queue: int | None = Field(
|
||||
None, description="Position in queue (for pending status)"
|
||||
)
|
||||
error_message: str | None = Field(
|
||||
None, description="Error message (for failed status)"
|
||||
)
|
||||
result_url: str | None = Field(
|
||||
None, description="URL to fetch results (for completed status)"
|
||||
)
|
||||
|
||||
|
||||
class AsyncResultResponse(BaseModel):
|
||||
"""Response for async result endpoint."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Processing status")
|
||||
processing_time_ms: float = Field(
|
||||
..., ge=0, description="Total processing time in milliseconds"
|
||||
)
|
||||
result: InferenceResult | None = Field(
|
||||
None, description="Extraction result (when completed)"
|
||||
)
|
||||
visualization_url: str | None = Field(
|
||||
None, description="URL to visualization image"
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestItem(BaseModel):
|
||||
"""Single item in async requests list."""
|
||||
|
||||
request_id: str = Field(..., description="Unique request identifier")
|
||||
status: AsyncStatus = Field(..., description="Current processing status")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
created_at: datetime = Field(..., description="Request creation timestamp")
|
||||
completed_at: datetime | None = Field(
|
||||
None, description="Processing completion timestamp"
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestsListResponse(BaseModel):
|
||||
"""Response for async requests list endpoint."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total number of requests")
|
||||
limit: int = Field(..., ge=1, description="Maximum items per page")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
requests: list[AsyncRequestItem] = Field(
|
||||
default_factory=list, description="List of requests"
|
||||
)
|
||||
|
||||
|
||||
class RateLimitInfo(BaseModel):
|
||||
"""Rate limit information (included in headers)."""
|
||||
|
||||
limit: int = Field(..., description="Maximum requests per minute")
|
||||
remaining: int = Field(..., description="Remaining requests in current window")
|
||||
reset_at: datetime = Field(..., description="Time when limit resets")
|
||||
13
packages/inference/inference/web/schemas/labeling.py
Normal file
13
packages/inference/inference/web/schemas/labeling.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Labeling API Schemas
|
||||
|
||||
Pydantic models for pre-labeling and label validation endpoints.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreLabelResponse(BaseModel):
|
||||
"""API response for pre-label endpoint."""
|
||||
|
||||
document_id: str = Field(..., description="Document identifier for retrieving results")
|
||||
18
packages/inference/inference/web/services/__init__.py
Normal file
18
packages/inference/inference/web/services/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Business Logic Services
|
||||
|
||||
Service layer for processing requests and orchestrating data operations.
|
||||
"""
|
||||
|
||||
from inference.web.services.autolabel import AutoLabelService, get_auto_label_service
|
||||
from inference.web.services.inference import InferenceService
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
__all__ = [
|
||||
"AutoLabelService",
|
||||
"get_auto_label_service",
|
||||
"InferenceService",
|
||||
"AsyncProcessingService",
|
||||
"BatchUploadService",
|
||||
]
|
||||
383
packages/inference/inference/web/services/async_processing.py
Normal file
383
packages/inference/inference/web/services/async_processing.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Async Processing Service
|
||||
|
||||
Manages async request lifecycle and background processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from threading import Event, Thread
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.services.inference import InferenceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncSubmitResult:
|
||||
"""Result from async submit operation."""
|
||||
|
||||
success: bool
|
||||
request_id: str | None = None
|
||||
estimated_wait_seconds: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AsyncProcessingService:
|
||||
"""
|
||||
Manages async request lifecycle and processing.
|
||||
|
||||
Coordinates between:
|
||||
- HTTP endpoints (submit/status/result)
|
||||
- Background task queue
|
||||
- Database storage
|
||||
- Rate limiting
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inference_service: "InferenceService",
|
||||
db: AsyncRequestDB,
|
||||
queue: AsyncTaskQueue,
|
||||
rate_limiter: RateLimiter,
|
||||
async_config: "AsyncConfig",
|
||||
storage_config: "StorageConfig",
|
||||
) -> None:
|
||||
self._inference = inference_service
|
||||
self._db = db
|
||||
self._queue = queue
|
||||
self._rate_limiter = rate_limiter
|
||||
self._async_config = async_config
|
||||
self._storage_config = storage_config
|
||||
|
||||
# Cleanup thread
|
||||
self._cleanup_stop_event = Event()
|
||||
self._cleanup_thread: Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the async processing service."""
|
||||
# Start the task queue with our handler
|
||||
self._queue.start(self._process_task)
|
||||
|
||||
# Start cleanup thread
|
||||
self._cleanup_stop_event.clear()
|
||||
self._cleanup_thread = Thread(
|
||||
target=self._cleanup_loop,
|
||||
name="async-cleanup",
|
||||
daemon=True,
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
logger.info("AsyncProcessingService started")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
"""Stop the async processing service."""
|
||||
# Stop cleanup thread
|
||||
self._cleanup_stop_event.set()
|
||||
if self._cleanup_thread and self._cleanup_thread.is_alive():
|
||||
self._cleanup_thread.join(timeout=5.0)
|
||||
|
||||
# Stop task queue
|
||||
self._queue.stop(timeout=timeout)
|
||||
logger.info("AsyncProcessingService stopped")
|
||||
|
||||
def submit_request(
|
||||
self,
|
||||
api_key: str,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
) -> AsyncSubmitResult:
|
||||
"""
|
||||
Submit a new async processing request.
|
||||
|
||||
Args:
|
||||
api_key: API key for the request
|
||||
file_content: File content as bytes
|
||||
filename: Original filename
|
||||
content_type: File content type
|
||||
|
||||
Returns:
|
||||
AsyncSubmitResult with request_id and status
|
||||
"""
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Save file to temp storage
|
||||
file_path = self._save_upload(request_id, filename, file_content)
|
||||
file_size = len(file_content)
|
||||
|
||||
try:
|
||||
# Calculate expiration
|
||||
expires_at = datetime.utcnow() + timedelta(
|
||||
days=self._async_config.result_retention_days
|
||||
)
|
||||
|
||||
# Create database record
|
||||
self._db.create_request(
|
||||
api_key=api_key,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
expires_at=expires_at,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Record rate limit event
|
||||
self._rate_limiter.record_request(api_key)
|
||||
|
||||
# Create and queue task
|
||||
task = AsyncTask(
|
||||
request_id=request_id,
|
||||
api_key=api_key,
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
if not self._queue.submit(task):
|
||||
# Queue is full
|
||||
self._db.update_status(
|
||||
request_id,
|
||||
"failed",
|
||||
error_message="Processing queue is full",
|
||||
)
|
||||
# Cleanup file
|
||||
file_path.unlink(missing_ok=True)
|
||||
return AsyncSubmitResult(
|
||||
success=False,
|
||||
request_id=request_id,
|
||||
error="Processing queue is full. Please try again later.",
|
||||
)
|
||||
|
||||
# Estimate wait time
|
||||
estimated_wait = self._estimate_wait()
|
||||
|
||||
return AsyncSubmitResult(
|
||||
success=True,
|
||||
request_id=request_id,
|
||||
estimated_wait_seconds=estimated_wait,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit request: {e}", exc_info=True)
|
||||
# Cleanup file on error
|
||||
file_path.unlink(missing_ok=True)
|
||||
return AsyncSubmitResult(
|
||||
success=False,
|
||||
# Return generic error message to avoid leaking implementation details
|
||||
error="Failed to process request. Please try again later.",
|
||||
)
|
||||
|
||||
# Allowed file extensions whitelist
|
||||
ALLOWED_EXTENSIONS = frozenset({".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"})
|
||||
|
||||
def _save_upload(
|
||||
self,
|
||||
request_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> Path:
|
||||
"""Save uploaded file to temp storage."""
|
||||
import re
|
||||
|
||||
# Extract extension from filename
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
# Validate extension: must be alphanumeric only (e.g., .pdf, .png)
|
||||
if not ext or not re.match(r'^\.[a-z0-9]+$', ext):
|
||||
ext = ".pdf"
|
||||
|
||||
# Validate against whitelist
|
||||
if ext not in self.ALLOWED_EXTENSIONS:
|
||||
ext = ".pdf"
|
||||
|
||||
# Create async upload directory
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build file path - request_id is a UUID so it's safe
|
||||
file_path = upload_dir / f"{request_id}{ext}"
|
||||
|
||||
# Defense in depth: ensure path is within upload_dir
|
||||
if not file_path.resolve().is_relative_to(upload_dir.resolve()):
|
||||
raise ValueError("Invalid file path detected")
|
||||
|
||||
file_path.write_bytes(content)
|
||||
|
||||
return file_path
|
||||
|
||||
def _process_task(self, task: AsyncTask) -> None:
|
||||
"""
|
||||
Process a single task (called by worker thread).
|
||||
|
||||
This method is called by the AsyncTaskQueue worker threads.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
self._db.update_status(task.request_id, "processing")
|
||||
|
||||
# Ensure file exists
|
||||
if not task.file_path.exists():
|
||||
raise FileNotFoundError(f"Upload file not found: {task.file_path}")
|
||||
|
||||
# Run inference based on file type
|
||||
file_ext = task.file_path.suffix.lower()
|
||||
if file_ext == ".pdf":
|
||||
result = self._inference.process_pdf(
|
||||
task.file_path,
|
||||
document_id=task.request_id[:8],
|
||||
)
|
||||
else:
|
||||
result = self._inference.process_image(
|
||||
task.file_path,
|
||||
document_id=task.request_id[:8],
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Prepare result for storage
|
||||
result_data = {
|
||||
"document_id": result.document_id,
|
||||
"success": result.success,
|
||||
"document_type": result.document_type,
|
||||
"fields": result.fields,
|
||||
"confidence": result.confidence,
|
||||
"detections": result.detections,
|
||||
"errors": result.errors,
|
||||
}
|
||||
|
||||
# Get visualization path as string
|
||||
viz_path = None
|
||||
if result.visualization_path:
|
||||
viz_path = str(result.visualization_path.name)
|
||||
|
||||
# Store result in database
|
||||
self._db.complete_request(
|
||||
request_id=task.request_id,
|
||||
document_id=result.document_id,
|
||||
result=result_data,
|
||||
processing_time_ms=processing_time_ms,
|
||||
visualization_path=viz_path,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Task {task.request_id} completed successfully "
|
||||
f"in {processing_time_ms:.0f}ms"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task {task.request_id} failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
self._db.update_status(
|
||||
task.request_id,
|
||||
"failed",
|
||||
error_message=str(e),
|
||||
increment_retry=True,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup uploaded file
|
||||
if task.file_path.exists():
|
||||
task.file_path.unlink(missing_ok=True)
|
||||
|
||||
def _estimate_wait(self) -> int:
|
||||
"""Estimate wait time based on queue depth."""
|
||||
queue_depth = self._queue.get_queue_depth()
|
||||
processing_count = self._queue.get_processing_count()
|
||||
total_pending = queue_depth + processing_count
|
||||
|
||||
# Estimate ~5 seconds per document
|
||||
avg_processing_time = 5
|
||||
return total_pending * avg_processing_time
|
||||
|
||||
def _cleanup_loop(self) -> None:
|
||||
"""Background cleanup loop."""
|
||||
logger.info("Cleanup thread started")
|
||||
cleanup_interval = self._async_config.cleanup_interval_hours * 3600
|
||||
|
||||
while not self._cleanup_stop_event.wait(timeout=cleanup_interval):
|
||||
try:
|
||||
self._run_cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Cleanup thread stopped")
|
||||
|
||||
def _run_cleanup(self) -> None:
|
||||
"""Run cleanup operations."""
|
||||
logger.info("Running cleanup...")
|
||||
|
||||
# Delete expired requests
|
||||
deleted_requests = self._db.delete_expired_requests()
|
||||
|
||||
# Reset stale processing requests
|
||||
reset_count = self._db.reset_stale_processing_requests(
|
||||
stale_minutes=self._async_config.task_timeout_seconds // 60,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Cleanup old rate limit events
|
||||
deleted_events = self._db.cleanup_old_rate_limit_events(hours=1)
|
||||
|
||||
# Cleanup old poll timestamps
|
||||
cleaned_polls = self._rate_limiter.cleanup_poll_timestamps()
|
||||
|
||||
# Cleanup rate limiter request windows
|
||||
self._rate_limiter.cleanup_request_windows()
|
||||
|
||||
# Cleanup orphaned upload files
|
||||
orphan_count = self._cleanup_orphan_files()
|
||||
|
||||
logger.info(
|
||||
f"Cleanup complete: {deleted_requests} expired requests, "
|
||||
f"{reset_count} stale requests reset, "
|
||||
f"{deleted_events} rate limit events, "
|
||||
f"{cleaned_polls} poll timestamps, "
|
||||
f"{orphan_count} orphan files"
|
||||
)
|
||||
|
||||
def _cleanup_orphan_files(self) -> int:
|
||||
"""Clean up upload files that don't have matching requests."""
|
||||
upload_dir = self._async_config.temp_upload_dir
|
||||
if not upload_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
# Files older than 1 hour without matching request are considered orphans
|
||||
cutoff = time.time() - 3600
|
||||
|
||||
for file_path in upload_dir.iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Check if file is old enough
|
||||
if file_path.stat().st_mtime > cutoff:
|
||||
continue
|
||||
|
||||
# Extract request_id from filename
|
||||
request_id = file_path.stem
|
||||
|
||||
# Check if request exists in database
|
||||
request = self._db.get_request(request_id)
|
||||
if request is None:
|
||||
file_path.unlink(missing_ok=True)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
335
packages/inference/inference/web/services/autolabel.py
Normal file
335
packages/inference/inference/web/services/autolabel.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Admin Auto-Labeling Service
|
||||
|
||||
Uses FieldMatcher to automatically create annotations from field values.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelService:
|
||||
"""Service for automatic document labeling using field matching."""
|
||||
|
||||
def __init__(self, ocr_engine: OCREngine | None = None):
|
||||
"""
|
||||
Initialize auto-label service.
|
||||
|
||||
Args:
|
||||
ocr_engine: OCR engine instance (creates one if not provided)
|
||||
"""
|
||||
self._ocr_engine = ocr_engine
|
||||
self._field_matcher = FieldMatcher()
|
||||
|
||||
@property
|
||||
def ocr_engine(self) -> OCREngine:
|
||||
"""Lazy initialization of OCR engine."""
|
||||
if self._ocr_engine is None:
|
||||
self._ocr_engine = OCREngine(lang="en")
|
||||
return self._ocr_engine
|
||||
|
||||
def auto_label_document(
|
||||
self,
|
||||
document_id: str,
|
||||
file_path: str,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
replace_existing: bool = False,
|
||||
skip_lock_check: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Auto-label a document using field matching.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID
|
||||
file_path: Path to document file
|
||||
field_values: Dict of field_name -> value to match
|
||||
db: Admin database instance
|
||||
replace_existing: Whether to replace existing auto annotations
|
||||
skip_lock_check: Skip annotation lock check (for batch processing)
|
||||
|
||||
Returns:
|
||||
Dict with status and annotation count
|
||||
"""
|
||||
try:
|
||||
# Get document info first
|
||||
document = db.get_document(document_id)
|
||||
if document is None:
|
||||
raise ValueError(f"Document not found: {document_id}")
|
||||
|
||||
# Check annotation lock unless explicitly skipped
|
||||
if not skip_lock_check:
|
||||
from datetime import datetime, timezone
|
||||
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||
if document.annotation_lock_until > datetime.now(timezone.utc):
|
||||
raise ValueError(
|
||||
f"Document is locked for annotation until {document.annotation_lock_until}. "
|
||||
"Auto-labeling skipped."
|
||||
)
|
||||
|
||||
# Update status to running
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
|
||||
# Delete existing auto annotations if requested
|
||||
if replace_existing:
|
||||
deleted = db.delete_annotations_for_document(
|
||||
document_id=document_id,
|
||||
source="auto",
|
||||
)
|
||||
logger.info(f"Deleted {deleted} existing auto annotations")
|
||||
|
||||
# Process document
|
||||
path = Path(file_path)
|
||||
annotations_created = 0
|
||||
|
||||
if path.suffix.lower() == ".pdf":
|
||||
# Process PDF (all pages)
|
||||
annotations_created = self._process_pdf(
|
||||
document_id, path, field_values, db
|
||||
)
|
||||
else:
|
||||
# Process single image
|
||||
annotations_created = self._process_image(
|
||||
document_id, path, field_values, db, page_number=1
|
||||
)
|
||||
|
||||
# Update document status
|
||||
status = "labeled" if annotations_created > 0 else "pending"
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status=status,
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"annotations_created": annotations_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-labeling failed for {document_id}: {e}")
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error=str(e),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"annotations_created": 0,
|
||||
}
|
||||
|
||||
def _process_pdf(
|
||||
self,
|
||||
document_id: str,
|
||||
pdf_path: Path,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
) -> int:
|
||||
"""Process PDF document and create annotations."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
import io
|
||||
|
||||
total_annotations = 0
|
||||
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
|
||||
# Convert to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Extract tokens
|
||||
tokens = self.ocr_engine.extract_from_image(
|
||||
image_array,
|
||||
page_no=page_no,
|
||||
)
|
||||
|
||||
# Find matches
|
||||
annotations = self._find_annotations(
|
||||
document_id,
|
||||
tokens,
|
||||
field_values,
|
||||
page_number=page_no + 1, # 1-indexed
|
||||
image_width=image_array.shape[1],
|
||||
image_height=image_array.shape[0],
|
||||
)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
db.create_annotations_batch(annotations)
|
||||
total_annotations += len(annotations)
|
||||
|
||||
return total_annotations
|
||||
|
||||
def _process_image(
|
||||
self,
|
||||
document_id: str,
|
||||
image_path: Path,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
page_number: int = 1,
|
||||
) -> int:
|
||||
"""Process single image and create annotations."""
|
||||
# Load image
|
||||
image = Image.open(image_path)
|
||||
image_array = np.array(image)
|
||||
|
||||
# Extract tokens
|
||||
tokens = self.ocr_engine.extract_from_image(
|
||||
image_array,
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
# Find matches
|
||||
annotations = self._find_annotations(
|
||||
document_id,
|
||||
tokens,
|
||||
field_values,
|
||||
page_number=page_number,
|
||||
image_width=image_array.shape[1],
|
||||
image_height=image_array.shape[0],
|
||||
)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
db.create_annotations_batch(annotations)
|
||||
|
||||
return len(annotations)
|
||||
|
||||
def _find_annotations(
|
||||
self,
|
||||
document_id: str,
|
||||
tokens: list[OCRToken],
|
||||
field_values: dict[str, str],
|
||||
page_number: int,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Find annotations for field values using token matching."""
|
||||
from shared.normalize import normalize_field
|
||||
|
||||
annotations = []
|
||||
|
||||
for field_name, value in field_values.items():
|
||||
if not value or not value.strip():
|
||||
continue
|
||||
|
||||
# Map field name to class ID
|
||||
class_id = self._get_class_id(field_name)
|
||||
if class_id is None:
|
||||
logger.warning(f"Unknown field name: {field_name}")
|
||||
continue
|
||||
|
||||
class_name = FIELD_CLASSES[class_id]
|
||||
|
||||
# Normalize value
|
||||
try:
|
||||
normalized_values = normalize_field(field_name, value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
|
||||
normalized_values = [value]
|
||||
|
||||
# Find matches
|
||||
matches = self._field_matcher.find_matches(
|
||||
tokens=tokens,
|
||||
field_name=field_name,
|
||||
normalized_values=normalized_values,
|
||||
page_no=page_number - 1, # 0-indexed for matcher
|
||||
)
|
||||
|
||||
# Take best match
|
||||
if matches:
|
||||
best_match = matches[0]
|
||||
bbox = best_match.bbox # (x0, y0, x1, y1)
|
||||
|
||||
# Calculate normalized coordinates (YOLO format)
|
||||
x_center = (bbox[0] + bbox[2]) / 2 / image_width
|
||||
y_center = (bbox[1] + bbox[3]) / 2 / image_height
|
||||
width = (bbox[2] - bbox[0]) / image_width
|
||||
height = (bbox[3] - bbox[1]) / image_height
|
||||
|
||||
# Pixel coordinates
|
||||
bbox_x = int(bbox[0])
|
||||
bbox_y = int(bbox[1])
|
||||
bbox_width = int(bbox[2] - bbox[0])
|
||||
bbox_height = int(bbox[3] - bbox[1])
|
||||
|
||||
annotations.append({
|
||||
"document_id": document_id,
|
||||
"page_number": page_number,
|
||||
"class_id": class_id,
|
||||
"class_name": class_name,
|
||||
"x_center": x_center,
|
||||
"y_center": y_center,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"bbox_x": bbox_x,
|
||||
"bbox_y": bbox_y,
|
||||
"bbox_width": bbox_width,
|
||||
"bbox_height": bbox_height,
|
||||
"text_value": best_match.matched_value,
|
||||
"confidence": best_match.score,
|
||||
"source": "auto",
|
||||
})
|
||||
|
||||
return annotations
|
||||
|
||||
def _get_class_id(self, field_name: str) -> int | None:
|
||||
"""Map field name to class ID."""
|
||||
# Direct match
|
||||
if field_name in FIELD_CLASS_IDS:
|
||||
return FIELD_CLASS_IDS[field_name]
|
||||
|
||||
# Handle alternative names
|
||||
name_mapping = {
|
||||
"InvoiceNumber": "invoice_number",
|
||||
"InvoiceDate": "invoice_date",
|
||||
"InvoiceDueDate": "invoice_due_date",
|
||||
"OCR": "ocr_number",
|
||||
"Bankgiro": "bankgiro",
|
||||
"Plusgiro": "plusgiro",
|
||||
"Amount": "amount",
|
||||
"supplier_organisation_number": "supplier_organisation_number",
|
||||
"PaymentLine": "payment_line",
|
||||
"customer_number": "customer_number",
|
||||
}
|
||||
|
||||
mapped_name = name_mapping.get(field_name)
|
||||
if mapped_name and mapped_name in FIELD_CLASS_IDS:
|
||||
return FIELD_CLASS_IDS[mapped_name]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global service instance
|
||||
_auto_label_service: AutoLabelService | None = None
|
||||
|
||||
|
||||
def get_auto_label_service() -> AutoLabelService:
|
||||
"""Get the auto-label service instance."""
|
||||
global _auto_label_service
|
||||
if _auto_label_service is None:
|
||||
_auto_label_service = AutoLabelService()
|
||||
return _auto_label_service
|
||||
|
||||
|
||||
def reset_auto_label_service() -> None:
|
||||
"""Reset the auto-label service (for testing)."""
|
||||
global _auto_label_service
|
||||
_auto_label_service = None
|
||||
548
packages/inference/inference/web/services/batch_upload.py
Normal file
548
packages/inference/inference/web/services/batch_upload.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Batch Upload Service
|
||||
|
||||
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Security limits
|
||||
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
|
||||
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
MAX_FILES_IN_ZIP = 1000
|
||||
|
||||
|
||||
class CSVRowData(BaseModel):
|
||||
"""Validated CSV row data with security checks."""
|
||||
|
||||
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
|
||||
invoice_number: str | None = Field(None, max_length=255)
|
||||
invoice_date: str | None = Field(None, max_length=50)
|
||||
invoice_due_date: str | None = Field(None, max_length=50)
|
||||
amount: str | None = Field(None, max_length=100)
|
||||
ocr: str | None = Field(None, max_length=100)
|
||||
bankgiro: str | None = Field(None, max_length=50)
|
||||
plusgiro: str | None = Field(None, max_length=50)
|
||||
customer_number: str | None = Field(None, max_length=255)
|
||||
supplier_organisation_number: str | None = Field(None, max_length=50)
|
||||
|
||||
@field_validator('*', mode='before')
|
||||
@classmethod
|
||||
def strip_whitespace(cls, v):
|
||||
"""Strip whitespace from all string fields."""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v
|
||||
|
||||
@field_validator('*', mode='before')
|
||||
@classmethod
|
||||
def reject_suspicious_patterns(cls, v):
|
||||
"""Reject values with suspicious characters."""
|
||||
if isinstance(v, str):
|
||||
# Reject SQL/shell metacharacters and newlines
|
||||
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
|
||||
if any(char in v for char in dangerous_chars):
|
||||
raise ValueError(f"Suspicious characters detected in value")
|
||||
return v
|
||||
|
||||
|
||||
class BatchUploadService:
|
||||
"""Service for handling batch uploads of documents via ZIP files."""
|
||||
|
||||
def __init__(self, admin_db: AdminDB):
|
||||
"""Initialize the batch upload service.
|
||||
|
||||
Args:
|
||||
admin_db: Admin database interface
|
||||
"""
|
||||
self.admin_db = admin_db
|
||||
|
||||
def _safe_extract_filename(self, zip_path: str) -> str:
|
||||
"""Safely extract filename from ZIP path, preventing path traversal.
|
||||
|
||||
Args:
|
||||
zip_path: Path from ZIP file entry
|
||||
|
||||
Returns:
|
||||
Safe filename
|
||||
|
||||
Raises:
|
||||
ValueError: If path contains traversal attempts or is invalid
|
||||
"""
|
||||
# Reject absolute paths
|
||||
if zip_path.startswith('/') or zip_path.startswith('\\'):
|
||||
raise ValueError(f"Absolute path rejected: {zip_path}")
|
||||
|
||||
# Reject path traversal attempts
|
||||
if '..' in zip_path:
|
||||
raise ValueError(f"Path traversal rejected: {zip_path}")
|
||||
|
||||
# Reject Windows drive letters
|
||||
if len(zip_path) >= 2 and zip_path[1] == ':':
|
||||
raise ValueError(f"Windows path rejected: {zip_path}")
|
||||
|
||||
# Get only the basename
|
||||
safe_name = Path(zip_path).name
|
||||
if not safe_name or safe_name in ['.', '..']:
|
||||
raise ValueError(f"Invalid filename: {zip_path}")
|
||||
|
||||
# Validate filename doesn't contain suspicious characters
|
||||
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
|
||||
raise ValueError(f"Invalid characters in filename: {safe_name}")
|
||||
|
||||
return safe_name
|
||||
|
||||
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
|
||||
"""Validate ZIP file against Zip bomb and other attacks.
|
||||
|
||||
Args:
|
||||
zip_file: Opened ZIP file
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP file is unsafe
|
||||
"""
|
||||
total_uncompressed = 0
|
||||
file_count = 0
|
||||
|
||||
for zip_info in zip_file.infolist():
|
||||
file_count += 1
|
||||
|
||||
# Check file count limit
|
||||
if file_count > MAX_FILES_IN_ZIP:
|
||||
raise ValueError(
|
||||
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
|
||||
)
|
||||
|
||||
# Check individual file size
|
||||
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
|
||||
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
|
||||
raise ValueError(
|
||||
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
# Accumulate uncompressed size
|
||||
total_uncompressed += zip_info.file_size
|
||||
|
||||
# Check total uncompressed size (Zip bomb protection)
|
||||
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
|
||||
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
|
||||
raise ValueError(
|
||||
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
# Validate filename safety
|
||||
try:
|
||||
self._safe_extract_filename(zip_info.filename)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Rejecting malicious ZIP entry: {e}")
|
||||
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
|
||||
|
||||
def process_zip_upload(
|
||||
self,
|
||||
admin_token: str,
|
||||
zip_filename: str,
|
||||
zip_content: bytes,
|
||||
upload_source: str = "ui",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a ZIP file containing PDFs and optional CSV.
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
zip_filename: Name of the ZIP file
|
||||
zip_content: ZIP file content as bytes
|
||||
upload_source: Upload source (ui or api)
|
||||
|
||||
Returns:
|
||||
Dictionary with batch upload results
|
||||
"""
|
||||
batch = self.admin_db.create_batch_upload(
|
||||
admin_token=admin_token,
|
||||
filename=zip_filename,
|
||||
file_size=len(zip_content),
|
||||
upload_source=upload_source,
|
||||
)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
|
||||
# Validate ZIP safety first
|
||||
self._validate_zip_safety(zip_file)
|
||||
|
||||
result = self._process_zip_contents(
|
||||
batch_id=batch.batch_id,
|
||||
admin_token=admin_token,
|
||||
zip_file=zip_file,
|
||||
)
|
||||
|
||||
# Update batch upload status
|
||||
self.admin_db.update_batch_upload(
|
||||
batch_id=batch.batch_id,
|
||||
status=result["status"],
|
||||
total_files=result["total_files"],
|
||||
processed_files=result["processed_files"],
|
||||
successful_files=result["successful_files"],
|
||||
failed_files=result["failed_files"],
|
||||
csv_filename=result.get("csv_filename"),
|
||||
csv_row_count=result.get("csv_row_count"),
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
**result,
|
||||
}
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
|
||||
self.admin_db.update_batch_upload(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Invalid ZIP file format",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "Invalid ZIP file format",
|
||||
}
|
||||
except ValueError as e:
|
||||
# Security validation errors
|
||||
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
|
||||
self.admin_db.update_batch_upload(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP file validation failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "ZIP file validation failed",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
|
||||
self.admin_db.update_batch_upload(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"status": "failed",
|
||||
"error": "Failed to process batch upload",
|
||||
}
|
||||
|
||||
def _process_zip_contents(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
admin_token: str,
|
||||
zip_file: zipfile.ZipFile,
|
||||
) -> dict[str, Any]:
|
||||
"""Process contents of ZIP file.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
zip_file: Opened ZIP file
|
||||
|
||||
Returns:
|
||||
Processing results dictionary
|
||||
"""
|
||||
# Extract file lists
|
||||
pdf_files = []
|
||||
csv_file = None
|
||||
csv_data = {}
|
||||
|
||||
for file_info in zip_file.filelist:
|
||||
if file_info.is_dir():
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use safe filename extraction
|
||||
filename = self._safe_extract_filename(file_info.filename)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Skipping invalid file: {e}")
|
||||
continue
|
||||
|
||||
if filename.lower().endswith('.pdf'):
|
||||
pdf_files.append(file_info)
|
||||
elif filename.lower().endswith('.csv'):
|
||||
if csv_file is None:
|
||||
csv_file = file_info
|
||||
# Parse CSV
|
||||
csv_data = self._parse_csv_file(zip_file, file_info)
|
||||
else:
|
||||
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
|
||||
|
||||
if not pdf_files:
|
||||
return {
|
||||
"status": "failed",
|
||||
"total_files": 0,
|
||||
"processed_files": 0,
|
||||
"successful_files": 0,
|
||||
"failed_files": 0,
|
||||
"error": "No PDF files found in ZIP",
|
||||
}
|
||||
|
||||
# Process each PDF file
|
||||
total_files = len(pdf_files)
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
for pdf_info in pdf_files:
|
||||
file_record = None
|
||||
|
||||
try:
|
||||
# Use safe filename extraction
|
||||
filename = self._safe_extract_filename(pdf_info.filename)
|
||||
|
||||
# Create batch upload file record
|
||||
file_record = self.admin_db.create_batch_upload_file(
|
||||
batch_id=batch_id,
|
||||
filename=filename,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
# Get CSV data for this file if available
|
||||
document_id_base = Path(filename).stem
|
||||
csv_row_data = csv_data.get(document_id_base)
|
||||
|
||||
# Extract PDF content
|
||||
pdf_content = zip_file.read(pdf_info.filename)
|
||||
|
||||
# TODO: Save PDF file and create document
|
||||
# For now, just mark as completed
|
||||
|
||||
self.admin_db.update_batch_upload_file(
|
||||
file_id=file_record.file_id,
|
||||
status="completed",
|
||||
csv_row_data=csv_row_data,
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
successful_files += 1
|
||||
|
||||
except ValueError as e:
|
||||
# Path validation error
|
||||
logger.warning(f"Skipping invalid file: {e}")
|
||||
if file_record:
|
||||
self.admin_db.update_batch_upload_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid filename",
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
failed_files += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF: {e}", exc_info=True)
|
||||
if file_record:
|
||||
self.admin_db.update_batch_upload_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
processed_at=datetime.utcnow(),
|
||||
)
|
||||
failed_files += 1
|
||||
|
||||
# Determine overall status
|
||||
if failed_files == 0:
|
||||
status = "completed"
|
||||
elif successful_files == 0:
|
||||
status = "failed"
|
||||
else:
|
||||
status = "partial"
|
||||
|
||||
result = {
|
||||
"status": status,
|
||||
"total_files": total_files,
|
||||
"processed_files": total_files,
|
||||
"successful_files": successful_files,
|
||||
"failed_files": failed_files,
|
||||
}
|
||||
|
||||
if csv_file:
|
||||
result["csv_filename"] = Path(csv_file.filename).name
|
||||
result["csv_row_count"] = len(csv_data)
|
||||
|
||||
return result
|
||||
|
||||
def _parse_csv_file(
|
||||
self,
|
||||
zip_file: zipfile.ZipFile,
|
||||
csv_file_info: zipfile.ZipInfo,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Parse CSV file and extract field values with validation.
|
||||
|
||||
Args:
|
||||
zip_file: Opened ZIP file
|
||||
csv_file_info: CSV file info
|
||||
|
||||
Returns:
|
||||
Dictionary mapping DocumentId to validated field values
|
||||
"""
|
||||
# Try multiple encodings
|
||||
csv_bytes = zip_file.read(csv_file_info.filename)
|
||||
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
|
||||
csv_content = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
csv_content = csv_bytes.decode(encoding)
|
||||
logger.info(f"CSV decoded with {encoding}")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if csv_content is None:
|
||||
logger.error("Failed to decode CSV with any encoding")
|
||||
raise ValueError("Unable to decode CSV file")
|
||||
|
||||
csv_reader = csv.DictReader(io.StringIO(csv_content))
|
||||
result = {}
|
||||
|
||||
# Case-insensitive column mapping
|
||||
field_name_map = {
|
||||
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
|
||||
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
|
||||
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
|
||||
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
|
||||
'Amount': ['Amount', 'amount'],
|
||||
'OCR': ['OCR', 'ocr'],
|
||||
'Bankgiro': ['Bankgiro', 'bankgiro'],
|
||||
'Plusgiro': ['Plusgiro', 'plusgiro'],
|
||||
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
|
||||
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
|
||||
}
|
||||
|
||||
for row_num, row in enumerate(csv_reader, start=2):
|
||||
try:
|
||||
# Create case-insensitive lookup
|
||||
row_lower = {k.lower(): v for k, v in row.items()}
|
||||
|
||||
# Find DocumentId with case-insensitive matching
|
||||
document_id = None
|
||||
for variant in field_name_map['DocumentId']:
|
||||
if variant.lower() in row_lower:
|
||||
document_id = row_lower[variant.lower()]
|
||||
break
|
||||
|
||||
if not document_id:
|
||||
logger.warning(f"Row {row_num}: No DocumentId found")
|
||||
continue
|
||||
|
||||
# Validate using Pydantic model
|
||||
csv_row_dict = {'document_id': document_id}
|
||||
|
||||
# Map CSV field names to model attribute names
|
||||
csv_to_model_attr = {
|
||||
'InvoiceNumber': 'invoice_number',
|
||||
'InvoiceDate': 'invoice_date',
|
||||
'InvoiceDueDate': 'invoice_due_date',
|
||||
'Amount': 'amount',
|
||||
'OCR': 'ocr',
|
||||
'Bankgiro': 'bankgiro',
|
||||
'Plusgiro': 'plusgiro',
|
||||
'customer_number': 'customer_number',
|
||||
'supplier_organisation_number': 'supplier_organisation_number',
|
||||
}
|
||||
|
||||
for csv_field in field_name_map.keys():
|
||||
if csv_field == 'DocumentId':
|
||||
continue
|
||||
|
||||
model_attr = csv_to_model_attr.get(csv_field)
|
||||
if not model_attr:
|
||||
continue
|
||||
|
||||
for variant in field_name_map[csv_field]:
|
||||
if variant.lower() in row_lower and row_lower[variant.lower()]:
|
||||
csv_row_dict[model_attr] = row_lower[variant.lower()]
|
||||
break
|
||||
|
||||
# Validate
|
||||
validated_row = CSVRowData(**csv_row_dict)
|
||||
|
||||
# Extract only the fields we care about (map back to CSV field names)
|
||||
field_values = {}
|
||||
model_attr_to_csv = {
|
||||
'invoice_number': 'InvoiceNumber',
|
||||
'invoice_date': 'InvoiceDate',
|
||||
'invoice_due_date': 'InvoiceDueDate',
|
||||
'amount': 'Amount',
|
||||
'ocr': 'OCR',
|
||||
'bankgiro': 'Bankgiro',
|
||||
'plusgiro': 'Plusgiro',
|
||||
'customer_number': 'customer_number',
|
||||
'supplier_organisation_number': 'supplier_organisation_number',
|
||||
}
|
||||
|
||||
for model_attr, csv_field in model_attr_to_csv.items():
|
||||
value = getattr(validated_row, model_attr, None)
|
||||
if value and csv_field in CSV_TO_CLASS_MAPPING:
|
||||
field_values[csv_field] = value
|
||||
|
||||
if field_values:
|
||||
result[document_id] = field_values
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Row {row_num}: Validation error - {e}")
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
|
||||
"""Get batch upload status.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
|
||||
Returns:
|
||||
Batch status dictionary
|
||||
"""
|
||||
batch = self.admin_db.get_batch_upload(UUID(batch_id))
|
||||
if not batch:
|
||||
return {
|
||||
"error": "Batch upload not found",
|
||||
}
|
||||
|
||||
files = self.admin_db.get_batch_upload_files(batch.batch_id)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
"filename": batch.filename,
|
||||
"status": batch.status,
|
||||
"total_files": batch.total_files,
|
||||
"processed_files": batch.processed_files,
|
||||
"successful_files": batch.successful_files,
|
||||
"failed_files": batch.failed_files,
|
||||
"csv_filename": batch.csv_filename,
|
||||
"csv_row_count": batch.csv_row_count,
|
||||
"error_message": batch.error_message,
|
||||
"created_at": batch.created_at.isoformat() if batch.created_at else None,
|
||||
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
||||
"files": [
|
||||
{
|
||||
"filename": f.filename,
|
||||
"status": f.status,
|
||||
"error_message": f.error_message,
|
||||
"annotation_count": f.annotation_count,
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
188
packages/inference/inference/web/services/dataset_builder.py
Normal file
188
packages/inference/inference/web/services/dataset_builder.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Dataset Builder Service
|
||||
|
||||
Creates training datasets by copying images from admin storage,
|
||||
generating YOLO label files, and splitting into train/val/test sets.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training datasets from admin documents."""
|
||||
|
||||
def __init__(self, db, base_dir: Path):
|
||||
self._db = db
|
||||
self._base_dir = Path(base_dir)
|
||||
|
||||
def build_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
"""Build a complete YOLO dataset from document IDs.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset record.
|
||||
document_ids: List of document UUIDs to include.
|
||||
train_ratio: Fraction for training set.
|
||||
val_ratio: Fraction for validation set.
|
||||
seed: Random seed for reproducible splits.
|
||||
admin_images_dir: Root directory of admin images.
|
||||
|
||||
Returns:
|
||||
Summary dict with total_documents, total_images, total_annotations.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid documents found.
|
||||
"""
|
||||
try:
|
||||
return self._do_build(
|
||||
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
||||
)
|
||||
except Exception as e:
|
||||
self._db.update_dataset_status(
|
||||
dataset_id=dataset_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
def _do_build(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
# 1. Fetch documents
|
||||
documents = self._db.get_documents_by_ids(document_ids)
|
||||
if not documents:
|
||||
raise ValueError("No valid documents found for the given IDs")
|
||||
|
||||
# 2. Create directory structure
|
||||
dataset_dir = self._base_dir / dataset_id
|
||||
for split in ["train", "val", "test"]:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Shuffle and split documents
|
||||
doc_list = list(documents)
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(doc_list)
|
||||
|
||||
n = len(doc_list)
|
||||
n_train = max(1, round(n * train_ratio))
|
||||
n_val = max(0, round(n * val_ratio))
|
||||
n_test = n - n_train - n_val
|
||||
|
||||
splits = (
|
||||
["train"] * n_train
|
||||
+ ["val"] * n_val
|
||||
+ ["test"] * n_test
|
||||
)
|
||||
|
||||
# 4. Process each document
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
dataset_docs = []
|
||||
|
||||
for doc, split in zip(doc_list, splits):
|
||||
doc_id = str(doc.document_id)
|
||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||
|
||||
# Group annotations by page
|
||||
page_annotations: dict[int, list] = {}
|
||||
for ann in annotations:
|
||||
page_annotations.setdefault(ann.page_number, []).append(ann)
|
||||
|
||||
doc_image_count = 0
|
||||
doc_ann_count = 0
|
||||
|
||||
# Copy images and write labels for each page
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
logger.warning("Image not found: %s", src_image)
|
||||
continue
|
||||
|
||||
dst_name = f"{doc_id}_page{page_num}"
|
||||
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
|
||||
shutil.copy2(src_image, dst_image)
|
||||
doc_image_count += 1
|
||||
|
||||
# Write YOLO label file
|
||||
page_anns = page_annotations.get(page_num, [])
|
||||
label_lines = []
|
||||
for ann in page_anns:
|
||||
label_lines.append(
|
||||
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
|
||||
f"{ann.width:.6f} {ann.height:.6f}"
|
||||
)
|
||||
doc_ann_count += 1
|
||||
|
||||
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
|
||||
label_path.write_text("\n".join(label_lines))
|
||||
|
||||
total_images += doc_image_count
|
||||
total_annotations += doc_ann_count
|
||||
|
||||
dataset_docs.append({
|
||||
"document_id": doc_id,
|
||||
"split": split,
|
||||
"page_count": doc_image_count,
|
||||
"annotation_count": doc_ann_count,
|
||||
})
|
||||
|
||||
# 5. Record document-split assignments in DB
|
||||
self._db.add_dataset_documents(
|
||||
dataset_id=dataset_id,
|
||||
documents=dataset_docs,
|
||||
)
|
||||
|
||||
# 6. Generate data.yaml
|
||||
self._generate_data_yaml(dataset_dir)
|
||||
|
||||
# 7. Update dataset status
|
||||
self._db.update_dataset_status(
|
||||
dataset_id=dataset_id,
|
||||
status="ready",
|
||||
total_documents=len(doc_list),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
dataset_path=str(dataset_dir),
|
||||
)
|
||||
|
||||
return {
|
||||
"total_documents": len(doc_list),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||
"""Generate YOLO data.yaml configuration file."""
|
||||
data = {
|
||||
"path": str(dataset_dir.absolute()),
|
||||
"train": "images/train",
|
||||
"val": "images/val",
|
||||
"test": "images/test",
|
||||
"nc": len(FIELD_CLASSES),
|
||||
"names": FIELD_CLASSES,
|
||||
}
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))
|
||||
531
packages/inference/inference/web/services/db_autolabel.py
Normal file
531
packages/inference/inference/web/services/db_autolabel.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Database-based Auto-labeling Service
|
||||
|
||||
Processes documents with field values stored in the database (csv_field_values).
|
||||
Used by the pre-label API to create annotations from expected values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
||||
from shared.data.db import DocumentDB
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize DocumentDB for saving match reports
|
||||
_document_db: DocumentDB | None = None
|
||||
|
||||
|
||||
def get_document_db() -> DocumentDB:
|
||||
"""Get or create DocumentDB instance with connection and tables initialized.
|
||||
|
||||
Follows the same pattern as CLI autolabel (src/cli/autolabel.py lines 370-373).
|
||||
"""
|
||||
global _document_db
|
||||
if _document_db is None:
|
||||
_document_db = DocumentDB()
|
||||
_document_db.connect()
|
||||
_document_db.create_tables() # Ensure tables exist
|
||||
logger.info("Connected to PostgreSQL DocumentDB for match reports")
|
||||
return _document_db
|
||||
|
||||
|
||||
def convert_csv_field_values_to_row_dict(
|
||||
document: AdminDocument,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert AdminDocument.csv_field_values to row_dict format for autolabel.
|
||||
|
||||
Args:
|
||||
document: AdminDocument with csv_field_values
|
||||
|
||||
Returns:
|
||||
Dictionary in row_dict format compatible with autolabel_tasks
|
||||
"""
|
||||
csv_values = document.csv_field_values or {}
|
||||
|
||||
# Build row_dict with DocumentId
|
||||
row_dict = {
|
||||
"DocumentId": str(document.document_id),
|
||||
}
|
||||
|
||||
# Map csv_field_values to row_dict format
|
||||
# csv_field_values uses keys like: InvoiceNumber, InvoiceDate, Amount, OCR, Bankgiro, etc.
|
||||
# row_dict uses same keys
|
||||
for key, value in csv_values.items():
|
||||
if value is not None and value != "":
|
||||
row_dict[key] = str(value)
|
||||
|
||||
return row_dict
|
||||
|
||||
|
||||
def get_pending_autolabel_documents(
|
||||
db: AdminDB,
|
||||
limit: int = 10,
|
||||
) -> list[AdminDocument]:
|
||||
"""
|
||||
Get documents pending auto-labeling.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance
|
||||
limit: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == "auto_labeling",
|
||||
AdminDocument.auto_label_status == "pending",
|
||||
).order_by(AdminDocument.created_at).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
|
||||
def process_document_autolabel(
|
||||
document: AdminDocument,
|
||||
db: AdminDB,
|
||||
output_dir: Path | None = None,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
min_confidence: float = 0.5,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a single document for auto-labeling using csv_field_values.
|
||||
|
||||
Args:
|
||||
document: AdminDocument with csv_field_values and file_path
|
||||
db: AdminDB instance for updating status
|
||||
output_dir: Output directory for temp files
|
||||
dpi: Rendering DPI
|
||||
min_confidence: Minimum match confidence
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status and annotations
|
||||
"""
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
from shared.pdf import PDFDocument
|
||||
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = Path("data/autolabel_output")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mark as processing
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if file exists
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# Convert csv_field_values to row_dict
|
||||
row_dict = convert_csv_field_values_to_row_dict(document)
|
||||
|
||||
if len(row_dict) <= 1: # Only has DocumentId
|
||||
raise ValueError("No field values to match")
|
||||
|
||||
# Determine PDF type (text or scanned)
|
||||
is_scanned = False
|
||||
with PDFDocument(file_path) as pdf_doc:
|
||||
# Check if first page has extractable text
|
||||
tokens = list(pdf_doc.extract_text_tokens(0))
|
||||
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
||||
|
||||
# Build task data
|
||||
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
|
||||
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
||||
storage_config = StorageConfig()
|
||||
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
|
||||
task_data = {
|
||||
"row_dict": row_dict,
|
||||
"pdf_path": str(pdf_path_for_report),
|
||||
"output_dir": str(output_dir),
|
||||
"dpi": dpi,
|
||||
"min_confidence": min_confidence,
|
||||
}
|
||||
|
||||
# Process based on PDF type
|
||||
if is_scanned:
|
||||
result = process_scanned_pdf(task_data)
|
||||
else:
|
||||
result = process_text_pdf(task_data)
|
||||
|
||||
# Save report to DocumentDB (same as CLI autolabel)
|
||||
if result.get("report"):
|
||||
try:
|
||||
doc_db = get_document_db()
|
||||
doc_db.save_document(result["report"])
|
||||
logger.info(f"Saved match report to DocumentDB for {document_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save report to DocumentDB: {e}")
|
||||
|
||||
# Save annotations to AdminDB
|
||||
if result.get("success") and result.get("report"):
|
||||
_save_annotations_to_db(
|
||||
db=db,
|
||||
document_id=document_id,
|
||||
report=result["report"],
|
||||
page_annotations=result.get("pages", []),
|
||||
dpi=dpi,
|
||||
)
|
||||
|
||||
# Mark as completed
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
else:
|
||||
# Mark as failed
|
||||
errors = result.get("report", {}).get("errors", ["Unknown error"])
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="; ".join(errors) if errors else "No annotations generated",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
|
||||
|
||||
# Mark as failed
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error=str(e),
|
||||
)
|
||||
|
||||
return {
|
||||
"doc_id": document_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def _save_annotations_to_db(
|
||||
db: AdminDB,
|
||||
document_id: str,
|
||||
report: dict[str, Any],
|
||||
page_annotations: list[dict[str, Any]],
|
||||
dpi: int = 200,
|
||||
) -> int:
|
||||
"""
|
||||
Save generated annotations to database.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance
|
||||
document_id: Document ID
|
||||
report: AutoLabelReport as dict
|
||||
page_annotations: List of page annotation data
|
||||
dpi: DPI used for rendering images (for coordinate conversion)
|
||||
|
||||
Returns:
|
||||
Number of annotations saved
|
||||
"""
|
||||
from PIL import Image
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS
|
||||
|
||||
# Mapping from CSV field names to internal field names
|
||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||
"InvoiceNumber": "invoice_number",
|
||||
"InvoiceDate": "invoice_date",
|
||||
"InvoiceDueDate": "invoice_due_date",
|
||||
"OCR": "ocr_number",
|
||||
"Bankgiro": "bankgiro",
|
||||
"Plusgiro": "plusgiro",
|
||||
"Amount": "amount",
|
||||
"supplier_organisation_number": "supplier_organisation_number",
|
||||
"customer_number": "customer_number",
|
||||
"payment_line": "payment_line",
|
||||
}
|
||||
|
||||
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
||||
scale = dpi / 72.0
|
||||
|
||||
# Cache for image dimensions per page
|
||||
image_dimensions: dict[int, tuple[int, int]] = {}
|
||||
|
||||
def get_image_dimensions(page_no: int) -> tuple[int, int] | None:
|
||||
"""Get image dimensions for a page (1-indexed)."""
|
||||
if page_no in image_dimensions:
|
||||
return image_dimensions[page_no]
|
||||
|
||||
# Try to load from admin_images
|
||||
admin_images_dir = Path("data/admin_images") / document_id
|
||||
image_path = admin_images_dir / f"page_{page_no}.png"
|
||||
|
||||
if image_path.exists():
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
dims = img.size # (width, height)
|
||||
image_dimensions[page_no] = dims
|
||||
return dims
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
annotation_count = 0
|
||||
|
||||
# Get field results from report (list of dicts)
|
||||
field_results = report.get("field_results", [])
|
||||
|
||||
for field_info in field_results:
|
||||
if not field_info.get("matched"):
|
||||
continue
|
||||
|
||||
csv_field_name = field_info.get("field_name", "")
|
||||
|
||||
# Map CSV field name to internal field name
|
||||
field_name = CSV_TO_INTERNAL_FIELD.get(csv_field_name, csv_field_name)
|
||||
|
||||
# Get class_id from field name
|
||||
class_id = FIELD_CLASS_IDS.get(field_name)
|
||||
if class_id is None:
|
||||
logger.warning(f"Unknown field name: {csv_field_name} -> {field_name}")
|
||||
continue
|
||||
|
||||
# Get bbox info (list: [x, y, x2, y2] in PDF points - 72 DPI)
|
||||
bbox = field_info.get("bbox", [])
|
||||
if not bbox or len(bbox) < 4:
|
||||
continue
|
||||
|
||||
# Convert PDF points (72 DPI) to pixel coordinates (at configured DPI)
|
||||
pdf_x1, pdf_y1, pdf_x2, pdf_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x1 = pdf_x1 * scale
|
||||
y1 = pdf_y1 * scale
|
||||
x2 = pdf_x2 * scale
|
||||
y2 = pdf_y2 * scale
|
||||
|
||||
bbox_width = x2 - x1
|
||||
bbox_height = y2 - y1
|
||||
|
||||
# Get page number (convert to 1-indexed)
|
||||
page_no = field_info.get("page_no", 0) + 1
|
||||
|
||||
# Get image dimensions for normalization
|
||||
dims = get_image_dimensions(page_no)
|
||||
if dims:
|
||||
img_width, img_height = dims
|
||||
# Calculate normalized coordinates
|
||||
x_center = (x1 + x2) / 2 / img_width
|
||||
y_center = (y1 + y2) / 2 / img_height
|
||||
width = bbox_width / img_width
|
||||
height = bbox_height / img_height
|
||||
else:
|
||||
# Fallback: use pixel coordinates as-is for normalization
|
||||
# (will be slightly off but better than /1000)
|
||||
logger.warning(f"Could not get image dimensions for page {page_no}, using estimates")
|
||||
# Estimate A4 at configured DPI: 595 x 842 points * scale
|
||||
estimated_width = 595 * scale
|
||||
estimated_height = 842 * scale
|
||||
x_center = (x1 + x2) / 2 / estimated_width
|
||||
y_center = (y1 + y2) / 2 / estimated_height
|
||||
width = bbox_width / estimated_width
|
||||
height = bbox_height / estimated_height
|
||||
|
||||
# Create annotation
|
||||
try:
|
||||
db.create_annotation(
|
||||
document_id=document_id,
|
||||
page_number=page_no,
|
||||
class_id=class_id,
|
||||
class_name=field_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=int(x1),
|
||||
bbox_y=int(y1),
|
||||
bbox_width=int(bbox_width),
|
||||
bbox_height=int(bbox_height),
|
||||
text_value=field_info.get("matched_text"),
|
||||
confidence=field_info.get("score"),
|
||||
source="auto",
|
||||
)
|
||||
annotation_count += 1
|
||||
logger.info(f"Saved annotation for {field_name}: bbox=({int(x1)}, {int(y1)}, {int(bbox_width)}, {int(bbox_height)})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save annotation for {field_name}: {e}")
|
||||
|
||||
return annotation_count
|
||||
|
||||
|
||||
def run_pending_autolabel_batch(
|
||||
db: AdminDB | None = None,
|
||||
batch_size: int = 10,
|
||||
output_dir: Path | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a batch of pending auto-label documents.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance (created if None)
|
||||
batch_size: Number of documents to process
|
||||
output_dir: Output directory for temp files
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
if db is None:
|
||||
db = AdminDB()
|
||||
|
||||
documents = get_pending_autolabel_documents(db, limit=batch_size)
|
||||
|
||||
results = {
|
||||
"total": len(documents),
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"documents": [],
|
||||
}
|
||||
|
||||
for doc in documents:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
db=db,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
doc_result = {
|
||||
"document_id": str(doc.document_id),
|
||||
"success": result.get("success", False),
|
||||
}
|
||||
|
||||
if result.get("success"):
|
||||
results["successful"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
doc_result["error"] = result.get("error") or "Unknown error"
|
||||
|
||||
results["documents"].append(doc_result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_manual_annotations_to_document_db(
|
||||
document: AdminDocument,
|
||||
annotations: list,
|
||||
db: AdminDB,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save manual annotations to PostgreSQL documents and field_results tables.
|
||||
|
||||
Called when user marks a document as 'labeled' from the web UI.
|
||||
This ensures manually labeled documents are also tracked in the same
|
||||
database as auto-labeled documents for consistency.
|
||||
|
||||
Args:
|
||||
document: AdminDocument instance
|
||||
annotations: List of AdminAnnotation instances
|
||||
db: AdminDB instance
|
||||
|
||||
Returns:
|
||||
Dict with success status and details
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
document_id = str(document.document_id)
|
||||
storage_config = StorageConfig()
|
||||
|
||||
# Build pdf_path using admin_upload_dir (same as auto-label)
|
||||
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
||||
|
||||
# Build report dict compatible with DocumentDB.save_document()
|
||||
field_results = []
|
||||
fields_total = len(annotations)
|
||||
fields_matched = 0
|
||||
|
||||
for ann in annotations:
|
||||
# All manual annotations are considered "matched" since user verified them
|
||||
field_result = {
|
||||
"field_name": ann.class_name,
|
||||
"csv_value": ann.text_value or "", # Manual annotations may not have CSV value
|
||||
"matched": True,
|
||||
"score": ann.confidence or 1.0, # Manual = high confidence
|
||||
"matched_text": ann.text_value,
|
||||
"candidate_used": "manual",
|
||||
"bbox": [ann.bbox_x, ann.bbox_y, ann.bbox_x + ann.bbox_width, ann.bbox_y + ann.bbox_height],
|
||||
"page_no": ann.page_number - 1, # Convert to 0-indexed
|
||||
"context_keywords": [],
|
||||
"error": None,
|
||||
}
|
||||
field_results.append(field_result)
|
||||
fields_matched += 1
|
||||
|
||||
# Determine PDF type
|
||||
pdf_type = "unknown"
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
from shared.pdf import PDFDocument
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
tokens = list(pdf_doc.extract_text_tokens(0))
|
||||
pdf_type = "scanned" if len(tokens) < 10 else "text"
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not determine PDF type: {e}")
|
||||
|
||||
# Build report
|
||||
report = {
|
||||
"document_id": document_id,
|
||||
"pdf_path": str(pdf_path),
|
||||
"pdf_type": pdf_type,
|
||||
"success": fields_matched > 0,
|
||||
"total_pages": document.page_count,
|
||||
"fields_matched": fields_matched,
|
||||
"fields_total": fields_total,
|
||||
"annotations_generated": fields_matched,
|
||||
"processing_time_ms": 0, # Manual labeling - no processing time
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"errors": [],
|
||||
"field_results": field_results,
|
||||
# Extended fields (from CSV if available)
|
||||
"split": None,
|
||||
"customer_number": document.csv_field_values.get("customer_number") if document.csv_field_values else None,
|
||||
"supplier_name": document.csv_field_values.get("supplier_name") if document.csv_field_values else None,
|
||||
"supplier_organisation_number": document.csv_field_values.get("supplier_organisation_number") if document.csv_field_values else None,
|
||||
"supplier_accounts": document.csv_field_values.get("supplier_accounts") if document.csv_field_values else None,
|
||||
}
|
||||
|
||||
# Save to PostgreSQL DocumentDB
|
||||
try:
|
||||
doc_db = get_document_db()
|
||||
doc_db.save_document(report)
|
||||
logger.info(f"Saved manual annotations to DocumentDB for {document_id}: {fields_matched} fields")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document_id": document_id,
|
||||
"fields_saved": fields_matched,
|
||||
"message": f"Saved {fields_matched} annotations to DocumentDB",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save manual annotations to DocumentDB: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"document_id": document_id,
|
||||
"error": str(e),
|
||||
}
|
||||
285
packages/inference/inference/web/services/inference.py
Normal file
285
packages/inference/inference/web/services/inference.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Inference Service
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
|
||||
document_id: str
|
||||
success: bool = False
|
||||
document_type: str = "invoice" # "invoice" or "letter"
|
||||
fields: dict[str, str | None] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
detections: list[dict] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration
|
||||
storage_config: Storage configuration
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing inference service...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(self.model_config.model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(self.model_config.model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
return self._is_initialized
|
||||
|
||||
@property
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process an image file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_visualization(image_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process a PDF file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization (render first page)
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF {pdf_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
24
packages/inference/inference/web/workers/__init__.py
Normal file
24
packages/inference/inference/web/workers/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Background Task Queues
|
||||
|
||||
Worker queues for asynchronous and batch processing.
|
||||
"""
|
||||
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask
|
||||
from inference.web.workers.batch_queue import (
|
||||
BatchTaskQueue,
|
||||
BatchTask,
|
||||
init_batch_queue,
|
||||
shutdown_batch_queue,
|
||||
get_batch_queue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncTaskQueue",
|
||||
"AsyncTask",
|
||||
"BatchTaskQueue",
|
||||
"BatchTask",
|
||||
"init_batch_queue",
|
||||
"shutdown_batch_queue",
|
||||
"get_batch_queue",
|
||||
]
|
||||
181
packages/inference/inference/web/workers/async_queue.py
Normal file
181
packages/inference/inference/web/workers/async_queue.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Async Task Queue
|
||||
|
||||
Thread-safe queue for background invoice processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from queue import Empty, Full, Queue
|
||||
import threading
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncTask:
|
||||
"""Task queued for background processing."""
|
||||
|
||||
request_id: str
|
||||
api_key: str
|
||||
file_path: Path
|
||||
filename: str
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
priority: int = 0 # Lower = higher priority (not implemented yet)
|
||||
|
||||
|
||||
class AsyncTaskQueue:
|
||||
"""Thread-safe queue for async invoice processing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 100,
|
||||
worker_count: int = 1,
|
||||
) -> None:
|
||||
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
|
||||
self._workers: list[Thread] = []
|
||||
self._stop_event = Event()
|
||||
self._worker_count = worker_count
|
||||
self._lock = Lock()
|
||||
self._processing: set[str] = set() # Currently processing request_ids
|
||||
self._task_handler: Callable[[AsyncTask], None] | None = None
|
||||
self._started = False
|
||||
|
||||
def start(self, task_handler: Callable[[AsyncTask], None]) -> None:
|
||||
"""Start background worker threads."""
|
||||
if self._started:
|
||||
logger.warning("AsyncTaskQueue already started")
|
||||
return
|
||||
|
||||
self._task_handler = task_handler
|
||||
self._stop_event.clear()
|
||||
|
||||
for i in range(self._worker_count):
|
||||
worker = Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"async-worker-{i}",
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
logger.info(f"Started async worker thread: {worker.name}")
|
||||
|
||||
self._started = True
|
||||
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
"""Gracefully stop all workers."""
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
logger.info("Stopping AsyncTaskQueue...")
|
||||
self._stop_event.set()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
worker.join(timeout=timeout / self._worker_count)
|
||||
if worker.is_alive():
|
||||
logger.warning(f"Worker {worker.name} did not stop gracefully")
|
||||
|
||||
self._workers.clear()
|
||||
self._started = False
|
||||
logger.info("AsyncTaskQueue stopped")
|
||||
|
||||
def submit(self, task: AsyncTask) -> bool:
|
||||
"""
|
||||
Submit a task to the queue.
|
||||
|
||||
Returns:
|
||||
True if task was queued, False if queue is full
|
||||
"""
|
||||
try:
|
||||
self._queue.put_nowait(task)
|
||||
logger.info(f"Task {task.request_id} queued for processing")
|
||||
return True
|
||||
except Full:
|
||||
logger.warning(f"Queue full, task {task.request_id} rejected")
|
||||
return False
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""Get current number of tasks in queue."""
|
||||
return self._queue.qsize()
|
||||
|
||||
def get_processing_count(self) -> int:
|
||||
"""Get number of tasks currently being processed."""
|
||||
with self._lock:
|
||||
return len(self._processing)
|
||||
|
||||
def is_processing(self, request_id: str) -> bool:
|
||||
"""Check if a specific request is currently being processed."""
|
||||
with self._lock:
|
||||
return request_id in self._processing
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the queue is running."""
|
||||
return self._started and not self._stop_event.is_set()
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker loop that processes tasks from queue."""
|
||||
thread_name = threading.current_thread().name
|
||||
logger.info(f"Worker {thread_name} started")
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Block for up to 1 second waiting for tasks
|
||||
task = self._queue.get(timeout=1.0)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
self._processing.add(task.request_id)
|
||||
|
||||
logger.info(
|
||||
f"Worker {thread_name} processing task {task.request_id}"
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
if self._task_handler:
|
||||
self._task_handler(task)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"Worker {thread_name} completed task {task.request_id} "
|
||||
f"in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Worker {thread_name} failed to process task "
|
||||
f"{task.request_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
finally:
|
||||
with self._lock:
|
||||
self._processing.discard(task.request_id)
|
||||
self._queue.task_done()
|
||||
|
||||
logger.info(f"Worker {thread_name} stopped")
|
||||
|
||||
def wait_for_completion(self, timeout: float | None = None) -> bool:
|
||||
"""
|
||||
Wait for all queued tasks to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
True if all tasks completed, False if timeout
|
||||
"""
|
||||
try:
|
||||
self._queue.join()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
225
packages/inference/inference/web/workers/batch_queue.py
Normal file
225
packages/inference/inference/web/workers/batch_queue.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Batch Upload Processing Queue
|
||||
|
||||
Background queue for async batch upload processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from queue import Queue, Full, Empty
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTask:
|
||||
"""Task for batch upload processing."""
|
||||
|
||||
batch_id: UUID
|
||||
admin_token: str
|
||||
zip_content: bytes
|
||||
zip_filename: str
|
||||
upload_source: str
|
||||
auto_label: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class BatchTaskQueue:
|
||||
"""Thread-safe queue for async batch upload processing."""
|
||||
|
||||
def __init__(self, max_size: int = 20, worker_count: int = 2):
|
||||
"""Initialize the batch task queue.
|
||||
|
||||
Args:
|
||||
max_size: Maximum queue size
|
||||
worker_count: Number of worker threads
|
||||
"""
|
||||
self._queue: Queue[BatchTask] = Queue(maxsize=max_size)
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._stop_event = threading.Event()
|
||||
self._worker_count = worker_count
|
||||
self._batch_service: Any | None = None
|
||||
self._running = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self, batch_service: Any) -> None:
|
||||
"""Start worker threads with batch service.
|
||||
|
||||
Args:
|
||||
batch_service: BatchUploadService instance for processing
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("Batch queue already running")
|
||||
return
|
||||
|
||||
self._batch_service = batch_service
|
||||
self._stop_event.clear()
|
||||
self._running = True
|
||||
|
||||
# Start worker threads
|
||||
for i in range(self._worker_count):
|
||||
worker = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"BatchWorker-{i}",
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
logger.info(f"Started {self._worker_count} batch workers")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
"""Stop all worker threads gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for workers to finish
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping batch queue...")
|
||||
self._stop_event.set()
|
||||
self._running = False
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
worker.join(timeout=timeout)
|
||||
|
||||
self._workers.clear()
|
||||
logger.info("Batch queue stopped")
|
||||
|
||||
def submit(self, task: BatchTask) -> bool:
|
||||
"""Submit a batch task to the queue.
|
||||
|
||||
Args:
|
||||
task: Batch task to process
|
||||
|
||||
Returns:
|
||||
True if task was queued, False if queue is full
|
||||
"""
|
||||
try:
|
||||
self._queue.put(task, block=False)
|
||||
logger.info(f"Queued batch task: batch_id={task.batch_id}")
|
||||
return True
|
||||
except Full:
|
||||
logger.warning(f"Queue full, rejected task: batch_id={task.batch_id}")
|
||||
return False
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""Get the number of pending tasks in queue.
|
||||
|
||||
Returns:
|
||||
Number of tasks waiting to be processed
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue is running.
|
||||
|
||||
Returns:
|
||||
True if queue is active
|
||||
"""
|
||||
return self._running
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker thread main loop."""
|
||||
worker_name = threading.current_thread().name
|
||||
logger.info(f"{worker_name} started")
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Get task with timeout to check stop event periodically
|
||||
task = self._queue.get(timeout=1.0)
|
||||
self._process_task(task)
|
||||
self._queue.task_done()
|
||||
except Empty:
|
||||
# No tasks, continue loop to check stop event
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"{worker_name} error processing task: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"{worker_name} stopped")
|
||||
|
||||
def _process_task(self, task: BatchTask) -> None:
|
||||
"""Process a single batch task.
|
||||
|
||||
Args:
|
||||
task: Batch task to process
|
||||
"""
|
||||
if self._batch_service is None:
|
||||
logger.error("Batch service not initialized, cannot process task")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Processing batch task: batch_id={task.batch_id}, "
|
||||
f"filename={task.zip_filename}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Process the batch upload using the service
|
||||
result = self._batch_service.process_zip_upload(
|
||||
admin_token=task.admin_token,
|
||||
zip_filename=task.zip_filename,
|
||||
zip_content=task.zip_content,
|
||||
upload_source=task.upload_source,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch task completed: batch_id={task.batch_id}, "
|
||||
f"status={result.get('status')}, "
|
||||
f"successful_files={result.get('successful_files')}, "
|
||||
f"failed_files={result.get('failed_files')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing batch task {task.batch_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Global batch queue instance
|
||||
_batch_queue: BatchTaskQueue | None = None
|
||||
_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_batch_queue() -> BatchTaskQueue:
|
||||
"""Get or create the global batch queue instance.
|
||||
|
||||
Returns:
|
||||
Batch task queue instance
|
||||
"""
|
||||
global _batch_queue
|
||||
|
||||
if _batch_queue is None:
|
||||
with _queue_lock:
|
||||
if _batch_queue is None:
|
||||
_batch_queue = BatchTaskQueue(max_size=20, worker_count=2)
|
||||
|
||||
return _batch_queue
|
||||
|
||||
|
||||
def init_batch_queue(batch_service: Any) -> None:
|
||||
"""Initialize and start the batch queue.
|
||||
|
||||
Args:
|
||||
batch_service: BatchUploadService instance
|
||||
"""
|
||||
queue = get_batch_queue()
|
||||
if not queue.is_running:
|
||||
queue.start(batch_service)
|
||||
|
||||
|
||||
def shutdown_batch_queue() -> None:
|
||||
"""Shutdown the batch queue gracefully."""
|
||||
global _batch_queue
|
||||
|
||||
if _batch_queue is not None:
|
||||
_batch_queue.stop()
|
||||
8
packages/inference/requirements.txt
Normal file
8
packages/inference/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
-e ../shared
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
sqlmodel>=0.0.22
|
||||
ultralytics>=8.1.0
|
||||
httpx>=0.25.0
|
||||
openai>=1.0.0
|
||||
14
packages/inference/run_server.py
Normal file
14
packages/inference/run_server.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Quick start script for the web server.
|
||||
|
||||
Usage:
|
||||
python run_server.py
|
||||
python run_server.py --port 8080
|
||||
python run_server.py --debug --reload
|
||||
"""
|
||||
|
||||
from inference.cli.serve import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
packages/inference/setup.py
Normal file
17
packages/inference/setup.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="invoice-inference",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
install_requires=[
|
||||
"invoice-shared",
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"sqlmodel>=0.0.22",
|
||||
"ultralytics>=8.1.0",
|
||||
"httpx>=0.25.0",
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user