408 lines
14 KiB
Python
408 lines
14 KiB
Python
"""
|
|
Tests for Model Version API routes.
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
|
|
from inference.data.admin_models import ModelVersion
|
|
from inference.web.api.v1.admin.training import create_training_router
|
|
from inference.web.schemas.admin import (
|
|
ModelVersionCreateRequest,
|
|
ModelVersionUpdateRequest,
|
|
)
|
|
|
|
|
|
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
|
|
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
|
|
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
|
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
|
TEST_TOKEN = "test-admin-token-12345"
|
|
|
|
|
|
def _make_model_version(**overrides) -> MagicMock:
|
|
"""Create a mock ModelVersion."""
|
|
defaults = dict(
|
|
version_id=UUID(TEST_VERSION_UUID),
|
|
version="1.0.0",
|
|
name="test-model-v1",
|
|
description="Test model version",
|
|
model_path="/models/test-model-v1.pt",
|
|
status="inactive",
|
|
is_active=False,
|
|
task_id=UUID(TEST_TASK_UUID),
|
|
dataset_id=UUID(TEST_DATASET_UUID),
|
|
metrics_mAP=0.935,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
document_count=100,
|
|
training_config={"epochs": 100, "batch_size": 16},
|
|
file_size=52428800,
|
|
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
|
activated_at=None,
|
|
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
|
)
|
|
defaults.update(overrides)
|
|
model = MagicMock(spec=ModelVersion)
|
|
for k, v in defaults.items():
|
|
setattr(model, k, v)
|
|
return model
|
|
|
|
|
|
def _find_endpoint(name: str):
|
|
"""Find endpoint function by name."""
|
|
router = create_training_router()
|
|
for route in router.routes:
|
|
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
|
return route.endpoint
|
|
raise AssertionError(f"Endpoint {name} not found")
|
|
|
|
|
|
class TestModelVersionRouterRegistration:
|
|
"""Tests that model version endpoints are registered."""
|
|
|
|
def test_router_has_model_endpoints(self):
|
|
router = create_training_router()
|
|
paths = [route.path for route in router.routes]
|
|
assert any("models" in p for p in paths)
|
|
|
|
def test_has_create_model_version_endpoint(self):
|
|
endpoint = _find_endpoint("create_model_version")
|
|
assert endpoint is not None
|
|
|
|
def test_has_list_model_versions_endpoint(self):
|
|
endpoint = _find_endpoint("list_model_versions")
|
|
assert endpoint is not None
|
|
|
|
def test_has_get_active_model_endpoint(self):
|
|
endpoint = _find_endpoint("get_active_model")
|
|
assert endpoint is not None
|
|
|
|
def test_has_activate_model_version_endpoint(self):
|
|
endpoint = _find_endpoint("activate_model_version")
|
|
assert endpoint is not None
|
|
|
|
|
|
class TestCreateModelVersionRoute:
|
|
"""Tests for POST /admin/training/models."""
|
|
|
|
def test_create_model_version(self):
|
|
fn = _find_endpoint("create_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.create_model_version.return_value = _make_model_version()
|
|
|
|
request = ModelVersionCreateRequest(
|
|
version="1.0.0",
|
|
name="test-model-v1",
|
|
model_path="/models/test-model-v1.pt",
|
|
description="Test model",
|
|
metrics_mAP=0.935,
|
|
document_count=100,
|
|
)
|
|
|
|
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
mock_db.create_model_version.assert_called_once()
|
|
assert result.version_id == TEST_VERSION_UUID
|
|
assert result.status == "inactive"
|
|
assert result.message == "Model version created successfully"
|
|
|
|
def test_create_model_version_with_task_and_dataset(self):
|
|
fn = _find_endpoint("create_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.create_model_version.return_value = _make_model_version()
|
|
|
|
request = ModelVersionCreateRequest(
|
|
version="1.0.0",
|
|
name="test-model-v1",
|
|
model_path="/models/test-model-v1.pt",
|
|
task_id=TEST_TASK_UUID,
|
|
dataset_id=TEST_DATASET_UUID,
|
|
)
|
|
|
|
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
call_kwargs = mock_db.create_model_version.call_args[1]
|
|
assert call_kwargs["task_id"] == TEST_TASK_UUID
|
|
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
|
|
|
|
|
|
class TestListModelVersionsRoute:
|
|
"""Tests for GET /admin/training/models."""
|
|
|
|
def test_list_model_versions(self):
|
|
fn = _find_endpoint("list_model_versions")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_model_versions.return_value = (
|
|
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
|
|
2,
|
|
)
|
|
|
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
|
|
|
assert result.total == 2
|
|
assert len(result.models) == 2
|
|
assert result.models[0].version == "1.0.0"
|
|
|
|
def test_list_model_versions_with_status_filter(self):
|
|
fn = _find_endpoint("list_model_versions")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
|
|
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
|
|
|
|
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
|
|
assert result.total == 1
|
|
assert result.models[0].status == "active"
|
|
|
|
|
|
class TestGetActiveModelRoute:
|
|
"""Tests for GET /admin/training/models/active."""
|
|
|
|
def test_get_active_model_when_exists(self):
|
|
fn = _find_endpoint("get_active_model")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
|
|
|
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
assert result.has_active_model is True
|
|
assert result.model is not None
|
|
assert result.model.is_active is True
|
|
|
|
def test_get_active_model_when_none(self):
|
|
fn = _find_endpoint("get_active_model")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_active_model_version.return_value = None
|
|
|
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
assert result.has_active_model is False
|
|
assert result.model is None
|
|
|
|
|
|
class TestGetModelVersionRoute:
|
|
"""Tests for GET /admin/training/models/{version_id}."""
|
|
|
|
def test_get_model_version(self):
|
|
fn = _find_endpoint("get_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_model_version.return_value = _make_model_version()
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
assert result.version_id == TEST_VERSION_UUID
|
|
assert result.version == "1.0.0"
|
|
assert result.name == "test-model-v1"
|
|
assert result.metrics_mAP == 0.935
|
|
|
|
def test_get_model_version_not_found(self):
|
|
fn = _find_endpoint("get_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.get_model_version.return_value = None
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
class TestUpdateModelVersionRoute:
|
|
"""Tests for PATCH /admin/training/models/{version_id}."""
|
|
|
|
def test_update_model_version(self):
|
|
fn = _find_endpoint("update_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
|
|
|
|
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
mock_db.update_model_version.assert_called_once_with(
|
|
version_id=TEST_VERSION_UUID,
|
|
name="updated-name",
|
|
description="Updated description",
|
|
status=None,
|
|
)
|
|
assert result.message == "Model version updated successfully"
|
|
|
|
def test_update_model_version_not_found(self):
|
|
fn = _find_endpoint("update_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.update_model_version.return_value = None
|
|
|
|
request = ModelVersionUpdateRequest(name="updated-name")
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
class TestActivateModelVersionRoute:
|
|
"""Tests for POST /admin/training/models/{version_id}/activate."""
|
|
|
|
def test_activate_model_version(self):
|
|
fn = _find_endpoint("activate_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
|
|
|
# Create mock request with app state
|
|
mock_request = MagicMock()
|
|
mock_request.app.state.inference_service = None
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
|
assert result.status == "active"
|
|
assert result.message == "Model version activated for inference"
|
|
|
|
def test_activate_model_version_not_found(self):
|
|
fn = _find_endpoint("activate_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.activate_model_version.return_value = None
|
|
|
|
# Create mock request with app state
|
|
mock_request = MagicMock()
|
|
mock_request.app.state.inference_service = None
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
class TestDeactivateModelVersionRoute:
|
|
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
|
|
|
|
def test_deactivate_model_version(self):
|
|
fn = _find_endpoint("deactivate_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
assert result.status == "inactive"
|
|
assert result.message == "Model version deactivated"
|
|
|
|
def test_deactivate_model_version_not_found(self):
|
|
fn = _find_endpoint("deactivate_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.deactivate_model_version.return_value = None
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
class TestArchiveModelVersionRoute:
|
|
"""Tests for POST /admin/training/models/{version_id}/archive."""
|
|
|
|
def test_archive_model_version(self):
|
|
fn = _find_endpoint("archive_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
assert result.status == "archived"
|
|
assert result.message == "Model version archived"
|
|
|
|
def test_archive_active_model_fails(self):
|
|
fn = _find_endpoint("archive_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.archive_model_version.return_value = None
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 400
|
|
|
|
|
|
class TestDeleteModelVersionRoute:
|
|
"""Tests for DELETE /admin/training/models/{version_id}."""
|
|
|
|
def test_delete_model_version(self):
|
|
fn = _find_endpoint("delete_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.delete_model_version.return_value = True
|
|
|
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
|
|
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
|
assert result["message"] == "Model version deleted"
|
|
|
|
def test_delete_active_model_fails(self):
|
|
fn = _find_endpoint("delete_model_version")
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.delete_model_version.return_value = False
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
|
assert exc_info.value.status_code == 400
|
|
|
|
|
|
class TestModelVersionSchemas:
|
|
"""Tests for model version Pydantic schemas."""
|
|
|
|
def test_create_request_validation(self):
|
|
request = ModelVersionCreateRequest(
|
|
version="1.0.0",
|
|
name="test-model",
|
|
model_path="/models/test.pt",
|
|
)
|
|
assert request.version == "1.0.0"
|
|
assert request.name == "test-model"
|
|
assert request.document_count == 0
|
|
|
|
def test_create_request_with_metrics(self):
|
|
request = ModelVersionCreateRequest(
|
|
version="2.0.0",
|
|
name="test-model-v2",
|
|
model_path="/models/v2.pt",
|
|
metrics_mAP=0.95,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
document_count=500,
|
|
)
|
|
assert request.metrics_mAP == 0.95
|
|
assert request.document_count == 500
|
|
|
|
def test_update_request_partial(self):
|
|
request = ModelVersionUpdateRequest(name="new-name")
|
|
assert request.name == "new-name"
|
|
assert request.description is None
|
|
assert request.status is None
|