Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,310 @@
"""
Model Version Repository Integration Tests
Tests ModelVersionRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:
"""Tests for model version creation."""
def test_create_model_version(self, patched_session):
"""Test creating a model version."""
repo = ModelVersionRepository()
model = repo.create(
version="1.0.0",
name="Invoice Extractor v1",
model_path="/models/invoice_v1.pt",
description="Initial production model",
metrics_mAP=0.92,
metrics_precision=0.89,
metrics_recall=0.85,
document_count=1000,
file_size=50000000,
)
assert model is not None
assert model.version == "1.0.0"
assert model.name == "Invoice Extractor v1"
assert model.model_path == "/models/invoice_v1.pt"
assert model.metrics_mAP == 0.92
assert model.is_active is False
assert model.status == "inactive"
def test_create_model_version_with_training_info(
self, patched_session, sample_training_task, sample_dataset
):
"""Test creating model version linked to training task and dataset."""
repo = ModelVersionRepository()
model = repo.create(
version="1.1.0",
name="Invoice Extractor v1.1",
model_path="/models/invoice_v1.1.pt",
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
training_config={"epochs": 100, "batch_size": 16},
trained_at=datetime.now(timezone.utc),
)
assert model is not None
assert model.task_id == sample_training_task.task_id
assert model.dataset_id == sample_dataset.dataset_id
assert model.training_config["epochs"] == 100
class TestModelVersionRead:
"""Tests for model version retrieval."""
def test_get_model_version_by_id(self, patched_session, sample_model_version):
"""Test getting model version by ID."""
repo = ModelVersionRepository()
model = repo.get(str(sample_model_version.version_id))
assert model is not None
assert model.version_id == sample_model_version.version_id
def test_get_nonexistent_model_version(self, patched_session):
"""Test getting model version that doesn't exist."""
repo = ModelVersionRepository()
model = repo.get(str(uuid4()))
assert model is None
def test_get_paginated_model_versions(self, patched_session):
"""Test paginated model version listing."""
repo = ModelVersionRepository()
# Create multiple versions
for i in range(5):
repo.create(
version=f"1.{i}.0",
name=f"Model v1.{i}",
model_path=f"/models/model_v1.{i}.pt",
)
models, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(models) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering model versions by status."""
repo = ModelVersionRepository()
# Create active and inactive models
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
repo.activate(str(m1.version_id))
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
active_models, active_total = repo.get_paginated(status="active")
inactive_models, inactive_total = repo.get_paginated(status="inactive")
assert active_total == 1
assert inactive_total == 1
class TestModelVersionActivation:
"""Tests for model version activation."""
def test_activate_model_version(self, patched_session, sample_model_version):
"""Test activating a model version."""
repo = ModelVersionRepository()
model = repo.activate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is True
assert model.status == "active"
assert model.activated_at is not None
def test_activate_deactivates_others(self, patched_session):
"""Test that activating one version deactivates others."""
repo = ModelVersionRepository()
# Create and activate first model
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
repo.activate(str(m1.version_id))
# Create and activate second model
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
repo.activate(str(m2.version_id))
# Check first model is now inactive
m1_after = repo.get(str(m1.version_id))
assert m1_after.is_active is False
assert m1_after.status == "inactive"
# Check second model is active
m2_after = repo.get(str(m2.version_id))
assert m2_after.is_active is True
def test_get_active_model(self, patched_session, sample_model_version):
"""Test getting the currently active model."""
repo = ModelVersionRepository()
# Initially no active model
active = repo.get_active()
assert active is None
# Activate model
repo.activate(str(sample_model_version.version_id))
# Now should return active model
active = repo.get_active()
assert active is not None
assert active.version_id == sample_model_version.version_id
def test_deactivate_model_version(self, patched_session, sample_model_version):
"""Test deactivating a model version."""
repo = ModelVersionRepository()
# First activate
repo.activate(str(sample_model_version.version_id))
# Then deactivate
model = repo.deactivate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is False
assert model.status == "inactive"
class TestModelVersionUpdate:
"""Tests for model version updates."""
def test_update_model_metadata(self, patched_session, sample_model_version):
"""Test updating model version metadata."""
repo = ModelVersionRepository()
model = repo.update(
str(sample_model_version.version_id),
name="Updated Model Name",
description="Updated description",
)
assert model is not None
assert model.name == "Updated Model Name"
assert model.description == "Updated description"
def test_update_model_status(self, patched_session, sample_model_version):
"""Test updating model version status."""
repo = ModelVersionRepository()
model = repo.update(str(sample_model_version.version_id), status="deprecated")
assert model is not None
assert model.status == "deprecated"
def test_update_nonexistent_model(self, patched_session):
"""Test updating model that doesn't exist."""
repo = ModelVersionRepository()
model = repo.update(str(uuid4()), name="New Name")
assert model is None
class TestModelVersionArchive:
"""Tests for model version archiving."""
def test_archive_model_version(self, patched_session, sample_model_version):
"""Test archiving an inactive model version."""
repo = ModelVersionRepository()
model = repo.archive(str(sample_model_version.version_id))
assert model is not None
assert model.status == "archived"
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be archived."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to archive
model = repo.archive(str(sample_model_version.version_id))
assert model is None
# Verify model is still active
current = repo.get(str(sample_model_version.version_id))
assert current.status == "active"
class TestModelVersionDelete:
"""Tests for model version deletion."""
def test_delete_inactive_model(self, patched_session, sample_model_version):
"""Test deleting an inactive model version."""
repo = ModelVersionRepository()
result = repo.delete(str(sample_model_version.version_id))
assert result is True
model = repo.get(str(sample_model_version.version_id))
assert model is None
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be deleted."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to delete
result = repo.delete(str(sample_model_version.version_id))
assert result is False
# Verify model still exists
model = repo.get(str(sample_model_version.version_id))
assert model is not None
def test_delete_nonexistent_model(self, patched_session):
"""Test deleting model that doesn't exist."""
repo = ModelVersionRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestOnlyOneActiveModel:
"""Tests to verify only one model can be active at a time."""
def test_single_active_model_constraint(self, patched_session):
"""Test that only one model can be active at any time."""
repo = ModelVersionRepository()
# Create multiple models
models = []
for i in range(3):
m = repo.create(
version=f"1.{i}.0",
name=f"Model {i}",
model_path=f"/models/model_{i}.pt",
)
models.append(m)
# Activate each model in sequence
for model in models:
repo.activate(str(model.version_id))
# Count active models
all_models, _ = repo.get_paginated(status="active")
assert len(all_models) == 1
# Verify it's the last one activated
assert all_models[0].version_id == models[-1].version_id