WIP
This commit is contained in:
@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
|
||||
# Generate 10 unique UUIDs for minimum document count tests
|
||||
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
|
||||
|
||||
|
||||
def _make_dataset(**overrides) -> MagicMock:
|
||||
defaults = dict(
|
||||
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_dataset.return_value = {
|
||||
"total_documents": 2,
|
||||
"total_images": 4,
|
||||
"total_annotations": 10,
|
||||
"total_documents": 10,
|
||||
"total_images": 20,
|
||||
"total_annotations": 50,
|
||||
}
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert result.name == "test-dataset"
|
||||
|
||||
def test_create_dataset_fails_with_less_than_10_documents(self):
|
||||
"""Test that creating dataset fails if fewer than 10 documents provided."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Only 2 documents - should fail
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
assert "got 2" in exc_info.value.detail
|
||||
# Ensure DB was never called since validation failed first
|
||||
mock_db.create_dataset.assert_not_called()
|
||||
|
||||
def test_create_dataset_fails_with_9_documents(self):
|
||||
"""Test boundary condition: 9 documents should fail."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 9 documents - just under the limit
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=TEST_DOC_UUIDS[:9],
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
|
||||
def test_create_dataset_succeeds_with_exactly_10_documents(self):
|
||||
"""Test boundary condition: exactly 10 documents should succeed."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
|
||||
# Exactly 10 documents - should pass
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=TEST_DOC_UUIDS[:10],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
):
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
|
||||
|
||||
class TestListDatasetsRoute:
|
||||
"""Tests for GET /admin/training/datasets."""
|
||||
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_incremental_training_with_base_model(self):
|
||||
"""Test training with base_model_version_id for incremental training."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_model_version = MagicMock()
|
||||
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
|
||||
mock_model_version.version = "1.0.0"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = mock_model_version
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
# Verify model version was looked up
|
||||
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
|
||||
|
||||
# Verify task was created with finetune type
|
||||
call_kwargs = mock_db.create_training_task.call_args[1]
|
||||
assert call_kwargs["task_type"] == "finetune"
|
||||
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
|
||||
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
|
||||
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert "Incremental training" in result.message
|
||||
|
||||
def test_incremental_training_with_invalid_base_model_fails(self):
|
||||
"""Test that training fails if base_model_version_id doesn't exist."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = None
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "Base model version not found" in exc_info.value.detail
|
||||
|
||||
Reference in New Issue
Block a user