429 lines
14 KiB
Python
429 lines
14 KiB
Python
"""
|
|
Tests for Dataset API routes in training.py.
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
from inference.data.admin_models import TrainingDataset, DatasetDocument
|
|
from inference.web.api.v1.admin.training import create_training_router
|
|
from inference.web.schemas.admin import (
|
|
DatasetCreateRequest,
|
|
DatasetTrainRequest,
|
|
TrainingConfig,
|
|
TrainingStatus,
|
|
)
|
|
|
|
|
|
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
|
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
|
|
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
|
TEST_TOKEN = "test-admin-token-12345"
|
|
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
|
|
|
# Generate 10 unique UUIDs for minimum document count tests
|
|
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
|
|
|
|
|
|
def _make_dataset(**overrides) -> MagicMock:
|
|
defaults = dict(
|
|
dataset_id=UUID(TEST_DATASET_UUID),
|
|
name="test-dataset",
|
|
description="Test dataset",
|
|
status="ready",
|
|
training_status=None,
|
|
active_training_task_id=None,
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
total_documents=2,
|
|
total_images=4,
|
|
total_annotations=10,
|
|
dataset_path="/data/datasets/test-dataset",
|
|
error_message=None,
|
|
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
)
|
|
defaults.update(overrides)
|
|
ds = MagicMock(spec=TrainingDataset)
|
|
for k, v in defaults.items():
|
|
setattr(ds, k, v)
|
|
return ds
|
|
|
|
|
|
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
|
|
doc = MagicMock(spec=DatasetDocument)
|
|
doc.document_id = UUID(doc_id)
|
|
doc.split = split
|
|
doc.page_count = 2
|
|
doc.annotation_count = 5
|
|
return doc
|
|
|
|
|
|
def _find_endpoint(name: str):
|
|
router = create_training_router()
|
|
for route in router.routes:
|
|
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
|
return route.endpoint
|
|
raise AssertionError(f"Endpoint {name} not found")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_datasets_repo():
|
|
"""Mock DatasetRepository."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_documents_repo():
|
|
"""Mock DocumentRepository."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_annotations_repo():
|
|
"""Mock AnnotationRepository."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_models_repo():
|
|
"""Mock ModelVersionRepository."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tasks_repo():
|
|
"""Mock TrainingTaskRepository."""
|
|
return MagicMock()
|
|
|
|
|
|
class TestCreateDatasetRoute:
|
|
"""Tests for POST /admin/training/datasets."""
|
|
|
|
def test_router_has_dataset_endpoints(self):
|
|
router = create_training_router()
|
|
paths = [route.path for route in router.routes]
|
|
assert any("datasets" in p for p in paths)
|
|
|
|
def test_create_dataset_calls_builder(
|
|
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
|
):
|
|
fn = _find_endpoint("create_dataset")
|
|
|
|
mock_datasets_repo.create.return_value = _make_dataset(status="building")
|
|
|
|
mock_builder = MagicMock()
|
|
mock_builder.build_dataset.return_value = {
|
|
"total_documents": 10,
|
|
"total_images": 20,
|
|
"total_annotations": 50,
|
|
}
|
|
|
|
request = DatasetCreateRequest(
|
|
name="test-dataset",
|
|
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
|
|
)
|
|
|
|
with patch(
|
|
"inference.web.services.dataset_builder.DatasetBuilder",
|
|
return_value=mock_builder,
|
|
), patch(
|
|
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
|
) as mock_storage:
|
|
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
|
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
|
result = asyncio.run(fn(
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets=mock_datasets_repo,
|
|
docs=mock_documents_repo,
|
|
annotations=mock_annotations_repo,
|
|
))
|
|
|
|
mock_datasets_repo.create.assert_called_once()
|
|
mock_builder.build_dataset.assert_called_once()
|
|
assert result.dataset_id == TEST_DATASET_UUID
|
|
assert result.name == "test-dataset"
|
|
|
|
def test_create_dataset_fails_with_less_than_10_documents(
|
|
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
|
):
|
|
"""Test that creating dataset fails if fewer than 10 documents provided."""
|
|
fn = _find_endpoint("create_dataset")
|
|
|
|
# Only 2 documents - should fail
|
|
request = DatasetCreateRequest(
|
|
name="test-dataset",
|
|
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
|
)
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets=mock_datasets_repo,
|
|
docs=mock_documents_repo,
|
|
annotations=mock_annotations_repo,
|
|
))
|
|
|
|
assert exc_info.value.status_code == 400
|
|
assert "Minimum 10 documents required" in exc_info.value.detail
|
|
assert "got 2" in exc_info.value.detail
|
|
# Ensure repo was never called since validation failed first
|
|
mock_datasets_repo.create.assert_not_called()
|
|
|
|
def test_create_dataset_fails_with_9_documents(
|
|
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
|
):
|
|
"""Test boundary condition: 9 documents should fail."""
|
|
fn = _find_endpoint("create_dataset")
|
|
|
|
# 9 documents - just under the limit
|
|
request = DatasetCreateRequest(
|
|
name="test-dataset",
|
|
document_ids=TEST_DOC_UUIDS[:9],
|
|
)
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets=mock_datasets_repo,
|
|
docs=mock_documents_repo,
|
|
annotations=mock_annotations_repo,
|
|
))
|
|
|
|
assert exc_info.value.status_code == 400
|
|
assert "Minimum 10 documents required" in exc_info.value.detail
|
|
|
|
def test_create_dataset_succeeds_with_exactly_10_documents(
|
|
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
|
):
|
|
"""Test boundary condition: exactly 10 documents should succeed."""
|
|
fn = _find_endpoint("create_dataset")
|
|
|
|
mock_datasets_repo.create.return_value = _make_dataset(status="building")
|
|
|
|
mock_builder = MagicMock()
|
|
|
|
# Exactly 10 documents - should pass
|
|
request = DatasetCreateRequest(
|
|
name="test-dataset",
|
|
document_ids=TEST_DOC_UUIDS[:10],
|
|
)
|
|
|
|
with patch(
|
|
"inference.web.services.dataset_builder.DatasetBuilder",
|
|
return_value=mock_builder,
|
|
), patch(
|
|
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
|
) as mock_storage:
|
|
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
|
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
|
result = asyncio.run(fn(
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets=mock_datasets_repo,
|
|
docs=mock_documents_repo,
|
|
annotations=mock_annotations_repo,
|
|
))
|
|
|
|
mock_datasets_repo.create.assert_called_once()
|
|
assert result.dataset_id == TEST_DATASET_UUID
|
|
|
|
|
|
class TestListDatasetsRoute:
|
|
"""Tests for GET /admin/training/datasets."""
|
|
|
|
def test_list_datasets(self, mock_datasets_repo):
|
|
fn = _find_endpoint("list_datasets")
|
|
|
|
mock_datasets_repo.get_paginated.return_value = ([_make_dataset()], 1)
|
|
# Mock the active training tasks lookup to return empty dict
|
|
mock_datasets_repo.get_active_training_tasks.return_value = {}
|
|
|
|
result = asyncio.run(fn(
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
status=None,
|
|
limit=20,
|
|
offset=0,
|
|
))
|
|
|
|
assert result.total == 1
|
|
assert len(result.datasets) == 1
|
|
assert result.datasets[0].name == "test-dataset"
|
|
|
|
|
|
class TestGetDatasetRoute:
|
|
"""Tests for GET /admin/training/datasets/{dataset_id}."""
|
|
|
|
def test_get_dataset_returns_detail(self, mock_datasets_repo):
|
|
fn = _find_endpoint("get_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset()
|
|
mock_datasets_repo.get_documents.return_value = [
|
|
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
|
|
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
|
|
]
|
|
|
|
result = asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
))
|
|
|
|
assert result.dataset_id == TEST_DATASET_UUID
|
|
assert len(result.documents) == 2
|
|
|
|
def test_get_dataset_not_found(self, mock_datasets_repo):
|
|
fn = _find_endpoint("get_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = None
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
))
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
class TestDeleteDatasetRoute:
|
|
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
|
|
|
|
def test_delete_dataset(self, mock_datasets_repo):
|
|
fn = _find_endpoint("delete_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset(dataset_path=None)
|
|
|
|
result = asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
))
|
|
|
|
mock_datasets_repo.delete.assert_called_once_with(TEST_DATASET_UUID)
|
|
assert result["message"] == "Dataset deleted"
|
|
|
|
|
|
class TestTrainFromDatasetRoute:
|
|
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
|
|
|
|
def test_train_from_ready_dataset(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
|
fn = _find_endpoint("train_from_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
|
mock_tasks_repo.create.return_value = TEST_TASK_UUID
|
|
|
|
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
|
|
|
result = asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
models=mock_models_repo,
|
|
tasks=mock_tasks_repo,
|
|
))
|
|
|
|
assert result.task_id == TEST_TASK_UUID
|
|
assert result.status == TrainingStatus.PENDING
|
|
mock_tasks_repo.create.assert_called_once()
|
|
|
|
def test_train_from_building_dataset_fails(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
|
fn = _find_endpoint("train_from_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset(status="building")
|
|
|
|
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
models=mock_models_repo,
|
|
tasks=mock_tasks_repo,
|
|
))
|
|
assert exc_info.value.status_code == 400
|
|
|
|
def test_incremental_training_with_base_model(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
|
"""Test training with base_model_version_id for incremental training."""
|
|
fn = _find_endpoint("train_from_dataset")
|
|
|
|
mock_model_version = MagicMock()
|
|
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
|
|
mock_model_version.version = "1.0.0"
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
|
mock_models_repo.get.return_value = mock_model_version
|
|
mock_tasks_repo.create.return_value = TEST_TASK_UUID
|
|
|
|
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
|
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
|
request = DatasetTrainRequest(name="incremental-train", config=config)
|
|
|
|
result = asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
models=mock_models_repo,
|
|
tasks=mock_tasks_repo,
|
|
))
|
|
|
|
# Verify model version was looked up
|
|
mock_models_repo.get.assert_called_once_with(base_model_uuid)
|
|
|
|
# Verify task was created with finetune type
|
|
call_kwargs = mock_tasks_repo.create.call_args[1]
|
|
assert call_kwargs["task_type"] == "finetune"
|
|
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
|
|
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
|
|
|
|
assert result.task_id == TEST_TASK_UUID
|
|
assert "Incremental training" in result.message
|
|
|
|
def test_incremental_training_with_invalid_base_model_fails(
|
|
self, mock_datasets_repo, mock_models_repo, mock_tasks_repo
|
|
):
|
|
"""Test that training fails if base_model_version_id doesn't exist."""
|
|
fn = _find_endpoint("train_from_dataset")
|
|
|
|
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
|
mock_models_repo.get.return_value = None
|
|
|
|
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
|
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
|
request = DatasetTrainRequest(name="incremental-train", config=config)
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(fn(
|
|
dataset_id=TEST_DATASET_UUID,
|
|
request=request,
|
|
admin_token=TEST_TOKEN,
|
|
datasets_repo=mock_datasets_repo,
|
|
models=mock_models_repo,
|
|
tasks=mock_tasks_repo,
|
|
))
|
|
assert exc_info.value.status_code == 404
|
|
assert "Base model version not found" in exc_info.value.detail
|