This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -63,6 +63,12 @@ def _find_endpoint(name: str):
raise AssertionError(f"Endpoint {name} not found")
@pytest.fixture
def mock_models_repo():
"""Mock ModelVersionRepository."""
return MagicMock()
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
@@ -91,11 +97,10 @@ class TestModelVersionRouterRegistration:
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self):
def test_create_model_version(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
@@ -106,18 +111,17 @@ class TestCreateModelVersionRoute:
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.create_model_version.assert_called_once()
mock_models_repo.create.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self):
def test_create_model_version_with_task_and_dataset(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
@@ -127,9 +131,9 @@ class TestCreateModelVersionRoute:
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
call_kwargs = mock_db.create_model_version.call_args[1]
call_kwargs = mock_models_repo.create.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
@@ -137,30 +141,28 @@ class TestCreateModelVersionRoute:
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self):
def test_list_model_versions(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = (
mock_models_repo.get_paginated.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self):
def test_list_model_versions_with_status_filter(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
mock_models_repo.get_paginated.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status="active", limit=20, offset=0))
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
mock_models_repo.get_paginated.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
@@ -168,25 +170,23 @@ class TestListModelVersionsRoute:
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self):
def test_get_active_model_when_exists(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self):
def test_get_active_model_when_none(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = None
mock_models_repo.get_active.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is False
assert result.model is None
@@ -195,46 +195,43 @@ class TestGetActiveModelRoute:
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self):
def test_get_model_version(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = _make_model_version()
mock_models_repo.get.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self):
def test_get_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = None
mock_models_repo.get.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self):
def test_update_model_version(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
mock_models_repo.update.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.update_model_version.assert_called_once_with(
mock_models_repo.update.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
@@ -242,45 +239,42 @@ class TestUpdateModelVersionRoute:
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self):
def test_update_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = None
mock_models_repo.update.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self):
def test_activate_model_version(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
mock_models_repo.activate.return_value = _make_model_version(status="active", is_active=True)
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
mock_models_repo.activate.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self):
def test_activate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
mock_models_repo.activate.return_value = None
# Create mock request with app state
mock_request = MagicMock()
@@ -289,88 +283,82 @@ class TestActivateModelVersionRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self):
def test_deactivate_model_version(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
mock_models_repo.deactivate.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self):
def test_deactivate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = None
mock_models_repo.deactivate.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self):
def test_archive_model_version(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
mock_models_repo.archive.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self):
def test_archive_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = None
mock_models_repo.archive.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self):
def test_delete_model_version(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = True
mock_models_repo.delete.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
mock_models_repo.delete.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self):
def test_delete_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = False
mock_models_repo.delete.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400