Files
invoice-master-poc-v2/tests/web/test_training_phase4.py
2026-01-27 23:58:17 +01:00

385 lines
14 KiB
Python

"""
Tests for Phase 4: Training Data Management
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockTrainingTask:
"""Mock TrainingTask for testing."""
def __init__(self, **kwargs):
self.task_id = kwargs.get('task_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.name = kwargs.get('name', 'Test Training')
self.description = kwargs.get('description', None)
self.status = kwargs.get('status', 'completed')
self.task_type = kwargs.get('task_type', 'train')
self.config = kwargs.get('config', {})
self.scheduled_at = kwargs.get('scheduled_at', None)
self.cron_expression = kwargs.get('cron_expression', None)
self.is_recurring = kwargs.get('is_recurring', False)
self.started_at = kwargs.get('started_at', datetime.utcnow())
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
self.error_message = kwargs.get('error_message', None)
self.result_metrics = kwargs.get('result_metrics', {})
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.document_count = kwargs.get('document_count', 0)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockTrainingDocumentLink:
"""Mock TrainingDocumentLink for testing."""
def __init__(self, **kwargs):
self.link_id = kwargs.get('link_id', uuid4())
self.task_id = kwargs.get('task_id')
self.document_id = kwargs.get('document_id')
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'labeled')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100)
self.bbox_y = kwargs.get('bbox_y', 100)
self.bbox_width = kwargs.get('bbox_width', 200)
self.bbox_height = kwargs.get('bbox_height', 50)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.is_verified = kwargs.get('is_verified', False)
self.verified_at = kwargs.get('verified_at', None)
self.verified_by = kwargs.get('verified_by', None)
self.override_source = kwargs.get('override_source', None)
self.original_annotation_id = kwargs.get('original_annotation_id', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
def get_documents_for_training(
self,
admin_token,
status="labeled",
has_annotations=True,
min_annotation_count=None,
exclude_used_in_training=False,
limit=100,
offset=0,
):
"""Get documents for training."""
# Filter documents by criteria
filtered = []
for doc in self.documents.values():
if doc.admin_token != admin_token or doc.status != status:
continue
# Check annotations
annotations = self.annotations.get(str(doc.document_id), [])
if has_annotations and len(annotations) == 0:
continue
if min_annotation_count and len(annotations) < min_annotation_count:
continue
# Check if used in training
if exclude_used_in_training:
links = self.training_links.get(str(doc.document_id), [])
if links:
continue
filtered.append(doc)
total = len(filtered)
return filtered[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
def get_training_tasks_by_token(
self,
admin_token,
status=None,
limit=20,
offset=0,
):
"""Get training tasks filtered by token."""
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
if status:
tasks = [t for t in tasks if t.status == status]
total = len(tasks)
return tasks[offset:offset+limit], total
def get_training_task(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test documents
doc1 = MockAdminDocument(
filename="DOC001.pdf",
status="labeled",
)
doc2 = MockAdminDocument(
filename="DOC002.pdf",
status="labeled",
)
doc3 = MockAdminDocument(
filename="DOC003.pdf",
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
# Add annotations
mock_db.annotations[str(doc1.document_id)] = [
MockAnnotation(document_id=doc1.document_id, source="manual"),
MockAnnotation(document_id=doc1.document_id, source="auto"),
]
mock_db.annotations[str(doc2.document_id)] = [
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
]
# doc3 has no annotations
# Add training tasks
task1 = MockTrainingTask(
name="Training Run 2024-01",
status="completed",
document_count=500,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
)
task2 = MockTrainingTask(
name="Training Run 2024-02",
status="completed",
document_count=600,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
)
mock_db.training_tasks[str(task1.task_id)] = task1
mock_db.training_tasks[str(task2.task_id)] = task2
# Add training links (doc1 used in task1)
link1 = MockTrainingDocumentLink(
task_id=task1.task_id,
document_id=doc1.document_id,
)
mock_db.training_links[str(doc1.document_id)] = [link1]
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_training_router()
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestTrainingDocuments:
"""Tests for GET /admin/training/documents endpoint."""
def test_get_training_documents_success(self, client):
"""Test getting documents for training."""
response = client.get("/admin/training/documents")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "documents" in data
assert data["total"] >= 0
assert isinstance(data["documents"], list)
def test_get_training_documents_with_annotations(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Should return doc1 and doc2 (both have annotations)
assert data["total"] == 2
def test_get_training_documents_min_annotation_count(self, client):
"""Test filtering by minimum annotation count."""
response = client.get("/admin/training/documents?min_annotation_count=3")
assert response.status_code == 200
data = response.json()
# Should return only doc2 (has 3 annotations)
assert data["total"] == 1
def test_get_training_documents_exclude_used(self, client):
"""Test excluding documents already used in training."""
response = client.get("/admin/training/documents?exclude_used_in_training=true")
assert response.status_code == 200
data = response.json()
# Should exclude doc1 (used in training)
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
def test_get_training_documents_annotation_sources(self, client):
"""Test that annotation sources are included."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Check that documents have annotation_sources field
for doc in data["documents"]:
assert "annotation_sources" in doc
assert isinstance(doc["annotation_sources"], dict)
assert "manual" in doc["annotation_sources"]
assert "auto" in doc["annotation_sources"]
def test_get_training_documents_pagination(self, client):
"""Test pagination parameters."""
response = client.get("/admin/training/documents?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["documents"]) <= 1
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint."""
def test_get_training_models_success(self, client):
"""Test getting trained models list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "models" in data
assert data["total"] == 2
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics
model = data["models"][0]
assert "metrics" in model
assert "mAP" in model["metrics"]
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
def test_get_training_models_includes_download_url(self, client):
"""Test that completed models have download URLs."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check completed models have download URLs
for model in data["models"]:
if model["status"] == "completed":
assert "download_url" in model
assert model["download_url"] is not None
def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status."""
response = client.get("/admin/training/models?status=completed")
assert response.status_code == 200
data = response.json()
# All returned models should be completed
for model in data["models"]:
assert model["status"] == "completed"
def test_get_training_models_pagination(self, client):
"""Test pagination for models."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["models"]) == 1