Files
invoice-master-poc-v2/tests/web/test_training_phase4.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

482 lines
18 KiB
Python

"""
Tests for Phase 4: Training Data Management
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
get_model_version_repository,
)
class MockTrainingTask:
"""Mock TrainingTask for testing."""
def __init__(self, **kwargs):
self.task_id = kwargs.get('task_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.name = kwargs.get('name', 'Test Training')
self.description = kwargs.get('description', None)
self.status = kwargs.get('status', 'completed')
self.task_type = kwargs.get('task_type', 'train')
self.config = kwargs.get('config', {})
self.scheduled_at = kwargs.get('scheduled_at', None)
self.cron_expression = kwargs.get('cron_expression', None)
self.is_recurring = kwargs.get('is_recurring', False)
self.started_at = kwargs.get('started_at', datetime.utcnow())
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
self.error_message = kwargs.get('error_message', None)
self.result_metrics = kwargs.get('result_metrics', {})
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.document_count = kwargs.get('document_count', 0)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockTrainingDocumentLink:
"""Mock TrainingDocumentLink for testing."""
def __init__(self, **kwargs):
self.link_id = kwargs.get('link_id', uuid4())
self.task_id = kwargs.get('task_id')
self.document_id = kwargs.get('document_id')
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'labeled')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100)
self.bbox_y = kwargs.get('bbox_y', 100)
self.bbox_width = kwargs.get('bbox_width', 200)
self.bbox_height = kwargs.get('bbox_height', 50)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.is_verified = kwargs.get('is_verified', False)
self.verified_at = kwargs.get('verified_at', None)
self.verified_by = kwargs.get('verified_by', None)
self.override_source = kwargs.get('override_source', None)
self.original_annotation_id = kwargs.get('original_annotation_id', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockModelVersion:
"""Mock ModelVersion for testing."""
def __init__(self, **kwargs):
self.version_id = kwargs.get('version_id', uuid4())
self.version = kwargs.get('version', '1.0.0')
self.name = kwargs.get('name', 'Test Model')
self.description = kwargs.get('description', None)
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.status = kwargs.get('status', 'inactive')
self.is_active = kwargs.get('is_active', False)
self.task_id = kwargs.get('task_id', None)
self.dataset_id = kwargs.get('dataset_id', None)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.document_count = kwargs.get('document_count', 100)
self.training_config = kwargs.get('training_config', {})
self.file_size = kwargs.get('file_size', 52428800)
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
self.activated_at = kwargs.get('activated_at', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockDocumentRepository:
"""Mock DocumentRepository for testing Phase 4."""
def __init__(self):
self.documents = {}
self.annotations = {} # Shared reference for filtering
self.training_links = {} # Shared reference for filtering
def get_for_training(
self,
admin_token=None,
status="labeled",
has_annotations=True,
min_annotation_count=None,
exclude_used_in_training=False,
limit=100,
offset=0,
):
"""Get documents for training."""
# Filter documents by criteria
filtered = []
for doc in self.documents.values():
if doc.admin_token != admin_token or doc.status != status:
continue
# Check annotations
annotations = self.annotations.get(str(doc.document_id), [])
if has_annotations and len(annotations) == 0:
continue
if min_annotation_count and len(annotations) < min_annotation_count:
continue
# Check if used in training
if exclude_used_in_training:
links = self.training_links.get(str(doc.document_id), [])
if links:
continue
filtered.append(doc)
total = len(filtered)
return filtered[offset:offset+limit], total
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing Phase 4."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing Phase 4."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_paginated(
self,
admin_token=None,
status=None,
limit=20,
offset=0,
):
"""Get training tasks filtered by token."""
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
if status:
tasks = [t for t in tasks if t.status == status]
total = len(tasks)
return tasks[offset:offset+limit], total
def get(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
class MockModelVersionRepository:
"""Mock ModelVersionRepository for testing Phase 4."""
def __init__(self):
self.model_versions = {}
def get_paginated(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
models = [m for m in models if m.status == status]
total = len(models)
return models[offset:offset+limit], total
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
mock_model_version_repo = MockModelVersionRepository()
# Add test documents
doc1 = MockAdminDocument(
filename="DOC001.pdf",
status="labeled",
)
doc2 = MockAdminDocument(
filename="DOC002.pdf",
status="labeled",
)
doc3 = MockAdminDocument(
filename="DOC003.pdf",
status="labeled",
)
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(document_id=doc1.document_id, source="manual"),
MockAnnotation(document_id=doc1.document_id, source="auto"),
]
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
]
# doc3 has no annotations
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Add training tasks
task1 = MockTrainingTask(
name="Training Run 2024-01",
status="completed",
document_count=500,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
)
task2 = MockTrainingTask(
name="Training Run 2024-02",
status="completed",
document_count=600,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
)
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
# Add training links (doc1 used in task1)
link1 = MockTrainingDocumentLink(
task_id=task1.task_id,
document_id=doc1.document_id,
)
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
# Share training links with document repo for filtering
mock_document_repo.training_links = mock_training_task_repo.training_links
# Add model versions
model1 = MockModelVersion(
version="1.0.0",
name="Model v1.0.0",
status="inactive",
is_active=False,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
model2 = MockModelVersion(
version="1.1.0",
name="Model v1.1.0",
status="active",
is_active=True,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=600,
)
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
# Include router
router = create_training_router()
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestTrainingDocuments:
"""Tests for GET /admin/training/documents endpoint."""
def test_get_training_documents_success(self, client):
"""Test getting documents for training."""
response = client.get("/admin/training/documents")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "documents" in data
assert data["total"] >= 0
assert isinstance(data["documents"], list)
def test_get_training_documents_with_annotations(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Should return doc1 and doc2 (both have annotations)
assert data["total"] == 2
def test_get_training_documents_min_annotation_count(self, client):
"""Test filtering by minimum annotation count."""
response = client.get("/admin/training/documents?min_annotation_count=3")
assert response.status_code == 200
data = response.json()
# Should return only doc2 (has 3 annotations)
assert data["total"] == 1
def test_get_training_documents_exclude_used(self, client):
"""Test excluding documents already used in training."""
response = client.get("/admin/training/documents?exclude_used_in_training=true")
assert response.status_code == 200
data = response.json()
# Should exclude doc1 (used in training)
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
def test_get_training_documents_annotation_sources(self, client):
"""Test that annotation sources are included."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Check that documents have annotation_sources field
for doc in data["documents"]:
assert "annotation_sources" in doc
assert isinstance(doc["annotation_sources"], dict)
assert "manual" in doc["annotation_sources"]
assert "auto" in doc["annotation_sources"]
def test_get_training_documents_pagination(self, client):
"""Test pagination parameters."""
response = client.get("/admin/training/documents?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["documents"]) <= 1
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
def test_get_training_models_success(self, client):
"""Test getting model versions list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "models" in data
assert data["total"] == 2
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that model versions include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics fields
model = data["models"][0]
assert "metrics_mAP" in model
assert model["metrics_mAP"] is not None
def test_get_training_models_includes_version_fields(self, client):
"""Test that model versions include version fields."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check model has expected fields
model = data["models"][0]
assert "version_id" in model
assert "version" in model
assert "name" in model
assert "status" in model
assert "is_active" in model
assert "document_count" in model
def test_get_training_models_filter_by_status(self, client):
"""Test filtering model versions by status."""
response = client.get("/admin/training/models?status=active")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
# All returned models should be active
for model in data["models"]:
assert model["status"] == "active"
def test_get_training_models_pagination(self, client):
"""Test pagination for model versions."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["models"]) == 1