482 lines
18 KiB
Python
482 lines
18 KiB
Python
"""
|
|
Tests for Phase 4: Training Data Management
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from inference.web.api.v1.admin.training import create_training_router
|
|
from inference.web.core.auth import (
|
|
validate_admin_token,
|
|
get_document_repository,
|
|
get_annotation_repository,
|
|
get_training_task_repository,
|
|
get_model_version_repository,
|
|
)
|
|
|
|
|
|
class MockTrainingTask:
|
|
"""Mock TrainingTask for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.task_id = kwargs.get('task_id', uuid4())
|
|
self.admin_token = kwargs.get('admin_token', 'test-token')
|
|
self.name = kwargs.get('name', 'Test Training')
|
|
self.description = kwargs.get('description', None)
|
|
self.status = kwargs.get('status', 'completed')
|
|
self.task_type = kwargs.get('task_type', 'train')
|
|
self.config = kwargs.get('config', {})
|
|
self.scheduled_at = kwargs.get('scheduled_at', None)
|
|
self.cron_expression = kwargs.get('cron_expression', None)
|
|
self.is_recurring = kwargs.get('is_recurring', False)
|
|
self.started_at = kwargs.get('started_at', datetime.utcnow())
|
|
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
|
|
self.error_message = kwargs.get('error_message', None)
|
|
self.result_metrics = kwargs.get('result_metrics', {})
|
|
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
|
self.document_count = kwargs.get('document_count', 0)
|
|
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
|
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
|
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockTrainingDocumentLink:
|
|
"""Mock TrainingDocumentLink for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.link_id = kwargs.get('link_id', uuid4())
|
|
self.task_id = kwargs.get('task_id')
|
|
self.document_id = kwargs.get('document_id')
|
|
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
|
|
|
|
class MockAdminDocument:
|
|
"""Mock AdminDocument for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.document_id = kwargs.get('document_id', uuid4())
|
|
self.admin_token = kwargs.get('admin_token', 'test-token')
|
|
self.filename = kwargs.get('filename', 'test.pdf')
|
|
self.file_size = kwargs.get('file_size', 100000)
|
|
self.content_type = kwargs.get('content_type', 'application/pdf')
|
|
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
|
|
self.page_count = kwargs.get('page_count', 1)
|
|
self.status = kwargs.get('status', 'labeled')
|
|
self.auto_label_status = kwargs.get('auto_label_status', None)
|
|
self.auto_label_error = kwargs.get('auto_label_error', None)
|
|
self.upload_source = kwargs.get('upload_source', 'ui')
|
|
self.batch_id = kwargs.get('batch_id', None)
|
|
self.csv_field_values = kwargs.get('csv_field_values', None)
|
|
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
|
|
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockAnnotation:
|
|
"""Mock AdminAnnotation for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
|
self.document_id = kwargs.get('document_id')
|
|
self.page_number = kwargs.get('page_number', 1)
|
|
self.class_id = kwargs.get('class_id', 0)
|
|
self.class_name = kwargs.get('class_name', 'invoice_number')
|
|
self.bbox_x = kwargs.get('bbox_x', 100)
|
|
self.bbox_y = kwargs.get('bbox_y', 100)
|
|
self.bbox_width = kwargs.get('bbox_width', 200)
|
|
self.bbox_height = kwargs.get('bbox_height', 50)
|
|
self.x_center = kwargs.get('x_center', 0.5)
|
|
self.y_center = kwargs.get('y_center', 0.5)
|
|
self.width = kwargs.get('width', 0.3)
|
|
self.height = kwargs.get('height', 0.1)
|
|
self.text_value = kwargs.get('text_value', 'INV-001')
|
|
self.confidence = kwargs.get('confidence', 0.95)
|
|
self.source = kwargs.get('source', 'manual')
|
|
self.is_verified = kwargs.get('is_verified', False)
|
|
self.verified_at = kwargs.get('verified_at', None)
|
|
self.verified_by = kwargs.get('verified_by', None)
|
|
self.override_source = kwargs.get('override_source', None)
|
|
self.original_annotation_id = kwargs.get('original_annotation_id', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockModelVersion:
|
|
"""Mock ModelVersion for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.version_id = kwargs.get('version_id', uuid4())
|
|
self.version = kwargs.get('version', '1.0.0')
|
|
self.name = kwargs.get('name', 'Test Model')
|
|
self.description = kwargs.get('description', None)
|
|
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
|
self.status = kwargs.get('status', 'inactive')
|
|
self.is_active = kwargs.get('is_active', False)
|
|
self.task_id = kwargs.get('task_id', None)
|
|
self.dataset_id = kwargs.get('dataset_id', None)
|
|
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
|
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
|
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
|
self.document_count = kwargs.get('document_count', 100)
|
|
self.training_config = kwargs.get('training_config', {})
|
|
self.file_size = kwargs.get('file_size', 52428800)
|
|
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
|
|
self.activated_at = kwargs.get('activated_at', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockDocumentRepository:
|
|
"""Mock DocumentRepository for testing Phase 4."""
|
|
|
|
def __init__(self):
|
|
self.documents = {}
|
|
self.annotations = {} # Shared reference for filtering
|
|
self.training_links = {} # Shared reference for filtering
|
|
|
|
def get_for_training(
|
|
self,
|
|
admin_token=None,
|
|
status="labeled",
|
|
has_annotations=True,
|
|
min_annotation_count=None,
|
|
exclude_used_in_training=False,
|
|
limit=100,
|
|
offset=0,
|
|
):
|
|
"""Get documents for training."""
|
|
# Filter documents by criteria
|
|
filtered = []
|
|
for doc in self.documents.values():
|
|
if doc.admin_token != admin_token or doc.status != status:
|
|
continue
|
|
|
|
# Check annotations
|
|
annotations = self.annotations.get(str(doc.document_id), [])
|
|
if has_annotations and len(annotations) == 0:
|
|
continue
|
|
if min_annotation_count and len(annotations) < min_annotation_count:
|
|
continue
|
|
|
|
# Check if used in training
|
|
if exclude_used_in_training:
|
|
links = self.training_links.get(str(doc.document_id), [])
|
|
if links:
|
|
continue
|
|
|
|
filtered.append(doc)
|
|
|
|
total = len(filtered)
|
|
return filtered[offset:offset+limit], total
|
|
|
|
|
|
class MockAnnotationRepository:
|
|
"""Mock AnnotationRepository for testing Phase 4."""
|
|
|
|
def __init__(self):
|
|
self.annotations = {}
|
|
|
|
def get_for_document(self, document_id, page_number=None):
|
|
"""Get annotations for document."""
|
|
return self.annotations.get(str(document_id), [])
|
|
|
|
|
|
class MockTrainingTaskRepository:
|
|
"""Mock TrainingTaskRepository for testing Phase 4."""
|
|
|
|
def __init__(self):
|
|
self.training_tasks = {}
|
|
self.training_links = {}
|
|
|
|
def get_paginated(
|
|
self,
|
|
admin_token=None,
|
|
status=None,
|
|
limit=20,
|
|
offset=0,
|
|
):
|
|
"""Get training tasks filtered by token."""
|
|
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
|
|
if status:
|
|
tasks = [t for t in tasks if t.status == status]
|
|
|
|
total = len(tasks)
|
|
return tasks[offset:offset+limit], total
|
|
|
|
def get(self, task_id):
|
|
"""Get training task by ID."""
|
|
return self.training_tasks.get(str(task_id))
|
|
|
|
def get_document_training_tasks(self, document_id):
|
|
"""Get training tasks that used this document."""
|
|
return self.training_links.get(str(document_id), [])
|
|
|
|
|
|
class MockModelVersionRepository:
|
|
"""Mock ModelVersionRepository for testing Phase 4."""
|
|
|
|
def __init__(self):
|
|
self.model_versions = {}
|
|
|
|
def get_paginated(self, status=None, limit=20, offset=0):
|
|
"""Get model versions with optional filtering."""
|
|
models = list(self.model_versions.values())
|
|
if status:
|
|
models = [m for m in models if m.status == status]
|
|
total = len(models)
|
|
return models[offset:offset+limit], total
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create test FastAPI app."""
|
|
app = FastAPI()
|
|
|
|
# Create mock repositories
|
|
mock_document_repo = MockDocumentRepository()
|
|
mock_annotation_repo = MockAnnotationRepository()
|
|
mock_training_task_repo = MockTrainingTaskRepository()
|
|
mock_model_version_repo = MockModelVersionRepository()
|
|
|
|
# Add test documents
|
|
doc1 = MockAdminDocument(
|
|
filename="DOC001.pdf",
|
|
status="labeled",
|
|
)
|
|
doc2 = MockAdminDocument(
|
|
filename="DOC002.pdf",
|
|
status="labeled",
|
|
)
|
|
doc3 = MockAdminDocument(
|
|
filename="DOC003.pdf",
|
|
status="labeled",
|
|
)
|
|
|
|
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
|
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
|
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
|
|
|
# Add annotations
|
|
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
|
MockAnnotation(document_id=doc1.document_id, source="manual"),
|
|
MockAnnotation(document_id=doc1.document_id, source="auto"),
|
|
]
|
|
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
|
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
|
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
|
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
|
]
|
|
# doc3 has no annotations
|
|
|
|
# Share annotation data with document repo for filtering
|
|
mock_document_repo.annotations = mock_annotation_repo.annotations
|
|
|
|
# Add training tasks
|
|
task1 = MockTrainingTask(
|
|
name="Training Run 2024-01",
|
|
status="completed",
|
|
document_count=500,
|
|
metrics_mAP=0.935,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
)
|
|
task2 = MockTrainingTask(
|
|
name="Training Run 2024-02",
|
|
status="completed",
|
|
document_count=600,
|
|
metrics_mAP=0.951,
|
|
metrics_precision=0.94,
|
|
metrics_recall=0.92,
|
|
)
|
|
|
|
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
|
|
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
|
|
|
|
# Add training links (doc1 used in task1)
|
|
link1 = MockTrainingDocumentLink(
|
|
task_id=task1.task_id,
|
|
document_id=doc1.document_id,
|
|
)
|
|
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
|
|
|
|
# Share training links with document repo for filtering
|
|
mock_document_repo.training_links = mock_training_task_repo.training_links
|
|
|
|
# Add model versions
|
|
model1 = MockModelVersion(
|
|
version="1.0.0",
|
|
name="Model v1.0.0",
|
|
status="inactive",
|
|
is_active=False,
|
|
metrics_mAP=0.935,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
document_count=500,
|
|
)
|
|
model2 = MockModelVersion(
|
|
version="1.1.0",
|
|
name="Model v1.1.0",
|
|
status="active",
|
|
is_active=True,
|
|
metrics_mAP=0.951,
|
|
metrics_precision=0.94,
|
|
metrics_recall=0.92,
|
|
document_count=600,
|
|
)
|
|
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
|
|
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
|
|
|
|
# Override dependencies
|
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
|
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
|
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
|
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
|
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
|
|
|
|
# Include router
|
|
router = create_training_router()
|
|
app.include_router(router)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
class TestTrainingDocuments:
|
|
"""Tests for GET /admin/training/documents endpoint."""
|
|
|
|
def test_get_training_documents_success(self, client):
|
|
"""Test getting documents for training."""
|
|
response = client.get("/admin/training/documents")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total" in data
|
|
assert "documents" in data
|
|
assert data["total"] >= 0
|
|
assert isinstance(data["documents"], list)
|
|
|
|
def test_get_training_documents_with_annotations(self, client):
|
|
"""Test filtering documents with annotations."""
|
|
response = client.get("/admin/training/documents?has_annotations=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Should return doc1 and doc2 (both have annotations)
|
|
assert data["total"] == 2
|
|
|
|
def test_get_training_documents_min_annotation_count(self, client):
|
|
"""Test filtering by minimum annotation count."""
|
|
response = client.get("/admin/training/documents?min_annotation_count=3")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Should return only doc2 (has 3 annotations)
|
|
assert data["total"] == 1
|
|
|
|
def test_get_training_documents_exclude_used(self, client):
|
|
"""Test excluding documents already used in training."""
|
|
response = client.get("/admin/training/documents?exclude_used_in_training=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Should exclude doc1 (used in training)
|
|
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
|
|
|
|
def test_get_training_documents_annotation_sources(self, client):
|
|
"""Test that annotation sources are included."""
|
|
response = client.get("/admin/training/documents?has_annotations=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Check that documents have annotation_sources field
|
|
for doc in data["documents"]:
|
|
assert "annotation_sources" in doc
|
|
assert isinstance(doc["annotation_sources"], dict)
|
|
assert "manual" in doc["annotation_sources"]
|
|
assert "auto" in doc["annotation_sources"]
|
|
|
|
def test_get_training_documents_pagination(self, client):
|
|
"""Test pagination parameters."""
|
|
response = client.get("/admin/training/documents?limit=1&offset=0")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["limit"] == 1
|
|
assert data["offset"] == 0
|
|
assert len(data["documents"]) <= 1
|
|
|
|
|
|
class TestTrainingModels:
|
|
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
|
|
|
|
def test_get_training_models_success(self, client):
|
|
"""Test getting model versions list."""
|
|
response = client.get("/admin/training/models")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total" in data
|
|
assert "models" in data
|
|
assert data["total"] == 2
|
|
assert len(data["models"]) == 2
|
|
|
|
def test_get_training_models_includes_metrics(self, client):
|
|
"""Test that model versions include metrics."""
|
|
response = client.get("/admin/training/models")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Check first model has metrics fields
|
|
model = data["models"][0]
|
|
assert "metrics_mAP" in model
|
|
assert model["metrics_mAP"] is not None
|
|
|
|
def test_get_training_models_includes_version_fields(self, client):
|
|
"""Test that model versions include version fields."""
|
|
response = client.get("/admin/training/models")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Check model has expected fields
|
|
model = data["models"][0]
|
|
assert "version_id" in model
|
|
assert "version" in model
|
|
assert "name" in model
|
|
assert "status" in model
|
|
assert "is_active" in model
|
|
assert "document_count" in model
|
|
|
|
def test_get_training_models_filter_by_status(self, client):
|
|
"""Test filtering model versions by status."""
|
|
response = client.get("/admin/training/models?status=active")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
# All returned models should be active
|
|
for model in data["models"]:
|
|
assert model["status"] == "active"
|
|
|
|
def test_get_training_models_pagination(self, client):
|
|
"""Test pagination for model versions."""
|
|
response = client.get("/admin/training/models?limit=1&offset=0")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["limit"] == 1
|
|
assert data["offset"] == 0
|
|
assert len(data["models"]) == 1
|