restructure project
This commit is contained in:
200
tests/web/test_dataset_routes.py
Normal file
200
tests/web/test_dataset_routes.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for Dataset API routes in training.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetTrainRequest,
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
)
|
||||
|
||||
|
||||
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
||||
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
|
||||
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
|
||||
|
||||
def _make_dataset(**overrides) -> MagicMock:
|
||||
defaults = dict(
|
||||
dataset_id=UUID(TEST_DATASET_UUID),
|
||||
name="test-dataset",
|
||||
description="Test dataset",
|
||||
status="ready",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=2,
|
||||
total_images=4,
|
||||
total_annotations=10,
|
||||
dataset_path="/data/datasets/test-dataset",
|
||||
error_message=None,
|
||||
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
ds = MagicMock(spec=TrainingDataset)
|
||||
for k, v in defaults.items():
|
||||
setattr(ds, k, v)
|
||||
return ds
|
||||
|
||||
|
||||
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
|
||||
doc = MagicMock(spec=DatasetDocument)
|
||||
doc.document_id = UUID(doc_id)
|
||||
doc.split = split
|
||||
doc.page_count = 2
|
||||
doc.annotation_count = 5
|
||||
return doc
|
||||
|
||||
|
||||
def _find_endpoint(name: str):
|
||||
router = create_training_router()
|
||||
for route in router.routes:
|
||||
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
class TestCreateDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets."""
|
||||
|
||||
def test_router_has_dataset_endpoints(self):
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
assert any("datasets" in p for p in paths)
|
||||
|
||||
def test_create_dataset_calls_builder(self):
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_dataset.return_value = {
|
||||
"total_documents": 2,
|
||||
"total_images": 4,
|
||||
"total_annotations": 10,
|
||||
}
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
) as mock_cls:
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
mock_builder.build_dataset.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert result.name == "test-dataset"
|
||||
|
||||
|
||||
class TestListDatasetsRoute:
|
||||
"""Tests for GET /admin/training/datasets."""
|
||||
|
||||
def test_list_datasets(self):
|
||||
fn = _find_endpoint("list_datasets")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.datasets) == 1
|
||||
assert result.datasets[0].name == "test-dataset"
|
||||
|
||||
|
||||
class TestGetDatasetRoute:
|
||||
"""Tests for GET /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_get_dataset_returns_detail(self):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset()
|
||||
mock_db.get_dataset_documents.return_value = [
|
||||
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
|
||||
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
|
||||
]
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert len(result.documents) == 2
|
||||
|
||||
def test_get_dataset_not_found(self):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeleteDatasetRoute:
|
||||
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_delete_dataset(self):
|
||||
fn = _find_endpoint("delete_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
|
||||
assert result["message"] == "Dataset deleted"
|
||||
|
||||
|
||||
class TestTrainFromDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
|
||||
|
||||
def test_train_from_ready_dataset(self):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert result.status == TrainingStatus.PENDING
|
||||
mock_db.create_training_task.assert_called_once()
|
||||
|
||||
def test_train_from_building_dataset_fails(self):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
Reference in New Issue
Block a user