Files
invoice-master-poc-v2/tests/web/test_model_versions.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

396 lines
14 KiB
Python

"""
Tests for Model Version API routes.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from inference.data.admin_models import ModelVersion
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
)
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_TOKEN = "test-admin-token-12345"
def _make_model_version(**overrides) -> MagicMock:
"""Create a mock ModelVersion."""
defaults = dict(
version_id=UUID(TEST_VERSION_UUID),
version="1.0.0",
name="test-model-v1",
description="Test model version",
model_path="/models/test-model-v1.pt",
status="inactive",
is_active=False,
task_id=UUID(TEST_TASK_UUID),
dataset_id=UUID(TEST_DATASET_UUID),
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100, "batch_size": 16},
file_size=52428800,
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
activated_at=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
)
defaults.update(overrides)
model = MagicMock(spec=ModelVersion)
for k, v in defaults.items():
setattr(model, k, v)
return model
def _find_endpoint(name: str):
"""Find endpoint function by name."""
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
@pytest.fixture
def mock_models_repo():
"""Mock ModelVersionRepository."""
return MagicMock()
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
def test_router_has_model_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("models" in p for p in paths)
def test_has_create_model_version_endpoint(self):
endpoint = _find_endpoint("create_model_version")
assert endpoint is not None
def test_has_list_model_versions_endpoint(self):
endpoint = _find_endpoint("list_model_versions")
assert endpoint is not None
def test_has_get_active_model_endpoint(self):
endpoint = _find_endpoint("get_active_model")
assert endpoint is not None
def test_has_activate_model_version_endpoint(self):
endpoint = _find_endpoint("activate_model_version")
assert endpoint is not None
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
description="Test model",
metrics_mAP=0.935,
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_models_repo.create.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
task_id=TEST_TASK_UUID,
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
call_kwargs = mock_models_repo.create.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_models_repo.get_paginated.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_models_repo.get_paginated.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status="active", limit=20, offset=0))
mock_models_repo.get_paginated.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_models_repo.get_active.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is False
assert result.model is None
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_models_repo.get.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_models_repo.get.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_models_repo.update.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_models_repo.update.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
status=None,
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_models_repo.update.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_models_repo.activate.return_value = _make_model_version(status="active", is_active=True)
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_models_repo.activate.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_models_repo.activate.return_value = None
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_models_repo.deactivate.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_models_repo.deactivate.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_models_repo.archive.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_models_repo.archive.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_models_repo.delete.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_models_repo.delete.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_models_repo.delete.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400
class TestModelVersionSchemas:
"""Tests for model version Pydantic schemas."""
def test_create_request_validation(self):
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model",
model_path="/models/test.pt",
)
assert request.version == "1.0.0"
assert request.name == "test-model"
assert request.document_count == 0
def test_create_request_with_metrics(self):
request = ModelVersionCreateRequest(
version="2.0.0",
name="test-model-v2",
model_path="/models/v2.pt",
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
assert request.metrics_mAP == 0.95
assert request.document_count == 500
def test_update_request_partial(self):
request = ModelVersionUpdateRequest(name="new-name")
assert request.name == "new-name"
assert request.description is None
assert request.status is None