Files
invoice-master-poc-v2/tests/integration/repositories/test_model_version_repo_integration.py
2026-02-01 22:40:41 +01:00

311 lines
10 KiB
Python

"""
Model Version Repository Integration Tests
Tests ModelVersionRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:
"""Tests for model version creation."""
def test_create_model_version(self, patched_session):
"""Test creating a model version."""
repo = ModelVersionRepository()
model = repo.create(
version="1.0.0",
name="Invoice Extractor v1",
model_path="/models/invoice_v1.pt",
description="Initial production model",
metrics_mAP=0.92,
metrics_precision=0.89,
metrics_recall=0.85,
document_count=1000,
file_size=50000000,
)
assert model is not None
assert model.version == "1.0.0"
assert model.name == "Invoice Extractor v1"
assert model.model_path == "/models/invoice_v1.pt"
assert model.metrics_mAP == 0.92
assert model.is_active is False
assert model.status == "inactive"
def test_create_model_version_with_training_info(
self, patched_session, sample_training_task, sample_dataset
):
"""Test creating model version linked to training task and dataset."""
repo = ModelVersionRepository()
model = repo.create(
version="1.1.0",
name="Invoice Extractor v1.1",
model_path="/models/invoice_v1.1.pt",
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
training_config={"epochs": 100, "batch_size": 16},
trained_at=datetime.now(timezone.utc),
)
assert model is not None
assert model.task_id == sample_training_task.task_id
assert model.dataset_id == sample_dataset.dataset_id
assert model.training_config["epochs"] == 100
class TestModelVersionRead:
"""Tests for model version retrieval."""
def test_get_model_version_by_id(self, patched_session, sample_model_version):
"""Test getting model version by ID."""
repo = ModelVersionRepository()
model = repo.get(str(sample_model_version.version_id))
assert model is not None
assert model.version_id == sample_model_version.version_id
def test_get_nonexistent_model_version(self, patched_session):
"""Test getting model version that doesn't exist."""
repo = ModelVersionRepository()
model = repo.get(str(uuid4()))
assert model is None
def test_get_paginated_model_versions(self, patched_session):
"""Test paginated model version listing."""
repo = ModelVersionRepository()
# Create multiple versions
for i in range(5):
repo.create(
version=f"1.{i}.0",
name=f"Model v1.{i}",
model_path=f"/models/model_v1.{i}.pt",
)
models, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(models) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering model versions by status."""
repo = ModelVersionRepository()
# Create active and inactive models
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
repo.activate(str(m1.version_id))
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
active_models, active_total = repo.get_paginated(status="active")
inactive_models, inactive_total = repo.get_paginated(status="inactive")
assert active_total == 1
assert inactive_total == 1
class TestModelVersionActivation:
"""Tests for model version activation."""
def test_activate_model_version(self, patched_session, sample_model_version):
"""Test activating a model version."""
repo = ModelVersionRepository()
model = repo.activate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is True
assert model.status == "active"
assert model.activated_at is not None
def test_activate_deactivates_others(self, patched_session):
"""Test that activating one version deactivates others."""
repo = ModelVersionRepository()
# Create and activate first model
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
repo.activate(str(m1.version_id))
# Create and activate second model
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
repo.activate(str(m2.version_id))
# Check first model is now inactive
m1_after = repo.get(str(m1.version_id))
assert m1_after.is_active is False
assert m1_after.status == "inactive"
# Check second model is active
m2_after = repo.get(str(m2.version_id))
assert m2_after.is_active is True
def test_get_active_model(self, patched_session, sample_model_version):
"""Test getting the currently active model."""
repo = ModelVersionRepository()
# Initially no active model
active = repo.get_active()
assert active is None
# Activate model
repo.activate(str(sample_model_version.version_id))
# Now should return active model
active = repo.get_active()
assert active is not None
assert active.version_id == sample_model_version.version_id
def test_deactivate_model_version(self, patched_session, sample_model_version):
"""Test deactivating a model version."""
repo = ModelVersionRepository()
# First activate
repo.activate(str(sample_model_version.version_id))
# Then deactivate
model = repo.deactivate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is False
assert model.status == "inactive"
class TestModelVersionUpdate:
"""Tests for model version updates."""
def test_update_model_metadata(self, patched_session, sample_model_version):
"""Test updating model version metadata."""
repo = ModelVersionRepository()
model = repo.update(
str(sample_model_version.version_id),
name="Updated Model Name",
description="Updated description",
)
assert model is not None
assert model.name == "Updated Model Name"
assert model.description == "Updated description"
def test_update_model_status(self, patched_session, sample_model_version):
"""Test updating model version status."""
repo = ModelVersionRepository()
model = repo.update(str(sample_model_version.version_id), status="deprecated")
assert model is not None
assert model.status == "deprecated"
def test_update_nonexistent_model(self, patched_session):
"""Test updating model that doesn't exist."""
repo = ModelVersionRepository()
model = repo.update(str(uuid4()), name="New Name")
assert model is None
class TestModelVersionArchive:
"""Tests for model version archiving."""
def test_archive_model_version(self, patched_session, sample_model_version):
"""Test archiving an inactive model version."""
repo = ModelVersionRepository()
model = repo.archive(str(sample_model_version.version_id))
assert model is not None
assert model.status == "archived"
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be archived."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to archive
model = repo.archive(str(sample_model_version.version_id))
assert model is None
# Verify model is still active
current = repo.get(str(sample_model_version.version_id))
assert current.status == "active"
class TestModelVersionDelete:
"""Tests for model version deletion."""
def test_delete_inactive_model(self, patched_session, sample_model_version):
"""Test deleting an inactive model version."""
repo = ModelVersionRepository()
result = repo.delete(str(sample_model_version.version_id))
assert result is True
model = repo.get(str(sample_model_version.version_id))
assert model is None
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be deleted."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to delete
result = repo.delete(str(sample_model_version.version_id))
assert result is False
# Verify model still exists
model = repo.get(str(sample_model_version.version_id))
assert model is not None
def test_delete_nonexistent_model(self, patched_session):
"""Test deleting model that doesn't exist."""
repo = ModelVersionRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestOnlyOneActiveModel:
"""Tests to verify only one model can be active at a time."""
def test_single_active_model_constraint(self, patched_session):
"""Test that only one model can be active at any time."""
repo = ModelVersionRepository()
# Create multiple models
models = []
for i in range(3):
m = repo.create(
version=f"1.{i}.0",
name=f"Model {i}",
model_path=f"/models/model_{i}.pt",
)
models.append(m)
# Activate each model in sequence
for model in models:
repo.activate(str(model.version_id))
# Count active models
all_models, _ = repo.get_paginated(status="active")
assert len(all_models) == 1
# Verify it's the last one activated
assert all_models[0].version_id == models[-1].version_id