""" Tests for Model Version API routes. """ import asyncio from datetime import datetime, timezone from unittest.mock import MagicMock from uuid import UUID import pytest from inference.data.admin_models import ModelVersion from inference.web.api.v1.admin.training import create_training_router from inference.web.schemas.admin import ( ModelVersionCreateRequest, ModelVersionUpdateRequest, ) TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020" TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021" TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002" TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010" TEST_TOKEN = "test-admin-token-12345" def _make_model_version(**overrides) -> MagicMock: """Create a mock ModelVersion.""" defaults = dict( version_id=UUID(TEST_VERSION_UUID), version="1.0.0", name="test-model-v1", description="Test model version", model_path="/models/test-model-v1.pt", status="inactive", is_active=False, task_id=UUID(TEST_TASK_UUID), dataset_id=UUID(TEST_DATASET_UUID), metrics_mAP=0.935, metrics_precision=0.92, metrics_recall=0.88, document_count=100, training_config={"epochs": 100, "batch_size": 16}, file_size=52428800, trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc), activated_at=None, created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc), ) defaults.update(overrides) model = MagicMock(spec=ModelVersion) for k, v in defaults.items(): setattr(model, k, v) return model def _find_endpoint(name: str): """Find endpoint function by name.""" router = create_training_router() for route in router.routes: if hasattr(route, "endpoint") and route.endpoint.__name__ == name: return route.endpoint raise AssertionError(f"Endpoint {name} not found") class TestModelVersionRouterRegistration: """Tests that model version endpoints are registered.""" def test_router_has_model_endpoints(self): router = create_training_router() paths = [route.path for route in router.routes] assert any("models" in p for p in paths) def test_has_create_model_version_endpoint(self): endpoint = _find_endpoint("create_model_version") assert endpoint is not None def test_has_list_model_versions_endpoint(self): endpoint = _find_endpoint("list_model_versions") assert endpoint is not None def test_has_get_active_model_endpoint(self): endpoint = _find_endpoint("get_active_model") assert endpoint is not None def test_has_activate_model_version_endpoint(self): endpoint = _find_endpoint("activate_model_version") assert endpoint is not None class TestCreateModelVersionRoute: """Tests for POST /admin/training/models.""" def test_create_model_version(self): fn = _find_endpoint("create_model_version") mock_db = MagicMock() mock_db.create_model_version.return_value = _make_model_version() request = ModelVersionCreateRequest( version="1.0.0", name="test-model-v1", model_path="/models/test-model-v1.pt", description="Test model", metrics_mAP=0.935, document_count=100, ) result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.create_model_version.assert_called_once() assert result.version_id == TEST_VERSION_UUID assert result.status == "inactive" assert result.message == "Model version created successfully" def test_create_model_version_with_task_and_dataset(self): fn = _find_endpoint("create_model_version") mock_db = MagicMock() mock_db.create_model_version.return_value = _make_model_version() request = ModelVersionCreateRequest( version="1.0.0", name="test-model-v1", model_path="/models/test-model-v1.pt", task_id=TEST_TASK_UUID, dataset_id=TEST_DATASET_UUID, ) result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) call_kwargs = mock_db.create_model_version.call_args[1] assert call_kwargs["task_id"] == TEST_TASK_UUID assert call_kwargs["dataset_id"] == TEST_DATASET_UUID class TestListModelVersionsRoute: """Tests for GET /admin/training/models.""" def test_list_model_versions(self): fn = _find_endpoint("list_model_versions") mock_db = MagicMock() mock_db.get_model_versions.return_value = ( [_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")], 2, ) result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0)) assert result.total == 2 assert len(result.models) == 2 assert result.models[0].version == "1.0.0" def test_list_model_versions_with_status_filter(self): fn = _find_endpoint("list_model_versions") mock_db = MagicMock() mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1) result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0)) mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0) assert result.total == 1 assert result.models[0].status == "active" class TestGetActiveModelRoute: """Tests for GET /admin/training/models/active.""" def test_get_active_model_when_exists(self): fn = _find_endpoint("get_active_model") mock_db = MagicMock() mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True) result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db)) assert result.has_active_model is True assert result.model is not None assert result.model.is_active is True def test_get_active_model_when_none(self): fn = _find_endpoint("get_active_model") mock_db = MagicMock() mock_db.get_active_model_version.return_value = None result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db)) assert result.has_active_model is False assert result.model is None class TestGetModelVersionRoute: """Tests for GET /admin/training/models/{version_id}.""" def test_get_model_version(self): fn = _find_endpoint("get_model_version") mock_db = MagicMock() mock_db.get_model_version.return_value = _make_model_version() result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert result.version_id == TEST_VERSION_UUID assert result.version == "1.0.0" assert result.name == "test-model-v1" assert result.metrics_mAP == 0.935 def test_get_model_version_not_found(self): fn = _find_endpoint("get_model_version") mock_db = MagicMock() mock_db.get_model_version.return_value = None from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 class TestUpdateModelVersionRoute: """Tests for PATCH /admin/training/models/{version_id}.""" def test_update_model_version(self): fn = _find_endpoint("update_model_version") mock_db = MagicMock() mock_db.update_model_version.return_value = _make_model_version(name="updated-name") request = ModelVersionUpdateRequest(name="updated-name", description="Updated description") result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.update_model_version.assert_called_once_with( version_id=TEST_VERSION_UUID, name="updated-name", description="Updated description", status=None, ) assert result.message == "Model version updated successfully" def test_update_model_version_not_found(self): fn = _find_endpoint("update_model_version") mock_db = MagicMock() mock_db.update_model_version.return_value = None request = ModelVersionUpdateRequest(name="updated-name") from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 class TestActivateModelVersionRoute: """Tests for POST /admin/training/models/{version_id}/activate.""" def test_activate_model_version(self): fn = _find_endpoint("activate_model_version") mock_db = MagicMock() mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True) # Create mock request with app state mock_request = MagicMock() mock_request.app.state.inference_service = None result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID) assert result.status == "active" assert result.message == "Model version activated for inference" def test_activate_model_version_not_found(self): fn = _find_endpoint("activate_model_version") mock_db = MagicMock() mock_db.activate_model_version.return_value = None # Create mock request with app state mock_request = MagicMock() mock_request.app.state.inference_service = None from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 class TestDeactivateModelVersionRoute: """Tests for POST /admin/training/models/{version_id}/deactivate.""" def test_deactivate_model_version(self): fn = _find_endpoint("deactivate_model_version") mock_db = MagicMock() mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False) result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert result.status == "inactive" assert result.message == "Model version deactivated" def test_deactivate_model_version_not_found(self): fn = _find_endpoint("deactivate_model_version") mock_db = MagicMock() mock_db.deactivate_model_version.return_value = None from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 class TestArchiveModelVersionRoute: """Tests for POST /admin/training/models/{version_id}/archive.""" def test_archive_model_version(self): fn = _find_endpoint("archive_model_version") mock_db = MagicMock() mock_db.archive_model_version.return_value = _make_model_version(status="archived") result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert result.status == "archived" assert result.message == "Model version archived" def test_archive_active_model_fails(self): fn = _find_endpoint("archive_model_version") mock_db = MagicMock() mock_db.archive_model_version.return_value = None from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 400 class TestDeleteModelVersionRoute: """Tests for DELETE /admin/training/models/{version_id}.""" def test_delete_model_version(self): fn = _find_endpoint("delete_model_version") mock_db = MagicMock() mock_db.delete_model_version.return_value = True result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID) assert result["message"] == "Model version deleted" def test_delete_active_model_fails(self): fn = _find_endpoint("delete_model_version") mock_db = MagicMock() mock_db.delete_model_version.return_value = False from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 400 class TestModelVersionSchemas: """Tests for model version Pydantic schemas.""" def test_create_request_validation(self): request = ModelVersionCreateRequest( version="1.0.0", name="test-model", model_path="/models/test.pt", ) assert request.version == "1.0.0" assert request.name == "test-model" assert request.document_count == 0 def test_create_request_with_metrics(self): request = ModelVersionCreateRequest( version="2.0.0", name="test-model-v2", model_path="/models/v2.pt", metrics_mAP=0.95, metrics_precision=0.92, metrics_recall=0.88, document_count=500, ) assert request.metrics_mAP == 0.95 assert request.document_count == 500 def test_update_request_partial(self): request = ModelVersionUpdateRequest(name="new-name") assert request.name == "new-name" assert request.description is None assert request.status is None