This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1,261 @@
"""
Tests for augmentation API routes.
TDD Phase 5: RED - Write tests first, then implement to pass.
"""
import pytest
from fastapi.testclient import TestClient
class TestAugmentationTypesEndpoint:
"""Tests for GET /admin/augmentation/types endpoint."""
def test_list_augmentation_types(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing available augmentation types."""
response = admin_client.get(
"/api/v1/admin/augmentation/types",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "augmentation_types" in data
assert len(data["augmentation_types"]) == 12
# Check structure
aug_type = data["augmentation_types"][0]
assert "name" in aug_type
assert "description" in aug_type
assert "affects_geometry" in aug_type
assert "stage" in aug_type
def test_list_augmentation_types_unauthorized(
self, admin_client: TestClient
) -> None:
"""Test that unauthorized request is rejected."""
response = admin_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401
class TestAugmentationPresetsEndpoint:
"""Tests for GET /admin/augmentation/presets endpoint."""
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
"""Test listing available presets."""
response = admin_client.get(
"/api/v1/admin/augmentation/presets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "presets" in data
assert len(data["presets"]) >= 4
# Check expected presets exist
preset_names = [p["name"] for p in data["presets"]]
assert "conservative" in preset_names
assert "moderate" in preset_names
assert "aggressive" in preset_names
assert "scanned_document" in preset_names
class TestAugmentationPreviewEndpoint:
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
def test_preview_augmentation(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test previewing augmentation on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
assert "applied_params" in data
def test_preview_invalid_augmentation_type(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test that invalid augmentation type returns error."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "nonexistent",
"params": {},
},
)
assert response.status_code == 400
def test_preview_nonexistent_document(
self,
admin_client: TestClient,
admin_token: str,
) -> None:
"""Test that nonexistent document returns 404."""
response = admin_client.post(
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {},
},
)
assert response.status_code == 404
class TestAugmentationPreviewConfigEndpoint:
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
def test_preview_config(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test previewing full config on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
class TestAugmentationBatchEndpoint:
"""Tests for POST /admin/augmentation/batch endpoint."""
def test_create_augmented_dataset(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
) -> None:
"""Test creating augmented dataset."""
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {
"gaussian_noise": {"enabled": True, "probability": 0.5},
"preserve_bboxes": True,
},
"output_name": "test_augmented_dataset",
"multiplier": 2,
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
assert "status" in data
assert "estimated_images" in data
def test_create_augmented_dataset_invalid_multiplier(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
) -> None:
"""Test that invalid multiplier is rejected."""
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {},
"output_name": "test",
"multiplier": 100, # Too high
},
)
assert response.status_code == 422 # Validation error
class TestAugmentedDatasetsListEndpoint:
"""Tests for GET /admin/augmentation/datasets endpoint."""
def test_list_augmented_datasets(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing augmented datasets."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "limit" in data
assert "offset" in data
assert "datasets" in data
assert isinstance(data["datasets"], list)
def test_list_augmented_datasets_pagination(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test pagination parameters."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
params={"limit": 5, "offset": 0},
)
assert response.status_code == 200
data = response.json()
assert data["limit"] == 5
assert data["offset"] == 0
# Fixtures for tests
@pytest.fixture
def sample_document_id() -> str:
"""Provide a sample document ID for testing."""
# This would need to be created in test setup
return "test-document-id"
@pytest.fixture
def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing."""
# This would need to be created in test setup
return "test-dataset-id"

View File

@@ -329,3 +329,414 @@ class TestDatasetBuilder:
results.append([(d["document_id"], d["split"]) for d in docs])
assert results[0] == results[1]
class TestAssignSplitsByGroup:
"""Tests for _assign_splits_by_group method with group_key logic."""
def _make_mock_doc(self, doc_id, group_key=None):
"""Create a mock AdminDocument with document_id and group_key."""
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.group_key = group_key
doc.page_count = 1
return doc
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
"""Documents with unique group_key are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 3 documents, each with unique group_key
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-C"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1 # Ensure val is not empty
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key=""),
self._make_mock_doc(uuid4(), group_key=None),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# Each null/empty group_key doc is independent, distributed across splits
# With 3 docs: ensure at least 1 in train and 1 in val
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
"""Documents with same group_key should be assigned to the same split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 6 documents in 2 groups
docs = [
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
# All docs in supplier-A should have same split
splits_a = [result[str(d.document_id)] for d in docs[:3]]
assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
# All docs in supplier-B should have same split
splits_b = [result[str(d.document_id)] for d in docs[3:]]
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
"""Multi-doc groups should be split according to train/val/test ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 10 groups with 2 docs each
docs = []
for i in range(10):
group_key = f"group-{i}"
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# Count groups per split
group_splits = {}
for doc in docs:
split = result[str(doc.document_id)]
if doc.group_key not in group_splits:
group_splits[doc.group_key] = split
else:
# Verify same group has same split
assert group_splits[doc.group_key] == split
split_counts = {"train": 0, "val": 0, "test": 0}
for split in group_splits.values():
split_counts[split] += 1
# With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
assert split_counts["train"] >= 6
assert split_counts["train"] <= 8
assert split_counts["val"] >= 1
assert split_counts["val"] <= 3
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
"""Mix of single-doc and multi-doc groups should be handled correctly."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
# Single-doc groups
self._make_mock_doc(uuid4(), group_key="single-1"),
self._make_mock_doc(uuid4(), group_key="single-2"),
self._make_mock_doc(uuid4(), group_key=None),
# Multi-doc groups
self._make_mock_doc(uuid4(), group_key="multi-A"),
self._make_mock_doc(uuid4(), group_key="multi-A"),
self._make_mock_doc(uuid4(), group_key="multi-B"),
self._make_mock_doc(uuid4(), group_key="multi-B"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
# All groups are shuffled and distributed
# Ensure at least 1 in train and 1 in val
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1
# Multi-doc groups stay together
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
"""Same seed should produce same split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-C"),
self._make_mock_doc(uuid4(), group_key="group-C"),
]
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
assert result1 == result2
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
"""Different seeds should potentially produce different split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# Many groups to increase chance of different results
docs = []
for i in range(20):
group_key = f"group-{i}"
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
# Results should be different (very likely with 20 groups)
assert result1 != result2
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
"""Every document should be assigned a split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key="single"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
assert len(result) == len(docs)
for doc in docs:
assert str(doc.document_id) in result
assert result[str(doc.document_id)] in ["train", "val", "test"]
def test_empty_documents_list(self, tmp_path, mock_admin_db):
"""Empty document list should return empty result."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
assert result == {}
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
"""When all groups have multiple docs, splits should follow ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 5 groups with 3 docs each
docs = []
for i in range(5):
group_key = f"group-{i}"
for _ in range(3):
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
# Group splits
group_splits = {}
for doc in docs:
if doc.group_key not in group_splits:
group_splits[doc.group_key] = result[str(doc.document_id)]
split_counts = {"train": 0, "val": 0, "test": 0}
for split in group_splits.values():
split_counts[split] += 1
# With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
assert split_counts["train"] >= 2
assert split_counts["train"] <= 4
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
"""When all groups have single doc, they are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="unique-1"),
self._make_mock_doc(uuid4(), group_key="unique-2"),
self._make_mock_doc(uuid4(), group_key="unique-3"),
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key=""),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
# With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 2
assert val_count >= 1 # Ensure val is not empty
class TestBuildDatasetWithGroupKey:
"""Integration tests for build_dataset with group_key logic."""
@pytest.fixture
def grouped_documents(self, tmp_path):
"""Create documents with various group_key configurations."""
doc_ids = []
docs = []
# Create 3 groups: 2 multi-doc groups + 2 single-doc groups
group_configs = [
("supplier-A", 3), # Multi-doc group: 3 docs
("supplier-B", 2), # Multi-doc group: 2 docs
("unique-1", 1), # Single-doc group
(None, 1), # Null group_key
]
for group_key, count in group_configs:
for _ in range(count):
doc_id = uuid4()
doc_ids.append(doc_id)
# Create image files
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
for page in range(1, 3):
(doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
# Create mock document
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 2
doc.group_key = group_key
doc.file_path = str(doc_dir)
docs.append(doc)
return tmp_path, docs
@pytest.fixture
def grouped_annotations(self, grouped_documents):
"""Create annotations for grouped documents."""
tmp_path, docs = grouped_documents
annotations = {}
for doc in docs:
doc_anns = []
for page in range(1, 3):
ann = MagicMock(spec=AdminAnnotation)
ann.document_id = doc.document_id
ann.page_number = page
ann.class_id = 0
ann.class_name = "invoice_number"
ann.x_center = 0.5
ann.y_center = 0.3
ann.width = 0.2
ann.height = 0.05
doc_anns.append(ann)
annotations[str(doc.document_id)] = doc_anns
return annotations
def test_build_respects_group_key_splits(
self, grouped_documents, grouped_annotations, mock_admin_db
):
"""build_dataset should use group_key for split assignment."""
from inference.web.services.dataset_builder import DatasetBuilder
tmp_path, docs = grouped_documents
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
grouped_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
train_ratio=0.5,
val_ratio=0.5,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Get the document splits from add_dataset_documents call
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
# Build mapping of doc_id -> split
doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
# Verify all docs are assigned a valid split
for doc_id in doc_split_map:
assert doc_split_map[doc_id] in ("train", "val", "test")
# Verify multi-doc groups stay together
supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
"""All docs with same group_key should go to same split."""
from inference.web.services.dataset_builder import DatasetBuilder
# Create 5 docs all with same group_key
docs = []
for i in range(5):
doc_id = uuid4()
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
(doc_dir / "page_1.png").write_bytes(b"fake-png")
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 1
doc.group_key = "same-group"
docs.append(doc)
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.return_value = []
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]
# All should be in the same split (one group)
assert len(set(splits)) == 1, "All docs with same group_key should be in same split"

View File

@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
# Generate 10 unique UUIDs for minimum document count tests
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 2,
"total_images": 4,
"total_annotations": 10,
"total_documents": 10,
"total_images": 20,
"total_annotations": 50,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
)
with patch(
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
def test_create_dataset_fails_with_less_than_10_documents(self):
"""Test that creating dataset fails if fewer than 10 documents provided."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# Only 2 documents - should fail
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
assert "got 2" in exc_info.value.detail
# Ensure DB was never called since validation failed first
mock_db.create_dataset.assert_not_called()
def test_create_dataset_fails_with_9_documents(self):
"""Test boundary condition: 9 documents should fail."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# 9 documents - just under the limit
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:9],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
def test_create_dataset_succeeds_with_exactly_10_documents(self):
"""Test boundary condition: exactly 10 documents should succeed."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
# Exactly 10 documents - should pass
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:10],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
):
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
def test_incremental_training_with_base_model(self):
"""Test training with base_model_version_id for incremental training."""
fn = _find_endpoint("train_from_dataset")
mock_model_version = MagicMock()
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
mock_model_version.version = "1.0.0"
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = mock_model_version
mock_db.create_training_task.return_value = TEST_TASK_UUID
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
# Verify model version was looked up
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
# Verify task was created with finetune type
call_kwargs = mock_db.create_training_task.call_args[1]
assert call_kwargs["task_type"] == "finetune"
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
assert result.task_id == TEST_TASK_UUID
assert "Incremental training" in result.message
def test_incremental_training_with_invalid_base_model_fails(self):
"""Test that training fails if base_model_version_id doesn't exist."""
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = None
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
assert "Base model version not found" in exc_info.value.detail

