WIP
This commit is contained in:
File diff suppressed because it is too large
Load Diff
26
packages/inference/inference/data/repositories/__init__.py
Normal file
26
packages/inference/inference/data/repositories/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Repository Pattern Implementation
|
||||
|
||||
Provides domain-specific repository classes to replace the monolithic AdminDB.
|
||||
Each repository handles a single domain following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
"TokenRepository",
|
||||
"DocumentRepository",
|
||||
"AnnotationRepository",
|
||||
"TrainingTaskRepository",
|
||||
"DatasetRepository",
|
||||
"ModelVersionRepository",
|
||||
"BatchUploadRepository",
|
||||
]
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Annotation Repository
|
||||
|
||||
Handles annotation operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationRepository(BaseRepository[AdminAnnotation]):
|
||||
"""Repository for annotation management.
|
||||
|
||||
Handles:
|
||||
- Annotation CRUD operations
|
||||
- Batch annotation creation
|
||||
- Annotation verification
|
||||
- Annotation override tracking
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
class_id: int,
|
||||
class_name: str,
|
||||
x_center: float,
|
||||
y_center: float,
|
||||
width: float,
|
||||
height: float,
|
||||
bbox_x: int,
|
||||
bbox_y: int,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
text_value: str | None = None,
|
||||
confidence: float | None = None,
|
||||
source: str = "manual",
|
||||
) -> str:
|
||||
"""Create a new annotation.
|
||||
|
||||
Returns:
|
||||
Annotation ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = AdminAnnotation(
|
||||
document_id=UUID(document_id),
|
||||
page_number=page_number,
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=bbox_x,
|
||||
bbox_y=bbox_y,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
text_value=text_value,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
session.add(annotation)
|
||||
session.flush()
|
||||
return str(annotation.annotation_id)
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
annotations: list[dict[str, Any]],
|
||||
) -> list[str]:
|
||||
"""Create multiple annotations in a batch.
|
||||
|
||||
Args:
|
||||
annotations: List of annotation data dicts
|
||||
|
||||
Returns:
|
||||
List of annotation IDs
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
ids = []
|
||||
for ann_data in annotations:
|
||||
annotation = AdminAnnotation(
|
||||
document_id=UUID(ann_data["document_id"]),
|
||||
page_number=ann_data.get("page_number", 1),
|
||||
class_id=ann_data["class_id"],
|
||||
class_name=ann_data["class_name"],
|
||||
x_center=ann_data["x_center"],
|
||||
y_center=ann_data["y_center"],
|
||||
width=ann_data["width"],
|
||||
height=ann_data["height"],
|
||||
bbox_x=ann_data["bbox_x"],
|
||||
bbox_y=ann_data["bbox_y"],
|
||||
bbox_width=ann_data["bbox_width"],
|
||||
bbox_height=ann_data["bbox_height"],
|
||||
text_value=ann_data.get("text_value"),
|
||||
confidence=ann_data.get("confidence"),
|
||||
source=ann_data.get("source", "auto"),
|
||||
)
|
||||
session.add(annotation)
|
||||
session.flush()
|
||||
ids.append(str(annotation.annotation_id))
|
||||
return ids
|
||||
|
||||
def get(self, annotation_id: str) -> AdminAnnotation | None:
|
||||
"""Get an annotation by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_for_document(
|
||||
self,
|
||||
document_id: str,
|
||||
page_number: int | None = None,
|
||||
) -> list[AdminAnnotation]:
|
||||
"""Get all annotations for a document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
if page_number is not None:
|
||||
statement = statement.where(AdminAnnotation.page_number == page_number)
|
||||
statement = statement.order_by(AdminAnnotation.class_id)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def update(
|
||||
self,
|
||||
annotation_id: str,
|
||||
x_center: float | None = None,
|
||||
y_center: float | None = None,
|
||||
width: float | None = None,
|
||||
height: float | None = None,
|
||||
bbox_x: int | None = None,
|
||||
bbox_y: int | None = None,
|
||||
bbox_width: int | None = None,
|
||||
bbox_height: int | None = None,
|
||||
text_value: str | None = None,
|
||||
class_id: int | None = None,
|
||||
class_name: str | None = None,
|
||||
) -> bool:
|
||||
"""Update an annotation.
|
||||
|
||||
Returns:
|
||||
True if updated, False if not found
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if annotation:
|
||||
if x_center is not None:
|
||||
annotation.x_center = x_center
|
||||
if y_center is not None:
|
||||
annotation.y_center = y_center
|
||||
if width is not None:
|
||||
annotation.width = width
|
||||
if height is not None:
|
||||
annotation.height = height
|
||||
if bbox_x is not None:
|
||||
annotation.bbox_x = bbox_x
|
||||
if bbox_y is not None:
|
||||
annotation.bbox_y = bbox_y
|
||||
if bbox_width is not None:
|
||||
annotation.bbox_width = bbox_width
|
||||
if bbox_height is not None:
|
||||
annotation.bbox_height = bbox_height
|
||||
if text_value is not None:
|
||||
annotation.text_value = text_value
|
||||
if class_id is not None:
|
||||
annotation.class_id = class_id
|
||||
if class_name is not None:
|
||||
annotation.class_name = class_name
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
session.add(annotation)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, annotation_id: str) -> bool:
|
||||
"""Delete an annotation."""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if annotation:
|
||||
session.delete(annotation)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_for_document(
|
||||
self,
|
||||
document_id: str,
|
||||
source: str | None = None,
|
||||
) -> int:
|
||||
"""Delete all annotations for a document.
|
||||
|
||||
Returns:
|
||||
Count of deleted annotations
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
if source:
|
||||
statement = statement.where(AdminAnnotation.source == source)
|
||||
annotations = session.exec(statement).all()
|
||||
count = len(annotations)
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
return count
|
||||
|
||||
def verify(
|
||||
self,
|
||||
annotation_id: str,
|
||||
admin_token: str,
|
||||
) -> AdminAnnotation | None:
|
||||
"""Mark an annotation as verified."""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if not annotation:
|
||||
return None
|
||||
|
||||
annotation.is_verified = True
|
||||
annotation.verified_at = datetime.utcnow()
|
||||
annotation.verified_by = admin_token
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(annotation)
|
||||
session.commit()
|
||||
session.refresh(annotation)
|
||||
session.expunge(annotation)
|
||||
return annotation
|
||||
|
||||
def override(
|
||||
self,
|
||||
annotation_id: str,
|
||||
admin_token: str,
|
||||
change_reason: str | None = None,
|
||||
**updates: Any,
|
||||
) -> AdminAnnotation | None:
|
||||
"""Override an auto-generated annotation.
|
||||
|
||||
Creates a history record and updates the annotation.
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if not annotation:
|
||||
return None
|
||||
|
||||
previous_value = {
|
||||
"class_id": annotation.class_id,
|
||||
"class_name": annotation.class_name,
|
||||
"bbox": {
|
||||
"x": annotation.bbox_x,
|
||||
"y": annotation.bbox_y,
|
||||
"width": annotation.bbox_width,
|
||||
"height": annotation.bbox_height,
|
||||
},
|
||||
"normalized": {
|
||||
"x_center": annotation.x_center,
|
||||
"y_center": annotation.y_center,
|
||||
"width": annotation.width,
|
||||
"height": annotation.height,
|
||||
},
|
||||
"text_value": annotation.text_value,
|
||||
"confidence": annotation.confidence,
|
||||
"source": annotation.source,
|
||||
}
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(annotation, key):
|
||||
setattr(annotation, key, value)
|
||||
|
||||
if annotation.source == "auto":
|
||||
annotation.override_source = "auto"
|
||||
annotation.source = "manual"
|
||||
|
||||
annotation.updated_at = datetime.utcnow()
|
||||
session.add(annotation)
|
||||
|
||||
history = AnnotationHistory(
|
||||
annotation_id=UUID(annotation_id),
|
||||
document_id=annotation.document_id,
|
||||
action="override",
|
||||
previous_value=previous_value,
|
||||
new_value=updates,
|
||||
changed_by=admin_token,
|
||||
change_reason=change_reason,
|
||||
)
|
||||
session.add(history)
|
||||
|
||||
session.commit()
|
||||
session.refresh(annotation)
|
||||
session.expunge(annotation)
|
||||
return annotation
|
||||
|
||||
def create_history(
|
||||
self,
|
||||
annotation_id: UUID,
|
||||
document_id: UUID,
|
||||
action: str,
|
||||
previous_value: dict[str, Any] | None = None,
|
||||
new_value: dict[str, Any] | None = None,
|
||||
changed_by: str | None = None,
|
||||
change_reason: str | None = None,
|
||||
) -> AnnotationHistory:
|
||||
"""Create an annotation history record."""
|
||||
with get_session_context() as session:
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
document_id=document_id,
|
||||
action=action,
|
||||
previous_value=previous_value,
|
||||
new_value=new_value,
|
||||
changed_by=changed_by,
|
||||
change_reason=change_reason,
|
||||
)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
session.refresh(history)
|
||||
session.expunge(history)
|
||||
return history
|
||||
|
||||
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
|
||||
"""Get history for a specific annotation."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AnnotationHistory).where(
|
||||
AnnotationHistory.annotation_id == annotation_id
|
||||
).order_by(AnnotationHistory.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
|
||||
"""Get all annotation history for a document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AnnotationHistory).where(
|
||||
AnnotationHistory.document_id == document_id
|
||||
).order_by(AnnotationHistory.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
75
packages/inference/inference/data/repositories/base.py
Normal file
75
packages/inference/inference/data/repositories/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Base Repository
|
||||
|
||||
Provides common functionality for all repositories.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator, TypeVar, Generic
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseRepository(ABC, Generic[T]):
|
||||
"""Base class for all repositories.
|
||||
|
||||
Provides:
|
||||
- Session management via context manager
|
||||
- Logging infrastructure
|
||||
- Common query patterns
|
||||
- Utility methods for datetime and UUID handling
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _session(self) -> Generator[Session, None, None]:
|
||||
"""Get a database session with auto-commit/rollback."""
|
||||
with get_session_context() as session:
|
||||
yield session
|
||||
|
||||
def _expunge(self, session: Session, entity: T) -> T:
|
||||
"""Detach entity from session for safe return."""
|
||||
session.expunge(entity)
|
||||
return entity
|
||||
|
||||
def _expunge_all(self, session: Session, entities: list[T]) -> list[T]:
|
||||
"""Detach multiple entities from session."""
|
||||
for entity in entities:
|
||||
session.expunge(entity)
|
||||
return entities
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime.
|
||||
|
||||
Use this instead of datetime.utcnow() which is deprecated in Python 3.12+.
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _validate_uuid(value: str, field_name: str = "id") -> UUID:
|
||||
"""Validate and convert string to UUID.
|
||||
|
||||
Args:
|
||||
value: String to convert to UUID
|
||||
field_name: Name of field for error message
|
||||
|
||||
Returns:
|
||||
Validated UUID
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid UUID
|
||||
"""
|
||||
try:
|
||||
return UUID(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid {field_name}: {value}") from e
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Batch Upload Repository
|
||||
|
||||
Handles batch upload operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchUploadRepository(BaseRepository[BatchUpload]):
|
||||
"""Repository for batch upload management.
|
||||
|
||||
Handles:
|
||||
- Batch upload CRUD operations
|
||||
- Batch file tracking
|
||||
- Progress monitoring
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
admin_token: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
upload_source: str = "ui",
|
||||
) -> BatchUpload:
|
||||
"""Create a new batch upload record."""
|
||||
with get_session_context() as session:
|
||||
batch = BatchUpload(
|
||||
admin_token=admin_token,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
upload_source=upload_source,
|
||||
)
|
||||
session.add(batch)
|
||||
session.commit()
|
||||
session.refresh(batch)
|
||||
session.expunge(batch)
|
||||
return batch
|
||||
|
||||
def get(self, batch_id: UUID) -> BatchUpload | None:
|
||||
"""Get batch upload by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(BatchUpload, batch_id)
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update batch upload fields."""
|
||||
with get_session_context() as session:
|
||||
batch = session.get(BatchUpload, batch_id)
|
||||
if batch:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(batch, key):
|
||||
setattr(batch, key, value)
|
||||
session.add(batch)
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
batch_id: UUID,
|
||||
filename: str,
|
||||
**kwargs: Any,
|
||||
) -> BatchUploadFile:
|
||||
"""Create a batch upload file record."""
|
||||
with get_session_context() as session:
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename=filename,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(file_record)
|
||||
session.commit()
|
||||
session.refresh(file_record)
|
||||
session.expunge(file_record)
|
||||
return file_record
|
||||
|
||||
def update_file(
|
||||
self,
|
||||
file_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update batch upload file fields."""
|
||||
with get_session_context() as session:
|
||||
file_record = session.get(BatchUploadFile, file_id)
|
||||
if file_record:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(file_record, key):
|
||||
setattr(file_record, key, value)
|
||||
session.add(file_record)
|
||||
|
||||
def get_files(self, batch_id: UUID) -> list[BatchUploadFile]:
|
||||
"""Get all files for a batch upload."""
|
||||
with get_session_context() as session:
|
||||
statement = select(BatchUploadFile).where(
|
||||
BatchUploadFile.batch_id == batch_id
|
||||
).order_by(BatchUploadFile.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[BatchUpload], int]:
|
||||
"""Get paginated batch uploads."""
|
||||
with get_session_context() as session:
|
||||
count_stmt = select(func.count()).select_from(BatchUpload)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(BatchUpload).order_by(
|
||||
BatchUpload.created_at.desc()
|
||||
).offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Dataset Repository
|
||||
|
||||
Handles training dataset operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRepository(BaseRepository[TrainingDataset]):
|
||||
"""Repository for training dataset management.
|
||||
|
||||
Handles:
|
||||
- Dataset CRUD operations
|
||||
- Dataset status management
|
||||
- Dataset document linking
|
||||
- Training status tracking
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> TrainingDataset:
|
||||
"""Create a new training dataset."""
|
||||
with get_session_context() as session:
|
||||
dataset = TrainingDataset(
|
||||
name=name,
|
||||
description=description,
|
||||
train_ratio=train_ratio,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed,
|
||||
)
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
session.refresh(dataset)
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get(self, dataset_id: str | UUID) -> TrainingDataset | None:
|
||||
"""Get a dataset by ID."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if dataset:
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TrainingDataset], int]:
|
||||
"""List datasets with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(TrainingDataset)
|
||||
count_query = select(func.count()).select_from(TrainingDataset)
|
||||
if status:
|
||||
query = query.where(TrainingDataset.status == status)
|
||||
count_query = count_query.where(TrainingDataset.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
datasets = session.exec(
|
||||
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for d in datasets:
|
||||
session.expunge(d)
|
||||
return list(datasets), total
|
||||
|
||||
def get_active_training_tasks(
|
||||
self, dataset_ids: list[str]
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Get active training tasks for datasets.
|
||||
|
||||
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
|
||||
"""
|
||||
if not dataset_ids:
|
||||
return {}
|
||||
|
||||
valid_uuids = []
|
||||
for d in dataset_ids:
|
||||
try:
|
||||
valid_uuids.append(UUID(d))
|
||||
except ValueError:
|
||||
logger.warning("Invalid UUID in get_active_training_tasks: %s", d)
|
||||
continue
|
||||
|
||||
if not valid_uuids:
|
||||
return {}
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.dataset_id.in_(valid_uuids),
|
||||
TrainingTask.status.in_(["pending", "scheduled", "running"]),
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
return {
|
||||
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
|
||||
for t in results
|
||||
}
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
total_documents: int | None = None,
|
||||
total_images: int | None = None,
|
||||
total_annotations: int | None = None,
|
||||
dataset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Update dataset status and optional totals."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.status = status
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
dataset.error_message = error_message
|
||||
if total_documents is not None:
|
||||
dataset.total_documents = total_documents
|
||||
if total_images is not None:
|
||||
dataset.total_images = total_images
|
||||
if total_annotations is not None:
|
||||
dataset.total_annotations = total_annotations
|
||||
if dataset_path is not None:
|
||||
dataset.dataset_path = dataset_path
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def update_training_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
training_status: str | None,
|
||||
active_training_task_id: str | UUID | None = None,
|
||||
update_main_status: bool = False,
|
||||
) -> None:
|
||||
"""Update dataset training status."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.training_status = training_status
|
||||
dataset.active_training_task_id = (
|
||||
UUID(str(active_training_task_id)) if active_training_task_id else None
|
||||
)
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
if update_main_status and training_status == "completed":
|
||||
dataset.status = "trained"
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Batch insert documents into a dataset.
|
||||
|
||||
Each dict: {document_id, split, page_count, annotation_count}
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
for doc in documents:
|
||||
dd = DatasetDocument(
|
||||
dataset_id=UUID(str(dataset_id)),
|
||||
document_id=UUID(str(doc["document_id"])),
|
||||
split=doc["split"],
|
||||
page_count=doc.get("page_count", 0),
|
||||
annotation_count=doc.get("annotation_count", 0),
|
||||
)
|
||||
session.add(dd)
|
||||
session.commit()
|
||||
|
||||
def get_documents(self, dataset_id: str | UUID) -> list[DatasetDocument]:
|
||||
"""Get all documents in a dataset."""
|
||||
with get_session_context() as session:
|
||||
results = session.exec(
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def delete(self, dataset_id: str | UUID) -> bool:
|
||||
"""Delete a dataset and its document links."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return False
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Document Repository
|
||||
|
||||
Handles document operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentRepository(BaseRepository[AdminDocument]):
|
||||
"""Repository for document management.
|
||||
|
||||
Handles:
|
||||
- Document CRUD operations
|
||||
- Document status management
|
||||
- Document filtering and pagination
|
||||
- Document category management
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
file_path: str,
|
||||
page_count: int = 1,
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
category: str = "invoice",
|
||||
admin_token: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new document record.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
file_size: File size in bytes
|
||||
content_type: MIME type
|
||||
file_path: Storage path
|
||||
page_count: Number of pages
|
||||
upload_source: Upload source (ui/api)
|
||||
csv_field_values: CSV field values for reference
|
||||
group_key: User-defined grouping key
|
||||
category: Document category
|
||||
admin_token: Deprecated, kept for compatibility
|
||||
|
||||
Returns:
|
||||
Document ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = AdminDocument(
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
file_path=file_path,
|
||||
page_count=page_count,
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
category=category,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
return str(document.document_id)
|
||||
|
||||
def get(self, document_id: str) -> AdminDocument | None:
|
||||
"""Get a document by ID.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
|
||||
Returns:
|
||||
AdminDocument if found, None otherwise
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AdminDocument, UUID(document_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_by_token(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
) -> AdminDocument | None:
|
||||
"""Get a document by ID. Token parameter is deprecated."""
|
||||
return self.get(document_id)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str | None = None,
|
||||
upload_source: str | None = None,
|
||||
has_annotations: bool | None = None,
|
||||
auto_label_status: str | None = None,
|
||||
batch_id: str | None = None,
|
||||
category: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
"""Get paginated documents with optional filters.
|
||||
|
||||
Args:
|
||||
admin_token: Deprecated, kept for compatibility
|
||||
status: Filter by status
|
||||
upload_source: Filter by upload source
|
||||
has_annotations: Filter by annotation presence
|
||||
auto_label_status: Filter by auto-label status
|
||||
batch_id: Filter by batch ID
|
||||
category: Filter by category
|
||||
limit: Page size
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
Tuple of (documents, total_count)
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
where_clauses = []
|
||||
|
||||
if status:
|
||||
where_clauses.append(AdminDocument.status == status)
|
||||
if upload_source:
|
||||
where_clauses.append(AdminDocument.upload_source == upload_source)
|
||||
if auto_label_status:
|
||||
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
||||
if batch_id:
|
||||
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
||||
if category:
|
||||
where_clauses.append(AdminDocument.category == category)
|
||||
|
||||
count_stmt = select(func.count()).select_from(AdminDocument)
|
||||
if where_clauses:
|
||||
count_stmt = count_stmt.where(*where_clauses)
|
||||
|
||||
if has_annotations is not None:
|
||||
if has_annotations:
|
||||
count_stmt = (
|
||||
count_stmt
|
||||
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.group_by(AdminDocument.document_id)
|
||||
)
|
||||
else:
|
||||
count_stmt = (
|
||||
count_stmt
|
||||
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.where(AdminAnnotation.annotation_id.is_(None))
|
||||
)
|
||||
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(AdminDocument)
|
||||
if where_clauses:
|
||||
statement = statement.where(*where_clauses)
|
||||
|
||||
if has_annotations is not None:
|
||||
if has_annotations:
|
||||
statement = (
|
||||
statement
|
||||
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.group_by(AdminDocument.document_id)
|
||||
)
|
||||
else:
|
||||
statement = (
|
||||
statement
|
||||
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.where(AdminAnnotation.annotation_id.is_(None))
|
||||
)
|
||||
|
||||
statement = statement.order_by(AdminDocument.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
document_id: str,
|
||||
status: str,
|
||||
auto_label_status: str | None = None,
|
||||
auto_label_error: str | None = None,
|
||||
) -> None:
|
||||
"""Update document status.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
status: New status
|
||||
auto_label_status: Auto-label status
|
||||
auto_label_error: Auto-label error message
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.status = status
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
if auto_label_status is not None:
|
||||
document.auto_label_status = auto_label_status
|
||||
if auto_label_error is not None:
|
||||
document.auto_label_error = auto_label_error
|
||||
session.add(document)
|
||||
|
||||
def update_file_path(self, document_id: str, file_path: str) -> None:
|
||||
"""Update document file path."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.file_path = file_path
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
|
||||
def update_group_key(self, document_id: str, group_key: str | None) -> bool:
|
||||
"""Update document group key."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.group_key = group_key
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_category(self, document_id: str, category: str) -> AdminDocument | None:
|
||||
"""Update document category."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.category = category
|
||||
document.updated_at = datetime.now(timezone.utc)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
session.refresh(document)
|
||||
return document
|
||||
return None
|
||||
|
||||
def delete(self, document_id: str) -> bool:
|
||||
"""Delete a document and its annotations.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID as string
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
ann_stmt = select(AdminAnnotation).where(
|
||||
AdminAnnotation.document_id == UUID(document_id)
|
||||
)
|
||||
annotations = session.exec(ann_stmt).all()
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
session.delete(document)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_categories(self) -> list[str]:
|
||||
"""Get list of unique document categories."""
|
||||
with get_session_context() as session:
|
||||
statement = (
|
||||
select(AdminDocument.category)
|
||||
.distinct()
|
||||
.order_by(AdminDocument.category)
|
||||
)
|
||||
categories = session.exec(statement).all()
|
||||
return [c for c in categories if c is not None]
|
||||
|
||||
def get_labeled_for_export(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
) -> list[AdminDocument]:
|
||||
"""Get all labeled documents ready for export."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == "labeled"
|
||||
)
|
||||
if admin_token:
|
||||
statement = statement.where(AdminDocument.admin_token == admin_token)
|
||||
statement = statement.order_by(AdminDocument.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def count_by_status(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""Count documents by status."""
|
||||
with get_session_context() as session:
|
||||
statement = select(
|
||||
AdminDocument.status,
|
||||
func.count(AdminDocument.document_id),
|
||||
).group_by(AdminDocument.status)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return {status: count for status, count in results}
|
||||
|
||||
def get_by_ids(self, document_ids: list[str]) -> list[AdminDocument]:
|
||||
"""Get documents by list of IDs."""
|
||||
with get_session_context() as session:
|
||||
uuids = [UUID(str(did)) for did in document_ids]
|
||||
results = session.exec(
|
||||
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_for_training(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str = "labeled",
|
||||
has_annotations: bool = True,
|
||||
min_annotation_count: int | None = None,
|
||||
exclude_used_in_training: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AdminDocument], int]:
|
||||
"""Get documents suitable for training with filtering."""
|
||||
from inference.data.admin_models import TrainingDocumentLink
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
AdminDocument.status == status,
|
||||
)
|
||||
|
||||
if has_annotations or min_annotation_count:
|
||||
annotation_subq = (
|
||||
select(func.count(AdminAnnotation.annotation_id))
|
||||
.where(AdminAnnotation.document_id == AdminDocument.document_id)
|
||||
.correlate(AdminDocument)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
if has_annotations:
|
||||
statement = statement.where(annotation_subq > 0)
|
||||
|
||||
if min_annotation_count:
|
||||
statement = statement.where(annotation_subq >= min_annotation_count)
|
||||
|
||||
if exclude_used_in_training:
|
||||
from sqlalchemy import exists
|
||||
training_subq = exists(
|
||||
select(1)
|
||||
.select_from(TrainingDocumentLink)
|
||||
.where(TrainingDocumentLink.document_id == AdminDocument.document_id)
|
||||
)
|
||||
statement = statement.where(~training_subq)
|
||||
|
||||
count_statement = select(func.count()).select_from(statement.subquery())
|
||||
total = session.exec(count_statement).one()
|
||||
|
||||
statement = statement.order_by(AdminDocument.created_at.desc())
|
||||
statement = statement.limit(limit).offset(offset)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
|
||||
return list(results), total
|
||||
|
||||
def acquire_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
duration_seconds: int = 300,
|
||||
) -> AdminDocument | None:
|
||||
"""Acquire annotation lock for a document."""
|
||||
from datetime import timedelta
|
||||
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if doc.annotation_lock_until and doc.annotation_lock_until > now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
|
||||
def release_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
force: bool = False,
|
||||
) -> AdminDocument | None:
|
||||
"""Release annotation lock for a document."""
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = None
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
|
||||
def extend_annotation_lock(
|
||||
self,
|
||||
document_id: str,
|
||||
admin_token: str | None = None,
|
||||
additional_seconds: int = 300,
|
||||
) -> AdminDocument | None:
|
||||
"""Extend an existing annotation lock."""
|
||||
from datetime import timedelta
|
||||
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
session.expunge(doc)
|
||||
return doc
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Model Version Repository
|
||||
|
||||
Handles model version operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelVersionRepository(BaseRepository[ModelVersion]):
|
||||
"""Repository for model version management.
|
||||
|
||||
Handles:
|
||||
- Model version CRUD operations
|
||||
- Model activation/deactivation
|
||||
- Active model resolution
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
version: str,
|
||||
name: str,
|
||||
model_path: str,
|
||||
description: str | None = None,
|
||||
task_id: str | UUID | None = None,
|
||||
dataset_id: str | UUID | None = None,
|
||||
metrics_mAP: float | None = None,
|
||||
metrics_precision: float | None = None,
|
||||
metrics_recall: float | None = None,
|
||||
document_count: int = 0,
|
||||
training_config: dict[str, Any] | None = None,
|
||||
file_size: int | None = None,
|
||||
trained_at: datetime | None = None,
|
||||
) -> ModelVersion:
|
||||
"""Create a new model version."""
|
||||
with get_session_context() as session:
|
||||
model = ModelVersion(
|
||||
version=version,
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
description=description,
|
||||
task_id=UUID(str(task_id)) if task_id else None,
|
||||
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=training_config,
|
||||
file_size=file_size,
|
||||
trained_at=trained_at,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Get a model version by ID."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if model:
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ModelVersion], int]:
|
||||
"""List model versions with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(ModelVersion)
|
||||
count_query = select(func.count()).select_from(ModelVersion)
|
||||
if status:
|
||||
query = query.where(ModelVersion.status == status)
|
||||
count_query = count_query.where(ModelVersion.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
models = session.exec(
|
||||
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for m in models:
|
||||
session.expunge(m)
|
||||
return list(models), total
|
||||
|
||||
def get_active(self) -> ModelVersion | None:
|
||||
"""Get the currently active model version for inference."""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def activate(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Activate a model version for inference (deactivates all others)."""
|
||||
with get_session_context() as session:
|
||||
all_versions = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).all()
|
||||
for v in all_versions:
|
||||
v.is_active = False
|
||||
v.status = "inactive"
|
||||
v.updated_at = datetime.utcnow()
|
||||
session.add(v)
|
||||
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = True
|
||||
model.status = "active"
|
||||
model.activated_at = datetime.utcnow()
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def deactivate(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Deactivate a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = False
|
||||
model.status = "inactive"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def update(
|
||||
self,
|
||||
version_id: str | UUID,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> ModelVersion | None:
|
||||
"""Update model version metadata."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if name is not None:
|
||||
model.name = name
|
||||
if description is not None:
|
||||
model.description = description
|
||||
if status is not None:
|
||||
model.status = status
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def archive(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Archive a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if model.is_active:
|
||||
return None
|
||||
model.status = "archived"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def delete(self, version_id: str | UUID) -> bool:
|
||||
"""Delete a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return False
|
||||
if model.is_active:
|
||||
return False
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
return True
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Token Repository
|
||||
|
||||
Handles admin token operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenRepository(BaseRepository[AdminToken]):
|
||||
"""Repository for admin token management.
|
||||
|
||||
Handles:
|
||||
- Token validation (active status, expiration)
|
||||
- Token CRUD operations
|
||||
- Usage tracking
|
||||
"""
|
||||
|
||||
def is_valid(self, token: str) -> bool:
|
||||
"""Check if admin token exists and is active.
|
||||
|
||||
Args:
|
||||
token: The token string to validate
|
||||
|
||||
Returns:
|
||||
True if token exists, is active, and not expired
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.get(AdminToken, token)
|
||||
if result is None:
|
||||
return False
|
||||
if not result.is_active:
|
||||
return False
|
||||
if result.expires_at and result.expires_at < self._now():
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, token: str) -> AdminToken | None:
|
||||
"""Get admin token details.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
|
||||
Returns:
|
||||
AdminToken if found, None otherwise
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.get(AdminToken, token)
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def create(
|
||||
self,
|
||||
token: str,
|
||||
name: str,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Create or update an admin token.
|
||||
|
||||
If token exists, updates name, expires_at, and reactivates it.
|
||||
Otherwise creates a new token.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
name: Display name for the token
|
||||
expires_at: Optional expiration datetime
|
||||
"""
|
||||
with self._session() as session:
|
||||
existing = session.get(AdminToken, token)
|
||||
if existing:
|
||||
existing.name = name
|
||||
existing.expires_at = expires_at
|
||||
existing.is_active = True
|
||||
session.add(existing)
|
||||
else:
|
||||
new_token = AdminToken(
|
||||
token=token,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(new_token)
|
||||
|
||||
def update_usage(self, token: str) -> None:
|
||||
"""Update admin token last used timestamp.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
"""
|
||||
with self._session() as session:
|
||||
admin_token = session.get(AdminToken, token)
|
||||
if admin_token:
|
||||
admin_token.last_used_at = self._now()
|
||||
session.add(admin_token)
|
||||
|
||||
def deactivate(self, token: str) -> bool:
|
||||
"""Deactivate an admin token.
|
||||
|
||||
Args:
|
||||
token: The token string
|
||||
|
||||
Returns:
|
||||
True if token was deactivated, False if not found
|
||||
"""
|
||||
with self._session() as session:
|
||||
admin_token = session.get(AdminToken, token)
|
||||
if admin_token:
|
||||
admin_token.is_active = False
|
||||
session.add(admin_token)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Training Task Repository
|
||||
|
||||
Handles training task operations following Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTaskRepository(BaseRepository[TrainingTask]):
|
||||
"""Repository for training task management.
|
||||
|
||||
Handles:
|
||||
- Training task CRUD operations
|
||||
- Task status management
|
||||
- Training logs
|
||||
- Training document links
|
||||
"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
admin_token: str,
|
||||
name: str,
|
||||
task_type: str = "train",
|
||||
description: str | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
scheduled_at: datetime | None = None,
|
||||
cron_expression: str | None = None,
|
||||
is_recurring: bool = False,
|
||||
dataset_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new training task.
|
||||
|
||||
Returns:
|
||||
Task ID as string
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
task = TrainingTask(
|
||||
admin_token=admin_token,
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
description=description,
|
||||
config=config,
|
||||
scheduled_at=scheduled_at,
|
||||
cron_expression=cron_expression,
|
||||
is_recurring=is_recurring,
|
||||
status="scheduled" if scheduled_at else "pending",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
session.add(task)
|
||||
session.flush()
|
||||
return str(task.task_id)
|
||||
|
||||
def get(self, task_id: str) -> TrainingTask | None:
|
||||
"""Get a training task by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(TrainingTask, UUID(task_id))
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_by_token(
|
||||
self,
|
||||
task_id: str,
|
||||
admin_token: str | None = None,
|
||||
) -> TrainingTask | None:
|
||||
"""Get a training task by ID. Token parameter is deprecated."""
|
||||
return self.get(task_id)
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TrainingTask], int]:
|
||||
"""Get paginated training tasks."""
|
||||
with get_session_context() as session:
|
||||
count_stmt = select(func.count()).select_from(TrainingTask)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(TrainingTask.status == status)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
statement = select(TrainingTask)
|
||||
if status:
|
||||
statement = statement.where(TrainingTask.status == status)
|
||||
statement = statement.order_by(TrainingTask.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def get_pending(self) -> list[TrainingTask]:
|
||||
"""Get pending training tasks ready to run."""
|
||||
with get_session_context() as session:
|
||||
now = datetime.utcnow()
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.status.in_(["pending", "scheduled"]),
|
||||
(TrainingTask.scheduled_at == None) | (TrainingTask.scheduled_at <= now),
|
||||
).order_by(TrainingTask.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
result_metrics: dict[str, Any] | None = None,
|
||||
model_path: str | None = None,
|
||||
) -> None:
|
||||
"""Update training task status."""
|
||||
with get_session_context() as session:
|
||||
task = session.get(TrainingTask, UUID(task_id))
|
||||
if task:
|
||||
task.status = status
|
||||
task.updated_at = datetime.utcnow()
|
||||
if status == "running":
|
||||
task.started_at = datetime.utcnow()
|
||||
elif status in ("completed", "failed"):
|
||||
task.completed_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
task.error_message = error_message
|
||||
if result_metrics is not None:
|
||||
task.result_metrics = result_metrics
|
||||
if model_path is not None:
|
||||
task.model_path = model_path
|
||||
session.add(task)
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
"""Cancel a training task."""
|
||||
with get_session_context() as session:
|
||||
task = session.get(TrainingTask, UUID(task_id))
|
||||
if task and task.status in ("pending", "scheduled"):
|
||||
task.status = "cancelled"
|
||||
task.updated_at = datetime.utcnow()
|
||||
session.add(task)
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_log(
|
||||
self,
|
||||
task_id: str,
|
||||
level: str,
|
||||
message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a training log entry."""
|
||||
with get_session_context() as session:
|
||||
log = TrainingLog(
|
||||
task_id=UUID(task_id),
|
||||
level=level,
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
session.add(log)
|
||||
|
||||
def get_logs(
|
||||
self,
|
||||
task_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[TrainingLog]:
|
||||
"""Get training logs for a task."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingLog).where(
|
||||
TrainingLog.task_id == UUID(task_id)
|
||||
).order_by(TrainingLog.created_at.desc()).offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def create_document_link(
|
||||
self,
|
||||
task_id: UUID,
|
||||
document_id: UUID,
|
||||
annotation_snapshot: dict[str, Any] | None = None,
|
||||
) -> TrainingDocumentLink:
|
||||
"""Create a training document link."""
|
||||
with get_session_context() as session:
|
||||
link = TrainingDocumentLink(
|
||||
task_id=task_id,
|
||||
document_id=document_id,
|
||||
annotation_snapshot=annotation_snapshot,
|
||||
)
|
||||
session.add(link)
|
||||
session.commit()
|
||||
session.refresh(link)
|
||||
session.expunge(link)
|
||||
return link
|
||||
|
||||
def get_document_links(self, task_id: UUID) -> list[TrainingDocumentLink]:
|
||||
"""Get all document links for a training task."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingDocumentLink).where(
|
||||
TrainingDocumentLink.task_id == task_id
|
||||
).order_by(TrainingDocumentLink.created_at)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_document_training_tasks(self, document_id: UUID) -> list[TrainingDocumentLink]:
|
||||
"""Get all training tasks that used this document."""
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingDocumentLink).where(
|
||||
TrainingDocumentLink.document_id == document_id
|
||||
).order_by(TrainingDocumentLink.created_at.desc())
|
||||
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
@@ -11,11 +11,11 @@ Enhanced features:
|
||||
- Smart amount parsing with multiple strategies
|
||||
- Enhanced date format unification
|
||||
- OCR error correction integration
|
||||
|
||||
Refactored to use modular normalizers for each field type.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import numpy as np
|
||||
@@ -25,15 +25,22 @@ from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import Detection
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.fuzzy_matcher import FuzzyMatcher
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
# Import new unified parsers
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
from .customer_number_parser import CustomerNumberParser
|
||||
|
||||
# Import normalizers
|
||||
from .normalizers import (
|
||||
BaseNormalizer,
|
||||
NormalizationResult,
|
||||
create_normalizer_registry,
|
||||
EnhancedAmountNormalizer,
|
||||
EnhancedDateNormalizer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
@@ -80,7 +87,8 @@ class FieldExtractor:
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
bbox_padding: float = 0.1,
|
||||
dpi: int = 300
|
||||
dpi: int = 300,
|
||||
use_enhanced_parsing: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize field extractor.
|
||||
@@ -90,17 +98,22 @@ class FieldExtractor:
|
||||
use_gpu: Whether to use GPU for OCR
|
||||
bbox_padding: Padding to add around bboxes (as fraction)
|
||||
dpi: DPI used for rendering (for coordinate conversion)
|
||||
use_enhanced_parsing: Whether to use enhanced normalizers
|
||||
"""
|
||||
self.ocr_lang = ocr_lang
|
||||
self.use_gpu = use_gpu
|
||||
self.bbox_padding = bbox_padding
|
||||
self.dpi = dpi
|
||||
self._ocr_engine = None # Lazy init
|
||||
self.use_enhanced_parsing = use_enhanced_parsing
|
||||
|
||||
# Initialize new unified parsers
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.customer_number_parser = CustomerNumberParser()
|
||||
|
||||
# Initialize normalizer registry
|
||||
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
|
||||
|
||||
@property
|
||||
def ocr_engine(self):
|
||||
"""Lazy-load OCR engine only when needed."""
|
||||
@@ -246,6 +259,9 @@ class FieldExtractor:
|
||||
"""
|
||||
Normalize and validate extracted text for a field.
|
||||
|
||||
Uses modular normalizers for each field type.
|
||||
Falls back to legacy methods for payment_line and customer_number.
|
||||
|
||||
Returns:
|
||||
(normalized_value, is_valid, validation_error)
|
||||
"""
|
||||
@@ -254,389 +270,21 @@ class FieldExtractor:
|
||||
if not text:
|
||||
return None, False, "Empty text"
|
||||
|
||||
if field_name == 'InvoiceNumber':
|
||||
return self._normalize_invoice_number(text)
|
||||
|
||||
elif field_name == 'OCR':
|
||||
return self._normalize_ocr_number(text)
|
||||
|
||||
elif field_name == 'Bankgiro':
|
||||
return self._normalize_bankgiro(text)
|
||||
|
||||
elif field_name == 'Plusgiro':
|
||||
return self._normalize_plusgiro(text)
|
||||
|
||||
elif field_name == 'Amount':
|
||||
return self._normalize_amount(text)
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
return self._normalize_date(text)
|
||||
|
||||
elif field_name == 'payment_line':
|
||||
# Special handling for payment_line and customer_number (use unified parsers)
|
||||
if field_name == 'payment_line':
|
||||
return self._normalize_payment_line(text)
|
||||
|
||||
elif field_name == 'supplier_org_number':
|
||||
return self._normalize_supplier_org_number(text)
|
||||
|
||||
elif field_name == 'customer_number':
|
||||
if field_name == 'customer_number':
|
||||
return self._normalize_customer_number(text)
|
||||
|
||||
else:
|
||||
return text, True, None
|
||||
# Use normalizer registry for other fields
|
||||
normalizer = self._normalizers.get(field_name)
|
||||
if normalizer:
|
||||
result = normalizer.normalize(text)
|
||||
return result.to_tuple()
|
||||
|
||||
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize invoice number.
|
||||
|
||||
Invoice numbers can be:
|
||||
- Pure digits: 12345678
|
||||
- Alphanumeric: A3861, INV-2024-001, F12345
|
||||
- With separators: 2024/001, 2024-001
|
||||
|
||||
Strategy:
|
||||
1. Look for common invoice number patterns
|
||||
2. Prefer shorter, more specific matches over long digit sequences
|
||||
"""
|
||||
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
|
||||
# Examples: A3861, F12345, INV001
|
||||
alpha_patterns = [
|
||||
r'\b([A-Z]{1,3}\d{3,10})\b', # A3861, INV12345
|
||||
r'\b(\d{3,10}[A-Z]{1,3})\b', # 12345A
|
||||
r'\b([A-Z]{2,5}[-/]?\d{3,10})\b', # INV-12345, FAK12345
|
||||
]
|
||||
|
||||
for pattern in alpha_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).upper(), True, None
|
||||
|
||||
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
|
||||
year_pattern = r'\b(20\d{2}[-/]\d{3,8})\b'
|
||||
match = re.search(year_pattern, text)
|
||||
if match:
|
||||
return match.group(1), True, None
|
||||
|
||||
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
|
||||
# This avoids capturing long OCR numbers
|
||||
digit_sequences = re.findall(r'\b(\d{3,10})\b', text)
|
||||
if digit_sequences:
|
||||
# Prefer shorter sequences (more likely to be invoice number)
|
||||
# Also filter out sequences that look like dates (8 digits starting with 20)
|
||||
valid_sequences = []
|
||||
for seq in digit_sequences:
|
||||
# Skip if it looks like a date (YYYYMMDD)
|
||||
if len(seq) == 8 and seq.startswith('20'):
|
||||
continue
|
||||
# Skip if too long (likely OCR number)
|
||||
if len(seq) > 10:
|
||||
continue
|
||||
valid_sequences.append(seq)
|
||||
|
||||
if valid_sequences:
|
||||
# Return shortest valid sequence
|
||||
return min(valid_sequences, key=len), True, None
|
||||
|
||||
# Fallback: extract all digits if nothing else works
|
||||
digits = re.sub(r'\D', '', text)
|
||||
if len(digits) >= 3:
|
||||
# Limit to first 15 digits to avoid very long sequences
|
||||
return digits[:15], True, "Fallback extraction"
|
||||
|
||||
return None, False, f"Cannot extract invoice number from: {text[:50]}"
|
||||
|
||||
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize OCR number."""
|
||||
digits = re.sub(r'\D', '', text)
|
||||
|
||||
if len(digits) < 5:
|
||||
return None, False, f"Too few digits for OCR: {len(digits)}"
|
||||
|
||||
return digits, True, None
|
||||
|
||||
def _luhn_checksum(self, digits: str) -> bool:
|
||||
"""
|
||||
Validate using Luhn (Mod10) algorithm.
|
||||
Used for Bankgiro, Plusgiro, and OCR number validation.
|
||||
|
||||
Delegates to shared FieldValidators for consistency.
|
||||
"""
|
||||
return FieldValidators.luhn_checksum(digits)
|
||||
|
||||
def _detect_giro_type(self, text: str) -> str | None:
|
||||
"""
|
||||
Detect if text matches BG or PG display format pattern.
|
||||
|
||||
BG typical format: ^\d{3,4}-\d{4}$ (e.g., 123-4567, 1234-5678)
|
||||
PG typical format: ^\d{1,7}-\d$ (e.g., 1-8, 12345-6, 1234567-8)
|
||||
|
||||
Returns: 'BG', 'PG', or None if cannot determine
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# BG pattern: 3-4 digits, dash, 4 digits (total 7-8 digits)
|
||||
if re.match(r'^\d{3,4}-\d{4}$', text):
|
||||
return 'BG'
|
||||
|
||||
# PG pattern: 1-7 digits, dash, 1 digit (total 2-8 digits)
|
||||
if re.match(r'^\d{1,7}-\d$', text):
|
||||
return 'PG'
|
||||
|
||||
return None
|
||||
|
||||
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize Bankgiro number.
|
||||
|
||||
Bankgiro rules:
|
||||
- 7 or 8 digits only
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
|
||||
|
||||
Display pattern: ^\d{3,4}-\d{4}$
|
||||
Normalized pattern: ^\d{7,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
|
||||
"""
|
||||
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
|
||||
# This distinguishes BG from PG which uses X-X format (digits-single digit)
|
||||
bg_matches = re.findall(r'(\d{3,4})-(\d{4})', text)
|
||||
|
||||
if bg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in bg_matches:
|
||||
digits = match[0] + match[1]
|
||||
if len(digits) in (7, 8) and self._luhn_checksum(digits):
|
||||
# Valid BG found
|
||||
if len(digits) == 8:
|
||||
formatted = f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
formatted = f"{digits[:3]}-{digits[3:]}"
|
||||
return formatted, True, None
|
||||
|
||||
# No valid Luhn, use first match
|
||||
digits = bg_matches[0][0] + bg_matches[0][1]
|
||||
if len(digits) in (7, 8):
|
||||
if len(digits) == 8:
|
||||
formatted = f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
formatted = f"{digits[:3]}-{digits[3:]}"
|
||||
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
||||
|
||||
# Fallback: try to find 7-8 consecutive digits
|
||||
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
|
||||
# to avoid misinterpreting PG as BG
|
||||
pg_format_present = re.search(r'(?<![0-9])\d{1,7}-\d(?!\d)', text)
|
||||
if pg_format_present:
|
||||
return None, False, f"No valid Bankgiro found in text"
|
||||
|
||||
digit_match = re.search(r'\b(\d{7,8})\b', text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
if len(digits) in (7, 8):
|
||||
luhn_ok = self._luhn_checksum(digits)
|
||||
if len(digits) == 8:
|
||||
formatted = f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
formatted = f"{digits[:3]}-{digits[3:]}"
|
||||
if luhn_ok:
|
||||
return formatted, True, None
|
||||
else:
|
||||
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
||||
|
||||
return None, False, f"No valid Bankgiro found in text"
|
||||
|
||||
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize Plusgiro number.
|
||||
|
||||
Plusgiro rules:
|
||||
- 2 to 8 digits
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
|
||||
|
||||
Display pattern: ^\d{1,7}-\d$
|
||||
Normalized pattern: ^\d{2,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
|
||||
"""
|
||||
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
|
||||
# This is distinct from BG format which has 4 digits after the dash
|
||||
# Pattern allows spaces within the number like "486 98 63-6"
|
||||
# (?<![0-9]) ensures we don't start from within another number (like BG)
|
||||
pg_matches = re.findall(r'(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)', text)
|
||||
|
||||
if pg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in pg_matches:
|
||||
# Remove spaces from the first part
|
||||
digits = re.sub(r'\s', '', match[0]) + match[1]
|
||||
if 2 <= len(digits) <= 8 and self._luhn_checksum(digits):
|
||||
# Valid PG found
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return formatted, True, None
|
||||
|
||||
# No valid Luhn, use first match with most digits
|
||||
best_match = max(pg_matches, key=lambda m: len(re.sub(r'\s', '', m[0])))
|
||||
digits = re.sub(r'\s', '', best_match[0]) + best_match[1]
|
||||
if 2 <= len(digits) <= 8:
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
||||
|
||||
# If no PG format found, extract all digits and format as PG
|
||||
# This handles cases where the number might be in BG format or raw digits
|
||||
all_digits = re.sub(r'\D', '', text)
|
||||
|
||||
# Try to find a valid 2-8 digit sequence
|
||||
if 2 <= len(all_digits) <= 8:
|
||||
luhn_ok = self._luhn_checksum(all_digits)
|
||||
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
|
||||
if luhn_ok:
|
||||
return formatted, True, None
|
||||
else:
|
||||
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
||||
|
||||
# Try to find any 2-8 digit sequence in text
|
||||
digit_match = re.search(r'\b(\d{2,8})\b', text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
luhn_ok = self._luhn_checksum(digits)
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
if luhn_ok:
|
||||
return formatted, True, None
|
||||
else:
|
||||
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
||||
|
||||
return None, False, f"No valid Plusgiro found in text"
|
||||
|
||||
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize monetary amount.
|
||||
|
||||
Uses shared TextCleaner for preprocessing and FieldValidators for parsing.
|
||||
If multiple amounts are found, returns the last one (usually the total).
|
||||
"""
|
||||
# Split by newlines and process line by line to get the last valid amount
|
||||
lines = text.split('\n')
|
||||
|
||||
# Collect all valid amounts from all lines
|
||||
all_amounts = []
|
||||
|
||||
# Pattern for Swedish amount format (with decimals)
|
||||
amount_pattern = r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?'
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Find all amounts in this line
|
||||
matches = re.findall(amount_pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
amount_str = match.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
all_amounts.append(amount)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
if all_amounts:
|
||||
return f"{all_amounts[-1]:.2f}", True, None
|
||||
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
|
||||
# Try to find any decimal number
|
||||
simple_pattern = r'(\d+[,\.]\d{2})'
|
||||
matches = re.findall(simple_pattern, text)
|
||||
if matches:
|
||||
amount_str = matches[-1].replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Last resort: try to find integer amount (no decimals)
|
||||
# Look for patterns like "Amount: 11699" or standalone numbers
|
||||
int_pattern = r'(?:amount|belopp|summa|total)[:\s]*(\d+)'
|
||||
match = re.search(int_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Very last resort: find any standalone number >= 3 digits
|
||||
standalone_pattern = r'\b(\d{3,})\b'
|
||||
matches = re.findall(standalone_pattern, text)
|
||||
if matches:
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None, False, f"Cannot parse amount: {text}"
|
||||
|
||||
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize date from text that may contain surrounding text.
|
||||
|
||||
Uses shared FieldValidators for date parsing and validation.
|
||||
|
||||
Handles various date formats found in Swedish invoices:
|
||||
- 2025-08-29 (ISO format)
|
||||
- 2025.08.29 (dot separator)
|
||||
- 29/08/2025 (European format)
|
||||
- 29.08.2025 (European with dots)
|
||||
- 20250829 (compact format)
|
||||
"""
|
||||
# First, try using shared validator
|
||||
iso_date = FieldValidators.format_date_iso(text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return iso_date, True, None
|
||||
|
||||
# Fallback: try original patterns for edge cases
|
||||
from datetime import datetime
|
||||
|
||||
patterns = [
|
||||
# ISO format: 2025-08-29
|
||||
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
||||
# Dot format: 2025.08.29 (common in Swedish)
|
||||
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
||||
# European slash format: 29/08/2025
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
||||
# European dot format: 29.08.2025
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
||||
# Compact format: 20250829
|
||||
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', lambda m: f"{m.group(1)}-{m.group(2)}-{m.group(3)}"),
|
||||
]
|
||||
|
||||
for pattern, formatter in patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
try:
|
||||
date_str = formatter(match)
|
||||
# Validate date
|
||||
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
# Sanity check: year should be reasonable (2000-2100)
|
||||
if 2000 <= parsed_date.year <= 2100:
|
||||
return date_str, True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None, False, f"Cannot parse date: {text}"
|
||||
# Fallback for unknown fields
|
||||
return text, True, None
|
||||
|
||||
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
@@ -657,44 +305,6 @@ class FieldExtractor:
|
||||
self.payment_line_parser.parse(text)
|
||||
)
|
||||
|
||||
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize Swedish supplier organization number.
|
||||
|
||||
Extracts organization number in format: NNNNNN-NNNN (10 digits)
|
||||
Also handles VAT numbers: SE + 10 digits + 01
|
||||
|
||||
Examples:
|
||||
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
|
||||
'Momsreg.nr SE556123456701' -> '556123-4567'
|
||||
"""
|
||||
# Pattern 1: Standard org number format: NNNNNN-NNNN
|
||||
org_pattern = r'\b(\d{6})-?(\d{4})\b'
|
||||
match = re.search(org_pattern, text)
|
||||
if match:
|
||||
org_num = f"{match.group(1)}-{match.group(2)}"
|
||||
return org_num, True, None
|
||||
|
||||
# Pattern 2: VAT number format: SE + 10 digits + 01
|
||||
vat_pattern = r'SE\s*(\d{10})01'
|
||||
match = re.search(vat_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return org_num, True, None
|
||||
|
||||
# Pattern 3: Just 10 consecutive digits
|
||||
digits_pattern = r'\b(\d{10})\b'
|
||||
match = re.search(digits_pattern, text)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
# Validate: first digit should be 1-9 for Swedish org numbers
|
||||
if digits[0] in '123456789':
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return org_num, True, None
|
||||
|
||||
return None, False, f"Cannot extract org number from: {text[:100]}"
|
||||
|
||||
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize customer number text using unified CustomerNumberParser.
|
||||
@@ -908,175 +518,6 @@ class FieldExtractor:
|
||||
best = max(items, key=lambda x: x[1][0])
|
||||
return best[0], best[1][1]
|
||||
|
||||
# =========================================================================
|
||||
# Enhanced Amount Parsing
|
||||
# =========================================================================
|
||||
|
||||
def _normalize_amount_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Enhanced amount parsing with multiple strategies.
|
||||
|
||||
Strategies:
|
||||
1. Pattern matching for Swedish formats
|
||||
2. Context-aware extraction (look for keywords like "Total", "Summa")
|
||||
3. OCR error correction for common digit errors
|
||||
4. Multi-amount handling (prefer last/largest as total)
|
||||
|
||||
This method replaces the original _normalize_amount when enhanced mode is enabled.
|
||||
"""
|
||||
# Strategy 1: Apply OCR corrections first
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Strategy 2: Look for labeled amounts (highest priority)
|
||||
labeled_patterns = [
|
||||
# Swedish patterns
|
||||
(r'(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 1.0),
|
||||
(r'(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 0.8), # Lower priority for VAT
|
||||
# Generic pattern
|
||||
(r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?', 0.7),
|
||||
]
|
||||
|
||||
candidates = []
|
||||
for pattern, priority in labeled_patterns:
|
||||
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
|
||||
amount_str = match.group(1).replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if 0 < amount < 10_000_000: # Reasonable range
|
||||
candidates.append((amount, priority, match.start()))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if candidates:
|
||||
# Sort by priority (desc), then by position (later is usually total)
|
||||
candidates.sort(key=lambda x: (-x[1], -x[2]))
|
||||
best_amount = candidates[0][0]
|
||||
return f"{best_amount:.2f}", True, None
|
||||
|
||||
# Strategy 3: Parse with shared validator
|
||||
cleaned = TextCleaner.normalize_amount_text(corrected_text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return f"{amount:.2f}", True, None
|
||||
|
||||
# Strategy 4: Try to extract any decimal number as fallback
|
||||
decimal_pattern = r'(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})'
|
||||
matches = re.findall(decimal_pattern, corrected_text)
|
||||
if matches:
|
||||
# Clean and parse each match
|
||||
amounts = []
|
||||
for m in matches:
|
||||
cleaned_m = m.replace(' ', '').replace('.', '').replace(',', '.')
|
||||
# Handle Swedish format: "1 234,56" -> "1234.56"
|
||||
if ',' in m and '.' not in m:
|
||||
cleaned_m = m.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amt = float(cleaned_m)
|
||||
if 0 < amt < 10_000_000:
|
||||
amounts.append(amt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if amounts:
|
||||
# Return the last/largest amount (usually the total)
|
||||
return f"{max(amounts):.2f}", True, None
|
||||
|
||||
return None, False, f"Cannot parse amount: {text[:50]}"
|
||||
|
||||
# =========================================================================
|
||||
# Enhanced Date Parsing
|
||||
# =========================================================================
|
||||
|
||||
def _normalize_date_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Enhanced date parsing with comprehensive format support.
|
||||
|
||||
Supports:
|
||||
- ISO: 2024-12-29, 2024/12/29
|
||||
- European: 29.12.2024, 29/12/2024, 29-12-2024
|
||||
- Swedish text: "29 december 2024", "29 dec 2024"
|
||||
- Compact: 20241229
|
||||
- With OCR corrections: 2O24-12-29 -> 2024-12-29
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Try shared validator first
|
||||
iso_date = FieldValidators.format_date_iso(corrected_text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return iso_date, True, None
|
||||
|
||||
# Swedish month names
|
||||
swedish_months = {
|
||||
'januari': 1, 'jan': 1,
|
||||
'februari': 2, 'feb': 2,
|
||||
'mars': 3, 'mar': 3,
|
||||
'april': 4, 'apr': 4,
|
||||
'maj': 5,
|
||||
'juni': 6, 'jun': 6,
|
||||
'juli': 7, 'jul': 7,
|
||||
'augusti': 8, 'aug': 8,
|
||||
'september': 9, 'sep': 9, 'sept': 9,
|
||||
'oktober': 10, 'okt': 10,
|
||||
'november': 11, 'nov': 11,
|
||||
'december': 12, 'dec': 12,
|
||||
}
|
||||
|
||||
# Pattern for Swedish text dates: "29 december 2024" or "29 dec 2024"
|
||||
swedish_pattern = r'(\d{1,2})\s+([a-zåäö]+)\s+(\d{4})'
|
||||
match = re.search(swedish_pattern, corrected_text.lower())
|
||||
if match:
|
||||
day = int(match.group(1))
|
||||
month_name = match.group(2)
|
||||
year = int(match.group(3))
|
||||
if month_name in swedish_months:
|
||||
month = swedish_months[month_name]
|
||||
try:
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return dt.strftime('%Y-%m-%d'), True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Extended patterns
|
||||
patterns = [
|
||||
# ISO format: 2025-08-29, 2025/08/29
|
||||
(r'(\d{4})[-/](\d{1,2})[-/](\d{1,2})', 'ymd'),
|
||||
# Dot format: 2025.08.29
|
||||
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', 'ymd'),
|
||||
# European slash: 29/08/2025
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', 'dmy'),
|
||||
# European dot: 29.08.2025
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', 'dmy'),
|
||||
# European dash: 29-08-2025
|
||||
(r'(\d{1,2})-(\d{1,2})-(\d{4})', 'dmy'),
|
||||
# Compact: 20250829
|
||||
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', 'ymd_compact'),
|
||||
]
|
||||
|
||||
for pattern, fmt in patterns:
|
||||
match = re.search(pattern, corrected_text)
|
||||
if match:
|
||||
try:
|
||||
if fmt == 'ymd':
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
elif fmt == 'dmy':
|
||||
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
elif fmt == 'ymd_compact':
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
else:
|
||||
continue
|
||||
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return dt.strftime('%Y-%m-%d'), True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None, False, f"Cannot parse date: {text[:50]}"
|
||||
|
||||
# =========================================================================
|
||||
# Apply OCR Corrections to Raw Text
|
||||
# =========================================================================
|
||||
@@ -1162,10 +603,15 @@ class FieldExtractor:
|
||||
|
||||
# Re-normalize with enhanced methods if corrections were applied
|
||||
if corrections or base_result.normalized_value is None:
|
||||
# Use enhanced normalizers for Amount and Date fields
|
||||
if base_result.field_name == 'Amount':
|
||||
normalized, is_valid, error = self._normalize_amount_enhanced(corrected_text)
|
||||
enhanced_normalizer = EnhancedAmountNormalizer()
|
||||
result = enhanced_normalizer.normalize(corrected_text)
|
||||
normalized, is_valid, error = result.to_tuple()
|
||||
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
normalized, is_valid, error = self._normalize_date_enhanced(corrected_text)
|
||||
enhanced_normalizer = EnhancedDateNormalizer()
|
||||
result = enhanced_normalizer.normalize(corrected_text)
|
||||
normalized, is_valid, error = result.to_tuple()
|
||||
else:
|
||||
# Re-run standard normalization with corrected text
|
||||
normalized, is_valid, error = self._normalize_and_validate(
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Normalizers Package
|
||||
|
||||
Provides field-specific normalizers for invoice data extraction.
|
||||
Each normalizer handles a specific field type's normalization and validation.
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
from .invoice_number import InvoiceNumberNormalizer
|
||||
from .ocr_number import OcrNumberNormalizer
|
||||
from .bankgiro import BankgiroNormalizer
|
||||
from .plusgiro import PlusgiroNormalizer
|
||||
from .amount import AmountNormalizer, EnhancedAmountNormalizer
|
||||
from .date import DateNormalizer, EnhancedDateNormalizer
|
||||
from .supplier_org_number import SupplierOrgNumberNormalizer
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseNormalizer",
|
||||
"NormalizationResult",
|
||||
# Normalizers
|
||||
"InvoiceNumberNormalizer",
|
||||
"OcrNumberNormalizer",
|
||||
"BankgiroNormalizer",
|
||||
"PlusgiroNormalizer",
|
||||
"AmountNormalizer",
|
||||
"EnhancedAmountNormalizer",
|
||||
"DateNormalizer",
|
||||
"EnhancedDateNormalizer",
|
||||
"SupplierOrgNumberNormalizer",
|
||||
]
|
||||
|
||||
|
||||
# Registry of all normalizers by field name
|
||||
def create_normalizer_registry(
|
||||
use_enhanced: bool = False,
|
||||
) -> dict[str, BaseNormalizer]:
|
||||
"""
|
||||
Create a registry mapping field names to normalizer instances.
|
||||
|
||||
Args:
|
||||
use_enhanced: Whether to use enhanced normalizers for amount/date
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field names to normalizer instances
|
||||
"""
|
||||
amount_normalizer = EnhancedAmountNormalizer() if use_enhanced else AmountNormalizer()
|
||||
date_normalizer = EnhancedDateNormalizer() if use_enhanced else DateNormalizer()
|
||||
|
||||
return {
|
||||
"InvoiceNumber": InvoiceNumberNormalizer(),
|
||||
"OCR": OcrNumberNormalizer(),
|
||||
"Bankgiro": BankgiroNormalizer(),
|
||||
"Plusgiro": PlusgiroNormalizer(),
|
||||
"Amount": amount_normalizer,
|
||||
"InvoiceDate": date_normalizer,
|
||||
"InvoiceDueDate": date_normalizer,
|
||||
"supplier_org_number": SupplierOrgNumberNormalizer(),
|
||||
}
|
||||
185
packages/inference/inference/pipeline/normalizers/amount.py
Normal file
185
packages/inference/inference/pipeline/normalizers/amount.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Amount Normalizer
|
||||
|
||||
Handles normalization and validation of monetary amounts.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class AmountNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes monetary amounts from Swedish invoices.
|
||||
|
||||
Handles various Swedish amount formats:
|
||||
- With decimal: 1 234,56 kr
|
||||
- With SEK suffix: 1234.56 SEK
|
||||
- Multiple amounts (returns the last one, usually the total)
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Amount"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Split by newlines and process line by line to get the last valid amount
|
||||
lines = text.split("\n")
|
||||
|
||||
# Collect all valid amounts from all lines
|
||||
all_amounts: list[float] = []
|
||||
|
||||
# Pattern for Swedish amount format (with decimals)
|
||||
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Find all amounts in this line
|
||||
matches = re.findall(amount_pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
amount_str = match.replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
all_amounts.append(amount)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
if all_amounts:
|
||||
return NormalizationResult.success(f"{all_amounts[-1]:.2f}")
|
||||
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Try to find any decimal number
|
||||
simple_pattern = r"(\d+[,\.]\d{2})"
|
||||
matches = re.findall(simple_pattern, text)
|
||||
if matches:
|
||||
amount_str = matches[-1].replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Last resort: try to find integer amount (no decimals)
|
||||
# Look for patterns like "Amount: 11699" or standalone numbers
|
||||
int_pattern = r"(?:amount|belopp|summa|total)[:\s]*(\d+)"
|
||||
match = re.search(int_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Very last resort: find any standalone number >= 3 digits
|
||||
standalone_pattern = r"\b(\d{3,})\b"
|
||||
matches = re.findall(standalone_pattern, text)
|
||||
if matches:
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse amount: {text}")
|
||||
|
||||
|
||||
class EnhancedAmountNormalizer(AmountNormalizer):
|
||||
"""
|
||||
Enhanced amount parsing with multiple strategies.
|
||||
|
||||
Strategies:
|
||||
1. Pattern matching for Swedish formats
|
||||
2. Context-aware extraction (look for keywords like "Total", "Summa")
|
||||
3. OCR error correction for common digit errors
|
||||
4. Multi-amount handling (prefer last/largest as total)
|
||||
"""
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Strategy 1: Apply OCR corrections first
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Strategy 2: Look for labeled amounts (highest priority)
|
||||
labeled_patterns = [
|
||||
# Swedish patterns
|
||||
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
|
||||
(
|
||||
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
|
||||
0.8,
|
||||
), # Lower priority for VAT
|
||||
# Generic pattern
|
||||
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
|
||||
]
|
||||
|
||||
candidates: list[tuple[float, float, int]] = []
|
||||
for pattern, priority in labeled_patterns:
|
||||
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
|
||||
amount_str = match.group(1).replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if 0 < amount < 10_000_000: # Reasonable range
|
||||
candidates.append((amount, priority, match.start()))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if candidates:
|
||||
# Sort by priority (desc), then by position (later is usually total)
|
||||
candidates.sort(key=lambda x: (-x[1], -x[2]))
|
||||
best_amount = candidates[0][0]
|
||||
return NormalizationResult.success(f"{best_amount:.2f}")
|
||||
|
||||
# Strategy 3: Parse with shared validator
|
||||
cleaned = TextCleaner.normalize_amount_text(corrected_text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Strategy 4: Try to extract any decimal number as fallback
|
||||
decimal_pattern = r"(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})"
|
||||
matches = re.findall(decimal_pattern, corrected_text)
|
||||
if matches:
|
||||
# Clean and parse each match
|
||||
amounts: list[float] = []
|
||||
for m in matches:
|
||||
cleaned_m = m.replace(" ", "").replace(".", "").replace(",", ".")
|
||||
# Handle Swedish format: "1 234,56" -> "1234.56"
|
||||
if "," in m and "." not in m:
|
||||
cleaned_m = m.replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amt = float(cleaned_m)
|
||||
if 0 < amt < 10_000_000:
|
||||
amounts.append(amt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if amounts:
|
||||
# Return the last/largest amount (usually the total)
|
||||
return NormalizationResult.success(f"{max(amounts):.2f}")
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse amount: {text[:50]}")
|
||||
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Bankgiro Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish Bankgiro numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class BankgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish Bankgiro numbers.
|
||||
|
||||
Bankgiro rules:
|
||||
- 7 or 8 digits only
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
|
||||
|
||||
Display pattern: ^\\d{3,4}-\\d{4}$
|
||||
Normalized pattern: ^\\d{7,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Bankgiro"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
|
||||
# This distinguishes BG from PG which uses X-X format (digits-single digit)
|
||||
bg_matches = re.findall(r"(\d{3,4})-(\d{4})", text)
|
||||
|
||||
if bg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in bg_matches:
|
||||
digits = match[0] + match[1]
|
||||
if len(digits) in (7, 8) and FieldValidators.luhn_checksum(digits):
|
||||
# Valid BG found
|
||||
formatted = self._format_bankgiro(digits)
|
||||
return NormalizationResult.success(formatted)
|
||||
|
||||
# No valid Luhn, use first match
|
||||
digits = bg_matches[0][0] + bg_matches[0][1]
|
||||
if len(digits) in (7, 8):
|
||||
formatted = self._format_bankgiro(digits)
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# Fallback: try to find 7-8 consecutive digits
|
||||
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
|
||||
# to avoid misinterpreting PG as BG
|
||||
pg_format_present = re.search(r"(?<![0-9])\d{1,7}-\d(?!\d)", text)
|
||||
if pg_format_present:
|
||||
return NormalizationResult.failure("No valid Bankgiro found in text")
|
||||
|
||||
digit_match = re.search(r"\b(\d{7,8})\b", text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
if len(digits) in (7, 8):
|
||||
formatted = self._format_bankgiro(digits)
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure("No valid Bankgiro found in text")
|
||||
|
||||
@staticmethod
|
||||
def _format_bankgiro(digits: str) -> str:
|
||||
"""Format Bankgiro number with dash."""
|
||||
if len(digits) == 8:
|
||||
return f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
return f"{digits[:3]}-{digits[3:]}"
|
||||
71
packages/inference/inference/pipeline/normalizers/base.py
Normal file
71
packages/inference/inference/pipeline/normalizers/base.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Base Normalizer Interface
|
||||
|
||||
Defines the contract for all field normalizers.
|
||||
Each normalizer handles a specific field type's normalization and validation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizationResult:
|
||||
"""Result of a normalization operation."""
|
||||
|
||||
value: str | None
|
||||
is_valid: bool
|
||||
error: str | None = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, value: str) -> "NormalizationResult":
|
||||
"""Create a successful result."""
|
||||
return cls(value=value, is_valid=True, error=None)
|
||||
|
||||
@classmethod
|
||||
def success_with_warning(cls, value: str, warning: str) -> "NormalizationResult":
|
||||
"""Create a successful result with a warning."""
|
||||
return cls(value=value, is_valid=True, error=warning)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "NormalizationResult":
|
||||
"""Create a failed result."""
|
||||
return cls(value=None, is_valid=False, error=error)
|
||||
|
||||
def to_tuple(self) -> tuple[str | None, bool, str | None]:
|
||||
"""Convert to legacy tuple format for backward compatibility."""
|
||||
return (self.value, self.is_valid, self.error)
|
||||
|
||||
|
||||
class BaseNormalizer(ABC):
|
||||
"""
|
||||
Abstract base class for field normalizers.
|
||||
|
||||
Each normalizer is responsible for:
|
||||
1. Cleaning and normalizing raw text
|
||||
2. Validating the normalized value
|
||||
3. Returning a standardized result
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def field_name(self) -> str:
|
||||
"""The field name this normalizer handles."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
"""
|
||||
Normalize and validate the input text.
|
||||
|
||||
Args:
|
||||
text: Raw text to normalize
|
||||
|
||||
Returns:
|
||||
NormalizationResult with normalized value or error
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, text: str) -> NormalizationResult:
|
||||
"""Allow using the normalizer as a callable."""
|
||||
return self.normalize(text)
|
||||
200
packages/inference/inference/pipeline/normalizers/date.py
Normal file
200
packages/inference/inference/pipeline/normalizers/date.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Date Normalizer
|
||||
|
||||
Handles normalization and validation of invoice dates.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class DateNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes dates from Swedish invoices.
|
||||
|
||||
Handles various date formats:
|
||||
- 2025-08-29 (ISO format)
|
||||
- 2025.08.29 (dot separator)
|
||||
- 29/08/2025 (European format)
|
||||
- 29.08.2025 (European with dots)
|
||||
- 20250829 (compact format)
|
||||
|
||||
Output format: YYYY-MM-DD (ISO 8601)
|
||||
"""
|
||||
|
||||
# Date patterns with their parsing logic
|
||||
PATTERNS = [
|
||||
# ISO format: 2025-08-29
|
||||
(
|
||||
r"(\d{4})-(\d{1,2})-(\d{1,2})",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
# Dot format: 2025.08.29 (common in Swedish)
|
||||
(
|
||||
r"(\d{4})\.(\d{1,2})\.(\d{1,2})",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
# European slash format: 29/08/2025
|
||||
(
|
||||
r"(\d{1,2})/(\d{1,2})/(\d{4})",
|
||||
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
|
||||
),
|
||||
# European dot format: 29.08.2025
|
||||
(
|
||||
r"(\d{1,2})\.(\d{1,2})\.(\d{4})",
|
||||
lambda m: (int(m.group(3)), int(m.group(2)), int(m.group(1))),
|
||||
),
|
||||
# Compact format: 20250829
|
||||
(
|
||||
r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)",
|
||||
lambda m: (int(m.group(1)), int(m.group(2)), int(m.group(3))),
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Date"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# First, try using shared validator
|
||||
iso_date = FieldValidators.format_date_iso(text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return NormalizationResult.success(iso_date)
|
||||
|
||||
# Fallback: try original patterns for edge cases
|
||||
for pattern, extractor in self.PATTERNS:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
try:
|
||||
year, month, day = extractor(match)
|
||||
# Validate date
|
||||
parsed_date = datetime(year, month, day)
|
||||
# Sanity check: year should be reasonable (2000-2100)
|
||||
if 2000 <= parsed_date.year <= 2100:
|
||||
return NormalizationResult.success(
|
||||
parsed_date.strftime("%Y-%m-%d")
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse date: {text}")
|
||||
|
||||
|
||||
class EnhancedDateNormalizer(DateNormalizer):
|
||||
"""
|
||||
Enhanced date parsing with comprehensive format support.
|
||||
|
||||
Additional support for:
|
||||
- Swedish text: "29 december 2024", "29 dec 2024"
|
||||
- OCR error correction: 2O24-12-29 -> 2024-12-29
|
||||
"""
|
||||
|
||||
# Swedish month names
|
||||
SWEDISH_MONTHS = {
|
||||
"januari": 1,
|
||||
"jan": 1,
|
||||
"februari": 2,
|
||||
"feb": 2,
|
||||
"mars": 3,
|
||||
"mar": 3,
|
||||
"april": 4,
|
||||
"apr": 4,
|
||||
"maj": 5,
|
||||
"juni": 6,
|
||||
"jun": 6,
|
||||
"juli": 7,
|
||||
"jul": 7,
|
||||
"augusti": 8,
|
||||
"aug": 8,
|
||||
"september": 9,
|
||||
"sep": 9,
|
||||
"sept": 9,
|
||||
"oktober": 10,
|
||||
"okt": 10,
|
||||
"november": 11,
|
||||
"nov": 11,
|
||||
"december": 12,
|
||||
"dec": 12,
|
||||
}
|
||||
|
||||
# Extended patterns
|
||||
EXTENDED_PATTERNS = [
|
||||
# ISO format: 2025-08-29, 2025/08/29
|
||||
("ymd", r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"),
|
||||
# Dot format: 2025.08.29
|
||||
("ymd", r"(\d{4})\.(\d{1,2})\.(\d{1,2})"),
|
||||
# European slash: 29/08/2025
|
||||
("dmy", r"(\d{1,2})/(\d{1,2})/(\d{4})"),
|
||||
# European dot: 29.08.2025
|
||||
("dmy", r"(\d{1,2})\.(\d{1,2})\.(\d{4})"),
|
||||
# European dash: 29-08-2025
|
||||
("dmy", r"(\d{1,2})-(\d{1,2})-(\d{4})"),
|
||||
# Compact: 20250829
|
||||
("ymd_compact", r"(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)"),
|
||||
]
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Try shared validator first
|
||||
iso_date = FieldValidators.format_date_iso(corrected_text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return NormalizationResult.success(iso_date)
|
||||
|
||||
# Try Swedish text date pattern: "29 december 2024" or "29 dec 2024"
|
||||
swedish_pattern = r"(\d{1,2})\s+([a-z\u00e5\u00e4\u00f6]+)\s+(\d{4})"
|
||||
match = re.search(swedish_pattern, corrected_text.lower())
|
||||
if match:
|
||||
day = int(match.group(1))
|
||||
month_name = match.group(2)
|
||||
year = int(match.group(3))
|
||||
if month_name in self.SWEDISH_MONTHS:
|
||||
month = self.SWEDISH_MONTHS[month_name]
|
||||
try:
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Extended patterns
|
||||
for fmt, pattern in self.EXTENDED_PATTERNS:
|
||||
match = re.search(pattern, corrected_text)
|
||||
if match:
|
||||
try:
|
||||
if fmt == "ymd":
|
||||
year = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
day = int(match.group(3))
|
||||
elif fmt == "dmy":
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
elif fmt == "ymd_compact":
|
||||
year = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
day = int(match.group(3))
|
||||
else:
|
||||
continue
|
||||
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return NormalizationResult.success(dt.strftime("%Y-%m-%d"))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return NormalizationResult.failure(f"Cannot parse date: {text[:50]}")
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Invoice Number Normalizer
|
||||
|
||||
Handles normalization and validation of invoice numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class InvoiceNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes invoice numbers from Swedish invoices.
|
||||
|
||||
Invoice numbers can be:
|
||||
- Pure digits: 12345678
|
||||
- Alphanumeric: A3861, INV-2024-001, F12345
|
||||
- With separators: 2024/001, 2024-001
|
||||
|
||||
Strategy:
|
||||
1. Look for common invoice number patterns
|
||||
2. Prefer shorter, more specific matches over long digit sequences
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "InvoiceNumber"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
|
||||
# Examples: A3861, F12345, INV001
|
||||
alpha_patterns = [
|
||||
r"\b([A-Z]{1,3}\d{3,10})\b", # A3861, INV12345
|
||||
r"\b(\d{3,10}[A-Z]{1,3})\b", # 12345A
|
||||
r"\b([A-Z]{2,5}[-/]?\d{3,10})\b", # INV-12345, FAK12345
|
||||
]
|
||||
|
||||
for pattern in alpha_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return NormalizationResult.success(match.group(1).upper())
|
||||
|
||||
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
|
||||
year_pattern = r"\b(20\d{2}[-/]\d{3,8})\b"
|
||||
match = re.search(year_pattern, text)
|
||||
if match:
|
||||
return NormalizationResult.success(match.group(1))
|
||||
|
||||
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
|
||||
# This avoids capturing long OCR numbers
|
||||
digit_sequences = re.findall(r"\b(\d{3,10})\b", text)
|
||||
if digit_sequences:
|
||||
# Prefer shorter sequences (more likely to be invoice number)
|
||||
# Also filter out sequences that look like dates (8 digits starting with 20)
|
||||
valid_sequences = []
|
||||
for seq in digit_sequences:
|
||||
# Skip if it looks like a date (YYYYMMDD)
|
||||
if len(seq) == 8 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip if too long (likely OCR number)
|
||||
if len(seq) > 10:
|
||||
continue
|
||||
valid_sequences.append(seq)
|
||||
|
||||
if valid_sequences:
|
||||
# Return shortest valid sequence
|
||||
return NormalizationResult.success(min(valid_sequences, key=len))
|
||||
|
||||
# Fallback: extract all digits if nothing else works
|
||||
digits = re.sub(r"\D", "", text)
|
||||
if len(digits) >= 3:
|
||||
# Limit to first 15 digits to avoid very long sequences
|
||||
return NormalizationResult.success_with_warning(
|
||||
digits[:15], "Fallback extraction"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure(
|
||||
f"Cannot extract invoice number from: {text[:50]}"
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
OCR Number Normalizer
|
||||
|
||||
Handles normalization and validation of OCR reference numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class OcrNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes OCR (Optical Character Recognition) reference numbers.
|
||||
|
||||
OCR numbers in Swedish payment systems:
|
||||
- Minimum 5 digits
|
||||
- Used for automated payment matching
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "OCR"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
if len(digits) < 5:
|
||||
return NormalizationResult.failure(
|
||||
f"Too few digits for OCR: {len(digits)}"
|
||||
)
|
||||
|
||||
return NormalizationResult.success(digits)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Plusgiro Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish Plusgiro numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from shared.utils.validators import FieldValidators
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class PlusgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish Plusgiro numbers.
|
||||
|
||||
Plusgiro rules:
|
||||
- 2 to 8 digits
|
||||
- Last digit is Luhn (Mod10) check digit
|
||||
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
|
||||
|
||||
Display pattern: ^\\d{1,7}-\\d$
|
||||
Normalized pattern: ^\\d{2,8}$
|
||||
|
||||
Note: Text may contain both BG and PG numbers. We specifically look for
|
||||
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Plusgiro"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
|
||||
# This is distinct from BG format which has 4 digits after the dash
|
||||
# Pattern allows spaces within the number like "486 98 63-6"
|
||||
# (?<![0-9]) ensures we don't start from within another number (like BG)
|
||||
pg_matches = re.findall(r"(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)", text)
|
||||
|
||||
if pg_matches:
|
||||
# Try each match and find one with valid Luhn
|
||||
for match in pg_matches:
|
||||
# Remove spaces from the first part
|
||||
digits = re.sub(r"\s", "", match[0]) + match[1]
|
||||
if 2 <= len(digits) <= 8 and FieldValidators.luhn_checksum(digits):
|
||||
# Valid PG found
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return NormalizationResult.success(formatted)
|
||||
|
||||
# No valid Luhn, use first match with most digits
|
||||
best_match = max(pg_matches, key=lambda m: len(re.sub(r"\s", "", m[0])))
|
||||
digits = re.sub(r"\s", "", best_match[0]) + best_match[1]
|
||||
if 2 <= len(digits) <= 8:
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# If no PG format found, extract all digits and format as PG
|
||||
# This handles cases where the number might be in BG format or raw digits
|
||||
all_digits = re.sub(r"\D", "", text)
|
||||
|
||||
# Try to find a valid 2-8 digit sequence
|
||||
if 2 <= len(all_digits) <= 8:
|
||||
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
|
||||
if FieldValidators.luhn_checksum(all_digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
# Try to find any 2-8 digit sequence in text
|
||||
digit_match = re.search(r"\b(\d{2,8})\b", text)
|
||||
if digit_match:
|
||||
digits = digit_match.group(1)
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return NormalizationResult.success(formatted)
|
||||
else:
|
||||
return NormalizationResult.success_with_warning(
|
||||
formatted, "Luhn checksum failed (possible OCR error)"
|
||||
)
|
||||
|
||||
return NormalizationResult.failure("No valid Plusgiro found in text")
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Supplier Organization Number Normalizer
|
||||
|
||||
Handles normalization and validation of Swedish organization numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseNormalizer, NormalizationResult
|
||||
|
||||
|
||||
class SupplierOrgNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish supplier organization numbers.
|
||||
|
||||
Extracts organization number in format: NNNNNN-NNNN (10 digits)
|
||||
Also handles VAT numbers: SE + 10 digits + 01
|
||||
|
||||
Examples:
|
||||
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
|
||||
'Momsreg.nr SE556123456701' -> '556123-4567'
|
||||
"""
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "supplier_org_number"
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Pattern 1: Standard org number format: NNNNNN-NNNN
|
||||
org_pattern = r"\b(\d{6})-?(\d{4})\b"
|
||||
match = re.search(org_pattern, text)
|
||||
if match:
|
||||
org_num = f"{match.group(1)}-{match.group(2)}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
# Pattern 2: VAT number format: SE + 10 digits + 01
|
||||
vat_pattern = r"SE\s*(\d{10})01"
|
||||
match = re.search(vat_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
# Pattern 3: Just 10 consecutive digits
|
||||
digits_pattern = r"\b(\d{10})\b"
|
||||
match = re.search(digits_pattern, text)
|
||||
if match:
|
||||
digits = match.group(1)
|
||||
# Validate: first digit should be 1-9 for Swedish org numbers
|
||||
if digits[0] in "123456789":
|
||||
org_num = f"{digits[:6]}-{digits[6:]}"
|
||||
return NormalizationResult.success(org_num)
|
||||
|
||||
return NormalizationResult.failure(
|
||||
f"Cannot extract org number from: {text[:100]}"
|
||||
)
|
||||
@@ -9,12 +9,12 @@ import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from inference.web.core.auth import AdminTokenDep
|
||||
from inference.web.services.autolabel import get_auto_label_service
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
@@ -36,6 +36,31 @@ from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global repository instances
|
||||
_doc_repo: DocumentRepository | None = None
|
||||
_ann_repo: AnnotationRepository | None = None
|
||||
|
||||
|
||||
def get_doc_repository() -> DocumentRepository:
|
||||
"""Get the DocumentRepository instance."""
|
||||
global _doc_repo
|
||||
if _doc_repo is None:
|
||||
_doc_repo = DocumentRepository()
|
||||
return _doc_repo
|
||||
|
||||
|
||||
def get_ann_repository() -> AnnotationRepository:
|
||||
"""Get the AnnotationRepository instance."""
|
||||
global _ann_repo
|
||||
if _ann_repo is None:
|
||||
_ann_repo = AnnotationRepository()
|
||||
return _ann_repo
|
||||
|
||||
|
||||
# Type aliases for dependency injection
|
||||
DocRepoDep = Annotated[DocumentRepository, Depends(get_doc_repository)]
|
||||
AnnRepoDep = Annotated[AnnotationRepository, Depends(get_ann_repository)]
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
@@ -71,17 +96,17 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
) -> FileResponse | StreamingResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
@@ -137,7 +162,8 @@ def create_annotation_router() -> APIRouter:
|
||||
async def list_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
page_number: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Filter by page number"),
|
||||
@@ -146,16 +172,16 @@ def create_annotation_router() -> APIRouter:
|
||||
"""List annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id, page_number)
|
||||
raw_annotations = ann_repo.get_for_document(document_id, page_number)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
@@ -204,17 +230,18 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id: str,
|
||||
request: AnnotationCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Create a new annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
@@ -244,7 +271,7 @@ def create_annotation_router() -> APIRouter:
|
||||
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
|
||||
|
||||
# Create annotation
|
||||
annotation_id = db.create_annotation(
|
||||
annotation_id = ann_repo.create(
|
||||
document_id=document_id,
|
||||
page_number=request.page_number,
|
||||
class_id=request.class_id,
|
||||
@@ -285,22 +312,23 @@ def create_annotation_router() -> APIRouter:
|
||||
annotation_id: str,
|
||||
request: AnnotationUpdate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Update an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
annotation = ann_repo.get(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -349,7 +377,7 @@ def create_annotation_router() -> APIRouter:
|
||||
|
||||
# Update annotation
|
||||
if update_kwargs:
|
||||
success = db.update_annotation(annotation_id, **update_kwargs)
|
||||
success = ann_repo.update(annotation_id, **update_kwargs)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -374,22 +402,23 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> dict:
|
||||
"""Delete an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
annotation = ann_repo.get(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -404,7 +433,7 @@ def create_annotation_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Delete annotation
|
||||
db.delete_annotation(annotation_id)
|
||||
ann_repo.delete(annotation_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
@@ -431,17 +460,18 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id: str,
|
||||
request: AutoLabelRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AutoLabelResponse:
|
||||
"""Trigger auto-labeling for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Validate field values
|
||||
@@ -457,7 +487,8 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id=document_id,
|
||||
file_path=document.file_path,
|
||||
field_values=request.field_values,
|
||||
db=db,
|
||||
doc_repo=doc_repo,
|
||||
ann_repo=ann_repo,
|
||||
replace_existing=request.replace_existing,
|
||||
)
|
||||
|
||||
@@ -486,7 +517,8 @@ def create_annotation_router() -> APIRouter:
|
||||
async def delete_all_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by source (manual, auto, imported)"),
|
||||
@@ -502,21 +534,21 @@ def create_annotation_router() -> APIRouter:
|
||||
detail=f"Invalid source: {source}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Delete annotations
|
||||
deleted_count = db.delete_annotations_for_document(document_id, source)
|
||||
deleted_count = ann_repo.delete_for_document(document_id, source)
|
||||
|
||||
# Update document status if all annotations deleted
|
||||
remaining = db.get_annotations_for_document(document_id)
|
||||
remaining = ann_repo.get_for_document(document_id)
|
||||
if not remaining:
|
||||
db.update_document_status(document_id, "pending")
|
||||
doc_repo.update_status(document_id, "pending")
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
@@ -543,23 +575,24 @@ def create_annotation_router() -> APIRouter:
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
|
||||
) -> AnnotationVerifyResponse:
|
||||
"""Verify an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Verify the annotation
|
||||
annotation = db.verify_annotation(annotation_id, admin_token)
|
||||
annotation = ann_repo.verify(annotation_id, admin_token)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -589,18 +622,19 @@ def create_annotation_router() -> APIRouter:
|
||||
annotation_id: str,
|
||||
request: AnnotationOverrideRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
doc_repo: DocRepoDep,
|
||||
ann_repo: AnnRepoDep,
|
||||
) -> AnnotationOverrideResponse:
|
||||
"""Override an auto-generated annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
# Get document
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Build updates dict from request
|
||||
@@ -632,7 +666,7 @@ def create_annotation_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Override the annotation
|
||||
annotation = db.override_annotation(
|
||||
annotation = ann_repo.override(
|
||||
annotation_id=annotation_id,
|
||||
admin_token=admin_token,
|
||||
change_reason=request.reason,
|
||||
@@ -646,7 +680,7 @@ def create_annotation_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Get history to return history_id
|
||||
history_records = db.get_annotation_history(UUID(annotation_id))
|
||||
history_records = ann_repo.get_history(UUID(annotation_id))
|
||||
latest_history = history_records[0] if history_records else None
|
||||
|
||||
return AnnotationOverrideResponse(
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Augmentation API routes."""
|
||||
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminDBDep, AdminTokenDep
|
||||
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep, DatasetRepoDep
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchRequest,
|
||||
AugmentationBatchResponse,
|
||||
@@ -13,7 +11,6 @@ from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationPreviewResponse,
|
||||
AugmentationTypeInfo,
|
||||
AugmentationTypesResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
PresetInfo,
|
||||
PresetsResponse,
|
||||
@@ -78,7 +75,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
document_id: str,
|
||||
request: AugmentationPreviewRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
@@ -88,7 +85,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
service = AugmentationService(doc_repo=docs)
|
||||
return await service.preview_single(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
@@ -105,13 +102,13 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
document_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""Preview complete augmentation pipeline on a document page."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
service = AugmentationService(doc_repo=docs)
|
||||
return await service.preview_config(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
@@ -126,7 +123,8 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
async def create_augmented_dataset(
|
||||
request: AugmentationBatchRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
datasets: DatasetRepoDep,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
@@ -136,7 +134,7 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
service = AugmentationService(doc_repo=docs, dataset_repo=datasets)
|
||||
return await service.create_augmented_dataset(
|
||||
source_dataset_id=request.dataset_id,
|
||||
config=request.config,
|
||||
@@ -151,12 +149,12 @@ def register_augmentation_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def list_augmented_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets: DatasetRepoDep,
|
||||
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
|
||||
offset: int = Query(default=0, ge=0, description="Offset"),
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""List all augmented datasets."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
service = AugmentationService(dataset_repo=datasets)
|
||||
return await service.list_augmented_datasets(limit=limit, offset=offset)
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import AdminTokenDep, TokenRepoDep
|
||||
from inference.web.schemas.admin import (
|
||||
AdminTokenCreate,
|
||||
AdminTokenResponse,
|
||||
@@ -35,7 +35,7 @@ def create_auth_router() -> APIRouter:
|
||||
)
|
||||
async def create_token(
|
||||
request: AdminTokenCreate,
|
||||
db: AdminDBDep,
|
||||
tokens: TokenRepoDep,
|
||||
) -> AdminTokenResponse:
|
||||
"""Create a new admin token."""
|
||||
# Generate secure token
|
||||
@@ -47,7 +47,7 @@ def create_auth_router() -> APIRouter:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
|
||||
|
||||
# Create token in database
|
||||
db.create_admin_token(
|
||||
tokens.create(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
@@ -70,10 +70,10 @@ def create_auth_router() -> APIRouter:
|
||||
)
|
||||
async def revoke_token(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tokens: TokenRepoDep,
|
||||
) -> dict:
|
||||
"""Revoke the current admin token."""
|
||||
db.deactivate_admin_token(admin_token)
|
||||
tokens.deactivate(admin_token)
|
||||
return {
|
||||
"status": "revoked",
|
||||
"message": "Admin token has been revoked",
|
||||
|
||||
@@ -12,7 +12,12 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
@@ -87,7 +92,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
async def upload_document(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
file: UploadFile = File(..., description="PDF or image file"),
|
||||
auto_label: Annotated[
|
||||
bool,
|
||||
@@ -142,7 +147,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record (token only used for auth, not stored)
|
||||
document_id = db.create_document(
|
||||
document_id = docs.create(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
@@ -184,7 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
auto_label_started = False
|
||||
if auto_label:
|
||||
# Auto-labeling will be triggered by a background task
|
||||
db.update_document_status(
|
||||
docs.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
@@ -214,7 +219,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
async def list_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
@@ -270,7 +276,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
detail=f"Invalid auto_label_status: {auto_label_status}",
|
||||
)
|
||||
|
||||
documents, total = db.get_documents_by_token(
|
||||
documents, total = docs.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
upload_source=upload_source,
|
||||
@@ -285,7 +291,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
# Get annotation counts and build items
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
@@ -301,7 +307,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
page_count=doc.page_count,
|
||||
status=DocumentStatus(doc.status),
|
||||
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
||||
annotation_count=len(annotations),
|
||||
annotation_count=len(doc_annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
@@ -330,10 +336,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
async def get_document_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> DocumentStatsResponse:
|
||||
"""Get document statistics."""
|
||||
counts = db.count_documents_by_status(admin_token)
|
||||
counts = docs.count_by_status(admin_token)
|
||||
|
||||
return DocumentStatsResponse(
|
||||
total=sum(counts.values()),
|
||||
@@ -343,6 +349,26 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
exported=counts.get("exported", 0),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
response_model=DocumentCategoriesResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get available categories",
|
||||
description="Get list of all available document categories.",
|
||||
)
|
||||
async def get_categories(
|
||||
admin_token: AdminTokenDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> DocumentCategoriesResponse:
|
||||
"""Get all available document categories."""
|
||||
categories = docs.get_categories()
|
||||
return DocumentCategoriesResponse(
|
||||
categories=categories,
|
||||
total=len(categories),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{document_id}",
|
||||
response_model=DocumentDetailResponse,
|
||||
@@ -356,12 +382,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> DocumentDetailResponse:
|
||||
"""Get document details."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -369,8 +397,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id)
|
||||
annotations = [
|
||||
raw_annotations = annotations.get_for_document(document_id)
|
||||
annotation_items = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
@@ -416,10 +444,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
|
||||
# Get training history (Phase 5)
|
||||
training_history = []
|
||||
training_links = db.get_document_training_tasks(document.document_id)
|
||||
training_links = tasks.get_document_training_tasks(document.document_id)
|
||||
for link in training_links:
|
||||
# Get task details
|
||||
task = db.get_training_task(str(link.task_id))
|
||||
task = tasks.get(str(link.task_id))
|
||||
if task:
|
||||
# Build metrics
|
||||
metrics = None
|
||||
@@ -455,7 +483,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
annotations=annotations,
|
||||
annotations=annotation_items,
|
||||
image_urls=image_urls,
|
||||
training_history=training_history,
|
||||
created_at=document.created_at,
|
||||
@@ -474,13 +502,13 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -505,7 +533,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
logger.warning(f"Failed to delete admin images: {e}")
|
||||
|
||||
# Delete from database
|
||||
db.delete_document(document_id)
|
||||
docs.delete(document_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
@@ -525,7 +553,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
async def update_document_status(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
status: Annotated[
|
||||
str,
|
||||
Query(description="New status"),
|
||||
@@ -547,7 +576,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -560,16 +589,15 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
|
||||
# Get all annotations for this document
|
||||
annotations = db.get_annotations_for_document(document_id)
|
||||
doc_annotations = annotations.get_for_document(document_id)
|
||||
|
||||
if annotations:
|
||||
if doc_annotations:
|
||||
db_save_result = save_manual_annotations_to_document_db(
|
||||
document=document,
|
||||
annotations=annotations,
|
||||
db=db,
|
||||
annotations=doc_annotations,
|
||||
)
|
||||
|
||||
db.update_document_status(document_id, status)
|
||||
docs.update_status(document_id, status)
|
||||
|
||||
response = {
|
||||
"status": "updated",
|
||||
@@ -597,7 +625,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
async def update_document_group_key(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="New group key (null to clear)"),
|
||||
@@ -614,7 +642,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -622,7 +650,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
)
|
||||
|
||||
# Update group key
|
||||
db.update_document_group_key(document_id, group_key)
|
||||
docs.update_group_key(document_id, group_key)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
@@ -631,26 +659,6 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
response_model=DocumentCategoriesResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get available categories",
|
||||
description="Get list of all available document categories.",
|
||||
)
|
||||
async def get_categories(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentCategoriesResponse:
|
||||
"""Get all available document categories."""
|
||||
categories = db.get_document_categories()
|
||||
return DocumentCategoriesResponse(
|
||||
categories=categories,
|
||||
total=len(categories),
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/category",
|
||||
responses={
|
||||
@@ -663,14 +671,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
async def update_document_category(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: DocumentUpdateRequest,
|
||||
) -> dict:
|
||||
"""Update document category."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -679,7 +687,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
|
||||
# Update category if provided
|
||||
if request.category is not None:
|
||||
db.update_document_category(document_id, request.category)
|
||||
docs.update_category(document_id, request.category)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
|
||||
@@ -4,21 +4,18 @@ Admin Document Lock Routes
|
||||
FastAPI endpoints for annotation lock management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationLockRequest,
|
||||
AnnotationLockResponse,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
@@ -49,14 +46,14 @@ def create_locks_router() -> APIRouter:
|
||||
async def acquire_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Acquire annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -64,7 +61,7 @@ def create_locks_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Attempt to acquire lock
|
||||
updated_doc = db.acquire_annotation_lock(
|
||||
updated_doc = docs.acquire_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
duration_seconds=request.duration_seconds,
|
||||
@@ -96,7 +93,7 @@ def create_locks_router() -> APIRouter:
|
||||
async def release_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(description="Force release (admin override)"),
|
||||
@@ -106,7 +103,7 @@ def create_locks_router() -> APIRouter:
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -114,7 +111,7 @@ def create_locks_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Release lock
|
||||
updated_doc = db.release_annotation_lock(
|
||||
updated_doc = docs.release_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
force=force,
|
||||
@@ -147,14 +144,14 @@ def create_locks_router() -> APIRouter:
|
||||
async def extend_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Extend annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
document = docs.get_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -162,7 +159,7 @@ def create_locks_router() -> APIRouter:
|
||||
)
|
||||
|
||||
# Attempt to extend lock
|
||||
updated_doc = db.extend_annotation_lock(
|
||||
updated_doc = docs.extend_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
additional_seconds=request.duration_seconds,
|
||||
|
||||
@@ -5,7 +5,14 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DatasetRepoDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
ModelVersionRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from inference.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
@@ -36,7 +43,9 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
async def create_dataset(
|
||||
request: DatasetCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets: DatasetRepoDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
@@ -48,7 +57,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
|
||||
)
|
||||
|
||||
dataset = db.create_dataset(
|
||||
dataset = datasets.create(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
train_ratio=request.train_ratio,
|
||||
@@ -67,7 +76,12 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
detail="Storage not configured for local access",
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=datasets,
|
||||
documents_repo=docs,
|
||||
annotations_repo=annotations,
|
||||
base_dir=datasets_dir,
|
||||
)
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
@@ -94,18 +108,18 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def list_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||
datasets_list, total = datasets_repo.get_paginated(status=status, limit=limit, offset=offset)
|
||||
|
||||
# Get active training tasks for each dataset (graceful degradation on error)
|
||||
dataset_ids = [str(d.dataset_id) for d in datasets]
|
||||
dataset_ids = [str(d.dataset_id) for d in datasets_list]
|
||||
try:
|
||||
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
|
||||
active_tasks = datasets_repo.get_active_training_tasks(dataset_ids)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch active training tasks")
|
||||
active_tasks = {}
|
||||
@@ -127,7 +141,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
total_annotations=d.total_annotations,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in datasets
|
||||
for d in datasets_list
|
||||
],
|
||||
)
|
||||
|
||||
@@ -139,15 +153,15 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
) -> DatasetDetailResponse:
|
||||
"""Get dataset details with document list."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
docs = db.get_dataset_documents(dataset_id)
|
||||
docs = datasets_repo.get_documents(dataset_id)
|
||||
return DatasetDetailResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
@@ -187,14 +201,14 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a dataset and its files."""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
@@ -203,7 +217,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
if dataset_dir.exists():
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
db.delete_dataset(dataset_id)
|
||||
datasets_repo.delete(dataset_id)
|
||||
return {"message": "Dataset deleted"}
|
||||
|
||||
@router.post(
|
||||
@@ -216,7 +230,9 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
dataset_id: str,
|
||||
request: DatasetTrainRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
datasets_repo: DatasetRepoDep,
|
||||
models: ModelVersionRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset.
|
||||
|
||||
@@ -224,7 +240,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
The training will use that model as the starting point instead of a pretrained model.
|
||||
"""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
dataset = datasets_repo.get(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
if dataset.status != "ready":
|
||||
@@ -239,7 +255,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
base_model_version_id = config_dict.get("base_model_version_id")
|
||||
if base_model_version_id:
|
||||
_validate_uuid(base_model_version_id, "base_model_version_id")
|
||||
base_model = db.get_model_version(base_model_version_id)
|
||||
base_model = models.get(base_model_version_id)
|
||||
if not base_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -254,7 +270,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
base_model.model_path,
|
||||
)
|
||||
|
||||
task_id = db.create_training_task(
|
||||
task_id = tasks.create(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="finetune" if base_model_version_id else "train",
|
||||
|
||||
@@ -5,7 +5,12 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from inference.web.schemas.admin import (
|
||||
ModelMetrics,
|
||||
TrainingDocumentItem,
|
||||
@@ -35,7 +40,9 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
@@ -58,7 +65,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
documents, total = db.get_documents_for_training(
|
||||
documents, total = docs.get_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
@@ -70,21 +77,21 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in annotations:
|
||||
for ann in doc_annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
training_links = db.get_document_training_tasks(doc.document_id)
|
||||
training_links = tasks.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(annotations),
|
||||
annotation_count=len(doc_annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
@@ -110,7 +117,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
@@ -118,7 +125,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -155,7 +162,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def get_completed_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks_repo: TrainingTaskRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
@@ -170,7 +177,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
task_list, total = tasks_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
@@ -178,7 +185,7 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
)
|
||||
|
||||
items = []
|
||||
for task in tasks:
|
||||
for task in task_list:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import AdminTokenDep, DocumentRepoDep, AnnotationRepoDep
|
||||
from inference.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
@@ -31,7 +31,8 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
docs: DocumentRepoDep,
|
||||
annotations: AnnotationRepoDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
@@ -45,7 +46,7 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
documents = db.get_labeled_documents_for_export(admin_token)
|
||||
documents = docs.get_labeled_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
@@ -78,13 +79,13 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
doc_annotations = annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
if not doc_annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
page_annotations = [a for a in doc_annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import AdminTokenDep, ModelVersionRepoDep
|
||||
from inference.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
@@ -33,7 +33,7 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
async def create_model_version(
|
||||
request: ModelVersionCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Create a new model version."""
|
||||
if request.task_id:
|
||||
@@ -41,7 +41,7 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
if request.dataset_id:
|
||||
_validate_uuid(request.dataset_id, "dataset_id")
|
||||
|
||||
model = db.create_model_version(
|
||||
model = models.create(
|
||||
version=request.version,
|
||||
name=request.name,
|
||||
model_path=request.model_path,
|
||||
@@ -70,13 +70,13 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def list_model_versions(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> ModelVersionListResponse:
|
||||
"""List model versions with optional status filter."""
|
||||
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
|
||||
model_list, total = models.get_paginated(status=status, limit=limit, offset=offset)
|
||||
return ModelVersionListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
@@ -94,7 +94,7 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
activated_at=m.activated_at,
|
||||
created_at=m.created_at,
|
||||
)
|
||||
for m in models
|
||||
for m in model_list
|
||||
],
|
||||
)
|
||||
|
||||
@@ -106,10 +106,10 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ActiveModelResponse:
|
||||
"""Get the currently active model version."""
|
||||
model = db.get_active_model_version()
|
||||
model = models.get_active()
|
||||
if not model:
|
||||
return ActiveModelResponse(has_active_model=False, model=None)
|
||||
|
||||
@@ -137,11 +137,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
async def get_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionDetailResponse:
|
||||
"""Get detailed model version information."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.get_model_version(version_id)
|
||||
model = models.get(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
@@ -176,11 +176,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
version_id: str,
|
||||
request: ModelVersionUpdateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Update model version metadata."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.update_model_version(
|
||||
model = models.update(
|
||||
version_id=version_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
@@ -205,11 +205,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
version_id: str,
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Activate a model version for inference."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.activate_model_version(version_id)
|
||||
model = models.activate(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
@@ -242,11 +242,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
async def deactivate_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Deactivate a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.deactivate_model_version(version_id)
|
||||
model = models.deactivate(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
@@ -264,11 +264,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
async def archive_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Archive a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.archive_model_version(version_id)
|
||||
model = models.archive(version_id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -288,11 +288,11 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
async def delete_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
models: ModelVersionRepoDep,
|
||||
) -> dict:
|
||||
"""Delete a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
success = db.delete_model_version(version_id)
|
||||
success = models.delete(version_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import AdminTokenDep, TrainingTaskRepoDep
|
||||
from inference.web.schemas.admin import (
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
@@ -40,12 +40,12 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
task_id = db.create_training_task(
|
||||
task_id = tasks.create(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
@@ -73,7 +73,7 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks_repo: TrainingTaskRepoDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
@@ -95,7 +95,7 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
task_list, total = tasks_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
@@ -114,7 +114,7 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in tasks
|
||||
for task in task_list
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
@@ -137,12 +137,12 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -181,12 +181,12 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -199,7 +199,7 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
success = db.cancel_training_task(task_id)
|
||||
success = tasks.cancel(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -225,7 +225,7 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
@@ -238,14 +238,14 @@ def register_task_routes(router: APIRouter) -> None:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
task = tasks.get_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
logs = db.get_training_logs(task_id, limit, offset)
|
||||
logs = tasks.get_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
|
||||
@@ -14,13 +14,25 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.data.repositories import BatchUploadRepository
|
||||
from inference.web.core.auth import validate_admin_token
|
||||
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global repository instance
|
||||
_batch_repo: BatchUploadRepository | None = None
|
||||
|
||||
|
||||
def get_batch_repository() -> BatchUploadRepository:
|
||||
"""Get the BatchUploadRepository instance."""
|
||||
global _batch_repo
|
||||
if _batch_repo is None:
|
||||
_batch_repo = BatchUploadRepository()
|
||||
return _batch_repo
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
|
||||
|
||||
|
||||
@@ -31,7 +43,7 @@ async def upload_batch(
|
||||
async_mode: bool = Form(default=True),
|
||||
auto_label: bool = Form(default=True),
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
) -> dict:
|
||||
"""Upload a batch of documents via ZIP file.
|
||||
|
||||
@@ -119,7 +131,7 @@ async def upload_batch(
|
||||
)
|
||||
else:
|
||||
# Sync mode: Process immediately and return results
|
||||
service = BatchUploadService(admin_db)
|
||||
service = BatchUploadService(batch_repo)
|
||||
result = service.process_zip_upload(
|
||||
admin_token=admin_token,
|
||||
zip_filename=file.filename,
|
||||
@@ -148,14 +160,14 @@ async def upload_batch(
|
||||
async def get_batch_status(
|
||||
batch_id: str,
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
) -> dict:
|
||||
"""Get batch upload status and file processing details.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
batch_repo: Batch upload repository
|
||||
|
||||
Returns:
|
||||
Batch status with file processing details
|
||||
@@ -167,7 +179,7 @@ async def get_batch_status(
|
||||
raise HTTPException(status_code=400, detail="Invalid batch ID format")
|
||||
|
||||
# Check batch exists and verify ownership
|
||||
batch = admin_db.get_batch_upload(batch_uuid)
|
||||
batch = batch_repo.get(batch_uuid)
|
||||
if not batch:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
@@ -179,7 +191,7 @@ async def get_batch_status(
|
||||
)
|
||||
|
||||
# Now safe to return details
|
||||
service = BatchUploadService(admin_db)
|
||||
service = BatchUploadService(batch_repo)
|
||||
result = service.get_batch_status(batch_id)
|
||||
|
||||
return result
|
||||
@@ -188,7 +200,7 @@ async def get_batch_status(
|
||||
@router.get("/list")
|
||||
async def list_batch_uploads(
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
batch_repo: Annotated[BatchUploadRepository, Depends(get_batch_repository)] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
@@ -196,7 +208,7 @@ async def list_batch_uploads(
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
batch_repo: Batch upload repository
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
@@ -210,7 +222,7 @@ async def list_batch_uploads(
|
||||
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||
|
||||
# Get batch uploads filtered by admin token
|
||||
batches, total = admin_db.get_batch_uploads_by_token(
|
||||
batches, total = batch_repo.get_paginated(
|
||||
admin_token=admin_token,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DocumentRepository
|
||||
from inference.web.schemas.labeling import PreLabelResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
@@ -46,9 +46,9 @@ def _convert_pdf_to_images(
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get admin database instance."""
|
||||
return AdminDB()
|
||||
def get_doc_repository() -> DocumentRepository:
|
||||
"""Get document repository instance."""
|
||||
return DocumentRepository()
|
||||
|
||||
|
||||
def create_labeling_router(
|
||||
@@ -85,7 +85,7 @@ def create_labeling_router(
|
||||
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
|
||||
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
|
||||
),
|
||||
db: AdminDB = Depends(get_admin_db),
|
||||
doc_repo: DocumentRepository = Depends(get_doc_repository),
|
||||
) -> PreLabelResponse:
|
||||
"""
|
||||
Upload a document with expected field values for pre-labeling.
|
||||
@@ -149,7 +149,7 @@ def create_labeling_router(
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record with field_values
|
||||
document_id = db.create_document(
|
||||
document_id = doc_repo.create(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
@@ -172,7 +172,7 @@ def create_labeling_router(
|
||||
)
|
||||
|
||||
# Update file path in database (using storage path)
|
||||
db.update_document_file_path(document_id, storage_path)
|
||||
doc_repo.update_file_path(document_id, storage_path)
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
@@ -184,7 +184,7 @@ def create_labeling_router(
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="pending",
|
||||
|
||||
@@ -51,7 +51,7 @@ from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, st
|
||||
from inference.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import ModelVersionRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -75,8 +75,8 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
def get_active_model_path():
|
||||
"""Resolve active model path from database."""
|
||||
try:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
model_repo = ModelVersionRepository()
|
||||
active_model = model_repo.get_active()
|
||||
if active_model and active_model.model_path:
|
||||
return active_model.model_path
|
||||
except Exception as e:
|
||||
@@ -139,8 +139,7 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
|
||||
# Start batch upload queue
|
||||
try:
|
||||
admin_db = AdminDB()
|
||||
batch_service = BatchUploadService(admin_db)
|
||||
batch_service = BatchUploadService()
|
||||
init_batch_queue(batch_service)
|
||||
logger.info("Batch upload queue started")
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,7 +4,24 @@ Core Components
|
||||
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||
"""
|
||||
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_token_repository,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_dataset_repository,
|
||||
get_training_task_repository,
|
||||
get_model_version_repository,
|
||||
get_batch_upload_repository,
|
||||
AdminTokenDep,
|
||||
TokenRepoDep,
|
||||
DocumentRepoDep,
|
||||
AnnotationRepoDep,
|
||||
DatasetRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
ModelVersionRepoDep,
|
||||
BatchUploadRepoDep,
|
||||
)
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from inference.web.core.autolabel_scheduler import (
|
||||
@@ -12,12 +29,25 @@ from inference.web.core.autolabel_scheduler import (
|
||||
stop_autolabel_scheduler,
|
||||
get_autolabel_scheduler,
|
||||
)
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus, TaskManager
|
||||
|
||||
__all__ = [
|
||||
"validate_admin_token",
|
||||
"get_admin_db",
|
||||
"get_token_repository",
|
||||
"get_document_repository",
|
||||
"get_annotation_repository",
|
||||
"get_dataset_repository",
|
||||
"get_training_task_repository",
|
||||
"get_model_version_repository",
|
||||
"get_batch_upload_repository",
|
||||
"AdminTokenDep",
|
||||
"AdminDBDep",
|
||||
"TokenRepoDep",
|
||||
"DocumentRepoDep",
|
||||
"AnnotationRepoDep",
|
||||
"DatasetRepoDep",
|
||||
"TrainingTaskRepoDep",
|
||||
"ModelVersionRepoDep",
|
||||
"BatchUploadRepoDep",
|
||||
"RateLimiter",
|
||||
"start_scheduler",
|
||||
"stop_scheduler",
|
||||
@@ -25,4 +55,7 @@ __all__ = [
|
||||
"start_autolabel_scheduler",
|
||||
"stop_autolabel_scheduler",
|
||||
"get_autolabel_scheduler",
|
||||
"TaskRunner",
|
||||
"TaskStatus",
|
||||
"TaskManager",
|
||||
]
|
||||
|
||||
@@ -1,40 +1,39 @@
|
||||
"""
|
||||
Admin Authentication
|
||||
|
||||
FastAPI dependencies for admin token authentication.
|
||||
FastAPI dependencies for admin token authentication and repository access.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global AdminDB instance
|
||||
_admin_db: AdminDB | None = None
|
||||
from inference.data.repositories import (
|
||||
TokenRepository,
|
||||
DocumentRepository,
|
||||
AnnotationRepository,
|
||||
DatasetRepository,
|
||||
TrainingTaskRepository,
|
||||
ModelVersionRepository,
|
||||
BatchUploadRepository,
|
||||
)
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get the AdminDB instance."""
|
||||
global _admin_db
|
||||
if _admin_db is None:
|
||||
_admin_db = AdminDB()
|
||||
return _admin_db
|
||||
@lru_cache(maxsize=1)
|
||||
def get_token_repository() -> TokenRepository:
|
||||
"""Get the TokenRepository instance (thread-safe singleton)."""
|
||||
return TokenRepository()
|
||||
|
||||
|
||||
def reset_admin_db() -> None:
|
||||
"""Reset the AdminDB instance (for testing)."""
|
||||
global _admin_db
|
||||
_admin_db = None
|
||||
def reset_token_repository() -> None:
|
||||
"""Reset the TokenRepository instance (for testing)."""
|
||||
get_token_repository.cache_clear()
|
||||
|
||||
|
||||
async def validate_admin_token(
|
||||
x_admin_token: Annotated[str | None, Header()] = None,
|
||||
admin_db: AdminDB = Depends(get_admin_db),
|
||||
token_repo: TokenRepository = Depends(get_token_repository),
|
||||
) -> str:
|
||||
"""Validate admin token from header."""
|
||||
if not x_admin_token:
|
||||
@@ -43,18 +42,74 @@ async def validate_admin_token(
|
||||
detail="Admin token required. Provide X-Admin-Token header.",
|
||||
)
|
||||
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
if not token_repo.is_valid(x_admin_token):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or expired admin token.",
|
||||
)
|
||||
|
||||
# Update last used timestamp
|
||||
admin_db.update_admin_token_usage(x_admin_token)
|
||||
token_repo.update_usage(x_admin_token)
|
||||
|
||||
return x_admin_token
|
||||
|
||||
|
||||
# Type alias for dependency injection
|
||||
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
|
||||
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]
|
||||
TokenRepoDep = Annotated[TokenRepository, Depends(get_token_repository)]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_document_repository() -> DocumentRepository:
|
||||
"""Get the DocumentRepository instance (thread-safe singleton)."""
|
||||
return DocumentRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_annotation_repository() -> AnnotationRepository:
|
||||
"""Get the AnnotationRepository instance (thread-safe singleton)."""
|
||||
return AnnotationRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_dataset_repository() -> DatasetRepository:
|
||||
"""Get the DatasetRepository instance (thread-safe singleton)."""
|
||||
return DatasetRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_training_task_repository() -> TrainingTaskRepository:
|
||||
"""Get the TrainingTaskRepository instance (thread-safe singleton)."""
|
||||
return TrainingTaskRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_model_version_repository() -> ModelVersionRepository:
|
||||
"""Get the ModelVersionRepository instance (thread-safe singleton)."""
|
||||
return ModelVersionRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_batch_upload_repository() -> BatchUploadRepository:
|
||||
"""Get the BatchUploadRepository instance (thread-safe singleton)."""
|
||||
return BatchUploadRepository()
|
||||
|
||||
|
||||
def reset_all_repositories() -> None:
|
||||
"""Reset all repository instances (for testing)."""
|
||||
get_token_repository.cache_clear()
|
||||
get_document_repository.cache_clear()
|
||||
get_annotation_repository.cache_clear()
|
||||
get_dataset_repository.cache_clear()
|
||||
get_training_task_repository.cache_clear()
|
||||
get_model_version_repository.cache_clear()
|
||||
get_batch_upload_repository.cache_clear()
|
||||
|
||||
|
||||
# Repository dependency type aliases
|
||||
DocumentRepoDep = Annotated[DocumentRepository, Depends(get_document_repository)]
|
||||
AnnotationRepoDep = Annotated[AnnotationRepository, Depends(get_annotation_repository)]
|
||||
DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
|
||||
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
|
||||
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
|
||||
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]
|
||||
|
||||
@@ -8,7 +8,8 @@ import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from inference.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
@@ -18,7 +19,7 @@ from inference.web.services.storage_helpers import get_storage_helper
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelScheduler:
|
||||
class AutoLabelScheduler(TaskRunner):
|
||||
"""Scheduler for auto-labeling tasks."""
|
||||
|
||||
def __init__(
|
||||
@@ -47,39 +48,73 @@ class AutoLabelScheduler:
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
self._lock = threading.Lock()
|
||||
self._doc_repo = DocumentRepository()
|
||||
self._ann_repo = AnnotationRepository()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("AutoLabel scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "autolabel_scheduler"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the scheduler."""
|
||||
try:
|
||||
pending_docs = get_pending_autolabel_documents(limit=1000)
|
||||
pending_count = len(pending_docs)
|
||||
except Exception:
|
||||
pending_count = 0
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=pending_count,
|
||||
processing_count=1 if self._running else 0,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("AutoLabel scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the scheduler.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown.
|
||||
If None, uses default of 5 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
thread_to_join = self._thread
|
||||
|
||||
effective_timeout = timeout if timeout is not None else 5.0
|
||||
if thread_to_join:
|
||||
thread_to_join.join(timeout=effective_timeout)
|
||||
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
@@ -94,9 +129,7 @@ class AutoLabelScheduler:
|
||||
def _process_pending_documents(self) -> None:
|
||||
"""Check and process pending auto-label documents."""
|
||||
try:
|
||||
documents = get_pending_autolabel_documents(
|
||||
self._db, limit=self._batch_size
|
||||
)
|
||||
documents = get_pending_autolabel_documents(limit=self._batch_size)
|
||||
|
||||
if not documents:
|
||||
return
|
||||
@@ -110,8 +143,9 @@ class AutoLabelScheduler:
|
||||
try:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
db=self._db,
|
||||
output_dir=self._output_dir,
|
||||
doc_repo=self._doc_repo,
|
||||
ann_repo=self._ann_repo,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
@@ -136,13 +170,21 @@ class AutoLabelScheduler:
|
||||
|
||||
# Global scheduler instance
|
||||
_autolabel_scheduler: AutoLabelScheduler | None = None
|
||||
_autolabel_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_autolabel_scheduler() -> AutoLabelScheduler:
|
||||
"""Get the auto-label scheduler instance."""
|
||||
"""Get the auto-label scheduler instance.
|
||||
|
||||
Uses double-checked locking pattern for thread safety.
|
||||
"""
|
||||
global _autolabel_scheduler
|
||||
|
||||
if _autolabel_scheduler is None:
|
||||
_autolabel_scheduler = AutoLabelScheduler()
|
||||
with _autolabel_lock:
|
||||
if _autolabel_scheduler is None:
|
||||
_autolabel_scheduler = AutoLabelScheduler()
|
||||
|
||||
return _autolabel_scheduler
|
||||
|
||||
|
||||
|
||||
@@ -10,13 +10,20 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import (
|
||||
TrainingTaskRepository,
|
||||
DatasetRepository,
|
||||
ModelVersionRepository,
|
||||
DocumentRepository,
|
||||
AnnotationRepository,
|
||||
)
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingScheduler:
|
||||
class TrainingScheduler(TaskRunner):
|
||||
"""Scheduler for training tasks."""
|
||||
|
||||
def __init__(
|
||||
@@ -33,30 +40,73 @@ class TrainingScheduler:
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
self._lock = threading.Lock()
|
||||
# Repositories
|
||||
self._training_tasks = TrainingTaskRepository()
|
||||
self._datasets = DatasetRepository()
|
||||
self._model_versions = ModelVersionRepository()
|
||||
self._documents = DocumentRepository()
|
||||
self._annotations = AnnotationRepository()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "training_scheduler"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the scheduler is currently active."""
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the scheduler."""
|
||||
try:
|
||||
pending_tasks = self._training_tasks.get_pending()
|
||||
pending_count = len(pending_tasks)
|
||||
except Exception:
|
||||
pending_count = 0
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=pending_count,
|
||||
processing_count=1 if self._running else 0,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("Training scheduler already running")
|
||||
return
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("Training scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Training scheduler started")
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Training scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the scheduler.
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown.
|
||||
If None, uses default of 5 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
thread_to_join = self._thread
|
||||
|
||||
effective_timeout = timeout if timeout is not None else 5.0
|
||||
if thread_to_join:
|
||||
thread_to_join.join(timeout=effective_timeout)
|
||||
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
logger.info("Training scheduler stopped")
|
||||
@@ -75,7 +125,7 @@ class TrainingScheduler:
|
||||
def _check_pending_tasks(self) -> None:
|
||||
"""Check and execute pending training tasks."""
|
||||
try:
|
||||
tasks = self._db.get_pending_training_tasks()
|
||||
tasks = self._training_tasks.get_pending()
|
||||
|
||||
for task in tasks:
|
||||
task_id = str(task.task_id)
|
||||
@@ -91,7 +141,7 @@ class TrainingScheduler:
|
||||
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.update_training_task_status(
|
||||
self._training_tasks.update_status(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
@@ -105,12 +155,12 @@ class TrainingScheduler:
|
||||
) -> None:
|
||||
"""Execute a training task."""
|
||||
# Update status to running
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||
self._training_tasks.update_status(task_id, "running")
|
||||
self._training_tasks.add_log(task_id, "INFO", "Training task started")
|
||||
|
||||
# Update dataset training status to running
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
@@ -137,7 +187,7 @@ class TrainingScheduler:
|
||||
if not Path(base_model_path).exists():
|
||||
raise ValueError(f"Base model not found: {base_model_path}")
|
||||
effective_model = base_model_path
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Incremental training from: {base_model_path}",
|
||||
)
|
||||
@@ -147,12 +197,12 @@ class TrainingScheduler:
|
||||
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
dataset = self._datasets.get(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
dataset_path = Path(dataset.dataset_path)
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
)
|
||||
@@ -162,7 +212,7 @@ class TrainingScheduler:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
dataset_path = Path(data_yaml).parent
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
@@ -173,7 +223,7 @@ class TrainingScheduler:
|
||||
task_id, dataset_path, augmentation_config, augmentation_multiplier
|
||||
)
|
||||
if aug_result:
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Augmentation complete: {aug_result['augmented_images']} new images "
|
||||
f"(total: {aug_result['total_images']})",
|
||||
@@ -193,17 +243,17 @@ class TrainingScheduler:
|
||||
)
|
||||
|
||||
# Update task with results
|
||||
self._db.update_training_task_status(
|
||||
self._training_tasks.update_status(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
result_metrics=result.get("metrics"),
|
||||
model_path=result.get("model_path"),
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
self._training_tasks.add_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Update dataset training status to completed and main status to trained
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
@@ -220,10 +270,10 @@ class TrainingScheduler:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
# Update dataset training status to failed
|
||||
if dataset_id:
|
||||
self._db.update_dataset_training_status(
|
||||
self._datasets.update_training_status(
|
||||
dataset_id,
|
||||
training_status="failed",
|
||||
active_training_task_id=None,
|
||||
@@ -245,11 +295,11 @@ class TrainingScheduler:
|
||||
return
|
||||
|
||||
# Get task info for name
|
||||
task = self._db.get_training_task(task_id)
|
||||
task = self._training_tasks.get(task_id)
|
||||
task_name = task.name if task else f"Task {task_id[:8]}"
|
||||
|
||||
# Generate version number based on existing versions
|
||||
existing_versions = self._db.get_model_versions(limit=1, offset=0)
|
||||
existing_versions = self._model_versions.get_paginated(limit=1, offset=0)
|
||||
version_count = existing_versions[1] if existing_versions else 0
|
||||
version = f"v{version_count + 1}.0"
|
||||
|
||||
@@ -268,12 +318,12 @@ class TrainingScheduler:
|
||||
# Get document count from dataset if available
|
||||
document_count = 0
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
dataset = self._datasets.get(dataset_id)
|
||||
if dataset:
|
||||
document_count = dataset.total_documents
|
||||
|
||||
# Create model version
|
||||
model_version = self._db.create_model_version(
|
||||
model_version = self._model_versions.create(
|
||||
version=version,
|
||||
name=task_name,
|
||||
model_path=str(model_path),
|
||||
@@ -294,14 +344,14 @@ class TrainingScheduler:
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {mAP_display})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Failed to auto-create model version: {e}",
|
||||
)
|
||||
@@ -316,16 +366,16 @@ class TrainingScheduler:
|
||||
storage = get_storage_helper()
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
documents = self._documents.get_labeled_for_export()
|
||||
|
||||
if not documents:
|
||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||
self._training_tasks.add_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory using StorageHelper
|
||||
training_base = storage.get_training_data_path()
|
||||
if training_base is None:
|
||||
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
|
||||
self._training_tasks.add_log(task_id, "ERROR", "Storage not configured for local access")
|
||||
return None
|
||||
export_dir = training_base / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -348,7 +398,7 @@ class TrainingScheduler:
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = self._db.get_annotations_for_document(str(doc.document_id))
|
||||
annotations = self._annotations.get_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
@@ -412,7 +462,7 @@ names: {list(FIELD_CLASSES.values())}
|
||||
|
||||
# Create log callback that writes to DB
|
||||
def log_callback(level: str, message: str) -> None:
|
||||
self._db.add_training_log(task_id, level, message)
|
||||
self._training_tasks.add_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
|
||||
@@ -468,7 +518,7 @@ names: {list(FIELD_CLASSES.values())}
|
||||
try:
|
||||
from shared.augmentation import DatasetAugmenter
|
||||
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Applying augmentation with multiplier={multiplier}",
|
||||
)
|
||||
@@ -480,7 +530,7 @@ names: {list(FIELD_CLASSES.values())}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Augmentation failed for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Augmentation failed: {e}. Continuing with original dataset.",
|
||||
)
|
||||
@@ -489,13 +539,21 @@ names: {list(FIELD_CLASSES.values())}
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: TrainingScheduler | None = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_training_scheduler() -> TrainingScheduler:
|
||||
"""Get the training scheduler instance."""
|
||||
"""Get the training scheduler instance.
|
||||
|
||||
Uses double-checked locking pattern for thread safety.
|
||||
"""
|
||||
global _scheduler
|
||||
|
||||
if _scheduler is None:
|
||||
_scheduler = TrainingScheduler()
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = TrainingScheduler()
|
||||
|
||||
return _scheduler
|
||||
|
||||
|
||||
|
||||
161
packages/inference/inference/web/core/task_interface.py
Normal file
161
packages/inference/inference/web/core/task_interface.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Unified task management interface.
|
||||
|
||||
Provides abstract base class for all task runners (schedulers and queues)
|
||||
and a TaskManager facade for unified lifecycle management.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskStatus:
|
||||
"""Status of a task runner.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the runner.
|
||||
is_running: Whether the runner is currently active.
|
||||
pending_count: Number of tasks waiting to be processed.
|
||||
processing_count: Number of tasks currently being processed.
|
||||
error: Optional error message if runner is in error state.
|
||||
"""
|
||||
|
||||
name: str
|
||||
is_running: bool
|
||||
pending_count: int
|
||||
processing_count: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TaskRunner(ABC):
|
||||
"""Abstract base class for all task runners.
|
||||
|
||||
All schedulers and task queues should implement this interface
|
||||
to enable unified lifecycle management and monitoring.
|
||||
|
||||
Note:
|
||||
Implementations may have different `start()` signatures based on
|
||||
their initialization needs (e.g., handler functions, services).
|
||||
Use the implementation-specific start methods for initialization,
|
||||
and use TaskManager for unified status monitoring.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, *args, **kwargs) -> None:
|
||||
"""Start the task runner.
|
||||
|
||||
Should be idempotent - calling start on an already running
|
||||
runner should have no effect.
|
||||
|
||||
Note:
|
||||
Implementations may require additional parameters (handlers,
|
||||
services). See implementation-specific documentation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop the task runner gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown in seconds.
|
||||
If None, use implementation default.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the runner is currently active."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the runner.
|
||||
|
||||
Returns:
|
||||
TaskStatus with current state information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Unified manager for all task runners.
|
||||
|
||||
Provides centralized lifecycle management and monitoring
|
||||
for all registered task runners.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the task manager."""
|
||||
self._runners: dict[str, TaskRunner] = {}
|
||||
|
||||
def register(self, runner: TaskRunner) -> None:
|
||||
"""Register a task runner.
|
||||
|
||||
Args:
|
||||
runner: TaskRunner instance to register.
|
||||
"""
|
||||
self._runners[runner.name] = runner
|
||||
|
||||
def get_runner(self, name: str) -> TaskRunner | None:
|
||||
"""Get a specific runner by name.
|
||||
|
||||
Args:
|
||||
name: Name of the runner to retrieve.
|
||||
|
||||
Returns:
|
||||
TaskRunner if found, None otherwise.
|
||||
"""
|
||||
return self._runners.get(name)
|
||||
|
||||
@property
|
||||
def runner_names(self) -> list[str]:
|
||||
"""Get names of all registered runners.
|
||||
|
||||
Returns:
|
||||
List of runner names.
|
||||
"""
|
||||
return list(self._runners.keys())
|
||||
|
||||
def start_all(self) -> None:
|
||||
"""Start all registered runners that support no-argument start.
|
||||
|
||||
Note:
|
||||
Runners requiring initialization parameters (like AsyncTaskQueue
|
||||
or BatchTaskQueue) should be started individually before
|
||||
registering with TaskManager.
|
||||
"""
|
||||
for runner in self._runners.values():
|
||||
try:
|
||||
runner.start()
|
||||
except TypeError:
|
||||
# Runner requires arguments - skip (should be started individually)
|
||||
pass
|
||||
|
||||
def stop_all(self, timeout: float = 30.0) -> None:
|
||||
"""Stop all registered runners gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Total timeout to distribute across all runners.
|
||||
"""
|
||||
if not self._runners:
|
||||
return
|
||||
|
||||
per_runner_timeout = timeout / len(self._runners)
|
||||
for runner in self._runners.values():
|
||||
runner.stop(timeout=per_runner_timeout)
|
||||
|
||||
def get_all_status(self) -> dict[str, TaskStatus]:
|
||||
"""Get status of all registered runners.
|
||||
|
||||
Returns:
|
||||
Dict mapping runner names to their status.
|
||||
"""
|
||||
return {name: runner.get_status() for name, runner in self._runners.items()}
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DocumentRepository, DatasetRepository
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
@@ -32,9 +32,14 @@ UUID_PATTERN = re.compile(
|
||||
class AugmentationService:
|
||||
"""Service for augmentation operations."""
|
||||
|
||||
def __init__(self, db: AdminDB) -> None:
|
||||
"""Initialize service with database connection."""
|
||||
self.db = db
|
||||
def __init__(
|
||||
self,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
dataset_repo: DatasetRepository | None = None,
|
||||
) -> None:
|
||||
"""Initialize service with repository connections."""
|
||||
self.doc_repo = doc_repo or DocumentRepository()
|
||||
self.dataset_repo = dataset_repo or DatasetRepository()
|
||||
|
||||
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
|
||||
"""
|
||||
@@ -179,7 +184,7 @@ class AugmentationService:
|
||||
"""
|
||||
# Validate source dataset exists
|
||||
try:
|
||||
source_dataset = self.db.get_dataset(source_dataset_id)
|
||||
source_dataset = self.dataset_repo.get(source_dataset_id)
|
||||
if source_dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -259,7 +264,7 @@ class AugmentationService:
|
||||
|
||||
# Get document from database
|
||||
try:
|
||||
document = self.db.get_document(document_id)
|
||||
document = self.doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
||||
@@ -12,7 +12,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
@@ -45,7 +45,8 @@ class AutoLabelService:
|
||||
document_id: str,
|
||||
file_path: str,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
replace_existing: bool = False,
|
||||
skip_lock_check: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
@@ -56,16 +57,23 @@ class AutoLabelService:
|
||||
document_id: Document UUID
|
||||
file_path: Path to document file
|
||||
field_values: Dict of field_name -> value to match
|
||||
db: Admin database instance
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
replace_existing: Whether to replace existing auto annotations
|
||||
skip_lock_check: Skip annotation lock check (for batch processing)
|
||||
|
||||
Returns:
|
||||
Dict with status and annotation count
|
||||
"""
|
||||
# Initialize repositories if not provided
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
try:
|
||||
# Get document info first
|
||||
document = db.get_document(document_id)
|
||||
document = doc_repo.get(document_id)
|
||||
if document is None:
|
||||
raise ValueError(f"Document not found: {document_id}")
|
||||
|
||||
@@ -80,7 +88,7 @@ class AutoLabelService:
|
||||
)
|
||||
|
||||
# Update status to running
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
@@ -88,7 +96,7 @@ class AutoLabelService:
|
||||
|
||||
# Delete existing auto annotations if requested
|
||||
if replace_existing:
|
||||
deleted = db.delete_annotations_for_document(
|
||||
deleted = ann_repo.delete_for_document(
|
||||
document_id=document_id,
|
||||
source="auto",
|
||||
)
|
||||
@@ -101,17 +109,17 @@ class AutoLabelService:
|
||||
if path.suffix.lower() == ".pdf":
|
||||
# Process PDF (all pages)
|
||||
annotations_created = self._process_pdf(
|
||||
document_id, path, field_values, db
|
||||
document_id, path, field_values, ann_repo
|
||||
)
|
||||
else:
|
||||
# Process single image
|
||||
annotations_created = self._process_image(
|
||||
document_id, path, field_values, db, page_number=1
|
||||
document_id, path, field_values, ann_repo, page_number=1
|
||||
)
|
||||
|
||||
# Update document status
|
||||
status = "labeled" if annotations_created > 0 else "pending"
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status=status,
|
||||
auto_label_status="completed",
|
||||
@@ -124,7 +132,7 @@ class AutoLabelService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-labeling failed for {document_id}: {e}")
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
@@ -141,7 +149,7 @@ class AutoLabelService:
|
||||
document_id: str,
|
||||
pdf_path: Path,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
ann_repo: AnnotationRepository,
|
||||
) -> int:
|
||||
"""Process PDF document and create annotations."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
@@ -172,7 +180,7 @@ class AutoLabelService:
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
db.create_annotations_batch(annotations)
|
||||
ann_repo.create_batch(annotations)
|
||||
total_annotations += len(annotations)
|
||||
|
||||
return total_annotations
|
||||
@@ -182,7 +190,7 @@ class AutoLabelService:
|
||||
document_id: str,
|
||||
image_path: Path,
|
||||
field_values: dict[str, str],
|
||||
db: AdminDB,
|
||||
ann_repo: AnnotationRepository,
|
||||
page_number: int = 1,
|
||||
) -> int:
|
||||
"""Process single image and create annotations."""
|
||||
@@ -208,7 +216,7 @@ class AutoLabelService:
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
db.create_annotations_batch(annotations)
|
||||
ann_repo.create_batch(annotations)
|
||||
|
||||
return len(annotations)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import BatchUploadRepository
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,13 +64,13 @@ class CSVRowData(BaseModel):
|
||||
class BatchUploadService:
|
||||
"""Service for handling batch uploads of documents via ZIP files."""
|
||||
|
||||
def __init__(self, admin_db: AdminDB):
|
||||
def __init__(self, batch_repo: BatchUploadRepository | None = None):
|
||||
"""Initialize the batch upload service.
|
||||
|
||||
Args:
|
||||
admin_db: Admin database interface
|
||||
batch_repo: Batch upload repository (created if None)
|
||||
"""
|
||||
self.admin_db = admin_db
|
||||
self.batch_repo = batch_repo or BatchUploadRepository()
|
||||
|
||||
def _safe_extract_filename(self, zip_path: str) -> str:
|
||||
"""Safely extract filename from ZIP path, preventing path traversal.
|
||||
@@ -170,7 +170,7 @@ class BatchUploadService:
|
||||
Returns:
|
||||
Dictionary with batch upload results
|
||||
"""
|
||||
batch = self.admin_db.create_batch_upload(
|
||||
batch = self.batch_repo.create(
|
||||
admin_token=admin_token,
|
||||
filename=zip_filename,
|
||||
file_size=len(zip_content),
|
||||
@@ -189,7 +189,7 @@ class BatchUploadService:
|
||||
)
|
||||
|
||||
# Update batch upload status
|
||||
self.admin_db.update_batch_upload(
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status=result["status"],
|
||||
total_files=result["total_files"],
|
||||
@@ -208,7 +208,7 @@ class BatchUploadService:
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
|
||||
self.admin_db.update_batch_upload(
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Invalid ZIP file format",
|
||||
@@ -222,7 +222,7 @@ class BatchUploadService:
|
||||
except ValueError as e:
|
||||
# Security validation errors
|
||||
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
|
||||
self.admin_db.update_batch_upload(
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP file validation failed",
|
||||
@@ -235,7 +235,7 @@ class BatchUploadService:
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
|
||||
self.admin_db.update_batch_upload(
|
||||
self.batch_repo.update(
|
||||
batch_id=batch.batch_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
@@ -312,7 +312,7 @@ class BatchUploadService:
|
||||
filename = self._safe_extract_filename(pdf_info.filename)
|
||||
|
||||
# Create batch upload file record
|
||||
file_record = self.admin_db.create_batch_upload_file(
|
||||
file_record = self.batch_repo.create_file(
|
||||
batch_id=batch_id,
|
||||
filename=filename,
|
||||
status="processing",
|
||||
@@ -328,7 +328,7 @@ class BatchUploadService:
|
||||
# TODO: Save PDF file and create document
|
||||
# For now, just mark as completed
|
||||
|
||||
self.admin_db.update_batch_upload_file(
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="completed",
|
||||
csv_row_data=csv_row_data,
|
||||
@@ -341,7 +341,7 @@ class BatchUploadService:
|
||||
# Path validation error
|
||||
logger.warning(f"Skipping invalid file: {e}")
|
||||
if file_record:
|
||||
self.admin_db.update_batch_upload_file(
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid filename",
|
||||
@@ -352,7 +352,7 @@ class BatchUploadService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF: {e}", exc_info=True)
|
||||
if file_record:
|
||||
self.admin_db.update_batch_upload_file(
|
||||
self.batch_repo.update_file(
|
||||
file_id=file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Processing error",
|
||||
@@ -515,13 +515,13 @@ class BatchUploadService:
|
||||
Returns:
|
||||
Batch status dictionary
|
||||
"""
|
||||
batch = self.admin_db.get_batch_upload(UUID(batch_id))
|
||||
batch = self.batch_repo.get(UUID(batch_id))
|
||||
if not batch:
|
||||
return {
|
||||
"error": "Batch upload not found",
|
||||
}
|
||||
|
||||
files = self.admin_db.get_batch_upload_files(batch.batch_id)
|
||||
files = self.batch_repo.get_files(batch.batch_id)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.batch_id),
|
||||
|
||||
@@ -20,8 +20,16 @@ logger = logging.getLogger(__name__)
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training datasets from admin documents."""
|
||||
|
||||
def __init__(self, db, base_dir: Path):
|
||||
self._db = db
|
||||
def __init__(
|
||||
self,
|
||||
datasets_repo,
|
||||
documents_repo,
|
||||
annotations_repo,
|
||||
base_dir: Path,
|
||||
):
|
||||
self._datasets_repo = datasets_repo
|
||||
self._documents_repo = documents_repo
|
||||
self._annotations_repo = annotations_repo
|
||||
self._base_dir = Path(base_dir)
|
||||
|
||||
def build_dataset(
|
||||
@@ -54,7 +62,7 @@ class DatasetBuilder:
|
||||
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
||||
)
|
||||
except Exception as e:
|
||||
self._db.update_dataset_status(
|
||||
self._datasets_repo.update_status(
|
||||
dataset_id=dataset_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
@@ -71,7 +79,7 @@ class DatasetBuilder:
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
# 1. Fetch documents
|
||||
documents = self._db.get_documents_by_ids(document_ids)
|
||||
documents = self._documents_repo.get_by_ids(document_ids)
|
||||
if not documents:
|
||||
raise ValueError("No valid documents found for the given IDs")
|
||||
|
||||
@@ -93,7 +101,7 @@ class DatasetBuilder:
|
||||
for doc in doc_list:
|
||||
doc_id = str(doc.document_id)
|
||||
split = doc_splits[doc_id]
|
||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
|
||||
|
||||
# Group annotations by page
|
||||
page_annotations: dict[int, list] = {}
|
||||
@@ -139,7 +147,7 @@ class DatasetBuilder:
|
||||
})
|
||||
|
||||
# 5. Record document-split assignments in DB
|
||||
self._db.add_dataset_documents(
|
||||
self._datasets_repo.add_documents(
|
||||
dataset_id=dataset_id,
|
||||
documents=dataset_docs,
|
||||
)
|
||||
@@ -148,7 +156,7 @@ class DatasetBuilder:
|
||||
self._generate_data_yaml(dataset_dir)
|
||||
|
||||
# 7. Update dataset status
|
||||
self._db.update_dataset_status(
|
||||
self._datasets_repo.update_status(
|
||||
dataset_id=dataset_id,
|
||||
status="ready",
|
||||
total_documents=len(doc_list),
|
||||
|
||||
@@ -12,9 +12,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||
from inference.data.admin_models import AdminDocument
|
||||
from inference.data.repositories import DocumentRepository, AnnotationRepository
|
||||
from shared.data.db import DocumentDB
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
@@ -68,14 +68,12 @@ def convert_csv_field_values_to_row_dict(
|
||||
|
||||
|
||||
def get_pending_autolabel_documents(
|
||||
db: AdminDB,
|
||||
limit: int = 10,
|
||||
) -> list[AdminDocument]:
|
||||
"""
|
||||
Get documents pending auto-labeling.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance
|
||||
limit: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
@@ -99,20 +97,22 @@ def get_pending_autolabel_documents(
|
||||
|
||||
def process_document_autolabel(
|
||||
document: AdminDocument,
|
||||
db: AdminDB,
|
||||
output_dir: Path | None = None,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
min_confidence: float = 0.5,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a single document for auto-labeling using csv_field_values.
|
||||
|
||||
Args:
|
||||
document: AdminDocument with csv_field_values and file_path
|
||||
db: AdminDB instance for updating status
|
||||
output_dir: Output directory for temp files
|
||||
dpi: Rendering DPI
|
||||
min_confidence: Minimum match confidence
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status and annotations
|
||||
@@ -120,6 +120,12 @@ def process_document_autolabel(
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
from shared.pdf import PDFDocument
|
||||
|
||||
# Initialize repositories if not provided
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
|
||||
@@ -132,7 +138,7 @@ def process_document_autolabel(
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mark as processing
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
@@ -187,10 +193,10 @@ def process_document_autolabel(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save report to DocumentDB: {e}")
|
||||
|
||||
# Save annotations to AdminDB
|
||||
# Save annotations to database
|
||||
if result.get("success") and result.get("report"):
|
||||
_save_annotations_to_db(
|
||||
db=db,
|
||||
ann_repo=ann_repo,
|
||||
document_id=document_id,
|
||||
report=result["report"],
|
||||
page_annotations=result.get("pages", []),
|
||||
@@ -198,7 +204,7 @@ def process_document_autolabel(
|
||||
)
|
||||
|
||||
# Mark as completed
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
@@ -206,7 +212,7 @@ def process_document_autolabel(
|
||||
else:
|
||||
# Mark as failed
|
||||
errors = result.get("report", {}).get("errors", ["Unknown error"])
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
@@ -219,7 +225,7 @@ def process_document_autolabel(
|
||||
logger.error(f"Error processing document {document_id}: {e}", exc_info=True)
|
||||
|
||||
# Mark as failed
|
||||
db.update_document_status(
|
||||
doc_repo.update_status(
|
||||
document_id=document_id,
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
@@ -234,7 +240,7 @@ def process_document_autolabel(
|
||||
|
||||
|
||||
def _save_annotations_to_db(
|
||||
db: AdminDB,
|
||||
ann_repo: AnnotationRepository,
|
||||
document_id: str,
|
||||
report: dict[str, Any],
|
||||
page_annotations: list[dict[str, Any]],
|
||||
@@ -244,7 +250,7 @@ def _save_annotations_to_db(
|
||||
Save generated annotations to database.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance
|
||||
ann_repo: Annotation repository instance
|
||||
document_id: Document ID
|
||||
report: AutoLabelReport as dict
|
||||
page_annotations: List of page annotation data
|
||||
@@ -353,7 +359,7 @@ def _save_annotations_to_db(
|
||||
|
||||
# Create annotation
|
||||
try:
|
||||
db.create_annotation(
|
||||
ann_repo.create(
|
||||
document_id=document_id,
|
||||
page_number=page_no,
|
||||
class_id=class_id,
|
||||
@@ -379,25 +385,29 @@ def _save_annotations_to_db(
|
||||
|
||||
|
||||
def run_pending_autolabel_batch(
|
||||
db: AdminDB | None = None,
|
||||
batch_size: int = 10,
|
||||
output_dir: Path | None = None,
|
||||
doc_repo: DocumentRepository | None = None,
|
||||
ann_repo: AnnotationRepository | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a batch of pending auto-label documents.
|
||||
|
||||
Args:
|
||||
db: AdminDB instance (created if None)
|
||||
batch_size: Number of documents to process
|
||||
output_dir: Output directory for temp files
|
||||
doc_repo: Document repository (created if None)
|
||||
ann_repo: Annotation repository (created if None)
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
if db is None:
|
||||
db = AdminDB()
|
||||
if doc_repo is None:
|
||||
doc_repo = DocumentRepository()
|
||||
if ann_repo is None:
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
documents = get_pending_autolabel_documents(db, limit=batch_size)
|
||||
documents = get_pending_autolabel_documents(limit=batch_size)
|
||||
|
||||
results = {
|
||||
"total": len(documents),
|
||||
@@ -409,8 +419,9 @@ def run_pending_autolabel_batch(
|
||||
for doc in documents:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
db=db,
|
||||
output_dir=output_dir,
|
||||
doc_repo=doc_repo,
|
||||
ann_repo=ann_repo,
|
||||
)
|
||||
|
||||
doc_result = {
|
||||
@@ -432,7 +443,6 @@ def run_pending_autolabel_batch(
|
||||
def save_manual_annotations_to_document_db(
|
||||
document: AdminDocument,
|
||||
annotations: list,
|
||||
db: AdminDB,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save manual annotations to PostgreSQL documents and field_results tables.
|
||||
@@ -444,7 +454,6 @@ def save_manual_annotations_to_document_db(
|
||||
Args:
|
||||
document: AdminDocument instance
|
||||
annotations: List of AdminAnnotation instances
|
||||
db: AdminDB instance
|
||||
|
||||
Returns:
|
||||
Dict with success status and details
|
||||
|
||||
@@ -14,6 +14,8 @@ import threading
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Callable
|
||||
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +31,7 @@ class AsyncTask:
|
||||
priority: int = 0 # Lower = higher priority (not implemented yet)
|
||||
|
||||
|
||||
class AsyncTaskQueue:
|
||||
class AsyncTaskQueue(TaskRunner):
|
||||
"""Thread-safe queue for async invoice processing."""
|
||||
|
||||
def __init__(
|
||||
@@ -46,44 +48,78 @@ class AsyncTaskQueue:
|
||||
self._task_handler: Callable[[AsyncTask], None] | None = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "async_task_queue"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the queue is running."""
|
||||
return self._started and not self._stop_event.is_set()
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the queue."""
|
||||
with self._lock:
|
||||
processing_count = len(self._processing)
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self.is_running,
|
||||
pending_count=self._queue.qsize(),
|
||||
processing_count=processing_count,
|
||||
)
|
||||
|
||||
def start(self, task_handler: Callable[[AsyncTask], None]) -> None:
|
||||
"""Start background worker threads."""
|
||||
if self._started:
|
||||
logger.warning("AsyncTaskQueue already started")
|
||||
return
|
||||
with self._lock:
|
||||
if self._started:
|
||||
logger.warning("AsyncTaskQueue already started")
|
||||
return
|
||||
|
||||
self._task_handler = task_handler
|
||||
self._stop_event.clear()
|
||||
self._task_handler = task_handler
|
||||
self._stop_event.clear()
|
||||
|
||||
for i in range(self._worker_count):
|
||||
worker = Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"async-worker-{i}",
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
logger.info(f"Started async worker thread: {worker.name}")
|
||||
for i in range(self._worker_count):
|
||||
worker = Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"async-worker-{i}",
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
logger.info(f"Started async worker thread: {worker.name}")
|
||||
|
||||
self._started = True
|
||||
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
|
||||
self._started = True
|
||||
logger.info(f"AsyncTaskQueue started with {self._worker_count} workers")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
"""Gracefully stop all workers."""
|
||||
if not self._started:
|
||||
return
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Gracefully stop all workers.
|
||||
|
||||
logger.info("Stopping AsyncTaskQueue...")
|
||||
self._stop_event.set()
|
||||
Args:
|
||||
timeout: Maximum time to wait for graceful shutdown.
|
||||
If None, uses default of 30 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
worker.join(timeout=timeout / self._worker_count)
|
||||
logger.info("Stopping AsyncTaskQueue...")
|
||||
self._stop_event.set()
|
||||
workers_to_join = list(self._workers)
|
||||
|
||||
effective_timeout = timeout if timeout is not None else 30.0
|
||||
|
||||
# Wait for workers to finish outside the lock
|
||||
for worker in workers_to_join:
|
||||
worker.join(timeout=effective_timeout / self._worker_count)
|
||||
if worker.is_alive():
|
||||
logger.warning(f"Worker {worker.name} did not stop gracefully")
|
||||
|
||||
self._workers.clear()
|
||||
self._started = False
|
||||
with self._lock:
|
||||
self._workers.clear()
|
||||
self._started = False
|
||||
logger.info("AsyncTaskQueue stopped")
|
||||
|
||||
def submit(self, task: AsyncTask) -> bool:
|
||||
@@ -115,11 +151,6 @@ class AsyncTaskQueue:
|
||||
with self._lock:
|
||||
return request_id in self._processing
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the queue is running."""
|
||||
return self._started and not self._stop_event.is_set()
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker loop that processes tasks from queue."""
|
||||
thread_name = threading.current_thread().name
|
||||
|
||||
@@ -12,6 +12,8 @@ from queue import Queue, Full, Empty
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -28,7 +30,7 @@ class BatchTask:
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class BatchTaskQueue:
|
||||
class BatchTaskQueue(TaskRunner):
|
||||
"""Thread-safe queue for async batch upload processing."""
|
||||
|
||||
def __init__(self, max_size: int = 20, worker_count: int = 2):
|
||||
@@ -45,6 +47,29 @@ class BatchTaskQueue:
|
||||
self._batch_service: Any | None = None
|
||||
self._running = False
|
||||
self._lock = threading.Lock()
|
||||
self._processing: set[UUID] = set() # Currently processing batch_ids
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this runner."""
|
||||
return "batch_task_queue"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue is running."""
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
"""Get current status of the queue."""
|
||||
with self._lock:
|
||||
processing_count = len(self._processing)
|
||||
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=self._queue.qsize(),
|
||||
processing_count=processing_count,
|
||||
)
|
||||
|
||||
def start(self, batch_service: Any) -> None:
|
||||
"""Start worker threads with batch service.
|
||||
@@ -73,12 +98,14 @@ class BatchTaskQueue:
|
||||
|
||||
logger.info(f"Started {self._worker_count} batch workers")
|
||||
|
||||
def stop(self, timeout: float = 30.0) -> None:
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
"""Stop all worker threads gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for workers to finish
|
||||
timeout: Maximum time to wait for workers to finish.
|
||||
If None, uses default of 30 seconds.
|
||||
"""
|
||||
# Minimize lock scope to avoid potential deadlock
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
@@ -86,13 +113,17 @@ class BatchTaskQueue:
|
||||
logger.info("Stopping batch queue...")
|
||||
self._stop_event.set()
|
||||
self._running = False
|
||||
workers_to_join = list(self._workers)
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
worker.join(timeout=timeout)
|
||||
effective_timeout = timeout if timeout is not None else 30.0
|
||||
|
||||
# Wait for workers to finish outside the lock
|
||||
for worker in workers_to_join:
|
||||
worker.join(timeout=effective_timeout)
|
||||
|
||||
with self._lock:
|
||||
self._workers.clear()
|
||||
logger.info("Batch queue stopped")
|
||||
logger.info("Batch queue stopped")
|
||||
|
||||
def submit(self, task: BatchTask) -> bool:
|
||||
"""Submit a batch task to the queue.
|
||||
@@ -119,15 +150,6 @@ class BatchTaskQueue:
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue is running.
|
||||
|
||||
Returns:
|
||||
True if queue is active
|
||||
"""
|
||||
return self._running
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker thread main loop."""
|
||||
worker_name = threading.current_thread().name
|
||||
@@ -157,6 +179,9 @@ class BatchTaskQueue:
|
||||
logger.error("Batch service not initialized, cannot process task")
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._processing.add(task.batch_id)
|
||||
|
||||
logger.info(
|
||||
f"Processing batch task: batch_id={task.batch_id}, "
|
||||
f"filename={task.zip_filename}"
|
||||
@@ -183,6 +208,9 @@ class BatchTaskQueue:
|
||||
f"Error processing batch task {task.batch_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
with self._lock:
|
||||
self._processing.discard(task.batch_id)
|
||||
|
||||
|
||||
# Global batch queue instance
|
||||
|
||||
Reference in New Issue
Block a user