This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1,399 @@
"""
Tests for Model Version API routes.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from inference.data.admin_models import ModelVersion
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
)
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_TOKEN = "test-admin-token-12345"
def _make_model_version(**overrides) -> MagicMock:
"""Create a mock ModelVersion."""
defaults = dict(
version_id=UUID(TEST_VERSION_UUID),
version="1.0.0",
name="test-model-v1",
description="Test model version",
model_path="/models/test-model-v1.pt",
status="inactive",
is_active=False,
task_id=UUID(TEST_TASK_UUID),
dataset_id=UUID(TEST_DATASET_UUID),
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100, "batch_size": 16},
file_size=52428800,
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
activated_at=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
)
defaults.update(overrides)
model = MagicMock(spec=ModelVersion)
for k, v in defaults.items():
setattr(model, k, v)
return model
def _find_endpoint(name: str):
"""Find endpoint function by name."""
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
def test_router_has_model_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("models" in p for p in paths)
def test_has_create_model_version_endpoint(self):
endpoint = _find_endpoint("create_model_version")
assert endpoint is not None
def test_has_list_model_versions_endpoint(self):
endpoint = _find_endpoint("list_model_versions")
assert endpoint is not None
def test_has_get_active_model_endpoint(self):
endpoint = _find_endpoint("get_active_model")
assert endpoint is not None
def test_has_activate_model_version_endpoint(self):
endpoint = _find_endpoint("activate_model_version")
assert endpoint is not None
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
description="Test model",
metrics_mAP=0.935,
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_model_version.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
task_id=TEST_TASK_UUID,
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
call_kwargs = mock_db.create_model_version.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is False
assert result.model is None
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.update_model_version.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
status=None,
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestModelVersionSchemas:
"""Tests for model version Pydantic schemas."""
def test_create_request_validation(self):
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model",
model_path="/models/test.pt",
)
assert request.version == "1.0.0"
assert request.name == "test-model"
assert request.document_count == 0
def test_create_request_with_metrics(self):
request = ModelVersionCreateRequest(
version="2.0.0",
name="test-model-v2",
model_path="/models/v2.pt",
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
assert request.metrics_mAP == 0.95
assert request.document_count == 500
def test_update_request_partial(self):
request = ModelVersionUpdateRequest(name="new-name")
assert request.name == "new-name"
assert request.description is None
assert request.status is None