Files
invoice-master-poc-v2/tests/data/repositories/test_model_version_repository.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

583 lines
26 KiB
Python

"""
Tests for ModelVersionRepository
100% coverage tests for model version management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import ModelVersion
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionRepository:
"""Tests for ModelVersionRepository."""
@pytest.fixture
def sample_model(self) -> ModelVersion:
"""Create a sample model version for testing."""
return ModelVersion(
version_id=uuid4(),
version="v1.0.0",
name="Test Model",
description="A test model",
model_path="/path/to/model.pt",
status="ready",
is_active=False,
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100},
file_size=1024000,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def active_model(self) -> ModelVersion:
"""Create an active model version for testing."""
return ModelVersion(
version_id=uuid4(),
version="v1.0.0",
name="Active Model",
model_path="/path/to/active_model.pt",
status="active",
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> ModelVersionRepository:
"""Create a ModelVersionRepository instance."""
return ModelVersionRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_model(self, repo):
"""Test create returns created model version."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
task_id = uuid4()
dataset_id = uuid4()
trained_at = datetime.now(timezone.utc)
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
version="v2.0.0",
name="Full Model",
model_path="/path/to/full_model.pt",
description="A complete model",
task_id=str(task_id),
dataset_id=str(dataset_id),
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
training_config={"epochs": 200},
file_size=2048000,
trained_at=trained_at,
)
added_model = mock_session.add.call_args[0][0]
assert added_model.version == "v2.0.0"
assert added_model.description == "A complete model"
assert added_model.task_id == task_id
assert added_model.dataset_id == dataset_id
assert added_model.metrics_mAP == 0.95
def test_create_with_uuid_objects(self, repo):
"""Test create works with UUID objects."""
task_id = uuid4()
dataset_id = uuid4()
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
task_id=task_id,
dataset_id=dataset_id,
)
added_model = mock_session.add.call_args[0][0]
assert added_model.task_id == task_id
assert added_model.dataset_id == dataset_id
def test_create_without_optional_ids(self, repo):
"""Test create without task_id and dataset_id."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
)
added_model = mock_session.add.call_args[0][0]
assert added_model.task_id is None
assert added_model.dataset_id is None
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_model(self, repo, sample_model):
"""Test get returns model when exists."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_model.version_id))
assert result is not None
assert result.name == "Test Model"
mock_session.expunge.assert_called_once()
def test_get_with_uuid(self, repo, sample_model):
"""Test get works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(sample_model.version_id)
assert result is not None
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
"""Test get_paginated returns list of models and total count."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated()
assert len(models) == 1
assert total == 1
def test_get_paginated_with_status_filter(self, repo, sample_model):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated(status="ready")
assert len(models) == 1
def test_get_paginated_with_pagination(self, repo, sample_model):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated(limit=10, offset=20)
assert total == 50
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated()
assert models == []
assert total == 0
# =========================================================================
# get_active() tests
# =========================================================================
def test_get_active_returns_active_model(self, repo, active_model):
"""Test get_active returns the active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active()
assert result is not None
assert result.is_active is True
mock_session.expunge.assert_called_once()
def test_get_active_returns_none(self, repo):
"""Test get_active returns None when no active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active()
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# activate() tests
# =========================================================================
def test_activate_activates_model(self, repo, sample_model, active_model):
"""Test activate sets model as active and deactivates others."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [active_model]
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(str(sample_model.version_id))
assert result is not None
assert sample_model.is_active is True
assert sample_model.status == "active"
assert active_model.is_active is False
assert active_model.status == "inactive"
def test_activate_with_uuid(self, repo, sample_model):
"""Test activate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(sample_model.version_id)
assert result is not None
assert sample_model.is_active is True
def test_activate_returns_none_when_not_found(self, repo):
"""Test activate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(str(uuid4()))
assert result is None
def test_activate_sets_activated_at(self, repo, sample_model):
"""Test activate sets activated_at timestamp."""
sample_model.activated_at = None
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.activate(str(sample_model.version_id))
assert sample_model.activated_at is not None
# =========================================================================
# deactivate() tests
# =========================================================================
def test_deactivate_deactivates_model(self, repo, active_model):
"""Test deactivate sets model as inactive."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(str(active_model.version_id))
assert result is not None
assert active_model.is_active is False
assert active_model.status == "inactive"
mock_session.commit.assert_called_once()
def test_deactivate_with_uuid(self, repo, active_model):
"""Test deactivate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(active_model.version_id)
assert result is not None
def test_deactivate_returns_none_when_not_found(self, repo):
"""Test deactivate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(str(uuid4()))
assert result is None
# =========================================================================
# update() tests
# =========================================================================
def test_update_updates_model(self, repo, sample_model):
"""Test update updates model metadata."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
name="Updated Model",
)
assert result is not None
assert sample_model.name == "Updated Model"
mock_session.commit.assert_called_once()
def test_update_all_fields(self, repo, sample_model):
"""Test update can update all fields."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
name="New Name",
description="New Description",
status="archived",
)
assert sample_model.name == "New Name"
assert sample_model.description == "New Description"
assert sample_model.status == "archived"
def test_update_with_uuid(self, repo, sample_model):
"""Test update works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(sample_model.version_id, name="Updated")
assert result is not None
def test_update_returns_none_when_not_found(self, repo):
"""Test update returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(str(uuid4()), name="New Name")
assert result is None
def test_update_partial_fields(self, repo, sample_model):
"""Test update only updates provided fields."""
original_name = sample_model.name
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
description="Only description changed",
)
assert sample_model.name == original_name
assert sample_model.description == "Only description changed"
# =========================================================================
# archive() tests
# =========================================================================
def test_archive_archives_model(self, repo, sample_model):
"""Test archive sets model status to archived."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(sample_model.version_id))
assert result is not None
assert sample_model.status == "archived"
mock_session.commit.assert_called_once()
def test_archive_with_uuid(self, repo, sample_model):
"""Test archive works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(sample_model.version_id)
assert result is not None
def test_archive_returns_none_when_not_found(self, repo):
"""Test archive returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(uuid4()))
assert result is None
def test_archive_returns_none_when_active(self, repo, active_model):
"""Test archive returns None when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(active_model.version_id))
assert result is None
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_model):
"""Test delete returns True when model exists and not active."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_model.version_id))
assert result is True
mock_session.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_with_uuid(self, repo, sample_model):
"""Test delete works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(sample_model.version_id)
assert result is True
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()
def test_delete_returns_false_when_active(self, repo, active_model):
"""Test delete returns False when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(active_model.version_id))
assert result is False
mock_session.delete.assert_not_called()