This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
"""
Repository Pattern Implementation
Provides domain-specific repository classes to replace the monolithic AdminDB.
Each repository handles a single domain following Single Responsibility Principle.
"""
from inference.data.repositories.base import BaseRepository
from inference.data.repositories.token_repository import TokenRepository
from inference.data.repositories.document_repository import DocumentRepository
from inference.data.repositories.annotation_repository import AnnotationRepository
from inference.data.repositories.training_task_repository import TrainingTaskRepository
from inference.data.repositories.dataset_repository import DatasetRepository
from inference.data.repositories.model_version_repository import ModelVersionRepository
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
__all__ = [
"BaseRepository",
"TokenRepository",
"DocumentRepository",
"AnnotationRepository",
"TrainingTaskRepository",
"DatasetRepository",
"ModelVersionRepository",
"BatchUploadRepository",
]

View File

@@ -0,0 +1,355 @@
"""
Annotation Repository
Handles annotation operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class AnnotationRepository(BaseRepository[AdminAnnotation]):
"""Repository for annotation management.
Handles:
- Annotation CRUD operations
- Batch annotation creation
- Annotation verification
- Annotation override tracking
"""
def create(
self,
document_id: str,
page_number: int,
class_id: int,
class_name: str,
x_center: float,
y_center: float,
width: float,
height: float,
bbox_x: int,
bbox_y: int,
bbox_width: int,
bbox_height: int,
text_value: str | None = None,
confidence: float | None = None,
source: str = "manual",
) -> str:
"""Create a new annotation.
Returns:
Annotation ID as string
"""
with get_session_context() as session:
annotation = AdminAnnotation(
document_id=UUID(document_id),
page_number=page_number,
class_id=class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=bbox_x,
bbox_y=bbox_y,
bbox_width=bbox_width,
bbox_height=bbox_height,
text_value=text_value,
confidence=confidence,
source=source,
)
session.add(annotation)
session.flush()
return str(annotation.annotation_id)
def create_batch(
self,
annotations: list[dict[str, Any]],
) -> list[str]:
"""Create multiple annotations in a batch.
Args:
annotations: List of annotation data dicts
Returns:
List of annotation IDs
"""
with get_session_context() as session:
ids = []
for ann_data in annotations:
annotation = AdminAnnotation(
document_id=UUID(ann_data["document_id"]),
page_number=ann_data.get("page_number", 1),
class_id=ann_data["class_id"],
class_name=ann_data["class_name"],
x_center=ann_data["x_center"],
y_center=ann_data["y_center"],
width=ann_data["width"],
height=ann_data["height"],
bbox_x=ann_data["bbox_x"],
bbox_y=ann_data["bbox_y"],
bbox_width=ann_data["bbox_width"],
bbox_height=ann_data["bbox_height"],
text_value=ann_data.get("text_value"),
confidence=ann_data.get("confidence"),
source=ann_data.get("source", "auto"),
)
session.add(annotation)
session.flush()
ids.append(str(annotation.annotation_id))
return ids
def get(self, annotation_id: str) -> AdminAnnotation | None:
"""Get an annotation by ID."""
with get_session_context() as session:
result = session.get(AdminAnnotation, UUID(annotation_id))
if result:
session.expunge(result)
return result
def get_for_document(
self,
document_id: str,
page_number: int | None = None,
) -> list[AdminAnnotation]:
"""Get all annotations for a document."""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if page_number is not None:
statement = statement.where(AdminAnnotation.page_number == page_number)
statement = statement.order_by(AdminAnnotation.class_id)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def update(
self,
annotation_id: str,
x_center: float | None = None,
y_center: float | None = None,
width: float | None = None,
height: float | None = None,
bbox_x: int | None = None,
bbox_y: int | None = None,
bbox_width: int | None = None,
bbox_height: int | None = None,
text_value: str | None = None,
class_id: int | None = None,
class_name: str | None = None,
) -> bool:
"""Update an annotation.
Returns:
True if updated, False if not found
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
if x_center is not None:
annotation.x_center = x_center
if y_center is not None:
annotation.y_center = y_center
if width is not None:
annotation.width = width
if height is not None:
annotation.height = height
if bbox_x is not None:
annotation.bbox_x = bbox_x
if bbox_y is not None:
annotation.bbox_y = bbox_y
if bbox_width is not None:
annotation.bbox_width = bbox_width
if bbox_height is not None:
annotation.bbox_height = bbox_height
if text_value is not None:
annotation.text_value = text_value
if class_id is not None:
annotation.class_id = class_id
if class_name is not None:
annotation.class_name = class_name
annotation.updated_at = datetime.utcnow()
session.add(annotation)
return True
return False
def delete(self, annotation_id: str) -> bool:
"""Delete an annotation."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
session.delete(annotation)
return True
return False
def delete_for_document(
self,
document_id: str,
source: str | None = None,
) -> int:
"""Delete all annotations for a document.
Returns:
Count of deleted annotations
"""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if source:
statement = statement.where(AdminAnnotation.source == source)
annotations = session.exec(statement).all()
count = len(annotations)
for ann in annotations:
session.delete(ann)
return count
def verify(
self,
annotation_id: str,
admin_token: str,
) -> AdminAnnotation | None:
"""Mark an annotation as verified."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
annotation.is_verified = True
annotation.verified_at = datetime.utcnow()
annotation.verified_by = admin_token
annotation.updated_at = datetime.utcnow()
session.add(annotation)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def override(
self,
annotation_id: str,
admin_token: str,
change_reason: str | None = None,
**updates: Any,
) -> AdminAnnotation | None:
"""Override an auto-generated annotation.
Creates a history record and updates the annotation.
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
previous_value = {
"class_id": annotation.class_id,
"class_name": annotation.class_name,
"bbox": {
"x": annotation.bbox_x,
"y": annotation.bbox_y,
"width": annotation.bbox_width,
"height": annotation.bbox_height,
},
"normalized": {
"x_center": annotation.x_center,
"y_center": annotation.y_center,
"width": annotation.width,
"height": annotation.height,
},
"text_value": annotation.text_value,
"confidence": annotation.confidence,
"source": annotation.source,
}
for key, value in updates.items():
if hasattr(annotation, key):
setattr(annotation, key, value)
if annotation.source == "auto":
annotation.override_source = "auto"
annotation.source = "manual"
annotation.updated_at = datetime.utcnow()
session.add(annotation)
history = AnnotationHistory(
annotation_id=UUID(annotation_id),
document_id=annotation.document_id,
action="override",
previous_value=previous_value,
new_value=updates,
changed_by=admin_token,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def create_history(
self,
annotation_id: UUID,
document_id: UUID,
action: str,
previous_value: dict[str, Any] | None = None,
new_value: dict[str, Any] | None = None,
changed_by: str | None = None,
change_reason: str | None = None,
) -> AnnotationHistory:
"""Create an annotation history record."""
with get_session_context() as session:
history = AnnotationHistory(
annotation_id=annotation_id,
document_id=document_id,
action=action,
previous_value=previous_value,
new_value=new_value,
changed_by=changed_by,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(history)
session.expunge(history)
return history
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
"""Get history for a specific annotation."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.annotation_id == annotation_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
"""Get all annotation history for a document."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.document_id == document_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)

View File

@@ -0,0 +1,75 @@
"""
Base Repository
Provides common functionality for all repositories.
"""
import logging
from abc import ABC
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Generator, TypeVar, Generic
from uuid import UUID
from sqlmodel import Session
from inference.data.database import get_session_context
logger = logging.getLogger(__name__)
T = TypeVar("T")
class BaseRepository(ABC, Generic[T]):
"""Base class for all repositories.
Provides:
- Session management via context manager
- Logging infrastructure
- Common query patterns
- Utility methods for datetime and UUID handling
"""
@contextmanager
def _session(self) -> Generator[Session, None, None]:
"""Get a database session with auto-commit/rollback."""
with get_session_context() as session:
yield session
def _expunge(self, session: Session, entity: T) -> T:
"""Detach entity from session for safe return."""
session.expunge(entity)
return entity
def _expunge_all(self, session: Session, entities: list[T]) -> list[T]:
"""Detach multiple entities from session."""
for entity in entities:
session.expunge(entity)
return entities
@staticmethod
def _now() -> datetime:
"""Get current UTC time as timezone-aware datetime.
Use this instead of datetime.utcnow() which is deprecated in Python 3.12+.
"""
return datetime.now(timezone.utc)
@staticmethod
def _validate_uuid(value: str, field_name: str = "id") -> UUID:
"""Validate and convert string to UUID.
Args:
value: String to convert to UUID
field_name: Name of field for error message
Returns:
Validated UUID
Raises:
ValueError: If value is not a valid UUID
"""
try:
return UUID(value)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid {field_name}: {value}") from e

View File

@@ -0,0 +1,136 @@
"""
Batch Upload Repository
Handles batch upload operations following Single Responsibility Principle.
"""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import BatchUpload, BatchUploadFile
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class BatchUploadRepository(BaseRepository[BatchUpload]):
"""Repository for batch upload management.
Handles:
- Batch upload CRUD operations
- Batch file tracking
- Progress monitoring
"""
def create(
self,
admin_token: str,
filename: str,
file_size: int,
upload_source: str = "ui",
) -> BatchUpload:
"""Create a new batch upload record."""
with get_session_context() as session:
batch = BatchUpload(
admin_token=admin_token,
filename=filename,
file_size=file_size,
upload_source=upload_source,
)
session.add(batch)
session.commit()
session.refresh(batch)
session.expunge(batch)
return batch
def get(self, batch_id: UUID) -> BatchUpload | None:
"""Get batch upload by ID."""
with get_session_context() as session:
result = session.get(BatchUpload, batch_id)
if result:
session.expunge(result)
return result
def update(
self,
batch_id: UUID,
**kwargs: Any,
) -> None:
"""Update batch upload fields."""
with get_session_context() as session:
batch = session.get(BatchUpload, batch_id)
if batch:
for key, value in kwargs.items():
if hasattr(batch, key):
setattr(batch, key, value)
session.add(batch)
def create_file(
self,
batch_id: UUID,
filename: str,
**kwargs: Any,
) -> BatchUploadFile:
"""Create a batch upload file record."""
with get_session_context() as session:
file_record = BatchUploadFile(
batch_id=batch_id,
filename=filename,
**kwargs,
)
session.add(file_record)
session.commit()
session.refresh(file_record)
session.expunge(file_record)
return file_record
def update_file(
self,
file_id: UUID,
**kwargs: Any,
) -> None:
"""Update batch upload file fields."""
with get_session_context() as session:
file_record = session.get(BatchUploadFile, file_id)
if file_record:
for key, value in kwargs.items():
if hasattr(file_record, key):
setattr(file_record, key, value)
session.add(file_record)
def get_files(self, batch_id: UUID) -> list[BatchUploadFile]:
"""Get all files for a batch upload."""
with get_session_context() as session:
statement = select(BatchUploadFile).where(
BatchUploadFile.batch_id == batch_id
).order_by(BatchUploadFile.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_paginated(
self,
admin_token: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[BatchUpload], int]:
"""Get paginated batch uploads."""
with get_session_context() as session:
count_stmt = select(func.count()).select_from(BatchUpload)
total = session.exec(count_stmt).one()
statement = select(BatchUpload).order_by(
BatchUpload.created_at.desc()
).offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total

View File

@@ -0,0 +1,208 @@
"""
Dataset Repository
Handles training dataset operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class DatasetRepository(BaseRepository[TrainingDataset]):
"""Repository for training dataset management.
Handles:
- Dataset CRUD operations
- Dataset status management
- Dataset document linking
- Training status tracking
"""
def create(
self,
name: str,
description: str | None = None,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> TrainingDataset:
"""Create a new training dataset."""
with get_session_context() as session:
dataset = TrainingDataset(
name=name,
description=description,
train_ratio=train_ratio,
val_ratio=val_ratio,
seed=seed,
)
session.add(dataset)
session.commit()
session.refresh(dataset)
session.expunge(dataset)
return dataset
def get(self, dataset_id: str | UUID) -> TrainingDataset | None:
"""Get a dataset by ID."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if dataset:
session.expunge(dataset)
return dataset
def get_paginated(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[TrainingDataset], int]:
"""List datasets with optional status filter."""
with get_session_context() as session:
query = select(TrainingDataset)
count_query = select(func.count()).select_from(TrainingDataset)
if status:
query = query.where(TrainingDataset.status == status)
count_query = count_query.where(TrainingDataset.status == status)
total = session.exec(count_query).one()
datasets = session.exec(
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
).all()
for d in datasets:
session.expunge(d)
return list(datasets), total
def get_active_training_tasks(
self, dataset_ids: list[str]
) -> dict[str, dict[str, str]]:
"""Get active training tasks for datasets.
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
"""
if not dataset_ids:
return {}
valid_uuids = []
for d in dataset_ids:
try:
valid_uuids.append(UUID(d))
except ValueError:
logger.warning("Invalid UUID in get_active_training_tasks: %s", d)
continue
if not valid_uuids:
return {}
with get_session_context() as session:
statement = select(TrainingTask).where(
TrainingTask.dataset_id.in_(valid_uuids),
TrainingTask.status.in_(["pending", "scheduled", "running"]),
)
results = session.exec(statement).all()
return {
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
for t in results
}
def update_status(
self,
dataset_id: str | UUID,
status: str,
error_message: str | None = None,
total_documents: int | None = None,
total_images: int | None = None,
total_annotations: int | None = None,
dataset_path: str | None = None,
) -> None:
"""Update dataset status and optional totals."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.status = status
dataset.updated_at = datetime.utcnow()
if error_message is not None:
dataset.error_message = error_message
if total_documents is not None:
dataset.total_documents = total_documents
if total_images is not None:
dataset.total_images = total_images
if total_annotations is not None:
dataset.total_annotations = total_annotations
if dataset_path is not None:
dataset.dataset_path = dataset_path
session.add(dataset)
session.commit()
def update_training_status(
self,
dataset_id: str | UUID,
training_status: str | None,
active_training_task_id: str | UUID | None = None,
update_main_status: bool = False,
) -> None:
"""Update dataset training status."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.training_status = training_status
dataset.active_training_task_id = (
UUID(str(active_training_task_id)) if active_training_task_id else None
)
dataset.updated_at = datetime.utcnow()
if update_main_status and training_status == "completed":
dataset.status = "trained"
session.add(dataset)
session.commit()
def add_documents(
self,
dataset_id: str | UUID,
documents: list[dict[str, Any]],
) -> None:
"""Batch insert documents into a dataset.
Each dict: {document_id, split, page_count, annotation_count}
"""
with get_session_context() as session:
for doc in documents:
dd = DatasetDocument(
dataset_id=UUID(str(dataset_id)),
document_id=UUID(str(doc["document_id"])),
split=doc["split"],
page_count=doc.get("page_count", 0),
annotation_count=doc.get("annotation_count", 0),
)
session.add(dd)
session.commit()
def get_documents(self, dataset_id: str | UUID) -> list[DatasetDocument]:
"""Get all documents in a dataset."""
with get_session_context() as session:
results = session.exec(
select(DatasetDocument)
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
).all()
for r in results:
session.expunge(r)
return list(results)
def delete(self, dataset_id: str | UUID) -> bool:
"""Delete a dataset and its document links."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return False
session.delete(dataset)
session.commit()
return True

View File

@@ -0,0 +1,444 @@
"""
Document Repository
Handles document operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument, AdminAnnotation
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class DocumentRepository(BaseRepository[AdminDocument]):
"""Repository for document management.
Handles:
- Document CRUD operations
- Document status management
- Document filtering and pagination
- Document category management
"""
def create(
self,
filename: str,
file_size: int,
content_type: str,
file_path: str,
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
category: str = "invoice",
admin_token: str | None = None,
) -> str:
"""Create a new document record.
Args:
filename: Original filename
file_size: File size in bytes
content_type: MIME type
file_path: Storage path
page_count: Number of pages
upload_source: Upload source (ui/api)
csv_field_values: CSV field values for reference
group_key: User-defined grouping key
category: Document category
admin_token: Deprecated, kept for compatibility
Returns:
Document ID as string
"""
with get_session_context() as session:
document = AdminDocument(
filename=filename,
file_size=file_size,
content_type=content_type,
file_path=file_path,
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
category=category,
)
session.add(document)
session.flush()
return str(document.document_id)
def get(self, document_id: str) -> AdminDocument | None:
"""Get a document by ID.
Args:
document_id: Document UUID as string
Returns:
AdminDocument if found, None otherwise
"""
with get_session_context() as session:
result = session.get(AdminDocument, UUID(document_id))
if result:
session.expunge(result)
return result
def get_by_token(
self,
document_id: str,
admin_token: str | None = None,
) -> AdminDocument | None:
"""Get a document by ID. Token parameter is deprecated."""
return self.get(document_id)
def get_paginated(
self,
admin_token: str | None = None,
status: str | None = None,
upload_source: str | None = None,
has_annotations: bool | None = None,
auto_label_status: str | None = None,
batch_id: str | None = None,
category: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AdminDocument], int]:
"""Get paginated documents with optional filters.
Args:
admin_token: Deprecated, kept for compatibility
status: Filter by status
upload_source: Filter by upload source
has_annotations: Filter by annotation presence
auto_label_status: Filter by auto-label status
batch_id: Filter by batch ID
category: Filter by category
limit: Page size
offset: Pagination offset
Returns:
Tuple of (documents, total_count)
"""
with get_session_context() as session:
where_clauses = []
if status:
where_clauses.append(AdminDocument.status == status)
if upload_source:
where_clauses.append(AdminDocument.upload_source == upload_source)
if auto_label_status:
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
if batch_id:
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
if category:
where_clauses.append(AdminDocument.category == category)
count_stmt = select(func.count()).select_from(AdminDocument)
if where_clauses:
count_stmt = count_stmt.where(*where_clauses)
if has_annotations is not None:
if has_annotations:
count_stmt = (
count_stmt
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.group_by(AdminDocument.document_id)
)
else:
count_stmt = (
count_stmt
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.where(AdminAnnotation.annotation_id.is_(None))
)
total = session.exec(count_stmt).one()
statement = select(AdminDocument)
if where_clauses:
statement = statement.where(*where_clauses)
if has_annotations is not None:
if has_annotations:
statement = (
statement
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.group_by(AdminDocument.document_id)
)
else:
statement = (
statement
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
.where(AdminAnnotation.annotation_id.is_(None))
)
statement = statement.order_by(AdminDocument.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def update_status(
self,
document_id: str,
status: str,
auto_label_status: str | None = None,
auto_label_error: str | None = None,
) -> None:
"""Update document status.
Args:
document_id: Document UUID as string
status: New status
auto_label_status: Auto-label status
auto_label_error: Auto-label error message
"""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.status = status
document.updated_at = datetime.now(timezone.utc)
if auto_label_status is not None:
document.auto_label_status = auto_label_status
if auto_label_error is not None:
document.auto_label_error = auto_label_error
session.add(document)
def update_file_path(self, document_id: str, file_path: str) -> None:
"""Update document file path."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.file_path = file_path
document.updated_at = datetime.now(timezone.utc)
session.add(document)
def update_group_key(self, document_id: str, group_key: str | None) -> bool:
"""Update document group key."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.group_key = group_key
document.updated_at = datetime.now(timezone.utc)
session.add(document)
return True
return False
def update_category(self, document_id: str, category: str) -> AdminDocument | None:
"""Update document category."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.category = category
document.updated_at = datetime.now(timezone.utc)
session.add(document)
session.commit()
session.refresh(document)
return document
return None
def delete(self, document_id: str) -> bool:
"""Delete a document and its annotations.
Args:
document_id: Document UUID as string
Returns:
True if deleted, False if not found
"""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
ann_stmt = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
annotations = session.exec(ann_stmt).all()
for ann in annotations:
session.delete(ann)
session.delete(document)
return True
return False
def get_categories(self) -> list[str]:
"""Get list of unique document categories."""
with get_session_context() as session:
statement = (
select(AdminDocument.category)
.distinct()
.order_by(AdminDocument.category)
)
categories = session.exec(statement).all()
return [c for c in categories if c is not None]
def get_labeled_for_export(
self,
admin_token: str | None = None,
) -> list[AdminDocument]:
"""Get all labeled documents ready for export."""
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == "labeled"
)
if admin_token:
statement = statement.where(AdminDocument.admin_token == admin_token)
statement = statement.order_by(AdminDocument.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def count_by_status(
self,
admin_token: str | None = None,
) -> dict[str, int]:
"""Count documents by status."""
with get_session_context() as session:
statement = select(
AdminDocument.status,
func.count(AdminDocument.document_id),
).group_by(AdminDocument.status)
results = session.exec(statement).all()
return {status: count for status, count in results}
def get_by_ids(self, document_ids: list[str]) -> list[AdminDocument]:
"""Get documents by list of IDs."""
with get_session_context() as session:
uuids = [UUID(str(did)) for did in document_ids]
results = session.exec(
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
).all()
for r in results:
session.expunge(r)
return list(results)
def get_for_training(
self,
admin_token: str | None = None,
status: str = "labeled",
has_annotations: bool = True,
min_annotation_count: int | None = None,
exclude_used_in_training: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[AdminDocument], int]:
"""Get documents suitable for training with filtering."""
from inference.data.admin_models import TrainingDocumentLink
with get_session_context() as session:
statement = select(AdminDocument).where(
AdminDocument.status == status,
)
if has_annotations or min_annotation_count:
annotation_subq = (
select(func.count(AdminAnnotation.annotation_id))
.where(AdminAnnotation.document_id == AdminDocument.document_id)
.correlate(AdminDocument)
.scalar_subquery()
)
if has_annotations:
statement = statement.where(annotation_subq > 0)
if min_annotation_count:
statement = statement.where(annotation_subq >= min_annotation_count)
if exclude_used_in_training:
from sqlalchemy import exists
training_subq = exists(
select(1)
.select_from(TrainingDocumentLink)
.where(TrainingDocumentLink.document_id == AdminDocument.document_id)
)
statement = statement.where(~training_subq)
count_statement = select(func.count()).select_from(statement.subquery())
total = session.exec(count_statement).one()
statement = statement.order_by(AdminDocument.created_at.desc())
statement = statement.limit(limit).offset(offset)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def acquire_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
duration_seconds: int = 300,
) -> AdminDocument | None:
"""Acquire annotation lock for a document."""
from datetime import timedelta
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
now = datetime.now(timezone.utc)
if doc.annotation_lock_until and doc.annotation_lock_until > now:
return None
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc
def release_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
force: bool = False,
) -> AdminDocument | None:
"""Release annotation lock for a document."""
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
doc.annotation_lock_until = None
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc
def extend_annotation_lock(
self,
document_id: str,
admin_token: str | None = None,
additional_seconds: int = 300,
) -> AdminDocument | None:
"""Extend an existing annotation lock."""
from datetime import timedelta
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if not doc:
return None
now = datetime.now(timezone.utc)
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
return None
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
session.add(doc)
session.commit()
session.refresh(doc)
session.expunge(doc)
return doc

