WIP
This commit is contained in:
384
tests/web/test_training_phase4.py
Normal file
384
tests/web/test_training_phase4.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Tests for Phase 4: Training Data Management
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.training import create_training_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockTrainingTask:
|
||||
"""Mock TrainingTask for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.task_id = kwargs.get('task_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.name = kwargs.get('name', 'Test Training')
|
||||
self.description = kwargs.get('description', None)
|
||||
self.status = kwargs.get('status', 'completed')
|
||||
self.task_type = kwargs.get('task_type', 'train')
|
||||
self.config = kwargs.get('config', {})
|
||||
self.scheduled_at = kwargs.get('scheduled_at', None)
|
||||
self.cron_expression = kwargs.get('cron_expression', None)
|
||||
self.is_recurring = kwargs.get('is_recurring', False)
|
||||
self.started_at = kwargs.get('started_at', datetime.utcnow())
|
||||
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
|
||||
self.error_message = kwargs.get('error_message', None)
|
||||
self.result_metrics = kwargs.get('result_metrics', {})
|
||||
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||
self.document_count = kwargs.get('document_count', 0)
|
||||
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockTrainingDocumentLink:
|
||||
"""Mock TrainingDocumentLink for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.link_id = kwargs.get('link_id', uuid4())
|
||||
self.task_id = kwargs.get('task_id')
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
"""Mock AdminDocument for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.document_id = kwargs.get('document_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.filename = kwargs.get('filename', 'test.pdf')
|
||||
self.file_size = kwargs.get('file_size', 100000)
|
||||
self.content_type = kwargs.get('content_type', 'application/pdf')
|
||||
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
|
||||
self.page_count = kwargs.get('page_count', 1)
|
||||
self.status = kwargs.get('status', 'labeled')
|
||||
self.auto_label_status = kwargs.get('auto_label_status', None)
|
||||
self.auto_label_error = kwargs.get('auto_label_error', None)
|
||||
self.upload_source = kwargs.get('upload_source', 'ui')
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAnnotation:
|
||||
"""Mock AdminAnnotation for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.page_number = kwargs.get('page_number', 1)
|
||||
self.class_id = kwargs.get('class_id', 0)
|
||||
self.class_name = kwargs.get('class_name', 'invoice_number')
|
||||
self.bbox_x = kwargs.get('bbox_x', 100)
|
||||
self.bbox_y = kwargs.get('bbox_y', 100)
|
||||
self.bbox_width = kwargs.get('bbox_width', 200)
|
||||
self.bbox_height = kwargs.get('bbox_height', 50)
|
||||
self.x_center = kwargs.get('x_center', 0.5)
|
||||
self.y_center = kwargs.get('y_center', 0.5)
|
||||
self.width = kwargs.get('width', 0.3)
|
||||
self.height = kwargs.get('height', 0.1)
|
||||
self.text_value = kwargs.get('text_value', 'INV-001')
|
||||
self.confidence = kwargs.get('confidence', 0.95)
|
||||
self.source = kwargs.get('source', 'manual')
|
||||
self.is_verified = kwargs.get('is_verified', False)
|
||||
self.verified_at = kwargs.get('verified_at', None)
|
||||
self.verified_by = kwargs.get('verified_by', None)
|
||||
self.override_source = kwargs.get('override_source', None)
|
||||
self.original_annotation_id = kwargs.get('original_annotation_id', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_documents_for_training(
|
||||
self,
|
||||
admin_token,
|
||||
status="labeled",
|
||||
has_annotations=True,
|
||||
min_annotation_count=None,
|
||||
exclude_used_in_training=False,
|
||||
limit=100,
|
||||
offset=0,
|
||||
):
|
||||
"""Get documents for training."""
|
||||
# Filter documents by criteria
|
||||
filtered = []
|
||||
for doc in self.documents.values():
|
||||
if doc.admin_token != admin_token or doc.status != status:
|
||||
continue
|
||||
|
||||
# Check annotations
|
||||
annotations = self.annotations.get(str(doc.document_id), [])
|
||||
if has_annotations and len(annotations) == 0:
|
||||
continue
|
||||
if min_annotation_count and len(annotations) < min_annotation_count:
|
||||
continue
|
||||
|
||||
# Check if used in training
|
||||
if exclude_used_in_training:
|
||||
links = self.training_links.get(str(doc.document_id), [])
|
||||
if links:
|
||||
continue
|
||||
|
||||
filtered.append(doc)
|
||||
|
||||
total = len(filtered)
|
||||
return filtered[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_tasks_by_token(
|
||||
self,
|
||||
admin_token,
|
||||
status=None,
|
||||
limit=20,
|
||||
offset=0,
|
||||
):
|
||||
"""Get training tasks filtered by token."""
|
||||
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
|
||||
if status:
|
||||
tasks = [t for t in tasks if t.status == status]
|
||||
|
||||
total = len(tasks)
|
||||
return tasks[offset:offset+limit], total
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
filename="DOC001.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
doc2 = MockAdminDocument(
|
||||
filename="DOC002.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
doc3 = MockAdminDocument(
|
||||
filename="DOC003.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(document_id=doc1.document_id, source="manual"),
|
||||
MockAnnotation(document_id=doc1.document_id, source="auto"),
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
]
|
||||
# doc3 has no annotations
|
||||
|
||||
# Add training tasks
|
||||
task1 = MockTrainingTask(
|
||||
name="Training Run 2024-01",
|
||||
status="completed",
|
||||
document_count=500,
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
)
|
||||
task2 = MockTrainingTask(
|
||||
name="Training Run 2024-02",
|
||||
status="completed",
|
||||
document_count=600,
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
)
|
||||
|
||||
mock_db.training_tasks[str(task1.task_id)] = task1
|
||||
mock_db.training_tasks[str(task2.task_id)] = task2
|
||||
|
||||
# Add training links (doc1 used in task1)
|
||||
link1 = MockTrainingDocumentLink(
|
||||
task_id=task1.task_id,
|
||||
document_id=doc1.document_id,
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_training_router()
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestTrainingDocuments:
|
||||
"""Tests for GET /admin/training/documents endpoint."""
|
||||
|
||||
def test_get_training_documents_success(self, client):
|
||||
"""Test getting documents for training."""
|
||||
response = client.get("/admin/training/documents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "documents" in data
|
||||
assert data["total"] >= 0
|
||||
assert isinstance(data["documents"], list)
|
||||
|
||||
def test_get_training_documents_with_annotations(self, client):
|
||||
"""Test filtering documents with annotations."""
|
||||
response = client.get("/admin/training/documents?has_annotations=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return doc1 and doc2 (both have annotations)
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_get_training_documents_min_annotation_count(self, client):
|
||||
"""Test filtering by minimum annotation count."""
|
||||
response = client.get("/admin/training/documents?min_annotation_count=3")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return only doc2 (has 3 annotations)
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_get_training_documents_exclude_used(self, client):
|
||||
"""Test excluding documents already used in training."""
|
||||
response = client.get("/admin/training/documents?exclude_used_in_training=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should exclude doc1 (used in training)
|
||||
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
|
||||
|
||||
def test_get_training_documents_annotation_sources(self, client):
|
||||
"""Test that annotation sources are included."""
|
||||
response = client.get("/admin/training/documents?has_annotations=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check that documents have annotation_sources field
|
||||
for doc in data["documents"]:
|
||||
assert "annotation_sources" in doc
|
||||
assert isinstance(doc["annotation_sources"], dict)
|
||||
assert "manual" in doc["annotation_sources"]
|
||||
assert "auto" in doc["annotation_sources"]
|
||||
|
||||
def test_get_training_documents_pagination(self, client):
|
||||
"""Test pagination parameters."""
|
||||
response = client.get("/admin/training/documents?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit"] == 1
|
||||
assert data["offset"] == 0
|
||||
assert len(data["documents"]) <= 1
|
||||
|
||||
|
||||
class TestTrainingModels:
|
||||
"""Tests for GET /admin/training/models endpoint."""
|
||||
|
||||
def test_get_training_models_success(self, client):
|
||||
"""Test getting trained models list."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "models" in data
|
||||
assert data["total"] == 2
|
||||
assert len(data["models"]) == 2
|
||||
|
||||
def test_get_training_models_includes_metrics(self, client):
|
||||
"""Test that models include metrics."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check first model has metrics
|
||||
model = data["models"][0]
|
||||
assert "metrics" in model
|
||||
assert "mAP" in model["metrics"]
|
||||
assert model["metrics"]["mAP"] is not None
|
||||
assert "precision" in model["metrics"]
|
||||
assert "recall" in model["metrics"]
|
||||
|
||||
def test_get_training_models_includes_download_url(self, client):
|
||||
"""Test that completed models have download URLs."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check completed models have download URLs
|
||||
for model in data["models"]:
|
||||
if model["status"] == "completed":
|
||||
assert "download_url" in model
|
||||
assert model["download_url"] is not None
|
||||
|
||||
def test_get_training_models_filter_by_status(self, client):
|
||||
"""Test filtering models by status."""
|
||||
response = client.get("/admin/training/models?status=completed")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned models should be completed
|
||||
for model in data["models"]:
|
||||
assert model["status"] == "completed"
|
||||
|
||||
def test_get_training_models_pagination(self, client):
|
||||
"""Test pagination for models."""
|
||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit"] == 1
|
||||
assert data["offset"] == 0
|
||||
assert len(data["models"]) == 1
|
||||
Reference in New Issue
Block a user