""" Tests for Dataset API routes in training.py. """ import asyncio import pytest from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import UUID from inference.data.admin_models import TrainingDataset, DatasetDocument from inference.web.api.v1.admin.training import create_training_router from inference.web.schemas.admin import ( DatasetCreateRequest, DatasetTrainRequest, TrainingConfig, TrainingStatus, ) TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010" TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011" TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012" TEST_TOKEN = "test-admin-token-12345" TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002" # Generate 10 unique UUIDs for minimum document count tests TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)] def _make_dataset(**overrides) -> MagicMock: defaults = dict( dataset_id=UUID(TEST_DATASET_UUID), name="test-dataset", description="Test dataset", status="ready", training_status=None, active_training_task_id=None, train_ratio=0.8, val_ratio=0.1, seed=42, total_documents=2, total_images=4, total_annotations=10, dataset_path="/data/datasets/test-dataset", error_message=None, created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), ) defaults.update(overrides) ds = MagicMock(spec=TrainingDataset) for k, v in defaults.items(): setattr(ds, k, v) return ds def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock: doc = MagicMock(spec=DatasetDocument) doc.document_id = UUID(doc_id) doc.split = split doc.page_count = 2 doc.annotation_count = 5 return doc def _find_endpoint(name: str): router = create_training_router() for route in router.routes: if hasattr(route, "endpoint") and route.endpoint.__name__ == name: return route.endpoint raise AssertionError(f"Endpoint {name} not found") class TestCreateDatasetRoute: """Tests for POST /admin/training/datasets.""" def test_router_has_dataset_endpoints(self): router = create_training_router() paths = [route.path for route in router.routes] assert any("datasets" in p for p in paths) def test_create_dataset_calls_builder(self): fn = _find_endpoint("create_dataset") mock_db = MagicMock() mock_db.create_dataset.return_value = _make_dataset(status="building") mock_builder = MagicMock() mock_builder.build_dataset.return_value = { "total_documents": 10, "total_images": 20, "total_annotations": 50, } request = DatasetCreateRequest( name="test-dataset", document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum ) with patch( "inference.web.services.dataset_builder.DatasetBuilder", return_value=mock_builder, ) as mock_cls: result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.create_dataset.assert_called_once() mock_builder.build_dataset.assert_called_once() assert result.dataset_id == TEST_DATASET_UUID assert result.name == "test-dataset" def test_create_dataset_fails_with_less_than_10_documents(self): """Test that creating dataset fails if fewer than 10 documents provided.""" fn = _find_endpoint("create_dataset") mock_db = MagicMock() # Only 2 documents - should fail request = DatasetCreateRequest( name="test-dataset", document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2], ) from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 400 assert "Minimum 10 documents required" in exc_info.value.detail assert "got 2" in exc_info.value.detail # Ensure DB was never called since validation failed first mock_db.create_dataset.assert_not_called() def test_create_dataset_fails_with_9_documents(self): """Test boundary condition: 9 documents should fail.""" fn = _find_endpoint("create_dataset") mock_db = MagicMock() # 9 documents - just under the limit request = DatasetCreateRequest( name="test-dataset", document_ids=TEST_DOC_UUIDS[:9], ) from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 400 assert "Minimum 10 documents required" in exc_info.value.detail def test_create_dataset_succeeds_with_exactly_10_documents(self): """Test boundary condition: exactly 10 documents should succeed.""" fn = _find_endpoint("create_dataset") mock_db = MagicMock() mock_db.create_dataset.return_value = _make_dataset(status="building") mock_builder = MagicMock() # Exactly 10 documents - should pass request = DatasetCreateRequest( name="test-dataset", document_ids=TEST_DOC_UUIDS[:10], ) with patch( "inference.web.services.dataset_builder.DatasetBuilder", return_value=mock_builder, ): result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.create_dataset.assert_called_once() assert result.dataset_id == TEST_DATASET_UUID class TestListDatasetsRoute: """Tests for GET /admin/training/datasets.""" def test_list_datasets(self): fn = _find_endpoint("list_datasets") mock_db = MagicMock() mock_db.get_datasets.return_value = ([_make_dataset()], 1) # Mock the active training tasks lookup to return empty dict mock_db.get_active_training_tasks_for_datasets.return_value = {} result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0)) assert result.total == 1 assert len(result.datasets) == 1 assert result.datasets[0].name == "test-dataset" class TestGetDatasetRoute: """Tests for GET /admin/training/datasets/{dataset_id}.""" def test_get_dataset_returns_detail(self): fn = _find_endpoint("get_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset() mock_db.get_dataset_documents.return_value = [ _make_dataset_doc(TEST_DOC_UUID_1, "train"), _make_dataset_doc(TEST_DOC_UUID_2, "val"), ] result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert result.dataset_id == TEST_DATASET_UUID assert len(result.documents) == 2 def test_get_dataset_not_found(self): fn = _find_endpoint("get_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = None from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 class TestDeleteDatasetRoute: """Tests for DELETE /admin/training/datasets/{dataset_id}.""" def test_delete_dataset(self): fn = _find_endpoint("delete_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset(dataset_path=None) result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID) assert result["message"] == "Dataset deleted" class TestTrainFromDatasetRoute: """Tests for POST /admin/training/datasets/{dataset_id}/train.""" def test_train_from_ready_dataset(self): fn = _find_endpoint("train_from_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset(status="ready") mock_db.create_training_task.return_value = TEST_TASK_UUID request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig()) result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) assert result.task_id == TEST_TASK_UUID assert result.status == TrainingStatus.PENDING mock_db.create_training_task.assert_called_once() def test_train_from_building_dataset_fails(self): fn = _find_endpoint("train_from_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset(status="building") request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig()) from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 400 def test_incremental_training_with_base_model(self): """Test training with base_model_version_id for incremental training.""" fn = _find_endpoint("train_from_dataset") mock_model_version = MagicMock() mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt" mock_model_version.version = "1.0.0" mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset(status="ready") mock_db.get_model_version.return_value = mock_model_version mock_db.create_training_task.return_value = TEST_TASK_UUID base_model_uuid = "550e8400-e29b-41d4-a716-446655440099" config = TrainingConfig(base_model_version_id=base_model_uuid) request = DatasetTrainRequest(name="incremental-train", config=config) result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) # Verify model version was looked up mock_db.get_model_version.assert_called_once_with(base_model_uuid) # Verify task was created with finetune type call_kwargs = mock_db.create_training_task.call_args[1] assert call_kwargs["task_type"] == "finetune" assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt" assert call_kwargs["config"]["base_model_version"] == "1.0.0" assert result.task_id == TEST_TASK_UUID assert "Incremental training" in result.message def test_incremental_training_with_invalid_base_model_fails(self): """Test that training fails if base_model_version_id doesn't exist.""" fn = _find_endpoint("train_from_dataset") mock_db = MagicMock() mock_db.get_dataset.return_value = _make_dataset(status="ready") mock_db.get_model_version.return_value = None base_model_uuid = "550e8400-e29b-41d4-a716-446655440099" config = TrainingConfig(base_model_version_id=base_model_uuid) request = DatasetTrainRequest(name="incremental-train", config=config) from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 assert "Base model version not found" in exc_info.value.detail