Files
invoice-master-poc-v2/tests/web/test_dataset_routes.py
Yaojia Wang 33ada0350d WIP
2026-01-30 00:44:21 +01:00

321 lines
11 KiB
Python

"""
Tests for Dataset API routes in training.py.
"""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetTrainRequest,
TrainingConfig,
TrainingStatus,
)
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
# Generate 10 unique UUIDs for minimum document count tests
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
dataset_id=UUID(TEST_DATASET_UUID),
name="test-dataset",
description="Test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=2,
total_images=4,
total_annotations=10,
dataset_path="/data/datasets/test-dataset",
error_message=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
)
defaults.update(overrides)
ds = MagicMock(spec=TrainingDataset)
for k, v in defaults.items():
setattr(ds, k, v)
return ds
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
doc = MagicMock(spec=DatasetDocument)
doc.document_id = UUID(doc_id)
doc.split = split
doc.page_count = 2
doc.annotation_count = 5
return doc
def _find_endpoint(name: str):
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
def test_router_has_dataset_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 10,
"total_images": 20,
"total_annotations": 50,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
def test_create_dataset_fails_with_less_than_10_documents(self):
"""Test that creating dataset fails if fewer than 10 documents provided."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# Only 2 documents - should fail
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
assert "got 2" in exc_info.value.detail
# Ensure DB was never called since validation failed first
mock_db.create_dataset.assert_not_called()
def test_create_dataset_fails_with_9_documents(self):
"""Test boundary condition: 9 documents should fail."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# 9 documents - just under the limit
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:9],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
def test_create_dataset_succeeds_with_exactly_10_documents(self):
"""Test boundary condition: exactly 10 documents should succeed."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
# Exactly 10 documents - should pass
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:10],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
):
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 1
assert len(result.datasets) == 1
assert result.datasets[0].name == "test-dataset"
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
def test_train_from_building_dataset_fails(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
def test_incremental_training_with_base_model(self):
"""Test training with base_model_version_id for incremental training."""
fn = _find_endpoint("train_from_dataset")
mock_model_version = MagicMock()
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
mock_model_version.version = "1.0.0"
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = mock_model_version
mock_db.create_training_task.return_value = TEST_TASK_UUID
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
# Verify model version was looked up
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
# Verify task was created with finetune type
call_kwargs = mock_db.create_training_task.call_args[1]
assert call_kwargs["task_type"] == "finetune"
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
assert result.task_id == TEST_TASK_UUID
assert "Incremental training" in result.message
def test_incremental_training_with_invalid_base_model_fails(self):
"""Test that training fails if base_model_version_id doesn't exist."""
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = None
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
assert "Base model version not found" in exc_info.value.detail