Files
invoice-master-poc-v2/tests/web/test_dataset_routes.py
2026-01-27 23:58:17 +01:00

201 lines
6.4 KiB
Python

"""
Tests for Dataset API routes in training.py.
"""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetTrainRequest,
TrainingConfig,
TrainingStatus,
)
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
dataset_id=UUID(TEST_DATASET_UUID),
name="test-dataset",
description="Test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=2,
total_images=4,
total_annotations=10,
dataset_path="/data/datasets/test-dataset",
error_message=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
)
defaults.update(overrides)
ds = MagicMock(spec=TrainingDataset)
for k, v in defaults.items():
setattr(ds, k, v)
return ds
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
doc = MagicMock(spec=DatasetDocument)
doc.document_id = UUID(doc_id)
doc.split = split
doc.page_count = 2
doc.annotation_count = 5
return doc
def _find_endpoint(name: str):
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
def test_router_has_dataset_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 2,
"total_images": 4,
"total_annotations": 10,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 1
assert len(result.datasets) == 1
assert result.datasets[0].name == "test-dataset"
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
def test_train_from_building_dataset_fails(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400