Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -175,6 +175,80 @@ def run_migrations() -> None:
);
""",
),
# Migration 007: Add extra columns to training_tasks
(
"training_tasks_name",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
""",
),
(
"training_tasks_description",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
""",
),
(
"training_tasks_admin_token",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
""",
),
(
"training_tasks_task_type",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
""",
),
(
"training_tasks_recurring",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
""",
),
(
"training_tasks_metrics",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
""",
),
(
"training_tasks_updated_at",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
""",
),
# Migration 008: Fix model_versions foreign key constraints
(
"model_versions_fk_fix",
"""
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_dataset_id_fkey
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_task_id_fkey
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
""",
),
# Migration 006b: Ensure only one active model at a time
(
"model_versions_single_active",
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
ON model_versions(is_active) WHERE is_active = TRUE;
""",
),
]
with engine.connect() as conn:

View File

@@ -193,6 +193,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
session.delete(annotation)
session.commit()
return True
return False
@@ -216,6 +217,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
count = len(annotations)
for ann in annotations:
session.delete(ann)
session.commit()
return count
def verify(

View File

@@ -203,6 +203,14 @@ class DatasetRepository(BaseRepository[TrainingDataset]):
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return False
# Delete associated document links first
doc_links = session.exec(
select(DatasetDocument).where(
DatasetDocument.dataset_id == UUID(str(dataset_id))
)
).all()
for link in doc_links:
session.delete(link)
session.delete(dataset)
session.commit()
return True

View File

@@ -264,6 +264,7 @@ class DocumentRepository(BaseRepository[AdminDocument]):
for ann in annotations:
session.delete(ann)
session.delete(document)
session.commit()
return True
return False
@@ -389,7 +390,11 @@ class DocumentRepository(BaseRepository[AdminDocument]):
return None
now = datetime.now(timezone.utc)
if doc.annotation_lock_until and doc.annotation_lock_until > now:
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if lock_until and lock_until > now:
return None
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
@@ -433,10 +438,14 @@ class DocumentRepository(BaseRepository[AdminDocument]):
return None
now = datetime.now(timezone.utc)
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if not lock_until or lock_until <= now:
return None
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds)
session.add(doc)
session.commit()
session.refresh(doc)

View File

@@ -118,6 +118,22 @@ class TrainingTaskRepository(BaseRepository[TrainingTask]):
session.expunge(r)
return list(results)
def get_running(self) -> TrainingTask | None:
"""Get currently running training task.
Returns:
Running task or None if no task is running
"""
with get_session_context() as session:
result = session.exec(
select(TrainingTask)
.where(TrainingTask.status == "running")
.order_by(TrainingTask.started_at.desc())
).first()
if result:
session.expunge(result)
return result
def update_status(
self,
task_id: str,

View File

@@ -55,5 +55,6 @@ def create_normalizer_registry(
"Amount": amount_normalizer,
"InvoiceDate": date_normalizer,
"InvoiceDueDate": date_normalizer,
"supplier_org_number": SupplierOrgNumberNormalizer(),
# Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping)
"supplier_organisation_number": SupplierOrgNumberNormalizer(),
}

View File

@@ -481,11 +481,22 @@ def create_annotation_router() -> APIRouter:
detail="At least one field value is required",
)
# Get the actual file path from storage
# document.file_path is a relative storage path like "raw_pdfs/uuid.pdf"
storage = get_storage_helper()
filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path
file_path = storage.get_raw_pdf_local_path(filename)
if file_path is None:
raise HTTPException(
status_code=500,
detail=f"Cannot find PDF file: {document.file_path}",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=document.file_path,
file_path=str(file_path),
field_values=request.field_values,
doc_repo=doc_repo,
ann_repo=ann_repo,

View File

@@ -6,7 +6,7 @@ FastAPI endpoints for admin token management.
import logging
import secrets
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter
@@ -41,10 +41,10 @@ def create_auth_router() -> APIRouter:
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration
# Calculate expiration (use timezone-aware datetime)
expires_at = None
if request.expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days)
# Create token in database
tokens.create(

View File

@@ -0,0 +1,135 @@
"""
Dashboard API Routes
FastAPI endpoints for dashboard statistics and activity.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from inference.web.core.auth import (
AdminTokenDep,
get_model_version_repository,
get_training_task_repository,
ModelVersionRepoDep,
TrainingTaskRepoDep,
)
from inference.web.schemas.admin import (
DashboardStatsResponse,
ActiveModelResponse,
ActiveModelInfo,
RunningTrainingInfo,
RecentActivityResponse,
ActivityItem,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
)
logger = logging.getLogger(__name__)
def create_dashboard_router() -> APIRouter:
"""Create dashboard API router."""
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
@router.get(
"/stats",
response_model=DashboardStatsResponse,
summary="Get dashboard statistics",
description="Returns document counts and annotation completeness metrics.",
)
async def get_dashboard_stats(
admin_token: AdminTokenDep,
) -> DashboardStatsResponse:
"""Get dashboard statistics."""
service = DashboardStatsService()
stats = service.get_stats()
return DashboardStatsResponse(
total_documents=stats["total_documents"],
annotation_complete=stats["annotation_complete"],
annotation_incomplete=stats["annotation_incomplete"],
pending=stats["pending"],
completeness_rate=stats["completeness_rate"],
)
@router.get(
"/active-model",
response_model=ActiveModelResponse,
summary="Get active model info",
description="Returns current active model and running training status.",
)
async def get_active_model(
admin_token: AdminTokenDep,
model_repo: ModelVersionRepoDep,
task_repo: TrainingTaskRepoDep,
) -> ActiveModelResponse:
"""Get active model and training status."""
# Get active model
active_model = model_repo.get_active()
model_info = None
if active_model:
model_info = ActiveModelInfo(
version_id=str(active_model.version_id),
version=active_model.version,
name=active_model.name,
metrics_mAP=active_model.metrics_mAP,
metrics_precision=active_model.metrics_precision,
metrics_recall=active_model.metrics_recall,
document_count=active_model.document_count,
activated_at=active_model.activated_at,
)
# Get running training task
running_task = task_repo.get_running()
training_info = None
if running_task:
training_info = RunningTrainingInfo(
task_id=str(running_task.task_id),
name=running_task.name,
status=running_task.status,
started_at=running_task.started_at,
progress=running_task.progress or 0,
)
return ActiveModelResponse(
model=model_info,
running_training=training_info,
)
@router.get(
"/activity",
response_model=RecentActivityResponse,
summary="Get recent activity",
description="Returns recent system activities sorted by timestamp.",
)
async def get_recent_activity(
admin_token: AdminTokenDep,
limit: Annotated[
int,
Query(ge=1, le=50, description="Maximum number of activities"),
] = 10,
) -> RecentActivityResponse:
"""Get recent system activity."""
service = DashboardActivityService()
activities = service.get_recent_activities(limit=limit)
return RecentActivityResponse(
activities=[
ActivityItem(
type=act["type"],
description=act["description"],
timestamp=act["timestamp"],
metadata=act["metadata"],
)
for act in activities
]
)
return router

View File

@@ -44,6 +44,7 @@ from inference.web.api.v1.admin import (
create_locks_router,
create_training_router,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.scheduler import start_scheduler, stop_scheduler
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
@@ -115,13 +116,21 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize database tables
# Initialize async request database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.)
try:
from inference.data.database import create_db_and_tables
create_db_and_tables()
logger.info("Admin database tables ready")
except Exception as e:
logger.error(f"Failed to initialize admin database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
@@ -279,6 +288,10 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include dashboard routes
dashboard_router = create_dashboard_router()
app.include_router(dashboard_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)

View File

@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .dashboard import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse

View File

@@ -0,0 +1,92 @@
"""
Dashboard API Schemas
Pydantic models for dashboard statistics and activity endpoints.
"""
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, Field
# Activity type literals for type safety
ActivityType = Literal[
"document_uploaded",
"annotation_modified",
"training_completed",
"training_failed",
"model_activated",
]
class DashboardStatsResponse(BaseModel):
"""Response for dashboard statistics."""
total_documents: int = Field(..., description="Total number of documents")
annotation_complete: int = Field(
..., description="Documents with complete annotations"
)
annotation_incomplete: int = Field(
..., description="Documents with incomplete annotations"
)
pending: int = Field(..., description="Documents pending processing")
completeness_rate: float = Field(
..., description="Annotation completeness percentage"
)
class ActiveModelInfo(BaseModel):
"""Active model information."""
version_id: str = Field(..., description="Model version UUID")
version: str = Field(..., description="Model version string")
name: str = Field(..., description="Model name")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision score")
metrics_recall: float | None = Field(None, description="Recall score")
document_count: int = Field(0, description="Number of training documents")
activated_at: datetime | None = Field(None, description="Activation timestamp")
class RunningTrainingInfo(BaseModel):
"""Running training task information."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
status: str = Field(..., description="Training status")
started_at: datetime | None = Field(None, description="Start timestamp")
progress: int = Field(0, description="Training progress percentage")
class ActiveModelResponse(BaseModel):
"""Response for active model endpoint."""
model: ActiveModelInfo | None = Field(
None, description="Active model info, null if none"
)
running_training: RunningTrainingInfo | None = Field(
None, description="Running training task, null if none"
)
class ActivityItem(BaseModel):
"""Single activity item."""
type: ActivityType = Field(
...,
description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated",
)
description: str = Field(..., description="Human-readable description")
timestamp: datetime = Field(..., description="Activity timestamp")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class RecentActivityResponse(BaseModel):
"""Response for recent activity endpoint."""
activities: list[ActivityItem] = Field(
default_factory=list, description="List of recent activities"
)

View File

@@ -291,7 +291,7 @@ class AutoLabelService:
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_value,
"text_value": best_match.matched_text,
"confidence": best_match.score,
"source": "auto",
})

View File

@@ -0,0 +1,276 @@
"""
Dashboard Service
Business logic for dashboard statistics and activity aggregation.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import func, exists, and_, or_
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import (
AdminDocument,
AdminAnnotation,
AnnotationHistory,
TrainingTask,
ModelVersion,
)
logger = logging.getLogger(__name__)
# Field class IDs for completeness calculation
# Identifiers: invoice_number (0) or ocr_number (3)
IDENTIFIER_CLASS_IDS = {0, 3}
# Payment accounts: bankgiro (4) or plusgiro (5)
PAYMENT_CLASS_IDS = {4, 5}
def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool:
"""Check if a document's annotations are complete.
A document is complete if it has:
- At least one identifier field (invoice_number OR ocr_number)
- At least one payment field (bankgiro OR plusgiro)
Args:
annotations: List of annotation dicts with class_id
Returns:
True if document has required fields
"""
class_ids = {ann.get("class_id") for ann in annotations}
has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS)
has_payment = bool(class_ids & PAYMENT_CLASS_IDS)
return has_identifier and has_payment
class DashboardStatsService:
"""Service for computing dashboard statistics."""
def get_stats(self) -> dict[str, Any]:
"""Get dashboard statistics.
Returns:
Dict with total_documents, annotation_complete, annotation_incomplete,
pending, and completeness_rate
"""
with get_session_context() as session:
# Total documents
total = session.exec(
select(func.count()).select_from(AdminDocument)
).one()
# Pending documents (status in ['pending', 'auto_labeling'])
pending = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status.in_(["pending", "auto_labeling"]))
).one()
# Complete annotations: labeled + has identifier + has payment
complete = self._count_complete(session)
# Incomplete: labeled but not complete
labeled_count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status == "labeled")
).one()
incomplete = labeled_count - complete
# Calculate completeness rate
total_assessed = complete + incomplete
completeness_rate = (
round(complete / total_assessed * 100, 2)
if total_assessed > 0
else 0.0
)
return {
"total_documents": total,
"annotation_complete": complete,
"annotation_incomplete": incomplete,
"pending": pending,
"completeness_rate": completeness_rate,
}
def _count_complete(self, session) -> int:
"""Count documents with complete annotations.
A document is complete if it:
1. Has status = 'labeled'
2. Has at least one identifier annotation (class_id 0 or 3)
3. Has at least one payment annotation (class_id 4 or 5)
"""
# Subquery for documents with identifier
has_identifier = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS),
)
)
)
# Subquery for documents with payment
has_payment = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS),
)
)
)
count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(
and_(
AdminDocument.status == "labeled",
has_identifier,
has_payment,
)
)
).one()
return count
class DashboardActivityService:
"""Service for aggregating recent activities."""
def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]:
"""Get recent system activities.
Aggregates from:
- Document uploads
- Annotation modifications
- Training completions/failures
- Model activations
Args:
limit: Maximum number of activities to return
Returns:
List of activity dicts sorted by timestamp DESC
"""
activities = []
with get_session_context() as session:
# Document uploads (recent 10)
uploads = session.exec(
select(AdminDocument)
.order_by(AdminDocument.created_at.desc())
.limit(limit)
).all()
for doc in uploads:
activities.append({
"type": "document_uploaded",
"description": f"Uploaded {doc.filename}",
"timestamp": doc.created_at,
"metadata": {
"document_id": str(doc.document_id),
"filename": doc.filename,
},
})
# Annotation modifications (from history)
modifications = session.exec(
select(AnnotationHistory)
.where(AnnotationHistory.action == "override")
.order_by(AnnotationHistory.created_at.desc())
.limit(limit)
).all()
for mod in modifications:
# Get document filename
doc = session.get(AdminDocument, mod.document_id)
filename = doc.filename if doc else "Unknown"
field_name = ""
if mod.new_value and isinstance(mod.new_value, dict):
field_name = mod.new_value.get("class_name", "")
activities.append({
"type": "annotation_modified",
"description": f"Modified {filename} {field_name}".strip(),
"timestamp": mod.created_at,
"metadata": {
"annotation_id": str(mod.annotation_id),
"document_id": str(mod.document_id),
"field_name": field_name,
},
})
# Training completions and failures
training_tasks = session.exec(
select(TrainingTask)
.where(TrainingTask.status.in_(["completed", "failed"]))
.order_by(TrainingTask.updated_at.desc())
.limit(limit)
).all()
for task in training_tasks:
if task.updated_at is None:
continue
if task.status == "completed":
# Use metrics_mAP field directly
mAP = task.metrics_mAP or 0.0
activities.append({
"type": "training_completed",
"description": f"Training complete: {task.name}, mAP {mAP:.1%}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"mAP": mAP,
},
})
else:
activities.append({
"type": "training_failed",
"description": f"Training failed: {task.name}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"error": task.error_message or "",
},
})
# Model activations
model_versions = session.exec(
select(ModelVersion)
.where(ModelVersion.activated_at.is_not(None))
.order_by(ModelVersion.activated_at.desc())
.limit(limit)
).all()
for model in model_versions:
if model.activated_at is None:
continue
activities.append({
"type": "model_activated",
"description": f"Activated model {model.version}",
"timestamp": model.activated_at,
"metadata": {
"version_id": str(model.version_id),
"version": model.version,
},
})
# Sort all activities by timestamp DESC and return top N
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return activities[:limit]