WIP
This commit is contained in:
399
tests/web/test_model_versions.py
Normal file
399
tests/web/test_model_versions.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Tests for Model Version API routes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
)
|
||||
|
||||
|
||||
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
|
||||
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
def _make_model_version(**overrides) -> MagicMock:
|
||||
"""Create a mock ModelVersion."""
|
||||
defaults = dict(
|
||||
version_id=UUID(TEST_VERSION_UUID),
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
description="Test model version",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
dataset_id=UUID(TEST_DATASET_UUID),
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=100,
|
||||
training_config={"epochs": 100, "batch_size": 16},
|
||||
file_size=52428800,
|
||||
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||
activated_at=None,
|
||||
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
model = MagicMock(spec=ModelVersion)
|
||||
for k, v in defaults.items():
|
||||
setattr(model, k, v)
|
||||
return model
|
||||
|
||||
|
||||
def _find_endpoint(name: str):
|
||||
"""Find endpoint function by name."""
|
||||
router = create_training_router()
|
||||
for route in router.routes:
|
||||
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
class TestModelVersionRouterRegistration:
|
||||
"""Tests that model version endpoints are registered."""
|
||||
|
||||
def test_router_has_model_endpoints(self):
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
assert any("models" in p for p in paths)
|
||||
|
||||
def test_has_create_model_version_endpoint(self):
|
||||
endpoint = _find_endpoint("create_model_version")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_list_model_versions_endpoint(self):
|
||||
endpoint = _find_endpoint("list_model_versions")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_get_active_model_endpoint(self):
|
||||
endpoint = _find_endpoint("get_active_model")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_activate_model_version_endpoint(self):
|
||||
endpoint = _find_endpoint("activate_model_version")
|
||||
assert endpoint is not None
|
||||
|
||||
|
||||
class TestCreateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models."""
|
||||
|
||||
def test_create_model_version(self):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
description="Test model",
|
||||
metrics_mAP=0.935,
|
||||
document_count=100,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_model_version.assert_called_once()
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version created successfully"
|
||||
|
||||
def test_create_model_version_with_task_and_dataset(self):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
task_id=TEST_TASK_UUID,
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
call_kwargs = mock_db.create_model_version.call_args[1]
|
||||
assert call_kwargs["task_id"] == TEST_TASK_UUID
|
||||
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
|
||||
|
||||
|
||||
class TestListModelVersionsRoute:
|
||||
"""Tests for GET /admin/training/models."""
|
||||
|
||||
def test_list_model_versions(self):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = (
|
||||
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
|
||||
2,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.models) == 2
|
||||
assert result.models[0].version == "1.0.0"
|
||||
|
||||
def test_list_model_versions_with_status_filter(self):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
|
||||
|
||||
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
|
||||
assert result.total == 1
|
||||
assert result.models[0].status == "active"
|
||||
|
||||
|
||||
class TestGetActiveModelRoute:
|
||||
"""Tests for GET /admin/training/models/active."""
|
||||
|
||||
def test_get_active_model_when_exists(self):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.has_active_model is True
|
||||
assert result.model is not None
|
||||
assert result.model.is_active is True
|
||||
|
||||
def test_get_active_model_when_none(self):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = None
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.has_active_model is False
|
||||
assert result.model is None
|
||||
|
||||
|
||||
class TestGetModelVersionRoute:
|
||||
"""Tests for GET /admin/training/models/{version_id}."""
|
||||
|
||||
def test_get_model_version(self):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = _make_model_version()
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.version == "1.0.0"
|
||||
assert result.name == "test-model-v1"
|
||||
assert result.metrics_mAP == 0.935
|
||||
|
||||
def test_get_model_version_not_found(self):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestUpdateModelVersionRoute:
|
||||
"""Tests for PATCH /admin/training/models/{version_id}."""
|
||||
|
||||
def test_update_model_version(self):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.update_model_version.assert_called_once_with(
|
||||
version_id=TEST_VERSION_UUID,
|
||||
name="updated-name",
|
||||
description="Updated description",
|
||||
status=None,
|
||||
)
|
||||
assert result.message == "Model version updated successfully"
|
||||
|
||||
def test_update_model_version_not_found(self):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = None
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name")
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestActivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/activate."""
|
||||
|
||||
def test_activate_model_version(self):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result.status == "active"
|
||||
assert result.message == "Model version activated for inference"
|
||||
|
||||
def test_activate_model_version_not_found(self):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeactivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
|
||||
|
||||
def test_deactivate_model_version(self):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version deactivated"
|
||||
|
||||
def test_deactivate_model_version_not_found(self):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestArchiveModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/archive."""
|
||||
|
||||
def test_archive_model_version(self):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.status == "archived"
|
||||
assert result.message == "Model version archived"
|
||||
|
||||
def test_archive_active_model_fails(self):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestDeleteModelVersionRoute:
|
||||
"""Tests for DELETE /admin/training/models/{version_id}."""
|
||||
|
||||
def test_delete_model_version(self):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = True
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result["message"] == "Model version deleted"
|
||||
|
||||
def test_delete_active_model_fails(self):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = False
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestModelVersionSchemas:
|
||||
"""Tests for model version Pydantic schemas."""
|
||||
|
||||
def test_create_request_validation(self):
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model",
|
||||
model_path="/models/test.pt",
|
||||
)
|
||||
assert request.version == "1.0.0"
|
||||
assert request.name == "test-model"
|
||||
assert request.document_count == 0
|
||||
|
||||
def test_create_request_with_metrics(self):
|
||||
request = ModelVersionCreateRequest(
|
||||
version="2.0.0",
|
||||
name="test-model-v2",
|
||||
model_path="/models/v2.pt",
|
||||
metrics_mAP=0.95,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
)
|
||||
assert request.metrics_mAP == 0.95
|
||||
assert request.document_count == 500
|
||||
|
||||
def test_update_request_partial(self):
|
||||
request = ModelVersionUpdateRequest(name="new-name")
|
||||
assert request.name == "new-name"
|
||||
assert request.description is None
|
||||
assert request.status is None
|
||||
Reference in New Issue
Block a user