View File

@@ -0,0 +1,399 @@
"""
Tests for Model Version API routes.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from inference.data.admin_models import ModelVersion
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
)
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_TOKEN = "test-admin-token-12345"
def _make_model_version(**overrides) -> MagicMock:
"""Create a mock ModelVersion."""
defaults = dict(
version_id=UUID(TEST_VERSION_UUID),
version="1.0.0",
name="test-model-v1",
description="Test model version",
model_path="/models/test-model-v1.pt",
status="inactive",
is_active=False,
task_id=UUID(TEST_TASK_UUID),
dataset_id=UUID(TEST_DATASET_UUID),
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100, "batch_size": 16},
file_size=52428800,
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
activated_at=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
)
defaults.update(overrides)
model = MagicMock(spec=ModelVersion)
for k, v in defaults.items():
setattr(model, k, v)
return model
def _find_endpoint(name: str):
"""Find endpoint function by name."""
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
def test_router_has_model_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("models" in p for p in paths)
def test_has_create_model_version_endpoint(self):
endpoint = _find_endpoint("create_model_version")
assert endpoint is not None
def test_has_list_model_versions_endpoint(self):
endpoint = _find_endpoint("list_model_versions")
assert endpoint is not None
def test_has_get_active_model_endpoint(self):
endpoint = _find_endpoint("get_active_model")
assert endpoint is not None
def test_has_activate_model_version_endpoint(self):
endpoint = _find_endpoint("activate_model_version")
assert endpoint is not None
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
description="Test model",
metrics_mAP=0.935,
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_model_version.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
task_id=TEST_TASK_UUID,
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
call_kwargs = mock_db.create_model_version.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is False
assert result.model is None
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.update_model_version.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
status=None,
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestModelVersionSchemas:
"""Tests for model version Pydantic schemas."""
def test_create_request_validation(self):
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model",
model_path="/models/test.pt",
)
assert request.version == "1.0.0"
assert request.name == "test-model"
assert request.document_count == 0
def test_create_request_with_metrics(self):
request = ModelVersionCreateRequest(
version="2.0.0",
name="test-model-v2",
model_path="/models/v2.pt",
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
assert request.metrics_mAP == 0.95
assert request.document_count == 500
def test_update_request_partial(self):
request = ModelVersionUpdateRequest(name="new-name")
assert request.name == "new-name"
assert request.description is None
assert request.status is None