""" Tests for Phase 4: Training Data Management """ import pytest from datetime import datetime from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient from inference.web.api.v1.admin.training import create_training_router from inference.web.core.auth import validate_admin_token, get_admin_db class MockTrainingTask: """Mock TrainingTask for testing.""" def __init__(self, **kwargs): self.task_id = kwargs.get('task_id', uuid4()) self.admin_token = kwargs.get('admin_token', 'test-token') self.name = kwargs.get('name', 'Test Training') self.description = kwargs.get('description', None) self.status = kwargs.get('status', 'completed') self.task_type = kwargs.get('task_type', 'train') self.config = kwargs.get('config', {}) self.scheduled_at = kwargs.get('scheduled_at', None) self.cron_expression = kwargs.get('cron_expression', None) self.is_recurring = kwargs.get('is_recurring', False) self.started_at = kwargs.get('started_at', datetime.utcnow()) self.completed_at = kwargs.get('completed_at', datetime.utcnow()) self.error_message = kwargs.get('error_message', None) self.result_metrics = kwargs.get('result_metrics', {}) self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt') self.document_count = kwargs.get('document_count', 0) self.metrics_mAP = kwargs.get('metrics_mAP', 0.935) self.metrics_precision = kwargs.get('metrics_precision', 0.92) self.metrics_recall = kwargs.get('metrics_recall', 0.88) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockTrainingDocumentLink: """Mock TrainingDocumentLink for testing.""" def __init__(self, **kwargs): self.link_id = kwargs.get('link_id', uuid4()) self.task_id = kwargs.get('task_id') self.document_id = kwargs.get('document_id') self.annotation_snapshot = kwargs.get('annotation_snapshot', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) class MockAdminDocument: """Mock AdminDocument for testing.""" def __init__(self, **kwargs): self.document_id = kwargs.get('document_id', uuid4()) self.admin_token = kwargs.get('admin_token', 'test-token') self.filename = kwargs.get('filename', 'test.pdf') self.file_size = kwargs.get('file_size', 100000) self.content_type = kwargs.get('content_type', 'application/pdf') self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf') self.page_count = kwargs.get('page_count', 1) self.status = kwargs.get('status', 'labeled') self.auto_label_status = kwargs.get('auto_label_status', None) self.auto_label_error = kwargs.get('auto_label_error', None) self.upload_source = kwargs.get('upload_source', 'ui') self.batch_id = kwargs.get('batch_id', None) self.csv_field_values = kwargs.get('csv_field_values', None) self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None) self.annotation_lock_until = kwargs.get('annotation_lock_until', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAnnotation: """Mock AdminAnnotation for testing.""" def __init__(self, **kwargs): self.annotation_id = kwargs.get('annotation_id', uuid4()) self.document_id = kwargs.get('document_id') self.page_number = kwargs.get('page_number', 1) self.class_id = kwargs.get('class_id', 0) self.class_name = kwargs.get('class_name', 'invoice_number') self.bbox_x = kwargs.get('bbox_x', 100) self.bbox_y = kwargs.get('bbox_y', 100) self.bbox_width = kwargs.get('bbox_width', 200) self.bbox_height = kwargs.get('bbox_height', 50) self.x_center = kwargs.get('x_center', 0.5) self.y_center = kwargs.get('y_center', 0.5) self.width = kwargs.get('width', 0.3) self.height = kwargs.get('height', 0.1) self.text_value = kwargs.get('text_value', 'INV-001') self.confidence = kwargs.get('confidence', 0.95) self.source = kwargs.get('source', 'manual') self.is_verified = kwargs.get('is_verified', False) self.verified_at = kwargs.get('verified_at', None) self.verified_by = kwargs.get('verified_by', None) self.override_source = kwargs.get('override_source', None) self.original_annotation_id = kwargs.get('original_annotation_id', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockModelVersion: """Mock ModelVersion for testing.""" def __init__(self, **kwargs): self.version_id = kwargs.get('version_id', uuid4()) self.version = kwargs.get('version', '1.0.0') self.name = kwargs.get('name', 'Test Model') self.description = kwargs.get('description', None) self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt') self.status = kwargs.get('status', 'inactive') self.is_active = kwargs.get('is_active', False) self.task_id = kwargs.get('task_id', None) self.dataset_id = kwargs.get('dataset_id', None) self.metrics_mAP = kwargs.get('metrics_mAP', 0.935) self.metrics_precision = kwargs.get('metrics_precision', 0.92) self.metrics_recall = kwargs.get('metrics_recall', 0.88) self.document_count = kwargs.get('document_count', 100) self.training_config = kwargs.get('training_config', {}) self.file_size = kwargs.get('file_size', 52428800) self.trained_at = kwargs.get('trained_at', datetime.utcnow()) self.activated_at = kwargs.get('activated_at', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAdminDB: """Mock AdminDB for testing Phase 4.""" def __init__(self): self.documents = {} self.annotations = {} self.training_tasks = {} self.training_links = {} self.model_versions = {} def get_documents_for_training( self, admin_token, status="labeled", has_annotations=True, min_annotation_count=None, exclude_used_in_training=False, limit=100, offset=0, ): """Get documents for training.""" # Filter documents by criteria filtered = [] for doc in self.documents.values(): if doc.admin_token != admin_token or doc.status != status: continue # Check annotations annotations = self.annotations.get(str(doc.document_id), []) if has_annotations and len(annotations) == 0: continue if min_annotation_count and len(annotations) < min_annotation_count: continue # Check if used in training if exclude_used_in_training: links = self.training_links.get(str(doc.document_id), []) if links: continue filtered.append(doc) total = len(filtered) return filtered[offset:offset+limit], total def get_annotations_for_document(self, document_id): """Get annotations for document.""" return self.annotations.get(str(document_id), []) def get_document_training_tasks(self, document_id): """Get training tasks that used this document.""" return self.training_links.get(str(document_id), []) def get_training_tasks_by_token( self, admin_token, status=None, limit=20, offset=0, ): """Get training tasks filtered by token.""" tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token] if status: tasks = [t for t in tasks if t.status == status] total = len(tasks) return tasks[offset:offset+limit], total def get_training_task(self, task_id): """Get training task by ID.""" return self.training_tasks.get(str(task_id)) def get_model_versions(self, status=None, limit=20, offset=0): """Get model versions with optional filtering.""" models = list(self.model_versions.values()) if status: models = [m for m in models if m.status == status] total = len(models) return models[offset:offset+limit], total @pytest.fixture def app(): """Create test FastAPI app.""" app = FastAPI() # Create mock DB mock_db = MockAdminDB() # Add test documents doc1 = MockAdminDocument( filename="DOC001.pdf", status="labeled", ) doc2 = MockAdminDocument( filename="DOC002.pdf", status="labeled", ) doc3 = MockAdminDocument( filename="DOC003.pdf", status="labeled", ) mock_db.documents[str(doc1.document_id)] = doc1 mock_db.documents[str(doc2.document_id)] = doc2 mock_db.documents[str(doc3.document_id)] = doc3 # Add annotations mock_db.annotations[str(doc1.document_id)] = [ MockAnnotation(document_id=doc1.document_id, source="manual"), MockAnnotation(document_id=doc1.document_id, source="auto"), ] mock_db.annotations[str(doc2.document_id)] = [ MockAnnotation(document_id=doc2.document_id, source="auto"), MockAnnotation(document_id=doc2.document_id, source="auto"), MockAnnotation(document_id=doc2.document_id, source="auto"), ] # doc3 has no annotations # Add training tasks task1 = MockTrainingTask( name="Training Run 2024-01", status="completed", document_count=500, metrics_mAP=0.935, metrics_precision=0.92, metrics_recall=0.88, ) task2 = MockTrainingTask( name="Training Run 2024-02", status="completed", document_count=600, metrics_mAP=0.951, metrics_precision=0.94, metrics_recall=0.92, ) mock_db.training_tasks[str(task1.task_id)] = task1 mock_db.training_tasks[str(task2.task_id)] = task2 # Add training links (doc1 used in task1) link1 = MockTrainingDocumentLink( task_id=task1.task_id, document_id=doc1.document_id, ) mock_db.training_links[str(doc1.document_id)] = [link1] # Add model versions model1 = MockModelVersion( version="1.0.0", name="Model v1.0.0", status="inactive", is_active=False, metrics_mAP=0.935, metrics_precision=0.92, metrics_recall=0.88, document_count=500, ) model2 = MockModelVersion( version="1.1.0", name="Model v1.1.0", status="active", is_active=True, metrics_mAP=0.951, metrics_precision=0.94, metrics_recall=0.92, document_count=600, ) mock_db.model_versions[str(model1.version_id)] = model1 mock_db.model_versions[str(model2.version_id)] = model2 # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_admin_db] = lambda: mock_db # Include router router = create_training_router() app.include_router(router) return app @pytest.fixture def client(app): """Create test client.""" return TestClient(app) class TestTrainingDocuments: """Tests for GET /admin/training/documents endpoint.""" def test_get_training_documents_success(self, client): """Test getting documents for training.""" response = client.get("/admin/training/documents") assert response.status_code == 200 data = response.json() assert "total" in data assert "documents" in data assert data["total"] >= 0 assert isinstance(data["documents"], list) def test_get_training_documents_with_annotations(self, client): """Test filtering documents with annotations.""" response = client.get("/admin/training/documents?has_annotations=true") assert response.status_code == 200 data = response.json() # Should return doc1 and doc2 (both have annotations) assert data["total"] == 2 def test_get_training_documents_min_annotation_count(self, client): """Test filtering by minimum annotation count.""" response = client.get("/admin/training/documents?min_annotation_count=3") assert response.status_code == 200 data = response.json() # Should return only doc2 (has 3 annotations) assert data["total"] == 1 def test_get_training_documents_exclude_used(self, client): """Test excluding documents already used in training.""" response = client.get("/admin/training/documents?exclude_used_in_training=true") assert response.status_code == 200 data = response.json() # Should exclude doc1 (used in training) assert data["total"] == 1 # Only doc2 (doc3 has no annotations) def test_get_training_documents_annotation_sources(self, client): """Test that annotation sources are included.""" response = client.get("/admin/training/documents?has_annotations=true") assert response.status_code == 200 data = response.json() # Check that documents have annotation_sources field for doc in data["documents"]: assert "annotation_sources" in doc assert isinstance(doc["annotation_sources"], dict) assert "manual" in doc["annotation_sources"] assert "auto" in doc["annotation_sources"] def test_get_training_documents_pagination(self, client): """Test pagination parameters.""" response = client.get("/admin/training/documents?limit=1&offset=0") assert response.status_code == 200 data = response.json() assert data["limit"] == 1 assert data["offset"] == 0 assert len(data["documents"]) <= 1 class TestTrainingModels: """Tests for GET /admin/training/models endpoint (ModelVersionListResponse).""" def test_get_training_models_success(self, client): """Test getting model versions list.""" response = client.get("/admin/training/models") assert response.status_code == 200 data = response.json() assert "total" in data assert "models" in data assert data["total"] == 2 assert len(data["models"]) == 2 def test_get_training_models_includes_metrics(self, client): """Test that model versions include metrics.""" response = client.get("/admin/training/models") assert response.status_code == 200 data = response.json() # Check first model has metrics fields model = data["models"][0] assert "metrics_mAP" in model assert model["metrics_mAP"] is not None def test_get_training_models_includes_version_fields(self, client): """Test that model versions include version fields.""" response = client.get("/admin/training/models") assert response.status_code == 200 data = response.json() # Check model has expected fields model = data["models"][0] assert "version_id" in model assert "version" in model assert "name" in model assert "status" in model assert "is_active" in model assert "document_count" in model def test_get_training_models_filter_by_status(self, client): """Test filtering model versions by status.""" response = client.get("/admin/training/models?status=active") assert response.status_code == 200 data = response.json() assert data["total"] == 1 # All returned models should be active for model in data["models"]: assert model["status"] == "active" def test_get_training_models_pagination(self, client): """Test pagination for model versions.""" response = client.get("/admin/training/models?limit=1&offset=0") assert response.status_code == 200 data = response.json() assert data["limit"] == 1 assert data["offset"] == 0 assert len(data["models"]) == 1