583 lines
26 KiB
Python
583 lines
26 KiB
Python
"""
|
|
Tests for ModelVersionRepository
|
|
|
|
100% coverage tests for model version management.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import uuid4, UUID
|
|
|
|
from inference.data.admin_models import ModelVersion
|
|
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
|
|
|
|
|
class TestModelVersionRepository:
|
|
"""Tests for ModelVersionRepository."""
|
|
|
|
@pytest.fixture
|
|
def sample_model(self) -> ModelVersion:
|
|
"""Create a sample model version for testing."""
|
|
return ModelVersion(
|
|
version_id=uuid4(),
|
|
version="v1.0.0",
|
|
name="Test Model",
|
|
description="A test model",
|
|
model_path="/path/to/model.pt",
|
|
status="ready",
|
|
is_active=False,
|
|
metrics_mAP=0.95,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
document_count=100,
|
|
training_config={"epochs": 100},
|
|
file_size=1024000,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@pytest.fixture
|
|
def active_model(self) -> ModelVersion:
|
|
"""Create an active model version for testing."""
|
|
return ModelVersion(
|
|
version_id=uuid4(),
|
|
version="v1.0.0",
|
|
name="Active Model",
|
|
model_path="/path/to/active_model.pt",
|
|
status="active",
|
|
is_active=True,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@pytest.fixture
|
|
def repo(self) -> ModelVersionRepository:
|
|
"""Create a ModelVersionRepository instance."""
|
|
return ModelVersionRepository()
|
|
|
|
# =========================================================================
|
|
# create() tests
|
|
# =========================================================================
|
|
|
|
def test_create_returns_model(self, repo):
|
|
"""Test create returns created model version."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.create(
|
|
version="v1.0.0",
|
|
name="Test Model",
|
|
model_path="/path/to/model.pt",
|
|
)
|
|
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_create_with_all_params(self, repo):
|
|
"""Test create with all parameters."""
|
|
task_id = uuid4()
|
|
dataset_id = uuid4()
|
|
trained_at = datetime.now(timezone.utc)
|
|
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.create(
|
|
version="v2.0.0",
|
|
name="Full Model",
|
|
model_path="/path/to/full_model.pt",
|
|
description="A complete model",
|
|
task_id=str(task_id),
|
|
dataset_id=str(dataset_id),
|
|
metrics_mAP=0.95,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
document_count=500,
|
|
training_config={"epochs": 200},
|
|
file_size=2048000,
|
|
trained_at=trained_at,
|
|
)
|
|
|
|
added_model = mock_session.add.call_args[0][0]
|
|
assert added_model.version == "v2.0.0"
|
|
assert added_model.description == "A complete model"
|
|
assert added_model.task_id == task_id
|
|
assert added_model.dataset_id == dataset_id
|
|
assert added_model.metrics_mAP == 0.95
|
|
|
|
def test_create_with_uuid_objects(self, repo):
|
|
"""Test create works with UUID objects."""
|
|
task_id = uuid4()
|
|
dataset_id = uuid4()
|
|
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.create(
|
|
version="v1.0.0",
|
|
name="Test Model",
|
|
model_path="/path/to/model.pt",
|
|
task_id=task_id,
|
|
dataset_id=dataset_id,
|
|
)
|
|
|
|
added_model = mock_session.add.call_args[0][0]
|
|
assert added_model.task_id == task_id
|
|
assert added_model.dataset_id == dataset_id
|
|
|
|
def test_create_without_optional_ids(self, repo):
|
|
"""Test create without task_id and dataset_id."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.create(
|
|
version="v1.0.0",
|
|
name="Test Model",
|
|
model_path="/path/to/model.pt",
|
|
)
|
|
|
|
added_model = mock_session.add.call_args[0][0]
|
|
assert added_model.task_id is None
|
|
assert added_model.dataset_id is None
|
|
|
|
# =========================================================================
|
|
# get() tests
|
|
# =========================================================================
|
|
|
|
def test_get_returns_model(self, repo, sample_model):
|
|
"""Test get returns model when exists."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get(str(sample_model.version_id))
|
|
|
|
assert result is not None
|
|
assert result.name == "Test Model"
|
|
mock_session.expunge.assert_called_once()
|
|
|
|
def test_get_with_uuid(self, repo, sample_model):
|
|
"""Test get works with UUID object."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get(sample_model.version_id)
|
|
|
|
assert result is not None
|
|
|
|
def test_get_returns_none_when_not_found(self, repo):
|
|
"""Test get returns None when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get(str(uuid4()))
|
|
|
|
assert result is None
|
|
mock_session.expunge.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# get_paginated() tests
|
|
# =========================================================================
|
|
|
|
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
|
|
"""Test get_paginated returns list of models and total count."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.one.return_value = 1
|
|
mock_session.exec.return_value.all.return_value = [sample_model]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
models, total = repo.get_paginated()
|
|
|
|
assert len(models) == 1
|
|
assert total == 1
|
|
|
|
def test_get_paginated_with_status_filter(self, repo, sample_model):
|
|
"""Test get_paginated filters by status."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.one.return_value = 1
|
|
mock_session.exec.return_value.all.return_value = [sample_model]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
models, total = repo.get_paginated(status="ready")
|
|
|
|
assert len(models) == 1
|
|
|
|
def test_get_paginated_with_pagination(self, repo, sample_model):
|
|
"""Test get_paginated with limit and offset."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.one.return_value = 50
|
|
mock_session.exec.return_value.all.return_value = [sample_model]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
models, total = repo.get_paginated(limit=10, offset=20)
|
|
|
|
assert total == 50
|
|
|
|
def test_get_paginated_empty_results(self, repo):
|
|
"""Test get_paginated with no results."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.one.return_value = 0
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
models, total = repo.get_paginated()
|
|
|
|
assert models == []
|
|
assert total == 0
|
|
|
|
# =========================================================================
|
|
# get_active() tests
|
|
# =========================================================================
|
|
|
|
def test_get_active_returns_active_model(self, repo, active_model):
|
|
"""Test get_active returns the active model."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.first.return_value = active_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_active()
|
|
|
|
assert result is not None
|
|
assert result.is_active is True
|
|
mock_session.expunge.assert_called_once()
|
|
|
|
def test_get_active_returns_none(self, repo):
|
|
"""Test get_active returns None when no active model."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.first.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_active()
|
|
|
|
assert result is None
|
|
mock_session.expunge.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# activate() tests
|
|
# =========================================================================
|
|
|
|
def test_activate_activates_model(self, repo, sample_model, active_model):
|
|
"""Test activate sets model as active and deactivates others."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [active_model]
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.activate(str(sample_model.version_id))
|
|
|
|
assert result is not None
|
|
assert sample_model.is_active is True
|
|
assert sample_model.status == "active"
|
|
assert active_model.is_active is False
|
|
assert active_model.status == "inactive"
|
|
|
|
def test_activate_with_uuid(self, repo, sample_model):
|
|
"""Test activate works with UUID object."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.activate(sample_model.version_id)
|
|
|
|
assert result is not None
|
|
assert sample_model.is_active is True
|
|
|
|
def test_activate_returns_none_when_not_found(self, repo):
|
|
"""Test activate returns None when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.activate(str(uuid4()))
|
|
|
|
assert result is None
|
|
|
|
def test_activate_sets_activated_at(self, repo, sample_model):
|
|
"""Test activate sets activated_at timestamp."""
|
|
sample_model.activated_at = None
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.activate(str(sample_model.version_id))
|
|
|
|
assert sample_model.activated_at is not None
|
|
|
|
# =========================================================================
|
|
# deactivate() tests
|
|
# =========================================================================
|
|
|
|
def test_deactivate_deactivates_model(self, repo, active_model):
|
|
"""Test deactivate sets model as inactive."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = active_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.deactivate(str(active_model.version_id))
|
|
|
|
assert result is not None
|
|
assert active_model.is_active is False
|
|
assert active_model.status == "inactive"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_deactivate_with_uuid(self, repo, active_model):
|
|
"""Test deactivate works with UUID object."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = active_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.deactivate(active_model.version_id)
|
|
|
|
assert result is not None
|
|
|
|
def test_deactivate_returns_none_when_not_found(self, repo):
|
|
"""Test deactivate returns None when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.deactivate(str(uuid4()))
|
|
|
|
assert result is None
|
|
|
|
# =========================================================================
|
|
# update() tests
|
|
# =========================================================================
|
|
|
|
def test_update_updates_model(self, repo, sample_model):
|
|
"""Test update updates model metadata."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_model.version_id),
|
|
name="Updated Model",
|
|
)
|
|
|
|
assert result is not None
|
|
assert sample_model.name == "Updated Model"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_update_all_fields(self, repo, sample_model):
|
|
"""Test update can update all fields."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_model.version_id),
|
|
name="New Name",
|
|
description="New Description",
|
|
status="archived",
|
|
)
|
|
|
|
assert sample_model.name == "New Name"
|
|
assert sample_model.description == "New Description"
|
|
assert sample_model.status == "archived"
|
|
|
|
def test_update_with_uuid(self, repo, sample_model):
|
|
"""Test update works with UUID object."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(sample_model.version_id, name="Updated")
|
|
|
|
assert result is not None
|
|
|
|
def test_update_returns_none_when_not_found(self, repo):
|
|
"""Test update returns None when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(str(uuid4()), name="New Name")
|
|
|
|
assert result is None
|
|
|
|
def test_update_partial_fields(self, repo, sample_model):
|
|
"""Test update only updates provided fields."""
|
|
original_name = sample_model.name
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_model.version_id),
|
|
description="Only description changed",
|
|
)
|
|
|
|
assert sample_model.name == original_name
|
|
assert sample_model.description == "Only description changed"
|
|
|
|
# =========================================================================
|
|
# archive() tests
|
|
# =========================================================================
|
|
|
|
def test_archive_archives_model(self, repo, sample_model):
|
|
"""Test archive sets model status to archived."""
|
|
sample_model.is_active = False
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.archive(str(sample_model.version_id))
|
|
|
|
assert result is not None
|
|
assert sample_model.status == "archived"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_archive_with_uuid(self, repo, sample_model):
|
|
"""Test archive works with UUID object."""
|
|
sample_model.is_active = False
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.archive(sample_model.version_id)
|
|
|
|
assert result is not None
|
|
|
|
def test_archive_returns_none_when_not_found(self, repo):
|
|
"""Test archive returns None when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.archive(str(uuid4()))
|
|
|
|
assert result is None
|
|
|
|
def test_archive_returns_none_when_active(self, repo, active_model):
|
|
"""Test archive returns None when model is active."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = active_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.archive(str(active_model.version_id))
|
|
|
|
assert result is None
|
|
|
|
# =========================================================================
|
|
# delete() tests
|
|
# =========================================================================
|
|
|
|
def test_delete_returns_true(self, repo, sample_model):
|
|
"""Test delete returns True when model exists and not active."""
|
|
sample_model.is_active = False
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(str(sample_model.version_id))
|
|
|
|
assert result is True
|
|
mock_session.delete.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_delete_with_uuid(self, repo, sample_model):
|
|
"""Test delete works with UUID object."""
|
|
sample_model.is_active = False
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(sample_model.version_id)
|
|
|
|
assert result is True
|
|
|
|
def test_delete_returns_false_when_not_found(self, repo):
|
|
"""Test delete returns False when model not found."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(str(uuid4()))
|
|
|
|
assert result is False
|
|
mock_session.delete.assert_not_called()
|
|
|
|
def test_delete_returns_false_when_active(self, repo, active_model):
|
|
"""Test delete returns False when model is active."""
|
|
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = active_model
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(str(active_model.version_id))
|
|
|
|
assert result is False
|
|
mock_session.delete.assert_not_called()
|