This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -103,6 +103,31 @@ class MockAnnotation:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockModelVersion:
"""Mock ModelVersion for testing."""
def __init__(self, **kwargs):
self.version_id = kwargs.get('version_id', uuid4())
self.version = kwargs.get('version', '1.0.0')
self.name = kwargs.get('name', 'Test Model')
self.description = kwargs.get('description', None)
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.status = kwargs.get('status', 'inactive')
self.is_active = kwargs.get('is_active', False)
self.task_id = kwargs.get('task_id', None)
self.dataset_id = kwargs.get('dataset_id', None)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.document_count = kwargs.get('document_count', 100)
self.training_config = kwargs.get('training_config', {})
self.file_size = kwargs.get('file_size', 52428800)
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
self.activated_at = kwargs.get('activated_at', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
@@ -111,6 +136,7 @@ class MockAdminDB:
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
self.model_versions = {}
def get_documents_for_training(
self,
@@ -174,6 +200,14 @@ class MockAdminDB:
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
models = [m for m in models if m.status == status]
total = len(models)
return models[offset:offset+limit], total
@pytest.fixture
def app():
@@ -241,6 +275,30 @@ def app():
)
mock_db.training_links[str(doc1.document_id)] = [link1]
# Add model versions
model1 = MockModelVersion(
version="1.0.0",
name="Model v1.0.0",
status="inactive",
is_active=False,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
model2 = MockModelVersion(
version="1.1.0",
name="Model v1.1.0",
status="active",
is_active=True,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint."""
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
def test_get_training_models_success(self, client):
"""Test getting trained models list."""
"""Test getting model versions list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
@@ -338,43 +396,44 @@ class TestTrainingModels:
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics."""
"""Test that model versions include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics
# Check first model has metrics fields
model = data["models"][0]
assert "metrics" in model
assert "mAP" in model["metrics"]
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
assert "metrics_mAP" in model
assert model["metrics_mAP"] is not None
def test_get_training_models_includes_download_url(self, client):
"""Test that completed models have download URLs."""
def test_get_training_models_includes_version_fields(self, client):
"""Test that model versions include version fields."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check completed models have download URLs
for model in data["models"]:
if model["status"] == "completed":
assert "download_url" in model
assert model["download_url"] is not None
# Check model has expected fields
model = data["models"][0]
assert "version_id" in model
assert "version" in model
assert "name" in model
assert "status" in model
assert "is_active" in model
assert "document_count" in model
def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status."""
response = client.get("/admin/training/models?status=completed")
"""Test filtering model versions by status."""
response = client.get("/admin/training/models?status=active")
assert response.status_code == 200
data = response.json()
# All returned models should be completed
assert data["total"] == 1
# All returned models should be active
for model in data["models"]:
assert model["status"] == "completed"
assert model["status"] == "active"
def test_get_training_models_pagination(self, client):
"""Test pagination for models."""
"""Test pagination for model versions."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200