View File

@@ -0,0 +1,200 @@
"""
Model Version Repository
Handles model version operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import ModelVersion
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class ModelVersionRepository(BaseRepository[ModelVersion]):
"""Repository for model version management.
Handles:
- Model version CRUD operations
- Model activation/deactivation
- Active model resolution
"""
def create(
self,
version: str,
name: str,
model_path: str,
description: str | None = None,
task_id: str | UUID | None = None,
dataset_id: str | UUID | None = None,
metrics_mAP: float | None = None,
metrics_precision: float | None = None,
metrics_recall: float | None = None,
document_count: int = 0,
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
model = ModelVersion(
version=version,
name=name,
model_path=model_path,
description=description,
task_id=UUID(str(task_id)) if task_id else None,
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
)
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def get(self, version_id: str | UUID) -> ModelVersion | None:
"""Get a model version by ID."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if model:
session.expunge(model)
return model
def get_paginated(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[ModelVersion], int]:
"""List model versions with optional status filter."""
with get_session_context() as session:
query = select(ModelVersion)
count_query = select(func.count()).select_from(ModelVersion)
if status:
query = query.where(ModelVersion.status == status)
count_query = count_query.where(ModelVersion.status == status)
total = session.exec(count_query).one()
models = session.exec(
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
).all()
for m in models:
session.expunge(m)
return list(models), total
def get_active(self) -> ModelVersion | None:
"""Get the currently active model version for inference."""
with get_session_context() as session:
result = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).first()
if result:
session.expunge(result)
return result
def activate(self, version_id: str | UUID) -> ModelVersion | None:
"""Activate a model version for inference (deactivates all others)."""
with get_session_context() as session:
all_versions = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).all()
for v in all_versions:
v.is_active = False
v.status = "inactive"
v.updated_at = datetime.utcnow()
session.add(v)
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = True
model.status = "active"
model.activated_at = datetime.utcnow()
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def deactivate(self, version_id: str | UUID) -> ModelVersion | None:
"""Deactivate a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = False
model.status = "inactive"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def update(
self,
version_id: str | UUID,
name: str | None = None,
description: str | None = None,
status: str | None = None,
) -> ModelVersion | None:
"""Update model version metadata."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if name is not None:
model.name = name
if description is not None:
model.description = description
if status is not None:
model.status = status
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def archive(self, version_id: str | UUID) -> ModelVersion | None:
"""Archive a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if model.is_active:
return None
model.status = "archived"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def delete(self, version_id: str | UUID) -> bool:
"""Delete a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return False
if model.is_active:
return False
session.delete(model)
session.commit()
return True

View File

