WIP
This commit is contained in:
@@ -103,6 +103,31 @@ class MockAnnotation:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockModelVersion:
|
||||
"""Mock ModelVersion for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.version_id = kwargs.get('version_id', uuid4())
|
||||
self.version = kwargs.get('version', '1.0.0')
|
||||
self.name = kwargs.get('name', 'Test Model')
|
||||
self.description = kwargs.get('description', None)
|
||||
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||
self.status = kwargs.get('status', 'inactive')
|
||||
self.is_active = kwargs.get('is_active', False)
|
||||
self.task_id = kwargs.get('task_id', None)
|
||||
self.dataset_id = kwargs.get('dataset_id', None)
|
||||
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||
self.document_count = kwargs.get('document_count', 100)
|
||||
self.training_config = kwargs.get('training_config', {})
|
||||
self.file_size = kwargs.get('file_size', 52428800)
|
||||
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
|
||||
self.activated_at = kwargs.get('activated_at', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
|
||||
@@ -111,6 +136,7 @@ class MockAdminDB:
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
self.model_versions = {}
|
||||
|
||||
def get_documents_for_training(
|
||||
self,
|
||||
@@ -174,6 +200,14 @@ class MockAdminDB:
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||
"""Get model versions with optional filtering."""
|
||||
models = list(self.model_versions.values())
|
||||
if status:
|
||||
models = [m for m in models if m.status == status]
|
||||
total = len(models)
|
||||
return models[offset:offset+limit], total
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
@@ -241,6 +275,30 @@ def app():
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Add model versions
|
||||
model1 = MockModelVersion(
|
||||
version="1.0.0",
|
||||
name="Model v1.0.0",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
)
|
||||
model2 = MockModelVersion(
|
||||
version="1.1.0",
|
||||
name="Model v1.1.0",
|
||||
status="active",
|
||||
is_active=True,
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
document_count=600,
|
||||
)
|
||||
mock_db.model_versions[str(model1.version_id)] = model1
|
||||
mock_db.model_versions[str(model2.version_id)] = model2
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
|
||||
|
||||
|
||||
class TestTrainingModels:
|
||||
"""Tests for GET /admin/training/models endpoint."""
|
||||
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
|
||||
|
||||
def test_get_training_models_success(self, client):
|
||||
"""Test getting trained models list."""
|
||||
"""Test getting model versions list."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -338,43 +396,44 @@ class TestTrainingModels:
|
||||
assert len(data["models"]) == 2
|
||||
|
||||
def test_get_training_models_includes_metrics(self, client):
|
||||
"""Test that models include metrics."""
|
||||
"""Test that model versions include metrics."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check first model has metrics
|
||||
# Check first model has metrics fields
|
||||
model = data["models"][0]
|
||||
assert "metrics" in model
|
||||
assert "mAP" in model["metrics"]
|
||||
assert model["metrics"]["mAP"] is not None
|
||||
assert "precision" in model["metrics"]
|
||||
assert "recall" in model["metrics"]
|
||||
assert "metrics_mAP" in model
|
||||
assert model["metrics_mAP"] is not None
|
||||
|
||||
def test_get_training_models_includes_download_url(self, client):
|
||||
"""Test that completed models have download URLs."""
|
||||
def test_get_training_models_includes_version_fields(self, client):
|
||||
"""Test that model versions include version fields."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check completed models have download URLs
|
||||
for model in data["models"]:
|
||||
if model["status"] == "completed":
|
||||
assert "download_url" in model
|
||||
assert model["download_url"] is not None
|
||||
# Check model has expected fields
|
||||
model = data["models"][0]
|
||||
assert "version_id" in model
|
||||
assert "version" in model
|
||||
assert "name" in model
|
||||
assert "status" in model
|
||||
assert "is_active" in model
|
||||
assert "document_count" in model
|
||||
|
||||
def test_get_training_models_filter_by_status(self, client):
|
||||
"""Test filtering models by status."""
|
||||
response = client.get("/admin/training/models?status=completed")
|
||||
"""Test filtering model versions by status."""
|
||||
response = client.get("/admin/training/models?status=active")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned models should be completed
|
||||
assert data["total"] == 1
|
||||
# All returned models should be active
|
||||
for model in data["models"]:
|
||||
assert model["status"] == "completed"
|
||||
assert model["status"] == "active"
|
||||
|
||||
def test_get_training_models_pagination(self, client):
|
||||
"""Test pagination for models."""
|
||||
"""Test pagination for model versions."""
|
||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user