311 lines
10 KiB
Python
311 lines
10 KiB
Python
"""
|
|
Model Version Repository Integration Tests
|
|
|
|
Tests ModelVersionRepository with real database operations.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
|
|
|
|
|
class TestModelVersionCreate:
|
|
"""Tests for model version creation."""
|
|
|
|
def test_create_model_version(self, patched_session):
|
|
"""Test creating a model version."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.create(
|
|
version="1.0.0",
|
|
name="Invoice Extractor v1",
|
|
model_path="/models/invoice_v1.pt",
|
|
description="Initial production model",
|
|
metrics_mAP=0.92,
|
|
metrics_precision=0.89,
|
|
metrics_recall=0.85,
|
|
document_count=1000,
|
|
file_size=50000000,
|
|
)
|
|
|
|
assert model is not None
|
|
assert model.version == "1.0.0"
|
|
assert model.name == "Invoice Extractor v1"
|
|
assert model.model_path == "/models/invoice_v1.pt"
|
|
assert model.metrics_mAP == 0.92
|
|
assert model.is_active is False
|
|
assert model.status == "inactive"
|
|
|
|
def test_create_model_version_with_training_info(
|
|
self, patched_session, sample_training_task, sample_dataset
|
|
):
|
|
"""Test creating model version linked to training task and dataset."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.create(
|
|
version="1.1.0",
|
|
name="Invoice Extractor v1.1",
|
|
model_path="/models/invoice_v1.1.pt",
|
|
task_id=sample_training_task.task_id,
|
|
dataset_id=sample_dataset.dataset_id,
|
|
training_config={"epochs": 100, "batch_size": 16},
|
|
trained_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
assert model is not None
|
|
assert model.task_id == sample_training_task.task_id
|
|
assert model.dataset_id == sample_dataset.dataset_id
|
|
assert model.training_config["epochs"] == 100
|
|
|
|
|
|
class TestModelVersionRead:
|
|
"""Tests for model version retrieval."""
|
|
|
|
def test_get_model_version_by_id(self, patched_session, sample_model_version):
|
|
"""Test getting model version by ID."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.get(str(sample_model_version.version_id))
|
|
|
|
assert model is not None
|
|
assert model.version_id == sample_model_version.version_id
|
|
|
|
def test_get_nonexistent_model_version(self, patched_session):
|
|
"""Test getting model version that doesn't exist."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.get(str(uuid4()))
|
|
assert model is None
|
|
|
|
def test_get_paginated_model_versions(self, patched_session):
|
|
"""Test paginated model version listing."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Create multiple versions
|
|
for i in range(5):
|
|
repo.create(
|
|
version=f"1.{i}.0",
|
|
name=f"Model v1.{i}",
|
|
model_path=f"/models/model_v1.{i}.pt",
|
|
)
|
|
|
|
models, total = repo.get_paginated(limit=2, offset=0)
|
|
|
|
assert total == 5
|
|
assert len(models) == 2
|
|
|
|
def test_get_paginated_with_status_filter(self, patched_session):
|
|
"""Test filtering model versions by status."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Create active and inactive models
|
|
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
|
|
repo.activate(str(m1.version_id))
|
|
|
|
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
|
|
|
|
active_models, active_total = repo.get_paginated(status="active")
|
|
inactive_models, inactive_total = repo.get_paginated(status="inactive")
|
|
|
|
assert active_total == 1
|
|
assert inactive_total == 1
|
|
|
|
|
|
class TestModelVersionActivation:
|
|
"""Tests for model version activation."""
|
|
|
|
def test_activate_model_version(self, patched_session, sample_model_version):
|
|
"""Test activating a model version."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.activate(str(sample_model_version.version_id))
|
|
|
|
assert model is not None
|
|
assert model.is_active is True
|
|
assert model.status == "active"
|
|
assert model.activated_at is not None
|
|
|
|
def test_activate_deactivates_others(self, patched_session):
|
|
"""Test that activating one version deactivates others."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Create and activate first model
|
|
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
|
|
repo.activate(str(m1.version_id))
|
|
|
|
# Create and activate second model
|
|
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
|
|
repo.activate(str(m2.version_id))
|
|
|
|
# Check first model is now inactive
|
|
m1_after = repo.get(str(m1.version_id))
|
|
assert m1_after.is_active is False
|
|
assert m1_after.status == "inactive"
|
|
|
|
# Check second model is active
|
|
m2_after = repo.get(str(m2.version_id))
|
|
assert m2_after.is_active is True
|
|
|
|
def test_get_active_model(self, patched_session, sample_model_version):
|
|
"""Test getting the currently active model."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Initially no active model
|
|
active = repo.get_active()
|
|
assert active is None
|
|
|
|
# Activate model
|
|
repo.activate(str(sample_model_version.version_id))
|
|
|
|
# Now should return active model
|
|
active = repo.get_active()
|
|
assert active is not None
|
|
assert active.version_id == sample_model_version.version_id
|
|
|
|
def test_deactivate_model_version(self, patched_session, sample_model_version):
|
|
"""Test deactivating a model version."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# First activate
|
|
repo.activate(str(sample_model_version.version_id))
|
|
|
|
# Then deactivate
|
|
model = repo.deactivate(str(sample_model_version.version_id))
|
|
|
|
assert model is not None
|
|
assert model.is_active is False
|
|
assert model.status == "inactive"
|
|
|
|
|
|
class TestModelVersionUpdate:
|
|
"""Tests for model version updates."""
|
|
|
|
def test_update_model_metadata(self, patched_session, sample_model_version):
|
|
"""Test updating model version metadata."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.update(
|
|
str(sample_model_version.version_id),
|
|
name="Updated Model Name",
|
|
description="Updated description",
|
|
)
|
|
|
|
assert model is not None
|
|
assert model.name == "Updated Model Name"
|
|
assert model.description == "Updated description"
|
|
|
|
def test_update_model_status(self, patched_session, sample_model_version):
|
|
"""Test updating model version status."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.update(str(sample_model_version.version_id), status="deprecated")
|
|
|
|
assert model is not None
|
|
assert model.status == "deprecated"
|
|
|
|
def test_update_nonexistent_model(self, patched_session):
|
|
"""Test updating model that doesn't exist."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.update(str(uuid4()), name="New Name")
|
|
assert model is None
|
|
|
|
|
|
class TestModelVersionArchive:
|
|
"""Tests for model version archiving."""
|
|
|
|
def test_archive_model_version(self, patched_session, sample_model_version):
|
|
"""Test archiving an inactive model version."""
|
|
repo = ModelVersionRepository()
|
|
|
|
model = repo.archive(str(sample_model_version.version_id))
|
|
|
|
assert model is not None
|
|
assert model.status == "archived"
|
|
|
|
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
|
|
"""Test that active model cannot be archived."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Activate the model
|
|
repo.activate(str(sample_model_version.version_id))
|
|
|
|
# Try to archive
|
|
model = repo.archive(str(sample_model_version.version_id))
|
|
|
|
assert model is None
|
|
|
|
# Verify model is still active
|
|
current = repo.get(str(sample_model_version.version_id))
|
|
assert current.status == "active"
|
|
|
|
|
|
class TestModelVersionDelete:
|
|
"""Tests for model version deletion."""
|
|
|
|
def test_delete_inactive_model(self, patched_session, sample_model_version):
|
|
"""Test deleting an inactive model version."""
|
|
repo = ModelVersionRepository()
|
|
|
|
result = repo.delete(str(sample_model_version.version_id))
|
|
|
|
assert result is True
|
|
|
|
model = repo.get(str(sample_model_version.version_id))
|
|
assert model is None
|
|
|
|
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
|
|
"""Test that active model cannot be deleted."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Activate the model
|
|
repo.activate(str(sample_model_version.version_id))
|
|
|
|
# Try to delete
|
|
result = repo.delete(str(sample_model_version.version_id))
|
|
|
|
assert result is False
|
|
|
|
# Verify model still exists
|
|
model = repo.get(str(sample_model_version.version_id))
|
|
assert model is not None
|
|
|
|
def test_delete_nonexistent_model(self, patched_session):
|
|
"""Test deleting model that doesn't exist."""
|
|
repo = ModelVersionRepository()
|
|
|
|
result = repo.delete(str(uuid4()))
|
|
assert result is False
|
|
|
|
|
|
class TestOnlyOneActiveModel:
|
|
"""Tests to verify only one model can be active at a time."""
|
|
|
|
def test_single_active_model_constraint(self, patched_session):
|
|
"""Test that only one model can be active at any time."""
|
|
repo = ModelVersionRepository()
|
|
|
|
# Create multiple models
|
|
models = []
|
|
for i in range(3):
|
|
m = repo.create(
|
|
version=f"1.{i}.0",
|
|
name=f"Model {i}",
|
|
model_path=f"/models/model_{i}.pt",
|
|
)
|
|
models.append(m)
|
|
|
|
# Activate each model in sequence
|
|
for model in models:
|
|
repo.activate(str(model.version_id))
|
|
|
|
# Count active models
|
|
all_models, _ = repo.get_paginated(status="active")
|
|
assert len(all_models) == 1
|
|
|
|
# Verify it's the last one activated
|
|
assert all_models[0].version_id == models[-1].version_id
|