@@ -0,0 +1,117 @@
"""
Token Repository
Handles admin token operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from inference.data.admin_models import AdminToken
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class TokenRepository(BaseRepository[AdminToken]):
"""Repository for admin token management.
Handles:
- Token validation (active status, expiration)
- Token CRUD operations
- Usage tracking
"""
def is_valid(self, token: str) -> bool:
"""Check if admin token exists and is active.
Args:
token: The token string to validate
Returns:
True if token exists, is active, and not expired
"""
with self._session() as session:
result = session.get(AdminToken, token)
if result is None:
return False
if not result.is_active:
return False
if result.expires_at and result.expires_at < self._now():
return False
return True
def get(self, token: str) -> AdminToken | None:
"""Get admin token details.
Args:
token: The token string
Returns:
AdminToken if found, None otherwise
"""
with self._session() as session:
result = session.get(AdminToken, token)
if result:
session.expunge(result)
return result
def create(
self,
token: str,
name: str,
expires_at: datetime | None = None,
) -> None:
"""Create or update an admin token.
If token exists, updates name, expires_at, and reactivates it.
Otherwise creates a new token.
Args:
token: The token string
name: Display name for the token
expires_at: Optional expiration datetime
"""
with self._session() as session:
existing = session.get(AdminToken, token)
if existing:
existing.name = name
existing.expires_at = expires_at
existing.is_active = True
session.add(existing)
else:
new_token = AdminToken(
token=token,
name=name,
expires_at=expires_at,
)
session.add(new_token)
def update_usage(self, token: str) -> None:
"""Update admin token last used timestamp.
Args:
token: The token string
"""
with self._session() as session:
admin_token = session.get(AdminToken, token)
if admin_token:
admin_token.last_used_at = self._now()
session.add(admin_token)
def deactivate(self, token: str) -> bool:
"""Deactivate an admin token.
Args:
token: The token string
Returns:
True if token was deactivated, False if not found
"""
with self._session() as session:
admin_token = session.get(AdminToken, token)
if admin_token:
admin_token.is_active = False
session.add(admin_token)
return True
return False

View File

@@ -0,0 +1,233 @@
"""
Training Task Repository
Handles training task operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class TrainingTaskRepository(BaseRepository[TrainingTask]):
"""Repository for training task management.
Handles:
- Training task CRUD operations
- Task status management
- Training logs
- Training document links
"""
def create(
self,
admin_token: str,
name: str,
task_type: str = "train",
description: str | None = None,
config: dict[str, Any] | None = None,
scheduled_at: datetime | None = None,
cron_expression: str | None = None,
is_recurring: bool = False,
dataset_id: str | None = None,
) -> str:
"""Create a new training task.
Returns:
Task ID as string
"""
with get_session_context() as session:
task = TrainingTask(
admin_token=admin_token,
name=name,
task_type=task_type,
description=description,
config=config,
scheduled_at=scheduled_at,
cron_expression=cron_expression,
is_recurring=is_recurring,
status="scheduled" if scheduled_at else "pending",
dataset_id=dataset_id,
)
session.add(task)
session.flush()
return str(task.task_id)
def get(self, task_id: str) -> TrainingTask | None:
"""Get a training task by ID."""
with get_session_context() as session:
result = session.get(TrainingTask, UUID(task_id))
if result:
session.expunge(result)
return result
def get_by_token(
self,
task_id: str,
admin_token: str | None = None,
) -> TrainingTask | None:
"""Get a training task by ID. Token parameter is deprecated."""
return self.get(task_id)
def get_paginated(
self,
admin_token: str | None = None,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[TrainingTask], int]:
"""Get paginated training tasks."""
with get_session_context() as session:
count_stmt = select(func.count()).select_from(TrainingTask)
if status:
count_stmt = count_stmt.where(TrainingTask.status == status)
total = session.exec(count_stmt).one()
statement = select(TrainingTask)
if status:
statement = statement.where(TrainingTask.status == status)
statement = statement.order_by(TrainingTask.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results), total
def get_pending(self) -> list[TrainingTask]:
"""Get pending training tasks ready to run."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(TrainingTask).where(
TrainingTask.status.in_(["pending", "scheduled"]),
(TrainingTask.scheduled_at == None) | (TrainingTask.scheduled_at <= now),
).order_by(TrainingTask.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def update_status(
self,
task_id: str,
status: str,
error_message: str | None = None,
result_metrics: dict[str, Any] | None = None,
model_path: str | None = None,
) -> None:
"""Update training task status."""
with get_session_context() as session:
task = session.get(TrainingTask, UUID(task_id))
if task:
task.status = status
task.updated_at = datetime.utcnow()
if status == "running":
task.started_at = datetime.utcnow()
elif status in ("completed", "failed"):
task.completed_at = datetime.utcnow()
if error_message is not None:
task.error_message = error_message
if result_metrics is not None:
task.result_metrics = result_metrics
if model_path is not None:
task.model_path = model_path
session.add(task)
def cancel(self, task_id: str) -> bool:
"""Cancel a training task."""
with get_session_context() as session:
task = session.get(TrainingTask, UUID(task_id))
if task and task.status in ("pending", "scheduled"):
task.status = "cancelled"
task.updated_at = datetime.utcnow()
session.add(task)
return True
return False
def add_log(
self,
task_id: str,
level: str,
message: str,
details: dict[str, Any] | None = None,
) -> None:
"""Add a training log entry."""
with get_session_context() as session:
log = TrainingLog(
task_id=UUID(task_id),
level=level,
message=message,
details=details,
)
session.add(log)
def get_logs(
self,
task_id: str,
limit: int = 100,
offset: int = 0,
) -> list[TrainingLog]:
"""Get training logs for a task."""
with get_session_context() as session:
statement = select(TrainingLog).where(
TrainingLog.task_id == UUID(task_id)
).order_by(TrainingLog.created_at.desc()).offset(offset).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def create_document_link(
self,
task_id: UUID,
document_id: UUID,
annotation_snapshot: dict[str, Any] | None = None,
) -> TrainingDocumentLink:
"""Create a training document link."""
with get_session_context() as session:
link = TrainingDocumentLink(
task_id=task_id,
document_id=document_id,
annotation_snapshot=annotation_snapshot,
)
session.add(link)
session.commit()
session.refresh(link)
session.expunge(link)
return link
def get_document_links(self, task_id: UUID) -> list[TrainingDocumentLink]:
"""Get all document links for a training task."""
with get_session_context() as session:
statement = select(TrainingDocumentLink).where(
TrainingDocumentLink.task_id == task_id
).order_by(TrainingDocumentLink.created_at)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_document_training_tasks(self, document_id: UUID) -> list[TrainingDocumentLink]:
"""Get all training tasks that used this document."""
with get_session_context() as session:
statement = select(TrainingDocumentLink).where(
TrainingDocumentLink.document_id == document_id
).order_by(TrainingDocumentLink.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)

View File

@@ -11,11 +11,11 @@ Enhanced features:
- Smart amount parsing with multiple strategies
- Enhanced date format unification
- OCR error correction integration
Refactored to use modular normalizers for each field type.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from collections import defaultdict
import re
import numpy as np
@@ -25,15 +25,22 @@ from shared.fields import CLASS_TO_FIELD
from .yolo_detector import Detection
# Import shared utilities for text cleaning and validation
from shared.utils.text_cleaner import TextCleaner
from shared.utils.validators import FieldValidators
from shared.utils.fuzzy_matcher import FuzzyMatcher
from shared.utils.ocr_corrections import OCRCorrections
# Import new unified parsers
from .payment_line_parser import PaymentLineParser
from .customer_number_parser import CustomerNumberParser
# Import normalizers
from .normalizers import (
BaseNormalizer,
NormalizationResult,
create_normalizer_registry,
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
@dataclass
class ExtractedField:
@@ -80,7 +87,8 @@ class FieldExtractor:
ocr_lang: str = 'en',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300
dpi: int = 300,
use_enhanced_parsing: bool = False
):
"""
Initialize field extractor.
@@ -90,17 +98,22 @@ class FieldExtractor:
use_gpu: Whether to use GPU for OCR
bbox_padding: Padding to add around bboxes (as fraction)
dpi: DPI used for rendering (for coordinate conversion)
use_enhanced_parsing: Whether to use enhanced normalizers
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self.bbox_padding = bbox_padding
self.dpi = dpi
self._ocr_engine = None # Lazy init
self.use_enhanced_parsing = use_enhanced_parsing
# Initialize new unified parsers
self.payment_line_parser = PaymentLineParser()
self.customer_number_parser = CustomerNumberParser()
# Initialize normalizer registry
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
@property
def ocr_engine(self):
"""Lazy-load OCR engine only when needed."""
@@ -246,6 +259,9 @@ class FieldExtractor:
"""
Normalize and validate extracted text for a field.
Uses modular normalizers for each field type.
Falls back to legacy methods for payment_line and customer_number.
Returns:
(normalized_value, is_valid, validation_error)
"""
@@ -254,389 +270,21 @@ class FieldExtractor:
if not text:
return None, False, "Empty text"
if field_name == 'InvoiceNumber':
return self._normalize_invoice_number(text)
elif field_name == 'OCR':
return self._normalize_ocr_number(text)
elif field_name == 'Bankgiro':
return self._normalize_bankgiro(text)
elif field_name == 'Plusgiro':
return self._normalize_plusgiro(text)
elif field_name == 'Amount':
return self._normalize_amount(text)
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
return self._normalize_date(text)
elif field_name == 'payment_line':
# Special handling for payment_line and customer_number (use unified parsers)
if field_name == 'payment_line':
return self._normalize_payment_line(text)
elif field_name == 'supplier_org_number':
return self._normalize_supplier_org_number(text)
elif field_name == 'customer_number':
if field_name == 'customer_number':
return self._normalize_customer_number(text)
else:
return text, True, None
# Use normalizer registry for other fields
normalizer = self._normalizers.get(field_name)
if normalizer:
result = normalizer.normalize(text)
return result.to_tuple()
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize invoice number.
Invoice numbers can be:
- Pure digits: 12345678
- Alphanumeric: A3861, INV-2024-001, F12345
- With separators: 2024/001, 2024-001
Strategy:
1. Look for common invoice number patterns
2. Prefer shorter, more specific matches over long digit sequences
"""
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
# Examples: A3861, F12345, INV001
alpha_patterns = [
r'\b([A-Z]{1,3}\d{3,10})\b', # A3861, INV12345
r'\b(\d{3,10}[A-Z]{1,3})\b', # 12345A
r'\b([A-Z]{2,5}[-/]?\d{3,10})\b', # INV-12345, FAK12345
]
for pattern in alpha_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1).upper(), True, None
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
year_pattern = r'\b(20\d{2}[-/]\d{3,8})\b'
match = re.search(year_pattern, text)
if match:
return match.group(1), True, None
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
# This avoids capturing long OCR numbers
digit_sequences = re.findall(r'\b(\d{3,10})\b', text)
if digit_sequences:
# Prefer shorter sequences (more likely to be invoice number)
# Also filter out sequences that look like dates (8 digits starting with 20)
valid_sequences = []
for seq in digit_sequences:
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith('20'):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return min(valid_sequences, key=len), True, None
# Fallback: extract all digits if nothing else works
digits = re.sub(r'\D', '', text)
if len(digits) >= 3:
# Limit to first 15 digits to avoid very long sequences
return digits[:15], True, "Fallback extraction"
return None, False, f"Cannot extract invoice number from: {text[:50]}"
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize OCR number."""
digits = re.sub(r'\D', '', text)
if len(digits) < 5:
return None, False, f"Too few digits for OCR: {len(digits)}"
return digits, True, None
def _luhn_checksum(self, digits: str) -> bool:
"""
Validate using Luhn (Mod10) algorithm.
Used for Bankgiro, Plusgiro, and OCR number validation.
Delegates to shared FieldValidators for consistency.
"""
return FieldValidators.luhn_checksum(digits)
def _detect_giro_type(self, text: str) -> str | None:
"""
Detect if text matches BG or PG display format pattern.
BG typical format: ^\d{3,4}-\d{4}$ (e.g., 123-4567, 1234-5678)
PG typical format: ^\d{1,7}-\d$ (e.g., 1-8, 12345-6, 1234567-8)
Returns: 'BG', 'PG', or None if cannot determine
"""
text = text.strip()
# BG pattern: 3-4 digits, dash, 4 digits (total 7-8 digits)
if re.match(r'^\d{3,4}-\d{4}$', text):
return 'BG'
# PG pattern: 1-7 digits, dash, 1 digit (total 2-8 digits)
if re.match(r'^\d{1,7}-\d$', text):
return 'PG'
return None
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize Bankgiro number.
Bankgiro rules:
- 7 or 8 digits only
- Last digit is Luhn (Mod10) check digit
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
Display pattern: ^\d{3,4}-\d{4}$
Normalized pattern: ^\d{7,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
"""
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
# This distinguishes BG from PG which uses X-X format (digits-single digit)
bg_matches = re.findall(r'(\d{3,4})-(\d{4})', text)
if bg_matches:
# Try each match and find one with valid Luhn
for match in bg_matches:
digits = match[0] + match[1]
if len(digits) in (7, 8) and self._luhn_checksum(digits):
# Valid BG found
if len(digits) == 8:
formatted = f"{digits[:4]}-{digits[4:]}"
else:
formatted = f"{digits[:3]}-{digits[3:]}"
return formatted, True, None
# No valid Luhn, use first match
digits = bg_matches[0][0] + bg_matches[0][1]
if len(digits) in (7, 8):
if len(digits) == 8:
formatted = f"{digits[:4]}-{digits[4:]}"
else:
formatted = f"{digits[:3]}-{digits[3:]}"
return formatted, True, f"Luhn checksum failed (possible OCR error)"
# Fallback: try to find 7-8 consecutive digits
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
# to avoid misinterpreting PG as BG
pg_format_present = re.search(r'(?<![0-9])\d{1,7}-\d(?!\d)', text)
if pg_format_present:
return None, False, f"No valid Bankgiro found in text"
digit_match = re.search(r'\b(\d{7,8})\b', text)
if digit_match:
digits = digit_match.group(1)
if len(digits) in (7, 8):
luhn_ok = self._luhn_checksum(digits)
if len(digits) == 8:
formatted = f"{digits[:4]}-{digits[4:]}"
else:
formatted = f"{digits[:3]}-{digits[3:]}"
if luhn_ok:
return formatted, True, None
else:
return formatted, True, f"Luhn checksum failed (possible OCR error)"
return None, False, f"No valid Bankgiro found in text"
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize Plusgiro number.
Plusgiro rules:
- 2 to 8 digits
- Last digit is Luhn (Mod10) check digit
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
Display pattern: ^\d{1,7}-\d$
Normalized pattern: ^\d{2,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
"""
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
# This is distinct from BG format which has 4 digits after the dash
# Pattern allows spaces within the number like "486 98 63-6"
# (?<![0-9]) ensures we don't start from within another number (like BG)
pg_matches = re.findall(r'(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)', text)
if pg_matches:
# Try each match and find one with valid Luhn
for match in pg_matches:
# Remove spaces from the first part
digits = re.sub(r'\s', '', match[0]) + match[1]
if 2 <= len(digits) <= 8 and self._luhn_checksum(digits):
# Valid PG found
formatted = f"{digits[:-1]}-{digits[-1]}"
return formatted, True, None
# No valid Luhn, use first match with most digits
best_match = max(pg_matches, key=lambda m: len(re.sub(r'\s', '', m[0])))
digits = re.sub(r'\s', '', best_match[0]) + best_match[1]
if 2 <= len(digits) <= 8:
formatted = f"{digits[:-1]}-{digits[-1]}"
return formatted, True, f"Luhn checksum failed (possible OCR error)"
# If no PG format found, extract all digits and format as PG
# This handles cases where the number might be in BG format or raw digits
all_digits = re.sub(r'\D', '', text)
# Try to find a valid 2-8 digit sequence
if 2 <= len(all_digits) <= 8:
luhn_ok = self._luhn_checksum(all_digits)
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
if luhn_ok:
return formatted, True, None
else:
return formatted, True, f"Luhn checksum failed (possible OCR error)"
# Try to find any 2-8 digit sequence in text
digit_match = re.search(r'\b(\d{2,8})\b', text)
if digit_match:
digits = digit_match.group(1)
luhn_ok = self._luhn_checksum(digits)
formatted = f"{digits[:-1]}-{digits[-1]}"
if luhn_ok:
return formatted, True, None
else:
return formatted, True, f"Luhn checksum failed (possible OCR error)"
return None, False, f"No valid Plusgiro found in text"
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize monetary amount.
Uses shared TextCleaner for preprocessing and FieldValidators for parsing.
If multiple amounts are found, returns the last one (usually the total).
"""
# Split by newlines and process line by line to get the last valid amount
lines = text.split('\n')
# Collect all valid amounts from all lines
all_amounts = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?'
for line in lines:
line = line.strip()
if not line:
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(' ', '').replace(',', '.')
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
continue
# Return the last amount found (usually the total)
if all_amounts:
return f"{all_amounts[-1]:.2f}", True, None
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
return f"{amount:.2f}", True, None
# Try to find any decimal number
simple_pattern = r'(\d+[,\.]\d{2})'
matches = re.findall(simple_pattern, text)
if matches:
amount_str = matches[-1].replace(',', '.')
try:
amount = float(amount_str)
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
# Last resort: try to find integer amount (no decimals)
# Look for patterns like "Amount: 11699" or standalone numbers
int_pattern = r'(?:amount|belopp|summa|total)[:\s]*(\d+)'
match = re.search(int_pattern, text, re.IGNORECASE)
if match:
try:
amount = float(match.group(1))
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
# Very last resort: find any standalone number >= 3 digits
standalone_pattern = r'\b(\d{3,})\b'
matches = re.findall(standalone_pattern, text)
if matches:
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
return None, False, f"Cannot parse amount: {text}"
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize date from text that may contain surrounding text.
Uses shared FieldValidators for date parsing and validation.
Handles various date formats found in Swedish invoices:
- 2025-08-29 (ISO format)
- 2025.08.29 (dot separator)
- 29/08/2025 (European format)
- 29.08.2025 (European with dots)
- 20250829 (compact format)
"""
# First, try using shared validator
iso_date = FieldValidators.format_date_iso(text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return iso_date, True, None
# Fallback: try original patterns for edge cases
from datetime import datetime
patterns = [
# ISO format: 2025-08-29
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
# Dot format: 2025.08.29 (common in Swedish)
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
# European slash format: 29/08/2025
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
# European dot format: 29.08.2025
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
# Compact format: 20250829
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', lambda m: f"{m.group(1)}-{m.group(2)}-{m.group(3)}"),
]
for pattern, formatter in patterns:
match = re.search(pattern, text)
if match:
try:
date_str = formatter(match)
# Validate date
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
# Sanity check: year should be reasonable (2000-2100)
if 2000 <= parsed_date.year <= 2100:
return date_str, True, None
except ValueError:
continue
return None, False, f"Cannot parse date: {text}"
# Fallback for unknown fields
return text, True, None
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
"""
@@ -657,44 +305,6 @@ class FieldExtractor:
self.payment_line_parser.parse(text)
)
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize Swedish supplier organization number.
Extracts organization number in format: NNNNNN-NNNN (10 digits)
Also handles VAT numbers: SE + 10 digits + 01
Examples:
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
'Momsreg.nr SE556123456701' -> '556123-4567'
"""
# Pattern 1: Standard org number format: NNNNNN-NNNN
org_pattern = r'\b(\d{6})-?(\d{4})\b'
match = re.search(org_pattern, text)
if match:
org_num = f"{match.group(1)}-{match.group(2)}"
return org_num, True, None
# Pattern 2: VAT number format: SE + 10 digits + 01
vat_pattern = r'SE\s*(\d{10})01'
match = re.search(vat_pattern, text, re.IGNORECASE)
if match:
digits = match.group(1)
org_num = f"{digits[:6]}-{digits[6:]}"
return org_num, True, None
# Pattern 3: Just 10 consecutive digits
digits_pattern = r'\b(\d{10})\b'
match = re.search(digits_pattern, text)
if match:
digits = match.group(1)
# Validate: first digit should be 1-9 for Swedish org numbers
if digits[0] in '123456789':
org_num = f"{digits[:6]}-{digits[6:]}"
return org_num, True, None
return None, False, f"Cannot extract org number from: {text[:100]}"
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize customer number text using unified CustomerNumberParser.
@@ -908,175 +518,6 @@ class FieldExtractor:
best = max(items, key=lambda x: x[1][0])
return best[0], best[1][1]
# =========================================================================
# Enhanced Amount Parsing
# =========================================================================
def _normalize_amount_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Enhanced amount parsing with multiple strategies.
Strategies:
1. Pattern matching for Swedish formats
2. Context-aware extraction (look for keywords like "Total", "Summa")
3. OCR error correction for common digit errors
4. Multi-amount handling (prefer last/largest as total)
This method replaces the original _normalize_amount when enhanced mode is enabled.
"""
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
labeled_patterns = [
# Swedish patterns
(r'(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 1.0),
(r'(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 0.8), # Lower priority for VAT
# Generic pattern
(r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?', 0.7),
]
candidates = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(' ', '').replace(',', '.')
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
if candidates:
# Sort by priority (desc), then by position (later is usually total)
candidates.sort(key=lambda x: (-x[1], -x[2]))
best_amount = candidates[0][0]
return f"{best_amount:.2f}", True, None
# Strategy 3: Parse with shared validator
cleaned = TextCleaner.normalize_amount_text(corrected_text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and 0 < amount < 10_000_000:
return f"{amount:.2f}", True, None
# Strategy 4: Try to extract any decimal number as fallback
decimal_pattern = r'(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})'
matches = re.findall(decimal_pattern, corrected_text)
if matches:
# Clean and parse each match
amounts = []
for m in matches:
cleaned_m = m.replace(' ', '').replace('.', '').replace(',', '.')
# Handle Swedish format: "1 234,56" -> "1234.56"
if ',' in m and '.' not in m:
cleaned_m = m.replace(' ', '').replace(',', '.')
try:
amt = float(cleaned_m)
if 0 < amt < 10_000_000:
amounts.append(amt)
except ValueError:
continue
if amounts:
# Return the last/largest amount (usually the total)
return f"{max(amounts):.2f}", True, None
return None, False, f"Cannot parse amount: {text[:50]}"
# =========================================================================
# Enhanced Date Parsing
# =========================================================================
def _normalize_date_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Enhanced date parsing with comprehensive format support.
Supports:
- ISO: 2024-12-29, 2024/12/29
- European: 29.12.2024, 29/12/2024, 29-12-2024
- Swedish text: "29 december 2024", "29 dec 2024"
- Compact: 20241229
- With OCR corrections: 2O24-12-29 -> 2024-12-29
"""
from datetime import datetime
# Apply OCR corrections
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Try shared validator first
iso_date = FieldValidators.format_date_iso(corrected_text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return iso_date, True, None
# Swedish month names
swedish_months = {
'januari': 1, 'jan': 1,
'februari': 2, 'feb': 2,
'mars': 3, 'mar': 3,
'april': 4, 'apr': 4,
'maj': 5,
'juni': 6, 'jun': 6,
'juli': 7, 'jul': 7,
'augusti': 8, 'aug': 8,
'september': 9, 'sep': 9, 'sept': 9,
'oktober': 10, 'okt': 10,
'november': 11, 'nov': 11,
'december': 12, 'dec': 12,
}
# Pattern for Swedish text dates: "29 december 2024" or "29 dec 2024"
swedish_pattern = r'(\d{1,2})\s+([a-zåäö]+)\s+(\d{4})'
match = re.search(swedish_pattern, corrected_text.lower())
if match:
day = int(match.group(1))
month_name = match.group(2)
year = int(match.group(3))
if month_name in swedish_months:
month = swedish_months[month_name]
try:
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return dt.strftime('%Y-%m-%d'), True, None
except ValueError:
pass
# Extended patterns
patterns = [
# ISO format: 2025-08-29, 2025/08/29
(r'(\d{4})[-/](\d{1,2})[-/](\d{1,2})', 'ymd'),
# Dot format: 2025.08.29
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', 'ymd'),
# European slash: 29/08/2025
(r'(\d{1,2})/(\d{1,2})/(\d{4})', 'dmy'),
# European dot: 29.08.2025
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', 'dmy'),
# European dash: 29-08-2025
(r'(\d{1,2})-(\d{1,2})-(\d{4})', 'dmy'),
# Compact: 20250829
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', 'ymd_compact'),
]
for pattern, fmt in patterns:
match = re.search(pattern, corrected_text)
if match:
try:
if fmt == 'ymd':
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
elif fmt == 'dmy':
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
elif fmt == 'ymd_compact':
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
else:
continue
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return dt.strftime('%Y-%m-%d'), True, None
except ValueError:
continue
return None, False, f"Cannot parse date: {text[:50]}"
# =========================================================================
# Apply OCR Corrections to Raw Text
# =========================================================================
@@ -1162,10 +603,15 @@ class FieldExtractor:
# Re-normalize with enhanced methods if corrections were applied
if corrections or base_result.normalized_value is None:
# Use enhanced normalizers for Amount and Date fields
if base_result.field_name == 'Amount':
normalized, is_valid, error = self._normalize_amount_enhanced(corrected_text)
enhanced_normalizer = EnhancedAmountNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
normalized, is_valid, error = self._normalize_date_enhanced(corrected_text)
enhanced_normalizer = EnhancedDateNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
else:
# Re-run standard normalization with corrected text
normalized, is_valid, error = self._normalize_and_validate(

View File

@@ -0,0 +1,59 @@
"""
Normalizers Package
Provides field-specific normalizers for invoice data extraction.
Each normalizer handles a specific field type's normalization and validation.
"""
from .base import BaseNormalizer, NormalizationResult
from .invoice_number import InvoiceNumberNormalizer
from .ocr_number import OcrNumberNormalizer
from .bankgiro import BankgiroNormalizer
from .plusgiro import PlusgiroNormalizer
from .amount import AmountNormalizer, EnhancedAmountNormalizer
from .date import DateNormalizer, EnhancedDateNormalizer
from .supplier_org_number import SupplierOrgNumberNormalizer
__all__ = [
# Base
"BaseNormalizer",
"NormalizationResult",
# Normalizers
"InvoiceNumberNormalizer",
"OcrNumberNormalizer",
"BankgiroNormalizer",
"PlusgiroNormalizer",
"AmountNormalizer",
"EnhancedAmountNormalizer",
"DateNormalizer",
"EnhancedDateNormalizer",
"SupplierOrgNumberNormalizer",
]
# Registry of all normalizers by field name
def create_normalizer_registry(
use_enhanced: bool = False,
) -> dict[str, BaseNormalizer]:
"""
Create a registry mapping field names to normalizer instances.
Args:
use_enhanced: Whether to use enhanced normalizers for amount/date
Returns:
Dictionary mapping field names to normalizer instances
"""
amount_normalizer = EnhancedAmountNormalizer() if use_enhanced else AmountNormalizer()
date_normalizer = EnhancedDateNormalizer() if use_enhanced else DateNormalizer()
return {
"InvoiceNumber": InvoiceNumberNormalizer(),
"OCR": OcrNumberNormalizer(),
"Bankgiro": BankgiroNormalizer(),
"Plusgiro": PlusgiroNormalizer(),
"Amount": amount_normalizer,
"InvoiceDate": date_normalizer,
"InvoiceDueDate": date_normalizer,
"supplier_org_number": SupplierOrgNumberNormalizer(),
}

View File

@@ -0,0 +1,185 @@
"""
Amount Normalizer
Handles normalization and validation of monetary amounts.
"""
import re
from shared.utils.text_cleaner import TextCleaner
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
from .base import BaseNormalizer, NormalizationResult
class AmountNormalizer(BaseNormalizer):
"""
Normalizes monetary amounts from Swedish invoices.
Handles various Swedish amount formats:
- With decimal: 1 234,56 kr
- With SEK suffix: 1234.56 SEK
- Multiple amounts (returns the last one, usually the total)
"""
@property
def field_name(self) -> str:
return "Amount"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Split by newlines and process line by line to get the last valid amount
lines = text.split("\n")
# Collect all valid amounts from all lines
all_amounts: list[float] = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
for line in lines:
line = line.strip()
if not line:
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
continue
# Return the last amount found (usually the total)
if all_amounts:
return NormalizationResult.success(f"{all_amounts[-1]:.2f}")
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
# Try to find any decimal number
simple_pattern = r"(\d+[,\.]\d{2})"
matches = re.findall(simple_pattern, text)
if matches:
amount_str = matches[-1].replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
# Last resort: try to find integer amount (no decimals)
# Look for patterns like "Amount: 11699" or standalone numbers
int_pattern = r"(?:amount|belopp|summa|total)[:\s]*(\d+)"
match = re.search(int_pattern, text, re.IGNORECASE)
if match:
try:
amount = float(match.group(1))
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
# Very last resort: find any standalone number >= 3 digits
standalone_pattern = r"\b(\d{3,})\b"
matches = re.findall(standalone_pattern, text)
if matches:
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
return NormalizationResult.failure(f"Cannot parse amount: {text}")
class EnhancedAmountNormalizer(AmountNormalizer):
"""
Enhanced amount parsing with multiple strategies.
Strategies:
1. Pattern matching for Swedish formats
2. Context-aware extraction (look for keywords like "Total", "Summa")
3. OCR error correction for common digit errors
4. Multi-amount handling (prefer last/largest as total)
"""
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
labeled_patterns = [
# Swedish patterns
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
(
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
0.8,
), # Lower priority for VAT
# Generic pattern
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
]
candidates: list[tuple[float, float, int]] = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
if candidates:
# Sort by priority (desc), then by position (later is usually total)
candidates.sort(key=lambda x: (-x[1], -x[2]))
best_amount = candidates[0][0]
return NormalizationResult.success(f"{best_amount:.2f}")
# Strategy 3: Parse with shared validator
cleaned = TextCleaner.normalize_amount_text(corrected_text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
# Strategy 4: Try to extract any decimal number as fallback
decimal_pattern = r"(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})"
matches = re.findall(decimal_pattern, corrected_text)
if matches:
# Clean and parse each match
amounts: list[float] = []
for m in matches:
cleaned_m = m.replace(" ", "").replace(".", "").replace(",", ".")
# Handle Swedish format: "1 234,56" -> "1234.56"
if "," in m and "." not in m:
cleaned_m = m.replace(" ", "").replace(",", ".")
try:
amt = float(cleaned_m)
if 0 < amt < 10_000_000:
amounts.append(amt)
except ValueError:
continue
if amounts:
# Return the last/largest amount (usually the total)
return NormalizationResult.success(f"{max(amounts):.2f}")
return NormalizationResult.failure(f"Cannot parse amount: {text[:50]}")

View File

@@ -0,0 +1,87 @@
"""
Bankgiro Normalizer
Handles normalization and validation of Swedish Bankgiro numbers.
"""
import re
from shared.utils.validators import FieldValidators
from .base import BaseNormalizer, NormalizationResult
class BankgiroNormalizer(BaseNormalizer):
"""
Normalizes Swedish Bankgiro numbers.
Bankgiro rules:
- 7 or 8 digits only
- Last digit is Luhn (Mod10) check digit
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
Display pattern: ^\\d{3,4}-\\d{4}$
Normalized pattern: ^\\d{7,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
"""
@property
def field_name(self) -> str:
return "Bankgiro"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
# This distinguishes BG from PG which uses X-X format (digits-single digit)
bg_matches = re.findall(r"(\d{3,4})-(\d{4})", text)
if bg_matches:
# Try each match and find one with valid Luhn
for match in bg_matches:
digits = match[0] + match[1]
if len(digits) in (7, 8) and FieldValidators.luhn_checksum(digits):
# Valid BG found
formatted = self._format_bankgiro(digits)
return NormalizationResult.success(formatted)
# No valid Luhn, use first match
digits = bg_matches[0][0] + bg_matches[0][1]
if len(digits) in (7, 8):
formatted = self._format_bankgiro(digits)
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# Fallback: try to find 7-8 consecutive digits
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
# to avoid misinterpreting PG as BG
pg_format_present = re.search(r"(?<![0-9])\d{1,7}-\d(?!\d)", text)
if pg_format_present:
return NormalizationResult.failure("No valid Bankgiro found in text")
digit_match = re.search(r"\b(\d{7,8})\b", text)
if digit_match:
digits = digit_match.group(1)
if len(digits) in (7, 8):
formatted = self._format_bankgiro(digits)
if FieldValidators.luhn_checksum(digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
return NormalizationResult.failure("No valid Bankgiro found in text")
@staticmethod
def _format_bankgiro(digits: str) -> str:
"""Format Bankgiro number with dash."""
if len(digits) == 8:
return f"{digits[:4]}-{digits[4:]}"
else:
return f"{digits[:3]}-{digits[3:]}"

View File

@@ -0,0 +1,71 @@
"""
Base Normalizer Interface
Defines the contract for all field normalizers.
Each normalizer handles a specific field type's normalization and validation.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass(frozen=True)
class NormalizationResult:
"""Result of a normalization operation."""
value: str | None
is_valid: bool
error: str | None = None
@classmethod
def success(cls, value: str) -> "NormalizationResult":
"""Create a successful result."""
return cls(value=value, is_valid=True, error=None)
@classmethod
def success_with_warning(cls, value: str, warning: str) -> "NormalizationResult":
"""Create a successful result with a warning."""
return cls(value=value, is_valid=True, error=warning)
@classmethod
def failure(cls, error: str) -> "NormalizationResult":
"""Create a failed result."""
return cls(value=None, is_valid=False, error=error)
def to_tuple(self) -> tuple[str | None, bool, str | None]:
"""Convert to legacy tuple format for backward compatibility."""
return (self.value, self.is_valid, self.error)
class BaseNormalizer(ABC):
"""
Abstract base class for field normalizers.
Each normalizer is responsible for:
1. Cleaning and normalizing raw text
2. Validating the normalized value
3. Returning a standardized result
"""
@property
@abstractmethod
def field_name(self) -> str:
"""The field name this normalizer handles."""
pass
@abstractmethod
def normalize(self, text: str) -> NormalizationResult:
"""
Normalize and validate the input text.
Args:
text: Raw text to normalize
Returns:
NormalizationResult with normalized value or error
"""
pass
def __call__(self, text: str) -> NormalizationResult:
"""Allow using the normalizer as a callable."""
return self.normalize(text)

View File

@@ -0,0 +1,200 @@
"""
Date Normalizer
Handles normalization and validation of invoice dates.
"""
import re
from datetime import datetime
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
from .base import BaseNormalizer, NormalizationResult
class DateNormalizer(BaseNormalizer):
"""
Normalizes dates from Swedish invoices.
Handles various date formats:
- 2025-08-29 (ISO format)
- 2025.08.29 (dot separator)
- 29/08/2025 (European format)
- 29.08.2025 (European with dots)
- 20250829 (compact format)
Output format: YYYY-MM-DD (ISO 8601)
"""
# Date patterns with their parsing logic
PATTERNS = [
# ISO format: 2025-08-29
(
r"(\d{4})-(\d{1,2})-(\d{1,2})",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
# Dot format: 2025.08.29 (common in Swedish)
(
r"(\d{4})\.(\d{1,2})\.(\d{1,2})",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
# European slash format: 29/08/2025
(
r"(\d{1,2})/(\d{1,2})/(\d{4})",
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
),
# European dot format: 29.08.2025
(
r"(\d{1,2})\.(\d{1,2})\.(\d{4})",
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
),
# Compact format: 20250829
(
r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)",
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
),
]
@property
def field_name(self) -> str:
return "Date"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# First, try using shared validator
iso_date = FieldValidators.format_date_iso(text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return NormalizationResult.success(iso_date)
# Fallback: try original patterns for edge cases
for pattern, extractor in self.PATTERNS:
match = re.search(pattern, text)
if match:
try:
year, month, day = extractor(match)
# Validate date
parsed_date = datetime(year, month, day)
# Sanity check: year should be reasonable (2000-2100)
if 2000 <= parsed_date.year <= 2100:
return NormalizationResult.success(
parsed_date.strftime("%Y-%m-%d")
)
except ValueError:
continue
return NormalizationResult.failure(f"Cannot parse date: {text}")
class EnhancedDateNormalizer(DateNormalizer):
"""
Enhanced date parsing with comprehensive format support.
Additional support for:
- Swedish text: "29 december 2024", "29 dec 2024"
- OCR error correction: 2O24-12-29 -> 2024-12-29
"""
# Swedish month names
SWEDISH_MONTHS = {
"januari": 1,
"jan": 1,
"februari": 2,
"feb": 2,
"mars": 3,
"mar": 3,
"april": 4,
"apr": 4,
"maj": 5,
"juni": 6,
"jun": 6,
"juli": 7,
"jul": 7,
"augusti": 8,
"aug": 8,
"september": 9,
"sep": 9,
"sept": 9,
"oktober": 10,
"okt": 10,
"november": 11,
"nov": 11,
"december": 12,
"dec": 12,
}
# Extended patterns
EXTENDED_PATTERNS = [
# ISO format: 2025-08-29, 2025/08/29
("ymd", r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"),
# Dot format: 2025.08.29
("ymd", r"(\d{4})\.(\d{1,2})\.(\d{1,2})"),
# European slash: 29/08/2025
("dmy", r"(\d{1,2})/(\d{1,2})/(\d{4})"),
# European dot: 29.08.2025
("dmy", r"(\d{1,2})\.(\d{1,2})\.(\d{4})"),
# European dash: 29-08-2025
("dmy", r"(\d{1,2})-(\d{1,2})-(\d{4})"),
# Compact: 20250829
("ymd_compact", r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)"),
]
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Apply OCR corrections
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Try shared validator first
iso_date = FieldValidators.format_date_iso(corrected_text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return NormalizationResult.success(iso_date)
# Try Swedish text date pattern: "29 december 2024" or "29 dec 2024"
swedish_pattern = r"(\d{1,2})\s+([a-z\u00e5\u00e4\u00f6]+)\s+(\d{4})"
match = re.search(swedish_pattern, corrected_text.lower())
if match:
day = int(match.group(1))
month_name = match.group(2)
year = int(match.group(3))
if month_name in self.SWEDISH_MONTHS:
month = self.SWEDISH_MONTHS[month_name]
try:
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
except ValueError:
pass
# Extended patterns
for fmt, pattern in self.EXTENDED_PATTERNS:
match = re.search(pattern, corrected_text)
if match:
try:
if fmt == "ymd":
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
elif fmt == "dmy":
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
elif fmt == "ymd_compact":
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
else:
continue
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
except ValueError:
continue
return NormalizationResult.failure(f"Cannot parse date: {text[:50]}")

View File

@@ -0,0 +1,84 @@
"""
Invoice Number Normalizer
Handles normalization and validation of invoice numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class InvoiceNumberNormalizer(BaseNormalizer):
"""
Normalizes invoice numbers from Swedish invoices.
Invoice numbers can be:
- Pure digits: 12345678
- Alphanumeric: A3861, INV-2024-001, F12345
- With separators: 2024/001, 2024-001
Strategy:
1. Look for common invoice number patterns
2. Prefer shorter, more specific matches over long digit sequences
"""
@property
def field_name(self) -> str:
return "InvoiceNumber"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
# Examples: A3861, F12345, INV001
alpha_patterns = [
r"\b([A-Z]{1,3}\d{3,10})\b", # A3861, INV12345
r"\b(\d{3,10}[A-Z]{1,3})\b", # 12345A
r"\b([A-Z]{2,5}[-/]?\d{3,10})\b", # INV-12345, FAK12345
]
for pattern in alpha_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return NormalizationResult.success(match.group(1).upper())
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
year_pattern = r"\b(20\d{2}[-/]\d{3,8})\b"
match = re.search(year_pattern, text)
if match:
return NormalizationResult.success(match.group(1))
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
# This avoids capturing long OCR numbers
digit_sequences = re.findall(r"\b(\d{3,10})\b", text)
if digit_sequences:
# Prefer shorter sequences (more likely to be invoice number)
# Also filter out sequences that look like dates (8 digits starting with 20)
valid_sequences = []
for seq in digit_sequences:
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith("20"):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return NormalizationResult.success(min(valid_sequences, key=len))
# Fallback: extract all digits if nothing else works
digits = re.sub(r"\D", "", text)
if len(digits) >= 3:
# Limit to first 15 digits to avoid very long sequences
return NormalizationResult.success_with_warning(
digits[:15], "Fallback extraction"
)
return NormalizationResult.failure(
f"Cannot extract invoice number from: {text[:50]}"
)

View File

@@ -0,0 +1,37 @@
"""
OCR Number Normalizer
Handles normalization and validation of OCR reference numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class OcrNumberNormalizer(BaseNormalizer):
"""
Normalizes OCR (Optical Character Recognition) reference numbers.
OCR numbers in Swedish payment systems:
- Minimum 5 digits
- Used for automated payment matching
"""
@property
def field_name(self) -> str:
return "OCR"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
digits = re.sub(r"\D", "", text)
if len(digits) < 5:
return NormalizationResult.failure(
f"Too few digits for OCR: {len(digits)}"
)
return NormalizationResult.success(digits)

View File

@@ -0,0 +1,90 @@
"""
Plusgiro Normalizer
Handles normalization and validation of Swedish Plusgiro numbers.
"""
import re
from shared.utils.validators import FieldValidators
from .base import BaseNormalizer, NormalizationResult
class PlusgiroNormalizer(BaseNormalizer):
"""
Normalizes Swedish Plusgiro numbers.
Plusgiro rules:
- 2 to 8 digits
- Last digit is Luhn (Mod10) check digit
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
Display pattern: ^\\d{1,7}-\\d$
Normalized pattern: ^\\d{2,8}$
Note: Text may contain both BG and PG numbers. We specifically look for
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
"""
@property
def field_name(self) -> str:
return "Plusgiro"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
# This is distinct from BG format which has 4 digits after the dash
# Pattern allows spaces within the number like "486 98 63-6"
# (?<![0-9]) ensures we don't start from within another number (like BG)
pg_matches = re.findall(r"(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)", text)
if pg_matches:
# Try each match and find one with valid Luhn
for match in pg_matches:
# Remove spaces from the first part
digits = re.sub(r"\s", "", match[0]) + match[1]
if 2 <= len(digits) <= 8 and FieldValidators.luhn_checksum(digits):
# Valid PG found
formatted = f"{digits[:-1]}-{digits[-1]}"
return NormalizationResult.success(formatted)
# No valid Luhn, use first match with most digits
best_match = max(pg_matches, key=lambda m: len(re.sub(r"\s", "", m[0])))
digits = re.sub(r"\s", "", best_match[0]) + best_match[1]
if 2 <= len(digits) <= 8:
formatted = f"{digits[:-1]}-{digits[-1]}"
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# If no PG format found, extract all digits and format as PG
# This handles cases where the number might be in BG format or raw digits
all_digits = re.sub(r"\D", "", text)
# Try to find a valid 2-8 digit sequence
if 2 <= len(all_digits) <= 8:
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
if FieldValidators.luhn_checksum(all_digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
# Try to find any 2-8 digit sequence in text
digit_match = re.search(r"\b(\d{2,8})\b", text)
if digit_match:
digits = digit_match.group(1)
formatted = f"{digits[:-1]}-{digits[-1]}"
if FieldValidators.luhn_checksum(digits):
return NormalizationResult.success(formatted)
else:
return NormalizationResult.success_with_warning(
formatted, "Luhn checksum failed (possible OCR error)"
)
return NormalizationResult.failure("No valid Plusgiro found in text")

View File

@@ -0,0 +1,60 @@
"""
Supplier Organization Number Normalizer
Handles normalization and validation of Swedish organization numbers.
"""
import re
from .base import BaseNormalizer, NormalizationResult
class SupplierOrgNumberNormalizer(BaseNormalizer):
"""
Normalizes Swedish supplier organization numbers.
Extracts organization number in format: NNNNNN-NNNN (10 digits)
Also handles VAT numbers: SE + 10 digits + 01
Examples:
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
'Momsreg.nr SE556123456701' -> '556123-4567'
"""
@property
def field_name(self) -> str:
return "supplier_org_number"
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Pattern 1: Standard org number format: NNNNNN-NNNN
org_pattern = r"\b(\d{6})-?(\d{4})\b"
match = re.search(org_pattern, text)
if match:
org_num = f"{match.group(1)}-{match.group(2)}"
return NormalizationResult.success(org_num)
# Pattern 2: VAT number format: SE + 10 digits + 01
vat_pattern = r"SE\s*(\d{10})01"
match = re.search(vat_pattern, text, re.IGNORECASE)
if match:
digits = match.group(1)
org_num = f"{digits[:6]}-{digits[6:]}"
return NormalizationResult.success(org_num)
# Pattern 3: Just 10 consecutive digits
digits_pattern = r"\b(\d{10})\b"
match = re.search(digits_pattern, text)
if match:
digits = match.group(1)
# Validate: first digit should be 1-9 for Swedish org numbers
if digits[0] in "123456789":
org_num = f"{digits[:6]}-{digits[6:]}"
return NormalizationResult.success(org_num)
return NormalizationResult.failure(
f"Cannot extract org number from: {text[:100]}"
)

View File

@@ -9,12 +9,12 @@ import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from inference.data.admin_db import AdminDB
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.data.repositories import DocumentRepository, AnnotationRepository
from inference.web.core.auth import AdminTokenDep
from inference.web.services.autolabel import get_auto_label_service
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import (
@@ -36,6 +36,31 @@ from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Global repository instances
_doc_repo: DocumentRepository | None = None
_ann_repo: AnnotationRepository | None = None
def get_doc_repository() -> DocumentRepository:
"""Get the DocumentRepository instance."""
global _doc_repo
if _doc_repo is None:
_doc_repo = DocumentRepository()
return _doc_repo
def get_ann_repository() -> AnnotationRepository:
"""Get the AnnotationRepository instance."""
global _ann_repo
if _ann_repo is None:
_ann_repo = AnnotationRepository()
return _ann_repo
# Type aliases for dependency injection
DocRepoDep = Annotated[DocumentRepository, Depends(get_doc_repository)]
AnnRepoDep = Annotated[AnnotationRepository, Depends(get_ann_repository)]
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
@@ -71,17 +96,17 @@ def create_annotation_router() -> APIRouter:
document_id: str,
page_number: int,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
) -> FileResponse | StreamingResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Validate page number
@@ -137,7 +162,8 @@ def create_annotation_router() -> APIRouter:
async def list_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
page_number: Annotated[
int | None,
Query(ge=1, description="Filter by page number"),
@@ -146,16 +172,16 @@ def create_annotation_router() -> APIRouter:
"""List annotations for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id, page_number)
raw_annotations = ann_repo.get_for_document(document_id, page_number)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
@@ -204,17 +230,18 @@ def create_annotation_router() -> APIRouter:
document_id: str,
request: AnnotationCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationResponse:
"""Create a new annotation."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Validate page number
@@ -244,7 +271,7 @@ def create_annotation_router() -> APIRouter:
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
# Create annotation
annotation_id = db.create_annotation(
annotation_id = ann_repo.create(
document_id=document_id,
page_number=request.page_number,
class_id=request.class_id,
@@ -285,22 +312,23 @@ def create_annotation_router() -> APIRouter:
annotation_id: str,
request: AnnotationUpdate,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationResponse:
"""Update an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
annotation = ann_repo.get(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
@@ -349,7 +377,7 @@ def create_annotation_router() -> APIRouter:
# Update annotation
if update_kwargs:
success = db.update_annotation(annotation_id, **update_kwargs)
success = ann_repo.update(annotation_id, **update_kwargs)
if not success:
raise HTTPException(
status_code=500,
@@ -374,22 +402,23 @@ def create_annotation_router() -> APIRouter:
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> dict:
"""Delete an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
annotation = ann_repo.get(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
@@ -404,7 +433,7 @@ def create_annotation_router() -> APIRouter:
)
# Delete annotation
db.delete_annotation(annotation_id)
ann_repo.delete(annotation_id)
return {
"status": "deleted",
@@ -431,17 +460,18 @@ def create_annotation_router() -> APIRouter:
document_id: str,
request: AutoLabelRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AutoLabelResponse:
"""Trigger auto-labeling for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Validate field values
@@ -457,7 +487,8 @@ def create_annotation_router() -> APIRouter:
document_id=document_id,
file_path=document.file_path,
field_values=request.field_values,
db=db,
doc_repo=doc_repo,
ann_repo=ann_repo,
replace_existing=request.replace_existing,
)
@@ -486,7 +517,8 @@ def create_annotation_router() -> APIRouter:
async def delete_all_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
source: Annotated[
str | None,
Query(description="Filter by source (manual, auto, imported)"),
@@ -502,21 +534,21 @@ def create_annotation_router() -> APIRouter:
detail=f"Invalid source: {source}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Delete annotations
deleted_count = db.delete_annotations_for_document(document_id, source)
deleted_count = ann_repo.delete_for_document(document_id, source)
# Update document status if all annotations deleted
remaining = db.get_annotations_for_document(document_id)
remaining = ann_repo.get_for_document(document_id)
if not remaining:
db.update_document_status(document_id, "pending")
doc_repo.update_status(document_id, "pending")
return {
"status": "deleted",
@@ -543,23 +575,24 @@ def create_annotation_router() -> APIRouter:
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
) -> AnnotationVerifyResponse:
"""Verify an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Verify the annotation
annotation = db.verify_annotation(annotation_id, admin_token)
annotation = ann_repo.verify(annotation_id, admin_token)
if annotation is None:
raise HTTPException(
status_code=404,
@@ -589,18 +622,19 @@ def create_annotation_router() -> APIRouter:
annotation_id: str,
request: AnnotationOverrideRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
doc_repo: DocRepoDep,
ann_repo: AnnRepoDep,
) -> AnnotationOverrideResponse:
"""Override an auto-generated annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
# Get document
document = doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
detail="Document not found",
)
# Build updates dict from request
@@ -632,7 +666,7 @@ def create_annotation_router() -> APIRouter:
)
# Override the annotation
annotation = db.override_annotation(
annotation = ann_repo.override(
annotation_id=annotation_id,
admin_token=admin_token,
change_reason=request.reason,
@@ -646,7 +680,7 @@ def create_annotation_router() -> APIRouter:
)
# Get history to return history_id
history_records = db.get_annotation_history(UUID(annotation_id))
history_records = ann_repo.get_history(UUID(annotation_id))
latest_history = history_records[0] if history_records else None
return AnnotationOverrideResponse(

View File

@@ -1,10 +1,8 @@
"""Augmentation API routes."""
from typing import Annotated
from fastapi import APIRouter, Query
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminDBDep, AdminTokenDep
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep, DatasetRepoDep
from inference.web.schemas.admin.augmentation import (
AugmentationBatchRequest,
AugmentationBatchResponse,
@@ -13,7 +11,6 @@ from inference.web.schemas.admin.augmentation import (
AugmentationPreviewResponse,
AugmentationTypeInfo,
AugmentationTypesResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
PresetInfo,
PresetsResponse,
@@ -78,7 +75,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
document_id: str,
request: AugmentationPreviewRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""
@@ -88,7 +85,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
service = AugmentationService(doc_repo=docs)
return await service.preview_single(
document_id=document_id,
page=page,
@@ -105,13 +102,13 @@ def register_augmentation_routes(router: APIRouter) -> None:
document_id: str,
config: AugmentationConfigSchema,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""Preview complete augmentation pipeline on a document page."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
service = AugmentationService(doc_repo=docs)
return await service.preview_config(
document_id=document_id,
page=page,
@@ -126,7 +123,8 @@ def register_augmentation_routes(router: APIRouter) -> None:
async def create_augmented_dataset(
request: AugmentationBatchRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
datasets: DatasetRepoDep,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
@@ -136,7 +134,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
service = AugmentationService(doc_repo=docs, dataset_repo=datasets)
return await service.create_augmented_dataset(
source_dataset_id=request.dataset_id,
config=request.config,
@@ -151,12 +149,12 @@ def register_augmentation_routes(router: APIRouter) -> None:
)
async def list_augmented_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets: DatasetRepoDep,
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
offset: int = Query(default=0, ge=0, description="Offset"),
) -> AugmentedDatasetListResponse:
"""List all augmented datasets."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
service = AugmentationService(dataset_repo=datasets)
return await service.list_augmented_datasets(limit=limit, offset=offset)

View File

@@ -10,7 +10,7 @@ from datetime import datetime, timedelta
from fastapi import APIRouter
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import AdminTokenDep, TokenRepoDep
from inference.web.schemas.admin import (
AdminTokenCreate,
AdminTokenResponse,
@@ -35,7 +35,7 @@ def create_auth_router() -> APIRouter:
)
async def create_token(
request: AdminTokenCreate,
db: AdminDBDep,
tokens: TokenRepoDep,
) -> AdminTokenResponse:
"""Create a new admin token."""
# Generate secure token
@@ -47,7 +47,7 @@ def create_auth_router() -> APIRouter:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
# Create token in database
db.create_admin_token(
tokens.create(
token=token,
name=request.name,
expires_at=expires_at,
@@ -70,10 +70,10 @@ def create_auth_router() -> APIRouter:
)
async def revoke_token(
admin_token: AdminTokenDep,
db: AdminDBDep,
tokens: TokenRepoDep,
) -> dict:
"""Revoke the current admin token."""
db.deactivate_admin_token(admin_token)
tokens.deactivate(admin_token)
return {
"status": "revoked",
"message": "Admin token has been revoked",

View File

@@ -12,7 +12,12 @@ from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.config import DEFAULT_DPI, StorageConfig
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import (
AdminTokenDep,
DocumentRepoDep,
AnnotationRepoDep,
TrainingTaskRepoDep,
)
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import (
AnnotationItem,
@@ -87,7 +92,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
async def upload_document(
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
file: UploadFile = File(..., description="PDF or image file"),
auto_label: Annotated[
bool,
@@ -142,7 +147,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record (token only used for auth, not stored)
document_id = db.create_document(
document_id = docs.create(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
@@ -184,7 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
auto_label_started = False
if auto_label:
# Auto-labeling will be triggered by a background task
db.update_document_status(
docs.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
@@ -214,7 +219,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
async def list_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
@@ -270,7 +276,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
detail=f"Invalid auto_label_status: {auto_label_status}",
)
documents, total = db.get_documents_by_token(
documents, total = docs.get_paginated(
admin_token=admin_token,
status=status,
upload_source=upload_source,
@@ -285,7 +291,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
# Get annotation counts and build items
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
doc_annotations = annotations.get_for_document(str(doc.document_id))
# Determine if document can be annotated (not locked)
can_annotate = True
@@ -301,7 +307,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
page_count=doc.page_count,
status=DocumentStatus(doc.status),
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
annotation_count=len(annotations),
annotation_count=len(doc_annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
@@ -330,10 +336,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
async def get_document_stats(
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
) -> DocumentStatsResponse:
"""Get document statistics."""
counts = db.count_documents_by_status(admin_token)
counts = docs.count_by_status(admin_token)
return DocumentStatsResponse(
total=sum(counts.values()),
@@ -343,6 +349,26 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
exported=counts.get("exported", 0),
)
@router.get(
"/categories",
response_model=DocumentCategoriesResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get available categories",
description="Get list of all available document categories.",
)
async def get_categories(
admin_token: AdminTokenDep,
docs: DocumentRepoDep,
) -> DocumentCategoriesResponse:
"""Get all available document categories."""
categories = docs.get_categories()
return DocumentCategoriesResponse(
categories=categories,
total=len(categories),
)
@router.get(
"/{document_id}",
response_model=DocumentDetailResponse,
@@ -356,12 +382,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
async def get_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
tasks: TrainingTaskRepoDep,
) -> DocumentDetailResponse:
"""Get document details."""
_validate_uuid(document_id, "document_id")
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -369,8 +397,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id)
annotations = [
raw_annotations = annotations.get_for_document(document_id)
annotation_items = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
@@ -416,10 +444,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
# Get training history (Phase 5)
training_history = []
training_links = db.get_document_training_tasks(document.document_id)
training_links = tasks.get_document_training_tasks(document.document_id)
for link in training_links:
# Get task details
task = db.get_training_task(str(link.task_id))
task = tasks.get(str(link.task_id))
if task:
# Build metrics
metrics = None
@@ -455,7 +483,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
annotations=annotations,
annotations=annotation_items,
image_urls=image_urls,
training_history=training_history,
created_at=document.created_at,
@@ -474,13 +502,13 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
async def delete_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
) -> dict:
"""Delete a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -505,7 +533,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
logger.warning(f"Failed to delete admin images: {e}")
# Delete from database
db.delete_document(document_id)
docs.delete(document_id)
return {
"status": "deleted",
@@ -525,7 +553,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
async def update_document_status(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
status: Annotated[
str,
Query(description="New status"),
@@ -547,7 +576,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -560,16 +589,15 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document
annotations = db.get_annotations_for_document(document_id)
doc_annotations = annotations.get_for_document(document_id)
if annotations:
if doc_annotations:
db_save_result = save_manual_annotations_to_document_db(
document=document,
annotations=annotations,
db=db,
annotations=doc_annotations,
)
db.update_document_status(document_id, status)
docs.update_status(document_id, status)
response = {
"status": "updated",
@@ -597,7 +625,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
async def update_document_group_key(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
group_key: Annotated[
str | None,
Query(description="New group key (null to clear)"),
@@ -614,7 +642,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -622,7 +650,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
)
# Update group key
db.update_document_group_key(document_id, group_key)
docs.update_group_key(document_id, group_key)
return {
"status": "updated",
@@ -631,26 +659,6 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"message": "Document group key updated",
}
@router.get(
"/categories",
response_model=DocumentCategoriesResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get available categories",
description="Get list of all available document categories.",
)
async def get_categories(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentCategoriesResponse:
"""Get all available document categories."""
categories = db.get_document_categories()
return DocumentCategoriesResponse(
categories=categories,
total=len(categories),
)
@router.patch(
"/{document_id}/category",
responses={
@@ -663,14 +671,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
async def update_document_category(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
request: DocumentUpdateRequest,
) -> dict:
"""Update document category."""
_validate_uuid(document_id, "document_id")
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -679,7 +687,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
# Update category if provided
if request.category is not None:
db.update_document_category(document_id, request.category)
docs.update_category(document_id, request.category)
return {
"status": "updated",

View File

@@ -4,21 +4,18 @@ Admin Document Lock Routes
FastAPI endpoints for annotation lock management.
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep
from inference.web.schemas.admin import (
AnnotationLockRequest,
AnnotationLockResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
@@ -49,14 +46,14 @@ def create_locks_router() -> APIRouter:
async def acquire_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Acquire annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -64,7 +61,7 @@ def create_locks_router() -> APIRouter:
)
# Attempt to acquire lock
updated_doc = db.acquire_annotation_lock(
updated_doc = docs.acquire_annotation_lock(
document_id=document_id,
admin_token=admin_token,
duration_seconds=request.duration_seconds,
@@ -96,7 +93,7 @@ def create_locks_router() -> APIRouter:
async def release_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
force: Annotated[
bool,
Query(description="Force release (admin override)"),
@@ -106,7 +103,7 @@ def create_locks_router() -> APIRouter:
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -114,7 +111,7 @@ def create_locks_router() -> APIRouter:
)
# Release lock
updated_doc = db.release_annotation_lock(
updated_doc = docs.release_annotation_lock(
document_id=document_id,
admin_token=admin_token,
force=force,
@@ -147,14 +144,14 @@ def create_locks_router() -> APIRouter:
async def extend_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Extend annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
document = docs.get_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
@@ -162,7 +159,7 @@ def create_locks_router() -> APIRouter:
)
# Attempt to extend lock
updated_doc = db.extend_annotation_lock(
updated_doc = docs.extend_annotation_lock(
document_id=document_id,
admin_token=admin_token,
additional_seconds=request.duration_seconds,

View File

@@ -5,7 +5,14 @@ from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import (
AdminTokenDep,
DatasetRepoDep,
DocumentRepoDep,
AnnotationRepoDep,
ModelVersionRepoDep,
TrainingTaskRepoDep,
)
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetDetailResponse,
@@ -36,7 +43,9 @@ def register_dataset_routes(router: APIRouter) -> None:
async def create_dataset(
request: DatasetCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets: DatasetRepoDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from inference.web.services.dataset_builder import DatasetBuilder
@@ -48,7 +57,7 @@ def register_dataset_routes(router: APIRouter) -> None:
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
)
dataset = db.create_dataset(
dataset = datasets.create(
name=request.name,
description=request.description,
train_ratio=request.train_ratio,
@@ -67,7 +76,12 @@ def register_dataset_routes(router: APIRouter) -> None:
detail="Storage not configured for local access",
)
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
builder = DatasetBuilder(
datasets_repo=datasets,
documents_repo=docs,
annotations_repo=annotations,
base_dir=datasets_dir,
)
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
@@ -94,18 +108,18 @@ def register_dataset_routes(router: APIRouter) -> None:
)
async def list_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets_repo: DatasetRepoDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
datasets_list, total = datasets_repo.get_paginated(status=status, limit=limit, offset=offset)
# Get active training tasks for each dataset (graceful degradation on error)
dataset_ids = [str(d.dataset_id) for d in datasets]
dataset_ids = [str(d.dataset_id) for d in datasets_list]
try:
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
active_tasks = datasets_repo.get_active_training_tasks(dataset_ids)
except Exception:
logger.exception("Failed to fetch active training tasks")
active_tasks = {}
@@ -127,7 +141,7 @@ def register_dataset_routes(router: APIRouter) -> None:
total_annotations=d.total_annotations,
created_at=d.created_at,
)
for d in datasets
for d in datasets_list
],
)
@@ -139,15 +153,15 @@ def register_dataset_routes(router: APIRouter) -> None:
async def get_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets_repo: DatasetRepoDep,
) -> DatasetDetailResponse:
"""Get dataset details with document list."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
docs = db.get_dataset_documents(dataset_id)
docs = datasets_repo.get_documents(dataset_id)
return DatasetDetailResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
@@ -187,14 +201,14 @@ def register_dataset_routes(router: APIRouter) -> None:
async def delete_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets_repo: DatasetRepoDep,
) -> dict:
"""Delete a dataset and its files."""
import shutil
from pathlib import Path
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
@@ -203,7 +217,7 @@ def register_dataset_routes(router: APIRouter) -> None:
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
db.delete_dataset(dataset_id)
datasets_repo.delete(dataset_id)
return {"message": "Dataset deleted"}
@router.post(
@@ -216,7 +230,9 @@ def register_dataset_routes(router: APIRouter) -> None:
dataset_id: str,
request: DatasetTrainRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
datasets_repo: DatasetRepoDep,
models: ModelVersionRepoDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset.
@@ -224,7 +240,7 @@ def register_dataset_routes(router: APIRouter) -> None:
The training will use that model as the starting point instead of a pretrained model.
"""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
dataset = datasets_repo.get(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.status != "ready":
@@ -239,7 +255,7 @@ def register_dataset_routes(router: APIRouter) -> None:
base_model_version_id = config_dict.get("base_model_version_id")
if base_model_version_id:
_validate_uuid(base_model_version_id, "base_model_version_id")
base_model = db.get_model_version(base_model_version_id)
base_model = models.get(base_model_version_id)
if not base_model:
raise HTTPException(
status_code=404,
@@ -254,7 +270,7 @@ def register_dataset_routes(router: APIRouter) -> None:
base_model.model_path,
)
task_id = db.create_training_task(
task_id = tasks.create(
admin_token=admin_token,
name=request.name,
task_type="finetune" if base_model_version_id else "train",

View File

@@ -5,7 +5,12 @@ from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import (
AdminTokenDep,
DocumentRepoDep,
AnnotationRepoDep,
TrainingTaskRepoDep,
)
from inference.web.schemas.admin import (
ModelMetrics,
TrainingDocumentItem,
@@ -35,7 +40,9 @@ def register_document_routes(router: APIRouter) -> None:
)
async def get_training_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
tasks: TrainingTaskRepoDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
@@ -58,7 +65,7 @@ def register_document_routes(router: APIRouter) -> None:
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
documents, total = db.get_documents_for_training(
documents, total = docs.get_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
@@ -70,21 +77,21 @@ def register_document_routes(router: APIRouter) -> None:
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
doc_annotations = annotations.get_for_document(str(doc.document_id))
sources = {"manual": 0, "auto": 0}
for ann in annotations:
for ann in doc_annotations:
if ann.source in sources:
sources[ann.source] += 1
training_links = db.get_document_training_tasks(doc.document_id)
training_links = tasks.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(annotations),
annotation_count=len(doc_annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
@@ -110,7 +117,7 @@ def register_document_routes(router: APIRouter) -> None:
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks: TrainingTaskRepoDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
@@ -118,7 +125,7 @@ def register_document_routes(router: APIRouter) -> None:
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
@@ -155,7 +162,7 @@ def register_document_routes(router: APIRouter) -> None:
)
async def get_completed_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks_repo: TrainingTaskRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
@@ -170,7 +177,7 @@ def register_document_routes(router: APIRouter) -> None:
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
tasks, total = db.get_training_tasks_by_token(
task_list, total = tasks_repo.get_paginated(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
@@ -178,7 +185,7 @@ def register_document_routes(router: APIRouter) -> None:
)
items = []
for task in tasks:
for task in task_list:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from fastapi import APIRouter, HTTPException
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep, AnnotationRepoDep
from inference.web.schemas.admin import (
ExportRequest,
ExportResponse,
@@ -31,7 +31,8 @@ def register_export_routes(router: APIRouter) -> None:
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
docs: DocumentRepoDep,
annotations: AnnotationRepoDep,
) -> ExportResponse:
"""Export annotations for training."""
from inference.web.services.storage_helpers import get_storage_helper
@@ -45,7 +46,7 @@ def register_export_routes(router: APIRouter) -> None:
detail=f"Unsupported export format: {request.format}",
)
documents = db.get_labeled_documents_for_export(admin_token)
documents = docs.get_labeled_for_export(admin_token)
if not documents:
raise HTTPException(
@@ -78,13 +79,13 @@ def register_export_routes(router: APIRouter) -> None:
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = db.get_annotations_for_document(str(doc.document_id))
doc_annotations = annotations.get_for_document(str(doc.document_id))
if not annotations:
if not doc_annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
page_annotations = [a for a in doc_annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, Request
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import AdminTokenDep, ModelVersionRepoDep
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
@@ -33,7 +33,7 @@ def register_model_routes(router: APIRouter) -> None:
async def create_model_version(
request: ModelVersionCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Create a new model version."""
if request.task_id:
@@ -41,7 +41,7 @@ def register_model_routes(router: APIRouter) -> None:
if request.dataset_id:
_validate_uuid(request.dataset_id, "dataset_id")
model = db.create_model_version(
model = models.create(
version=request.version,
name=request.name,
model_path=request.model_path,
@@ -70,13 +70,13 @@ def register_model_routes(router: APIRouter) -> None:
)
async def list_model_versions(
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> ModelVersionListResponse:
"""List model versions with optional status filter."""
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
model_list, total = models.get_paginated(status=status, limit=limit, offset=offset)
return ModelVersionListResponse(
total=total,
limit=limit,
@@ -94,7 +94,7 @@ def register_model_routes(router: APIRouter) -> None:
activated_at=m.activated_at,
created_at=m.created_at,
)
for m in models
for m in model_list
],
)
@@ -106,10 +106,10 @@ def register_model_routes(router: APIRouter) -> None:
)
async def get_active_model(
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ActiveModelResponse:
"""Get the currently active model version."""
model = db.get_active_model_version()
model = models.get_active()
if not model:
return ActiveModelResponse(has_active_model=False, model=None)
@@ -137,11 +137,11 @@ def register_model_routes(router: APIRouter) -> None:
async def get_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionDetailResponse:
"""Get detailed model version information."""
_validate_uuid(version_id, "version_id")
model = db.get_model_version(version_id)
model = models.get(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
@@ -176,11 +176,11 @@ def register_model_routes(router: APIRouter) -> None:
version_id: str,
request: ModelVersionUpdateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Update model version metadata."""
_validate_uuid(version_id, "version_id")
model = db.update_model_version(
model = models.update(
version_id=version_id,
name=request.name,
description=request.description,
@@ -205,11 +205,11 @@ def register_model_routes(router: APIRouter) -> None:
version_id: str,
request: Request,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
model = db.activate_model_version(version_id)
model = models.activate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
@@ -242,11 +242,11 @@ def register_model_routes(router: APIRouter) -> None:
async def deactivate_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Deactivate a model version."""
_validate_uuid(version_id, "version_id")
model = db.deactivate_model_version(version_id)
model = models.deactivate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
@@ -264,11 +264,11 @@ def register_model_routes(router: APIRouter) -> None:
async def archive_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> ModelVersionResponse:
"""Archive a model version."""
_validate_uuid(version_id, "version_id")
model = db.archive_model_version(version_id)
model = models.archive(version_id)
if not model:
raise HTTPException(
status_code=400,
@@ -288,11 +288,11 @@ def register_model_routes(router: APIRouter) -> None:
async def delete_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
models: ModelVersionRepoDep,
) -> dict:
"""Delete a model version."""
_validate_uuid(version_id, "version_id")
success = db.delete_model_version(version_id)
success = models.delete(version_id)
if not success:
raise HTTPException(
status_code=400,

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.core.auth import AdminTokenDep, TrainingTaskRepoDep
from inference.web.schemas.admin import (
TrainingLogItem,
TrainingLogsResponse,
@@ -40,12 +40,12 @@ def register_task_routes(router: APIRouter) -> None:
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
config_dict = request.config.model_dump() if request.config else {}
task_id = db.create_training_task(
task_id = tasks.create(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
@@ -73,7 +73,7 @@ def register_task_routes(router: APIRouter) -> None:
)
async def list_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks_repo: TrainingTaskRepoDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
@@ -95,7 +95,7 @@ def register_task_routes(router: APIRouter) -> None:
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
tasks, total = db.get_training_tasks_by_token(
task_list, total = tasks_repo.get_paginated(
admin_token=admin_token,
status=status,
limit=limit,
@@ -114,7 +114,7 @@ def register_task_routes(router: APIRouter) -> None:
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in tasks
for task in task_list
]
return TrainingTaskListResponse(
@@ -137,12 +137,12 @@ def register_task_routes(router: APIRouter) -> None:
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
@@ -181,12 +181,12 @@ def register_task_routes(router: APIRouter) -> None:
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks: TrainingTaskRepoDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
@@ -199,7 +199,7 @@ def register_task_routes(router: APIRouter) -> None:
detail=f"Cannot cancel task with status: {task.status}",
)
success = db.cancel_training_task(task_id)
success = tasks.cancel(task_id)
if not success:
raise HTTPException(
status_code=500,
@@ -225,7 +225,7 @@ def register_task_routes(router: APIRouter) -> None:
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
tasks: TrainingTaskRepoDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
@@ -238,14 +238,14 @@ def register_task_routes(router: APIRouter) -> None:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
task = tasks.get_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
logs = db.get_training_logs(task_id, limit, offset)
logs = tasks.get_logs(task_id, limit, offset)
items = [
TrainingLogItem(

View File

@@ -14,13 +14,25 @@ from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse
from inference.data.admin_db import AdminDB
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.data.repositories import BatchUploadRepository
from inference.web.core.auth import validate_admin_token
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__)
# Global repository instance
_batch_repo: BatchUploadRepository | None = None
def get_batch_repository() -> BatchUploadRepository:
"""Get the BatchUploadRepository instance."""
global _batch_repo
if _batch_repo is None:
_batch_repo = BatchUploadRepository()
return _batch_repo
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
@@ -31,7 +43,7 @@ async def upload_batch(
async_mode: bool = Form(default=True),
auto_label: bool = Form(default=True),
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
) -> dict:
"""Upload a batch of documents via ZIP file.
@@ -119,7 +131,7 @@ async def upload_batch(
)
else:
# Sync mode: Process immediately and return results
service = BatchUploadService(admin_db)
service = BatchUploadService(batch_repo)
result = service.process_zip_upload(
admin_token=admin_token,
zip_filename=file.filename,
@@ -148,14 +160,14 @@ async def upload_batch(
async def get_batch_status(
batch_id: str,
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
) -> dict:
"""Get batch upload status and file processing details.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
admin_db: Admin database interface
batch_repo: Batch upload repository
Returns:
Batch status with file processing details
@@ -167,7 +179,7 @@ async def get_batch_status(
raise HTTPException(status_code=400, detail="Invalid batch ID format")
# Check batch exists and verify ownership
batch = admin_db.get_batch_upload(batch_uuid)
batch = batch_repo.get(batch_uuid)
if not batch:
raise HTTPException(status_code=404, detail="Batch not found")
@@ -179,7 +191,7 @@ async def get_batch_status(
)
# Now safe to return details
service = BatchUploadService(admin_db)
service = BatchUploadService(batch_repo)
result = service.get_batch_status(batch_id)
return result
@@ -188,7 +200,7 @@ async def get_batch_status(
@router.get("/list")
async def list_batch_uploads(
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
limit: int = 50,
offset: int = 0,
) -> dict:
@@ -196,7 +208,7 @@ async def list_batch_uploads(
Args:
admin_token: Admin authentication token
admin_db: Admin database interface
batch_repo: Batch upload repository
limit: Maximum number of results
offset: Offset for pagination
@@ -210,7 +222,7 @@ async def list_batch_uploads(
raise HTTPException(status_code=400, detail="Offset must be non-negative")
# Get batch uploads filtered by admin token
batches, total = admin_db.get_batch_uploads_by_token(
batches, total = batch_repo.get_paginated(
admin_token=admin_token,
limit=limit,
offset=offset,

View File

@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from inference.data.admin_db import AdminDB
from inference.data.repositories import DocumentRepository
from inference.web.schemas.labeling import PreLabelResponse
from inference.web.schemas.common import ErrorResponse
from inference.web.services.storage_helpers import get_storage_helper
@@ -46,9 +46,9 @@ def _convert_pdf_to_images(
pdf_doc.close()
def get_admin_db() -> AdminDB:
"""Get admin database instance."""
return AdminDB()
def get_doc_repository() -> DocumentRepository:
"""Get document repository instance."""
return DocumentRepository()
def create_labeling_router(
@@ -85,7 +85,7 @@ def create_labeling_router(
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
),
db: AdminDB = Depends(get_admin_db),
doc_repo: DocumentRepository = Depends(get_doc_repository),
) -> PreLabelResponse:
"""
Upload a document with expected field values for pre-labeling.
@@ -149,7 +149,7 @@ def create_labeling_router(
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record with field_values
document_id = db.create_document(
document_id = doc_repo.create(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
@@ -172,7 +172,7 @@ def create_labeling_router(
)
# Update file path in database (using storage path)
db.update_document_file_path(document_id, storage_path)
doc_repo.update_file_path(document_id, storage_path)
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
@@ -184,7 +184,7 @@ def create_labeling_router(
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="pending",

View File

@@ -51,7 +51,7 @@ from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, st
from inference.web.api.v1.batch.routes import router as batch_upload_router
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
from inference.data.admin_db import AdminDB
from inference.data.repositories import ModelVersionRepository
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
@@ -75,8 +75,8 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
def get_active_model_path():
"""Resolve active model path from database."""
try:
db = AdminDB()
active_model = db.get_active_model_version()
model_repo = ModelVersionRepository()
active_model = model_repo.get_active()
if active_model and active_model.model_path:
return active_model.model_path
except Exception as e:
@@ -139,8 +139,7 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
# Start batch upload queue
try:
admin_db = AdminDB()
batch_service = BatchUploadService(admin_db)
batch_service = BatchUploadService()
init_batch_queue(batch_service)
logger.info("Batch upload queue started")
except Exception as e:

View File

@@ -4,7 +4,24 @@ Core Components
Reusable core functionality: authentication, rate limiting, scheduling.
"""
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
from inference.web.core.auth import (
validate_admin_token,
get_token_repository,
get_document_repository,
get_annotation_repository,
get_dataset_repository,
get_training_task_repository,
get_model_version_repository,
get_batch_upload_repository,
AdminTokenDep,
TokenRepoDep,
DocumentRepoDep,
AnnotationRepoDep,
DatasetRepoDep,
TrainingTaskRepoDep,
ModelVersionRepoDep,
BatchUploadRepoDep,
)
from inference.web.core.rate_limiter import RateLimiter
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from inference.web.core.autolabel_scheduler import (
@@ -12,12 +29,25 @@ from inference.web.core.autolabel_scheduler import (
stop_autolabel_scheduler,
get_autolabel_scheduler,
)
from inference.web.core.task_interface import TaskRunner, TaskStatus, TaskManager
__all__ = [
"validate_admin_token",
"get_admin_db",
"get_token_repository",
"get_document_repository",
"get_annotation_repository",
"get_dataset_repository",
"get_training_task_repository",
"get_model_version_repository",
"get_batch_upload_repository",
"AdminTokenDep",
"AdminDBDep",
"TokenRepoDep",
"DocumentRepoDep",
"AnnotationRepoDep",
"DatasetRepoDep",
"TrainingTaskRepoDep",
"ModelVersionRepoDep",
"BatchUploadRepoDep",
"RateLimiter",
"start_scheduler",
"stop_scheduler",
@@ -25,4 +55,7 @@ __all__ = [
"start_autolabel_scheduler",
"stop_autolabel_scheduler",
"get_autolabel_scheduler",
"TaskRunner",
"TaskStatus",
"TaskManager",
]

View File

@@ -1,40 +1,39 @@
"""
Admin Authentication
FastAPI dependencies for admin token authentication.
FastAPI dependencies for admin token authentication and repository access.
"""
import logging
from functools import lru_cache
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from inference.data.admin_db import AdminDB
from inference.data.database import get_session_context
logger = logging.getLogger(__name__)
# Global AdminDB instance
_admin_db: AdminDB | None = None
from inference.data.repositories import (
TokenRepository,
DocumentRepository,
AnnotationRepository,
DatasetRepository,
TrainingTaskRepository,
ModelVersionRepository,
BatchUploadRepository,
)
def get_admin_db() -> AdminDB:
"""Get the AdminDB instance."""
global _admin_db
if _admin_db is None:
_admin_db = AdminDB()
return _admin_db
@lru_cache(maxsize=1)
def get_token_repository() -> TokenRepository:
"""Get the TokenRepository instance (thread-safe singleton)."""
return TokenRepository()
def reset_admin_db() -> None:
"""Reset the AdminDB instance (for testing)."""
global _admin_db
_admin_db = None
def reset_token_repository() -> None:
"""Reset the TokenRepository instance (for testing)."""
get_token_repository.cache_clear()
async def validate_admin_token(
x_admin_token: Annotated[str | None, Header()] = None,
admin_db: AdminDB = Depends(get_admin_db),
token_repo: TokenRepository = Depends(get_token_repository),
) -> str:
"""Validate admin token from header."""
if not x_admin_token:
@@ -43,18 +42,74 @@ async def validate_admin_token(
detail="Admin token required. Provide X-Admin-Token header.",
)
if not admin_db.is_valid_admin_token(x_admin_token):
if not token_repo.is_valid(x_admin_token):
raise HTTPException(
status_code=401,
detail="Invalid or expired admin token.",
)
# Update last used timestamp
admin_db.update_admin_token_usage(x_admin_token)
token_repo.update_usage(x_admin_token)
return x_admin_token
# Type alias for dependency injection
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]
TokenRepoDep = Annotated[TokenRepository, Depends(get_token_repository)]
@lru_cache(maxsize=1)
def get_document_repository() -> DocumentRepository:
"""Get the DocumentRepository instance (thread-safe singleton)."""
return DocumentRepository()
@lru_cache(maxsize=1)
def get_annotation_repository() -> AnnotationRepository:
"""Get the AnnotationRepository instance (thread-safe singleton)."""
return AnnotationRepository()
@lru_cache(maxsize=1)
def get_dataset_repository() -> DatasetRepository:
"""Get the DatasetRepository instance (thread-safe singleton)."""
return DatasetRepository()
@lru_cache(maxsize=1)
def get_training_task_repository() -> TrainingTaskRepository:
"""Get the TrainingTaskRepository instance (thread-safe singleton)."""
return TrainingTaskRepository()
@lru_cache(maxsize=1)
def get_model_version_repository() -> ModelVersionRepository:
"""Get the ModelVersionRepository instance (thread-safe singleton)."""
return ModelVersionRepository()
@lru_cache(maxsize=1)
def get_batch_upload_repository() -> BatchUploadRepository:
"""Get the BatchUploadRepository instance (thread-safe singleton)."""
return BatchUploadRepository()
def reset_all_repositories() -> None:
"""Reset all repository instances (for testing)."""
get_token_repository.cache_clear()
get_document_repository.cache_clear()
get_annotation_repository.cache_clear()
get_dataset_repository.cache_clear()
get_training_task_repository.cache_clear()
get_model_version_repository.cache_clear()
get_batch_upload_repository.cache_clear()
# Repository dependency type aliases
DocumentRepoDep = Annotated[DocumentRepository, Depends(get_document_repository)]
AnnotationRepoDep = Annotated[AnnotationRepository, Depends(get_annotation_repository)]
DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]

View File

@@ -8,7 +8,8 @@ import logging
import threading
from pathlib import Path
from inference.data.admin_db import AdminDB
from inference.data.repositories import DocumentRepository, AnnotationRepository
from inference.web.core.task_interface import TaskRunner, TaskStatus
from inference.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
@@ -18,7 +19,7 @@ from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
class AutoLabelScheduler:
class AutoLabelScheduler(TaskRunner):
"""Scheduler for auto-labeling tasks."""
def __init__(
@@ -47,39 +48,73 @@ class AutoLabelScheduler:
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
self._lock = threading.Lock()
self._doc_repo = DocumentRepository()
self._ann_repo = AnnotationRepository()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("AutoLabel scheduler stopped")
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "autolabel_scheduler"
@property
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
def get_status(self) -> TaskStatus:
"""Get current status of the scheduler."""
try:
pending_docs = get_pending_autolabel_documents(limit=1000)
pending_count = len(pending_docs)
except Exception:
pending_count = 0
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=pending_count,
processing_count=1 if self._running else 0,
)
def start(self) -> None:
"""Start the scheduler."""
with self._lock:
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self, timeout: float | None = None) -> None:
"""Stop the scheduler.
Args:
timeout: Maximum time to wait for graceful shutdown.
If None, uses default of 5 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
thread_to_join = self._thread
effective_timeout = timeout if timeout is not None else 5.0
if thread_to_join:
thread_to_join.join(timeout=effective_timeout)
with self._lock:
self._thread = None
logger.info("AutoLabel scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
@@ -94,9 +129,7 @@ class AutoLabelScheduler:
def _process_pending_documents(self) -> None:
"""Check and process pending auto-label documents."""
try:
documents = get_pending_autolabel_documents(
self._db, limit=self._batch_size
)
documents = get_pending_autolabel_documents(limit=self._batch_size)
if not documents:
return
@@ -110,8 +143,9 @@ class AutoLabelScheduler:
try:
result = process_document_autolabel(
document=doc,
db=self._db,
output_dir=self._output_dir,
doc_repo=self._doc_repo,
ann_repo=self._ann_repo,
)
if result.get("success"):
@@ -136,13 +170,21 @@ class AutoLabelScheduler:
# Global scheduler instance
_autolabel_scheduler: AutoLabelScheduler | None = None
_autolabel_lock = threading.Lock()
def get_autolabel_scheduler() -> AutoLabelScheduler:
"""Get the auto-label scheduler instance."""
"""Get the auto-label scheduler instance.
Uses double-checked locking pattern for thread safety.
"""
global _autolabel_scheduler
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
with _autolabel_lock:
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
return _autolabel_scheduler

View File

@@ -10,13 +10,20 @@ from datetime import datetime
from pathlib import Path
from typing import Any
from inference.data.admin_db import AdminDB
from inference.data.repositories import (
TrainingTaskRepository,
DatasetRepository,
ModelVersionRepository,
DocumentRepository,
AnnotationRepository,
)
from inference.web.core.task_interface import TaskRunner, TaskStatus
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__)
class TrainingScheduler:
class TrainingScheduler(TaskRunner):
"""Scheduler for training tasks."""
def __init__(
@@ -33,30 +40,73 @@ class TrainingScheduler:
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
self._lock = threading.Lock()
# Repositories
self._training_tasks = TrainingTaskRepository()
self._datasets = DatasetRepository()
self._model_versions = ModelVersionRepository()
self._documents = DocumentRepository()
self._annotations = AnnotationRepository()
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "training_scheduler"
@property
def is_running(self) -> bool:
"""Check if the scheduler is currently active."""
return self._running
def get_status(self) -> TaskStatus:
"""Get current status of the scheduler."""
try:
pending_tasks = self._training_tasks.get_pending()
pending_count = len(pending_tasks)
except Exception:
pending_count = 0
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=pending_count,
processing_count=1 if self._running else 0,
)
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("Training scheduler already running")
return
with self._lock:
if self._running:
logger.warning("Training scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
def stop(self, timeout: float | None = None) -> None:
"""Stop the scheduler.
self._running = False
self._stop_event.set()
Args:
timeout: Maximum time to wait for graceful shutdown.
If None, uses default of 5 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._running:
return
if self._thread:
self._thread.join(timeout=5)
self._running = False
self._stop_event.set()
thread_to_join = self._thread
effective_timeout = timeout if timeout is not None else 5.0
if thread_to_join:
thread_to_join.join(timeout=effective_timeout)
with self._lock:
self._thread = None
logger.info("Training scheduler stopped")
@@ -75,7 +125,7 @@ class TrainingScheduler:
def _check_pending_tasks(self) -> None:
"""Check and execute pending training tasks."""
try:
tasks = self._db.get_pending_training_tasks()
tasks = self._training_tasks.get_pending()
for task in tasks:
task_id = str(task.task_id)
@@ -91,7 +141,7 @@ class TrainingScheduler:
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.update_training_task_status(
self._training_tasks.update_status(
task_id=task_id,
status="failed",
error_message=str(e),
@@ -105,12 +155,12 @@ class TrainingScheduler:
) -> None:
"""Execute a training task."""
# Update status to running
self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started")
self._training_tasks.update_status(task_id, "running")
self._training_tasks.add_log(task_id, "INFO", "Training task started")
# Update dataset training status to running
if dataset_id:
self._db.update_dataset_training_status(
self._datasets.update_training_status(
dataset_id,
training_status="running",
active_training_task_id=task_id,
@@ -137,7 +187,7 @@ class TrainingScheduler:
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
)
@@ -147,12 +197,12 @@ class TrainingScheduler:
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
dataset = self._datasets.get(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
dataset_path = Path(dataset.dataset_path)
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
)
@@ -162,7 +212,7 @@ class TrainingScheduler:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
dataset_path = Path(data_yaml).parent
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
@@ -173,7 +223,7 @@ class TrainingScheduler:
task_id, dataset_path, augmentation_config, augmentation_multiplier
)
if aug_result:
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Augmentation complete: {aug_result['augmented_images']} new images "
f"(total: {aug_result['total_images']})",
@@ -193,17 +243,17 @@ class TrainingScheduler:
)
# Update task with results
self._db.update_training_task_status(
self._training_tasks.update_status(
task_id=task_id,
status="completed",
result_metrics=result.get("metrics"),
model_path=result.get("model_path"),
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
self._training_tasks.add_log(task_id, "INFO", "Training completed successfully")
# Update dataset training status to completed and main status to trained
if dataset_id:
self._db.update_dataset_training_status(
self._datasets.update_training_status(
dataset_id,
training_status="completed",
active_training_task_id=None,
@@ -220,10 +270,10 @@ class TrainingScheduler:
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
# Update dataset training status to failed
if dataset_id:
self._db.update_dataset_training_status(
self._datasets.update_training_status(
dataset_id,
training_status="failed",
active_training_task_id=None,
@@ -245,11 +295,11 @@ class TrainingScheduler:
return
# Get task info for name
task = self._db.get_training_task(task_id)
task = self._training_tasks.get(task_id)
task_name = task.name if task else f"Task {task_id[:8]}"
# Generate version number based on existing versions
existing_versions = self._db.get_model_versions(limit=1, offset=0)
existing_versions = self._model_versions.get_paginated(limit=1, offset=0)
version_count = existing_versions[1] if existing_versions else 0
version = f"v{version_count + 1}.0"
@@ -268,12 +318,12 @@ class TrainingScheduler:
# Get document count from dataset if available
document_count = 0
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
dataset = self._datasets.get(dataset_id)
if dataset:
document_count = dataset.total_documents
# Create model version
model_version = self._db.create_model_version(
model_version = self._model_versions.create(
version=version,
name=task_name,
model_path=str(model_path),
@@ -294,14 +344,14 @@ class TrainingScheduler:
f"from training task {task_id}"
)
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Model version {version} created (mAP: {mAP_display})",
)
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
@@ -316,16 +366,16 @@ class TrainingScheduler:
storage = get_storage_helper()
# Get all labeled documents
documents = self._db.get_labeled_documents_for_export()
documents = self._documents.get_labeled_for_export()
if not documents:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
self._training_tasks.add_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory using StorageHelper
training_base = storage.get_training_data_path()
if training_base is None:
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
self._training_tasks.add_log(task_id, "ERROR", "Storage not configured for local access")
return None
export_dir = training_base / task_id
export_dir.mkdir(parents=True, exist_ok=True)
@@ -348,7 +398,7 @@ class TrainingScheduler:
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = self._db.get_annotations_for_document(str(doc.document_id))
annotations = self._annotations.get_for_document(str(doc.document_id))
if not annotations:
continue
@@ -412,7 +462,7 @@ names: {list(FIELD_CLASSES.values())}
# Create log callback that writes to DB
def log_callback(level: str, message: str) -> None:
self._db.add_training_log(task_id, level, message)
self._training_tasks.add_log(task_id, level, message)
# Create shared training config
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
@@ -468,7 +518,7 @@ names: {list(FIELD_CLASSES.values())}
try:
from shared.augmentation import DatasetAugmenter
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "INFO",
f"Applying augmentation with multiplier={multiplier}",
)
@@ -480,7 +530,7 @@ names: {list(FIELD_CLASSES.values())}
except Exception as e:
logger.error(f"Augmentation failed for task {task_id}: {e}")
self._db.add_training_log(
self._training_tasks.add_log(
task_id, "WARNING",
f"Augmentation failed: {e}. Continuing with original dataset.",
)
@@ -489,13 +539,21 @@ names: {list(FIELD_CLASSES.values())}
# Global scheduler instance
_scheduler: TrainingScheduler | None = None
_scheduler_lock = threading.Lock()
def get_training_scheduler() -> TrainingScheduler:
"""Get the training scheduler instance."""
"""Get the training scheduler instance.
Uses double-checked locking pattern for thread safety.
"""
global _scheduler
if _scheduler is None:
_scheduler = TrainingScheduler()
with _scheduler_lock:
if _scheduler is None:
_scheduler = TrainingScheduler()
return _scheduler

View File

@@ -0,0 +1,161 @@
"""Unified task management interface.
Provides abstract base class for all task runners (schedulers and queues)
and a TaskManager facade for unified lifecycle management.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass(frozen=True)
class TaskStatus:
"""Status of a task runner.
Attributes:
name: Unique identifier for the runner.
is_running: Whether the runner is currently active.
pending_count: Number of tasks waiting to be processed.
processing_count: Number of tasks currently being processed.
error: Optional error message if runner is in error state.
"""
name: str
is_running: bool
pending_count: int
processing_count: int
error: str | None = None
class TaskRunner(ABC):
"""Abstract base class for all task runners.
All schedulers and task queues should implement this interface
to enable unified lifecycle management and monitoring.
Note:
Implementations may have different `start()` signatures based on
their initialization needs (e.g., handler functions, services).
Use the implementation-specific start methods for initialization,
and use TaskManager for unified status monitoring.
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this runner."""
pass
@abstractmethod
def start(self, *args, **kwargs) -> None:
"""Start the task runner.
Should be idempotent - calling start on an already running
runner should have no effect.
Note:
Implementations may require additional parameters (handlers,
services). See implementation-specific documentation.
"""
pass
@abstractmethod
def stop(self, timeout: float | None = None) -> None:
"""Stop the task runner gracefully.
Args:
timeout: Maximum time to wait for graceful shutdown in seconds.
If None, use implementation default.
"""
pass
@property
@abstractmethod
def is_running(self) -> bool:
"""Check if the runner is currently active."""
pass
@abstractmethod
def get_status(self) -> TaskStatus:
"""Get current status of the runner.
Returns:
TaskStatus with current state information.
"""
pass
class TaskManager:
"""Unified manager for all task runners.
Provides centralized lifecycle management and monitoring
for all registered task runners.
"""
def __init__(self) -> None:
"""Initialize the task manager."""
self._runners: dict[str, TaskRunner] = {}
def register(self, runner: TaskRunner) -> None:
"""Register a task runner.
Args:
runner: TaskRunner instance to register.
"""
self._runners[runner.name] = runner
def get_runner(self, name: str) -> TaskRunner | None:
"""Get a specific runner by name.
Args:
name: Name of the runner to retrieve.
Returns:
TaskRunner if found, None otherwise.
"""
return self._runners.get(name)
@property
def runner_names(self) -> list[str]:
"""Get names of all registered runners.
Returns:
List of runner names.
"""
return list(self._runners.keys())
def start_all(self) -> None:
"""Start all registered runners that support no-argument start.
Note:
Runners requiring initialization parameters (like AsyncTaskQueue
or BatchTaskQueue) should be started individually before
registering with TaskManager.
"""
for runner in self._runners.values():
try:
runner.start()
except TypeError:
# Runner requires arguments - skip (should be started individually)
pass
def stop_all(self, timeout: float = 30.0) -> None:
"""Stop all registered runners gracefully.
Args:
timeout: Total timeout to distribute across all runners.
"""
if not self._runners:
return
per_runner_timeout = timeout / len(self._runners)
for runner in self._runners.values():
runner.stop(timeout=per_runner_timeout)
def get_all_status(self) -> dict[str, TaskStatus]:
"""Get status of all registered runners.
Returns:
Dict mapping runner names to their status.
"""
return {name: runner.get_status() for name, runner in self._runners.items()}

View File

@@ -11,7 +11,7 @@ import numpy as np
from fastapi import HTTPException
from PIL import Image
from inference.data.admin_db import AdminDB
from inference.data.repositories import DocumentRepository, DatasetRepository
from inference.web.schemas.admin.augmentation import (
AugmentationBatchResponse,
AugmentationConfigSchema,
@@ -32,9 +32,14 @@ UUID_PATTERN = re.compile(
class AugmentationService:
"""Service for augmentation operations."""
def __init__(self, db: AdminDB) -> None:
"""Initialize service with database connection."""
self.db = db
def __init__(
self,
doc_repo: DocumentRepository | None = None,
dataset_repo: DatasetRepository | None = None,
) -> None:
"""Initialize service with repository connections."""
self.doc_repo = doc_repo or DocumentRepository()
self.dataset_repo = dataset_repo or DatasetRepository()
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
"""
@@ -179,7 +184,7 @@ class AugmentationService:
"""
# Validate source dataset exists
try:
source_dataset = self.db.get_dataset(source_dataset_id)
source_dataset = self.dataset_repo.get(source_dataset_id)
if source_dataset is None:
raise HTTPException(
status_code=404,
@@ -259,7 +264,7 @@ class AugmentationService:
# Get document from database
try:
document = self.db.get_document(document_id)
document = self.doc_repo.get(document_id)
if document is None:
raise HTTPException(
status_code=404,

View File

@@ -12,7 +12,7 @@ import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from inference.data.repositories import DocumentRepository, AnnotationRepository
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken
@@ -45,7 +45,8 @@ class AutoLabelService:
document_id: str,
file_path: str,
field_values: dict[str, str],
db: AdminDB,
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
replace_existing: bool = False,
skip_lock_check: bool = False,
) -> dict[str, Any]:
@@ -56,16 +57,23 @@ class AutoLabelService:
document_id: Document UUID
file_path: Path to document file
field_values: Dict of field_name -> value to match
db: Admin database instance
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
replace_existing: Whether to replace existing auto annotations
skip_lock_check: Skip annotation lock check (for batch processing)
Returns:
Dict with status and annotation count
"""
# Initialize repositories if not provided
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
try:
# Get document info first
document = db.get_document(document_id)
document = doc_repo.get(document_id)
if document is None:
raise ValueError(f"Document not found: {document_id}")
@@ -80,7 +88,7 @@ class AutoLabelService:
)
# Update status to running
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
@@ -88,7 +96,7 @@ class AutoLabelService:
# Delete existing auto annotations if requested
if replace_existing:
deleted = db.delete_annotations_for_document(
deleted = ann_repo.delete_for_document(
document_id=document_id,
source="auto",
)
@@ -101,17 +109,17 @@ class AutoLabelService:
if path.suffix.lower() == ".pdf":
# Process PDF (all pages)
annotations_created = self._process_pdf(
document_id, path, field_values, db
document_id, path, field_values, ann_repo
)
else:
# Process single image
annotations_created = self._process_image(
document_id, path, field_values, db, page_number=1
document_id, path, field_values, ann_repo, page_number=1
)
# Update document status
status = "labeled" if annotations_created > 0 else "pending"
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status=status,
auto_label_status="completed",
@@ -124,7 +132,7 @@ class AutoLabelService:
except Exception as e:
logger.error(f"Auto-labeling failed for {document_id}: {e}")
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
@@ -141,7 +149,7 @@ class AutoLabelService:
document_id: str,
pdf_path: Path,
field_values: dict[str, str],
db: AdminDB,
ann_repo: AnnotationRepository,
) -> int:
"""Process PDF document and create annotations."""
from shared.pdf.renderer import render_pdf_to_images
@@ -172,7 +180,7 @@ class AutoLabelService:
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
ann_repo.create_batch(annotations)
total_annotations += len(annotations)
return total_annotations
@@ -182,7 +190,7 @@ class AutoLabelService:
document_id: str,
image_path: Path,
field_values: dict[str, str],
db: AdminDB,
ann_repo: AnnotationRepository,
page_number: int = 1,
) -> int:
"""Process single image and create annotations."""
@@ -208,7 +216,7 @@ class AutoLabelService:
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
ann_repo.create_batch(annotations)
return len(annotations)

View File

@@ -15,7 +15,7 @@ from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from inference.data.admin_db import AdminDB
from inference.data.repositories import BatchUploadRepository
from shared.fields import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)
@@ -64,13 +64,13 @@ class CSVRowData(BaseModel):
class BatchUploadService:
"""Service for handling batch uploads of documents via ZIP files."""
def __init__(self, admin_db: AdminDB):
def __init__(self, batch_repo: BatchUploadRepository | None = None):
"""Initialize the batch upload service.
Args:
admin_db: Admin database interface
batch_repo: Batch upload repository (created if None)
"""
self.admin_db = admin_db
self.batch_repo = batch_repo or BatchUploadRepository()
def _safe_extract_filename(self, zip_path: str) -> str:
"""Safely extract filename from ZIP path, preventing path traversal.
@@ -170,7 +170,7 @@ class BatchUploadService:
Returns:
Dictionary with batch upload results
"""
batch = self.admin_db.create_batch_upload(
batch = self.batch_repo.create(
admin_token=admin_token,
filename=zip_filename,
file_size=len(zip_content),
@@ -189,7 +189,7 @@ class BatchUploadService:
)
# Update batch upload status
self.admin_db.update_batch_upload(
self.batch_repo.update(
batch_id=batch.batch_id,
status=result["status"],
total_files=result["total_files"],
@@ -208,7 +208,7 @@ class BatchUploadService:
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
self.admin_db.update_batch_upload(
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Invalid ZIP file format",
@@ -222,7 +222,7 @@ class BatchUploadService:
except ValueError as e:
# Security validation errors
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
self.admin_db.update_batch_upload(
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="ZIP file validation failed",
@@ -235,7 +235,7 @@ class BatchUploadService:
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
self.admin_db.update_batch_upload(
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Processing error",
@@ -312,7 +312,7 @@ class BatchUploadService:
filename = self._safe_extract_filename(pdf_info.filename)
# Create batch upload file record
file_record = self.admin_db.create_batch_upload_file(
file_record = self.batch_repo.create_file(
batch_id=batch_id,
filename=filename,
status="processing",
@@ -328,7 +328,7 @@ class BatchUploadService:
# TODO: Save PDF file and create document
# For now, just mark as completed
self.admin_db.update_batch_upload_file(
self.batch_repo.update_file(
file_id=file_record.file_id,
status="completed",
csv_row_data=csv_row_data,
@@ -341,7 +341,7 @@ class BatchUploadService:
# Path validation error
logger.warning(f"Skipping invalid file: {e}")
if file_record:
self.admin_db.update_batch_upload_file(
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Invalid filename",
@@ -352,7 +352,7 @@ class BatchUploadService:
except Exception as e:
logger.error(f"Error processing PDF: {e}", exc_info=True)
if file_record:
self.admin_db.update_batch_upload_file(
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Processing error",
@@ -515,13 +515,13 @@ class BatchUploadService:
Returns:
Batch status dictionary
"""
batch = self.admin_db.get_batch_upload(UUID(batch_id))
batch = self.batch_repo.get(UUID(batch_id))
if not batch:
return {
"error": "Batch upload not found",
}
files = self.admin_db.get_batch_upload_files(batch.batch_id)
files = self.batch_repo.get_files(batch.batch_id)
return {
"batch_id": str(batch.batch_id),

View File

@@ -20,8 +20,16 @@ logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(self, db, base_dir: Path):
self._db = db
def __init__(
self,
datasets_repo,
documents_repo,
annotations_repo,
base_dir: Path,
):
self._datasets_repo = datasets_repo
self._documents_repo = documents_repo
self._annotations_repo = annotations_repo
self._base_dir = Path(base_dir)
def build_dataset(
@@ -54,7 +62,7 @@ class DatasetBuilder:
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._db.update_dataset_status(
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
@@ -71,7 +79,7 @@ class DatasetBuilder:
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._db.get_documents_by_ids(document_ids)
documents = self._documents_repo.get_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
@@ -93,7 +101,7 @@ class DatasetBuilder:
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._db.get_annotations_for_document(doc.document_id)
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
# Group annotations by page
page_annotations: dict[int, list] = {}
@@ -139,7 +147,7 @@ class DatasetBuilder:
})
# 5. Record document-split assignments in DB
self._db.add_dataset_documents(
self._datasets_repo.add_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
@@ -148,7 +156,7 @@ class DatasetBuilder:
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._db.update_dataset_status(
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),

View File

@@ -12,9 +12,9 @@ from pathlib import Path
from typing import Any
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from shared.fields import CSV_TO_CLASS_MAPPING
from inference.data.admin_models import AdminDocument
from inference.data.repositories import DocumentRepository, AnnotationRepository
from shared.data.db import DocumentDB
from inference.web.services.storage_helpers import get_storage_helper
@@ -68,14 +68,12 @@ def convert_csv_field_values_to_row_dict(
def get_pending_autolabel_documents(
db: AdminDB,
limit: int = 10,
) -> list[AdminDocument]:
"""
Get documents pending auto-labeling.
Args:
db: AdminDB instance
limit: Maximum number of documents to return
Returns:
@@ -99,20 +97,22 @@ def get_pending_autolabel_documents(
def process_document_autolabel(
document: AdminDocument,
db: AdminDB,
output_dir: Path | None = None,
dpi: int = DEFAULT_DPI,
min_confidence: float = 0.5,
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
) -> dict[str, Any]:
"""
Process a single document for auto-labeling using csv_field_values.
Args:
document: AdminDocument with csv_field_values and file_path
db: AdminDB instance for updating status
output_dir: Output directory for temp files
dpi: Rendering DPI
min_confidence: Minimum match confidence
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
Returns:
Result dictionary with success status and annotations
@@ -120,6 +120,12 @@ def process_document_autolabel(
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
from shared.pdf import PDFDocument
# Initialize repositories if not provided
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
document_id = str(document.document_id)
file_path = Path(document.file_path)
@@ -132,7 +138,7 @@ def process_document_autolabel(
output_dir.mkdir(parents=True, exist_ok=True)
# Mark as processing
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
@@ -187,10 +193,10 @@ def process_document_autolabel(
except Exception as e:
logger.warning(f"Failed to save report to DocumentDB: {e}")
# Save annotations to AdminDB
# Save annotations to database
if result.get("success") and result.get("report"):
_save_annotations_to_db(
db=db,
ann_repo=ann_repo,
document_id=document_id,
report=result["report"],
page_annotations=result.get("pages", []),
@@ -198,7 +204,7 @@ def process_document_autolabel(
)
# Mark as completed
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="labeled",
auto_label_status="completed",
@@ -206,7 +212,7 @@ def process_document_autolabel(
else:
# Mark as failed
errors = result.get("report", {}).get("errors", ["Unknown error"])
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
@@ -219,7 +225,7 @@ def process_document_autolabel(
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
# Mark as failed
db.update_document_status(
doc_repo.update_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
@@ -234,7 +240,7 @@ def process_document_autolabel(
def _save_annotations_to_db(
db: AdminDB,
ann_repo: AnnotationRepository,
document_id: str,
report: dict[str, Any],
page_annotations: list[dict[str, Any]],
@@ -244,7 +250,7 @@ def _save_annotations_to_db(
Save generated annotations to database.
Args:
db: AdminDB instance
ann_repo: Annotation repository instance
document_id: Document ID
report: AutoLabelReport as dict
page_annotations: List of page annotation data
@@ -353,7 +359,7 @@ def _save_annotations_to_db(
# Create annotation
try:
db.create_annotation(
ann_repo.create(
document_id=document_id,
page_number=page_no,
class_id=class_id,
@@ -379,25 +385,29 @@ def _save_annotations_to_db(
def run_pending_autolabel_batch(
db: AdminDB | None = None,
batch_size: int = 10,
output_dir: Path | None = None,
doc_repo: DocumentRepository | None = None,
ann_repo: AnnotationRepository | None = None,
) -> dict[str, Any]:
"""
Process a batch of pending auto-label documents.
Args:
db: AdminDB instance (created if None)
batch_size: Number of documents to process
output_dir: Output directory for temp files
doc_repo: Document repository (created if None)
ann_repo: Annotation repository (created if None)
Returns:
Summary of processing results
"""
if db is None:
db = AdminDB()
if doc_repo is None:
doc_repo = DocumentRepository()
if ann_repo is None:
ann_repo = AnnotationRepository()
documents = get_pending_autolabel_documents(db, limit=batch_size)
documents = get_pending_autolabel_documents(limit=batch_size)
results = {
"total": len(documents),
@@ -409,8 +419,9 @@ def run_pending_autolabel_batch(
for doc in documents:
result = process_document_autolabel(
document=doc,
db=db,
output_dir=output_dir,
doc_repo=doc_repo,
ann_repo=ann_repo,
)
doc_result = {
@@ -432,7 +443,6 @@ def run_pending_autolabel_batch(
def save_manual_annotations_to_document_db(
document: AdminDocument,
annotations: list,
db: AdminDB,
) -> dict[str, Any]:
"""
Save manual annotations to PostgreSQL documents and field_results tables.
@@ -444,7 +454,6 @@ def save_manual_annotations_to_document_db(
Args:
document: AdminDocument instance
annotations: List of AdminAnnotation instances
db: AdminDB instance
Returns:
Dict with success status and details

View File

@@ -14,6 +14,8 @@ import threading
from threading import Event, Lock, Thread
from typing import Callable
from inference.web.core.task_interface import TaskRunner, TaskStatus
logger = logging.getLogger(__name__)
@@ -29,7 +31,7 @@ class AsyncTask:
priority: int = 0 # Lower = higher priority (not implemented yet)
class AsyncTaskQueue:
class AsyncTaskQueue(TaskRunner):
"""Thread-safe queue for async invoice processing."""
def __init__(
@@ -46,44 +48,78 @@ class AsyncTaskQueue:
self._task_handler: Callable[[AsyncTask], None] | None = None
self._started = False
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "async_task_queue"
@property
def is_running(self) -> bool:
"""Check if the queue is running."""
return self._started and not self._stop_event.is_set()
def get_status(self) -> TaskStatus:
"""Get current status of the queue."""
with self._lock:
processing_count = len(self._processing)
return TaskStatus(
name=self.name,
is_running=self.is_running,
pending_count=self._queue.qsize(),
processing_count=processing_count,
)
def start(self, task_handler: Callable[[AsyncTask], None]) -> None:
"""Start background worker threads."""
if self._started:
logger.warning("AsyncTaskQueue already started")
return
with self._lock:
if self._started:
logger.warning("AsyncTaskQueue already started")
return
self._task_handler = task_handler
self._stop_event.clear()
self._task_handler = task_handler
self._stop_event.clear()
for i in range(self._worker_count):
worker = Thread(
target=self._worker_loop,
name=f"async-worker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started async worker thread: {worker.name}")
for i in range(self._worker_count):
worker = Thread(
target=self._worker_loop,
name=f"async-worker-{i}",
daemon=True,
)
worker.start()
self._workers.append(worker)
logger.info(f"Started async worker thread: {worker.name}")
self._started = True
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
self._started = True
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
def stop(self, timeout: float = 30.0) -> None:
"""Gracefully stop all workers."""
if not self._started:
return
def stop(self, timeout: float | None = None) -> None:
"""Gracefully stop all workers.
logger.info("Stopping AsyncTaskQueue...")
self._stop_event.set()
Args:
timeout: Maximum time to wait for graceful shutdown.
If None, uses default of 30 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._started:
return
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout / self._worker_count)
logger.info("Stopping AsyncTaskQueue...")
self._stop_event.set()
workers_to_join = list(self._workers)
effective_timeout = timeout if timeout is not None else 30.0
# Wait for workers to finish outside the lock
for worker in workers_to_join:
worker.join(timeout=effective_timeout / self._worker_count)
if worker.is_alive():
logger.warning(f"Worker {worker.name} did not stop gracefully")
self._workers.clear()
self._started = False
with self._lock:
self._workers.clear()
self._started = False
logger.info("AsyncTaskQueue stopped")
def submit(self, task: AsyncTask) -> bool:
@@ -115,11 +151,6 @@ class AsyncTaskQueue:
with self._lock:
return request_id in self._processing
@property
def is_running(self) -> bool:
"""Check if the queue is running."""
return self._started and not self._stop_event.is_set()
def _worker_loop(self) -> None:
"""Worker loop that processes tasks from queue."""
thread_name = threading.current_thread().name

View File

@@ -12,6 +12,8 @@ from queue import Queue, Full, Empty
from typing import Any
from uuid import UUID
from inference.web.core.task_interface import TaskRunner, TaskStatus
logger = logging.getLogger(__name__)
@@ -28,7 +30,7 @@ class BatchTask:
created_at: datetime
class BatchTaskQueue:
class BatchTaskQueue(TaskRunner):
"""Thread-safe queue for async batch upload processing."""
def __init__(self, max_size: int = 20, worker_count: int = 2):
@@ -45,6 +47,29 @@ class BatchTaskQueue:
self._batch_service: Any | None = None
self._running = False
self._lock = threading.Lock()
self._processing: set[UUID] = set() # Currently processing batch_ids
@property
def name(self) -> str:
"""Unique identifier for this runner."""
return "batch_task_queue"
@property
def is_running(self) -> bool:
"""Check if queue is running."""
return self._running
def get_status(self) -> TaskStatus:
"""Get current status of the queue."""
with self._lock:
processing_count = len(self._processing)
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=self._queue.qsize(),
processing_count=processing_count,
)
def start(self, batch_service: Any) -> None:
"""Start worker threads with batch service.
@@ -73,12 +98,14 @@ class BatchTaskQueue:
logger.info(f"Started {self._worker_count} batch workers")
def stop(self, timeout: float = 30.0) -> None:
def stop(self, timeout: float | None = None) -> None:
"""Stop all worker threads gracefully.
Args:
timeout: Maximum time to wait for workers to finish
timeout: Maximum time to wait for workers to finish.
If None, uses default of 30 seconds.
"""
# Minimize lock scope to avoid potential deadlock
with self._lock:
if not self._running:
return
@@ -86,13 +113,17 @@ class BatchTaskQueue:
logger.info("Stopping batch queue...")
self._stop_event.set()
self._running = False
workers_to_join = list(self._workers)
# Wait for workers to finish
for worker in self._workers:
worker.join(timeout=timeout)
effective_timeout = timeout if timeout is not None else 30.0
# Wait for workers to finish outside the lock
for worker in workers_to_join:
worker.join(timeout=effective_timeout)
with self._lock:
self._workers.clear()
logger.info("Batch queue stopped")
logger.info("Batch queue stopped")
def submit(self, task: BatchTask) -> bool:
"""Submit a batch task to the queue.
@@ -119,15 +150,6 @@ class BatchTaskQueue:
"""
return self._queue.qsize()
@property
def is_running(self) -> bool:
"""Check if queue is running.
Returns:
True if queue is active
"""
return self._running
def _worker_loop(self) -> None:
"""Worker thread main loop."""
worker_name = threading.current_thread().name
@@ -157,6 +179,9 @@ class BatchTaskQueue:
logger.error("Batch service not initialized, cannot process task")
return
with self._lock:
self._processing.add(task.batch_id)
logger.info(
f"Processing batch task: batch_id={task.batch_id}, "
f"filename={task.zip_filename}"
@@ -183,6 +208,9 @@ class BatchTaskQueue:
f"Error processing batch task {task.batch_id}: {e}",
exc_info=True,
)
finally:
with self._lock:
self._processing.discard(task.batch_id)
# Global batch queue instance