This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -72,6 +72,36 @@ def _find_endpoint(name: str):
raise AssertionError(f"Endpoint {name} not found")
@pytest.fixture
def mock_datasets_repo():
"""Mock DatasetRepository."""
return MagicMock()
@pytest.fixture
def mock_documents_repo():
"""Mock DocumentRepository."""
return MagicMock()
@pytest.fixture
def mock_annotations_repo():
"""Mock AnnotationRepository."""
return MagicMock()
@pytest.fixture
def mock_models_repo():
"""Mock ModelVersionRepository."""
return MagicMock()
@pytest.fixture
def mock_tasks_repo():
"""Mock TrainingTaskRepository."""
return MagicMock()
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
@@ -80,11 +110,12 @@ class TestCreateDatasetRoute:
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
def test_create_dataset_calls_builder(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.create.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
@@ -101,20 +132,30 @@ class TestCreateDatasetRoute:
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
result = asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
mock_db.create_dataset.assert_called_once()
mock_datasets_repo.create.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
def test_create_dataset_fails_with_less_than_10_documents(self):
def test_create_dataset_fails_with_less_than_10_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test that creating dataset fails if fewer than 10 documents provided."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# Only 2 documents - should fail
request = DatasetCreateRequest(
name="test-dataset",
@@ -124,20 +165,26 @@ class TestCreateDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
assert "got 2" in exc_info.value.detail
# Ensure DB was never called since validation failed first
mock_db.create_dataset.assert_not_called()
# Ensure repo was never called since validation failed first
mock_datasets_repo.create.assert_not_called()
def test_create_dataset_fails_with_9_documents(self):
def test_create_dataset_fails_with_9_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test boundary condition: 9 documents should fail."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# 9 documents - just under the limit
request = DatasetCreateRequest(
name="test-dataset",
@@ -147,17 +194,24 @@ class TestCreateDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
def test_create_dataset_succeeds_with_exactly_10_documents(self):
def test_create_dataset_succeeds_with_exactly_10_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test boundary condition: exactly 10 documents should succeed."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.create.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
@@ -170,25 +224,40 @@ class TestCreateDatasetRoute:
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
):
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
result = asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
mock_db.create_dataset.assert_called_once()
mock_datasets_repo.create.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
def test_list_datasets(self, mock_datasets_repo):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
mock_datasets_repo.get_paginated.return_value = ([_make_dataset()], 1)
# Mock the active training tasks lookup to return empty dict
mock_db.get_active_training_tasks_for_datasets.return_value = {}
mock_datasets_repo.get_active_training_tasks.return_value = {}
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
result = asyncio.run(fn(
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
status=None,
limit=20,
offset=0,
))
assert result.total == 1
assert len(result.datasets) == 1
@@ -198,82 +267,103 @@ class TestListDatasetsRoute:
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
def test_get_dataset_returns_detail(self, mock_datasets_repo):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
mock_datasets_repo.get.return_value = _make_dataset()
mock_datasets_repo.get_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
def test_get_dataset_not_found(self, mock_datasets_repo):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
mock_datasets_repo.get.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
def test_delete_dataset(self, mock_datasets_repo):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
mock_datasets_repo.get.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
mock_datasets_repo.delete.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
def test_train_from_ready_dataset(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_tasks_repo.create.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
mock_tasks_repo.create.assert_called_once()
def test_train_from_building_dataset_fails(self):
def test_train_from_building_dataset_fails(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.get.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert exc_info.value.status_code == 400
def test_incremental_training_with_base_model(self):
def test_incremental_training_with_base_model(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
"""Test training with base_model_version_id for incremental training."""
fn = _find_endpoint("train_from_dataset")
@@ -281,22 +371,28 @@ class TestTrainFromDatasetRoute:
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
mock_model_version.version = "1.0.0"
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = mock_model_version
mock_db.create_training_task.return_value = TEST_TASK_UUID
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_models_repo.get.return_value = mock_model_version
mock_tasks_repo.create.return_value = TEST_TASK_UUID
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
# Verify model version was looked up
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
mock_models_repo.get.assert_called_once_with(base_model_uuid)
# Verify task was created with finetune type
call_kwargs = mock_db.create_training_task.call_args[1]
call_kwargs = mock_tasks_repo.create.call_args[1]
assert call_kwargs["task_type"] == "finetune"
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
@@ -304,13 +400,14 @@ class TestTrainFromDatasetRoute:
assert result.task_id == TEST_TASK_UUID
assert "Incremental training" in result.message
def test_incremental_training_with_invalid_base_model_fails(self):
def test_incremental_training_with_invalid_base_model_fails(
self, mock_datasets_repo, mock_models_repo, mock_tasks_repo
):
"""Test that training fails if base_model_version_id doesn't exist."""
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = None
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_models_repo.get.return_value = None
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
@@ -319,6 +416,13 @@ class TestTrainFromDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert exc_info.value.status_code == 404
assert "Base model version not found" in exc_info.value.detail