restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -10,12 +10,12 @@ from uuid import UUID
import pytest
from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from src.data.models import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService
from src.web.config import AsyncConfig, StorageConfig
from src.web.core.rate_limiter import RateLimiter
from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from inference.data.models import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.core.rate_limiter import RateLimiter
@pytest.fixture

View File

@@ -9,9 +9,9 @@ from uuid import UUID
from fastapi import HTTPException
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from src.web.schemas.admin import (
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import (
AnnotationCreate,
AnnotationUpdate,
AutoLabelRequest,

View File

@@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from src.data.admin_db import AdminDB
from src.data.admin_models import AdminToken
from src.web.core.auth import (
from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminToken
from inference.web.core.auth import (
get_admin_db,
reset_admin_db,
validate_admin_token,
@@ -81,7 +81,7 @@ class TestAdminDB:
def test_is_valid_admin_token_active(self):
"""Test valid active token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -98,7 +98,7 @@ class TestAdminDB:
def test_is_valid_admin_token_inactive(self):
"""Test inactive token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -115,7 +115,7 @@ class TestAdminDB:
def test_is_valid_admin_token_expired(self):
"""Test expired token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -132,7 +132,7 @@ class TestAdminDB:
def test_is_valid_admin_token_not_found(self):
"""Test token not found."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None

View File

@@ -12,8 +12,9 @@ from uuid import UUID
from fastapi import HTTPException
from fastapi.testclient import TestClient
from src.data.admin_models import AdminDocument, AdminToken
from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router
from inference.data.admin_models import AdminDocument, AdminToken
from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router
from inference.web.config import StorageConfig
# Test UUID
@@ -42,13 +43,12 @@ class TestAdminRouter:
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_documents_router(StorageConfig())
# Get route paths (include prefix from router)
paths = [route.path for route in router.routes]
# Paths include the /admin prefix
assert any("/auth/token" in p for p in paths)
# Paths include the /admin/documents prefix
assert any("/documents" in p for p in paths)
assert any("/documents/stats" in p for p in paths)
assert any("{document_id}" in p for p in paths)
@@ -66,7 +66,7 @@ class TestCreateTokenEndpoint:
def test_create_token_success(self, mock_db):
"""Test successful token creation."""
from src.web.schemas.admin import AdminTokenCreate
from inference.web.schemas.admin import AdminTokenCreate
request = AdminTokenCreate(name="Test Token", expires_in_days=30)

View File

@@ -9,8 +9,9 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
@@ -189,7 +190,7 @@ def app():
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_documents_router(StorageConfig())
app.include_router(router)
return app

View File

@@ -0,0 +1,245 @@
"""
Tests to verify admin schemas split maintains backward compatibility.
All existing imports from inference.web.schemas.admin must continue to work.
"""
import pytest
class TestEnumImports:
"""All enums importable from inference.web.schemas.admin."""
def test_document_status(self):
from inference.web.schemas.admin import DocumentStatus
assert DocumentStatus.PENDING == "pending"
def test_auto_label_status(self):
from inference.web.schemas.admin import AutoLabelStatus
assert AutoLabelStatus.RUNNING == "running"
def test_training_status(self):
from inference.web.schemas.admin import TrainingStatus
assert TrainingStatus.PENDING == "pending"
def test_training_type(self):
from inference.web.schemas.admin import TrainingType
assert TrainingType.TRAIN == "train"
def test_annotation_source(self):
from inference.web.schemas.admin import AnnotationSource
assert AnnotationSource.MANUAL == "manual"
class TestAuthImports:
"""Auth schemas importable."""
def test_admin_token_create(self):
from inference.web.schemas.admin import AdminTokenCreate
token = AdminTokenCreate(name="test")
assert token.name == "test"
def test_admin_token_response(self):
from inference.web.schemas.admin import AdminTokenResponse
assert AdminTokenResponse is not None
class TestDocumentImports:
"""Document schemas importable."""
def test_document_upload_response(self):
from inference.web.schemas.admin import DocumentUploadResponse
assert DocumentUploadResponse is not None
def test_document_item(self):
from inference.web.schemas.admin import DocumentItem
assert DocumentItem is not None
def test_document_list_response(self):
from inference.web.schemas.admin import DocumentListResponse
assert DocumentListResponse is not None
def test_document_detail_response(self):
from inference.web.schemas.admin import DocumentDetailResponse
assert DocumentDetailResponse is not None
def test_document_stats_response(self):
from inference.web.schemas.admin import DocumentStatsResponse
assert DocumentStatsResponse is not None
class TestAnnotationImports:
"""Annotation schemas importable."""
def test_bounding_box(self):
from inference.web.schemas.admin import BoundingBox
bb = BoundingBox(x=0, y=0, width=100, height=50)
assert bb.width == 100
def test_annotation_create(self):
from inference.web.schemas.admin import AnnotationCreate
assert AnnotationCreate is not None
def test_annotation_update(self):
from inference.web.schemas.admin import AnnotationUpdate
assert AnnotationUpdate is not None
def test_annotation_item(self):
from inference.web.schemas.admin import AnnotationItem
assert AnnotationItem is not None
def test_annotation_response(self):
from inference.web.schemas.admin import AnnotationResponse
assert AnnotationResponse is not None
def test_annotation_list_response(self):
from inference.web.schemas.admin import AnnotationListResponse
assert AnnotationListResponse is not None
def test_annotation_lock_request(self):
from inference.web.schemas.admin import AnnotationLockRequest
assert AnnotationLockRequest is not None
def test_annotation_lock_response(self):
from inference.web.schemas.admin import AnnotationLockResponse
assert AnnotationLockResponse is not None
def test_auto_label_request(self):
from inference.web.schemas.admin import AutoLabelRequest
assert AutoLabelRequest is not None
def test_auto_label_response(self):
from inference.web.schemas.admin import AutoLabelResponse
assert AutoLabelResponse is not None
def test_annotation_verify_request(self):
from inference.web.schemas.admin import AnnotationVerifyRequest
assert AnnotationVerifyRequest is not None
def test_annotation_verify_response(self):
from inference.web.schemas.admin import AnnotationVerifyResponse
assert AnnotationVerifyResponse is not None
def test_annotation_override_request(self):
from inference.web.schemas.admin import AnnotationOverrideRequest
assert AnnotationOverrideRequest is not None
def test_annotation_override_response(self):
from inference.web.schemas.admin import AnnotationOverrideResponse
assert AnnotationOverrideResponse is not None
class TestTrainingImports:
"""Training schemas importable."""
def test_training_config(self):
from inference.web.schemas.admin import TrainingConfig
config = TrainingConfig()
assert config.epochs == 100
def test_training_task_create(self):
from inference.web.schemas.admin import TrainingTaskCreate
assert TrainingTaskCreate is not None
def test_training_task_item(self):
from inference.web.schemas.admin import TrainingTaskItem
assert TrainingTaskItem is not None
def test_training_task_list_response(self):
from inference.web.schemas.admin import TrainingTaskListResponse
assert TrainingTaskListResponse is not None
def test_training_task_detail_response(self):
from inference.web.schemas.admin import TrainingTaskDetailResponse
assert TrainingTaskDetailResponse is not None
def test_training_task_response(self):
from inference.web.schemas.admin import TrainingTaskResponse
assert TrainingTaskResponse is not None
def test_training_log_item(self):
from inference.web.schemas.admin import TrainingLogItem
assert TrainingLogItem is not None
def test_training_logs_response(self):
from inference.web.schemas.admin import TrainingLogsResponse
assert TrainingLogsResponse is not None
def test_export_request(self):
from inference.web.schemas.admin import ExportRequest
assert ExportRequest is not None
def test_export_response(self):
from inference.web.schemas.admin import ExportResponse
assert ExportResponse is not None
def test_training_document_item(self):
from inference.web.schemas.admin import TrainingDocumentItem
assert TrainingDocumentItem is not None
def test_training_documents_response(self):
from inference.web.schemas.admin import TrainingDocumentsResponse
assert TrainingDocumentsResponse is not None
def test_model_metrics(self):
from inference.web.schemas.admin import ModelMetrics
assert ModelMetrics is not None
def test_training_model_item(self):
from inference.web.schemas.admin import TrainingModelItem
assert TrainingModelItem is not None
def test_training_models_response(self):
from inference.web.schemas.admin import TrainingModelsResponse
assert TrainingModelsResponse is not None
def test_training_history_item(self):
from inference.web.schemas.admin import TrainingHistoryItem
assert TrainingHistoryItem is not None
class TestDatasetImports:
"""Dataset schemas importable."""
def test_dataset_create_request(self):
from inference.web.schemas.admin import DatasetCreateRequest
assert DatasetCreateRequest is not None
def test_dataset_document_item(self):
from inference.web.schemas.admin import DatasetDocumentItem
assert DatasetDocumentItem is not None
def test_dataset_response(self):
from inference.web.schemas.admin import DatasetResponse
assert DatasetResponse is not None
def test_dataset_detail_response(self):
from inference.web.schemas.admin import DatasetDetailResponse
assert DatasetDetailResponse is not None
def test_dataset_list_item(self):
from inference.web.schemas.admin import DatasetListItem
assert DatasetListItem is not None
def test_dataset_list_response(self):
from inference.web.schemas.admin import DatasetListResponse
assert DatasetListResponse is not None
def test_dataset_train_request(self):
from inference.web.schemas.admin import DatasetTrainRequest
assert DatasetTrainRequest is not None
class TestForwardReferences:
"""Forward references resolve correctly."""
def test_document_detail_has_annotation_items(self):
from inference.web.schemas.admin import DocumentDetailResponse
fields = DocumentDetailResponse.model_fields
assert "annotations" in fields
assert "training_history" in fields
def test_dataset_train_request_has_config(self):
from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig
req = DatasetTrainRequest(name="test", config=TrainingConfig())
assert req.config.epochs == 100

View File

@@ -7,15 +7,15 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import UUID
from src.data.admin_models import TrainingTask, TrainingLog
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
from src.web.core.scheduler import (
from inference.data.admin_models import TrainingTask, TrainingLog
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
from inference.web.core.scheduler import (
TrainingScheduler,
get_training_scheduler,
start_scheduler,
stop_scheduler,
)
from src.web.schemas.admin import (
from inference.web.schemas.admin import (
TrainingConfig,
TrainingStatus,
TrainingTaskCreate,

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
@@ -110,7 +110,7 @@ def app():
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_locks_router()
app.include_router(router)
return app

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.annotations import create_annotation_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:

View File

@@ -11,7 +11,7 @@ from unittest.mock import MagicMock
import pytest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
class TestAsyncTask:

View File

@@ -11,12 +11,12 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from src.web.api.v1.async_api.routes import create_async_router, set_async_service
from src.web.services.async_processing import AsyncSubmitResult
from src.web.dependencies import init_dependencies
from src.web.rate_limiter import RateLimiter, RateLimitStatus
from src.web.schemas.inference import AsyncStatus
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.services.async_processing import AsyncSubmitResult
from inference.web.dependencies import init_dependencies
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
from inference.web.schemas.inference import AsyncStatus
# Valid UUID for testing
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"

View File

@@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch
import pytest
from src.data.async_request_db import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from src.web.config import AsyncConfig, StorageConfig
from src.web.rate_limiter import RateLimiter
from inference.data.async_request_db import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.rate_limiter import RateLimiter
@pytest.fixture

View File

@@ -8,8 +8,8 @@ from pathlib import Path
from unittest.mock import Mock, MagicMock
from uuid import uuid4
from src.web.services.autolabel import AutoLabelService
from src.data.admin_db import AdminDB
from inference.web.services.autolabel import AutoLabelService
from inference.data.admin_db import AdminDB
class MockDocument:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from src.web.workers.batch_queue import BatchTask, BatchTaskQueue
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
class MockBatchService:

View File

@@ -11,10 +11,10 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.batch.routes import router
from src.web.core.auth import validate_admin_token, get_admin_db
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from src.web.services.batch_upload import BatchUploadService
from inference.web.api.v1.batch.routes import router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
class MockAdminDB:

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
import pytest
from src.data.admin_db import AdminDB
from src.web.services.batch_upload import BatchUploadService
from inference.data.admin_db import AdminDB
from inference.web.services.batch_upload import BatchUploadService
@pytest.fixture

View File

@@ -0,0 +1,331 @@
"""
Tests for DatasetBuilder service.
TDD: Write tests first, then implement dataset_builder.py.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
TrainingDataset,
FIELD_CLASSES,
)
@pytest.fixture
def tmp_admin_images(tmp_path):
"""Create mock admin images directory with sample images."""
doc_ids = [uuid4() for _ in range(5)]
for doc_id in doc_ids:
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
# Create 2 pages per doc
for page in range(1, 3):
img_path = doc_dir / f"page_{page}.png"
img_path.write_bytes(b"fake-png-data")
return tmp_path, doc_ids
@pytest.fixture
def mock_admin_db():
"""Mock AdminDB with dataset and document methods."""
db = MagicMock()
db.create_dataset.return_value = TrainingDataset(
dataset_id=uuid4(),
name="test-dataset",
status="building",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
return db
@pytest.fixture
def sample_documents(tmp_admin_images):
"""Create sample AdminDocument objects."""
tmp_path, doc_ids = tmp_admin_images
docs = []
for doc_id in doc_ids:
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 2
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
docs.append(doc)
return docs
@pytest.fixture
def sample_annotations(sample_documents):
"""Create sample annotations for each document page."""
annotations = {}
for doc in sample_documents:
doc_anns = []
for page in range(1, 3):
ann = MagicMock(spec=AdminAnnotation)
ann.document_id = doc.document_id
ann.page_number = page
ann.class_id = 0
ann.class_name = "invoice_number"
ann.x_center = 0.5
ann.y_center = 0.3
ann.width = 0.2
ann.height = 0.05
doc_anns.append(ann)
annotations[str(doc.document_id)] = doc_anns
return annotations
class TestDatasetBuilder:
"""Tests for DatasetBuilder."""
def test_build_creates_directory_structure(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
from inference.web.services.dataset_builder import DatasetBuilder
dataset_dir = tmp_path / "datasets" / "test"
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# Mock DB calls
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
for split in ["train", "val", "test"]:
assert (result_dir / "images" / split).exists()
assert (result_dir / "labels" / split).exists()
def test_build_copies_images(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Images should be copied from admin_images to dataset folder."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
result = builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Check total images copied
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
total_images = sum(
len(list((result_dir / "images" / split).glob("*.png")))
for split in ["train", "val", "test"]
)
assert total_images == 10 # 5 docs * 2 pages
def test_build_generates_yolo_labels(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""YOLO label files should be generated with correct format."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
total_labels = sum(
len(list((result_dir / "labels" / split).glob("*.txt")))
for split in ["train", "val", "test"]
)
assert total_labels == 10 # 5 docs * 2 pages
# Check label format: "class_id x_center y_center width height"
label_files = list((result_dir / "labels").rglob("*.txt"))
content = label_files[0].read_text().strip()
parts = content.split()
assert len(parts) == 5
assert int(parts[0]) == 0 # class_id
assert 0 <= float(parts[1]) <= 1 # x_center
assert 0 <= float(parts[2]) <= 1 # y_center
def test_build_generates_data_yaml(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""data.yaml should be generated with correct field classes."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml"
assert yaml_path.exists()
content = yaml_path.read_text()
assert "train:" in content
assert "val:" in content
assert "nc:" in content
assert "invoice_number" in content
def test_build_splits_documents_correctly(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Documents should be split into train/val/test according to ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Verify add_dataset_documents was called with correct splits
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]
assert "train" in splits
# With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test
train_count = splits.count("train")
assert train_count >= 3 # At least 3 of 5 should be train
def test_build_updates_status_to_ready(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""After successful build, dataset status should be updated to 'ready'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
assert call_kwargs["status"] == "ready"
assert call_kwargs["total_documents"] == 5
assert call_kwargs["total_images"] == 10
def test_build_sets_failed_on_error(
self, tmp_path, mock_admin_db
):
"""If build fails, dataset status should be set to 'failed'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
dataset = mock_admin_db.create_dataset.return_value
with pytest.raises(ValueError):
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=["nonexistent-id"],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
assert call_kwargs["status"] == "failed"
def test_build_with_seed_produces_deterministic_splits(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Same seed should produce same splits."""
from inference.web.services.dataset_builder import DatasetBuilder
results = []
for _ in range(2):
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
mock_admin_db.add_dataset_documents.reset_mock()
mock_admin_db.update_dataset_status.reset_mock()
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
results.append([(d["document_id"], d["split"]) for d in docs])
assert results[0] == results[1]

View File

@@ -0,0 +1,200 @@
"""
Tests for Dataset API routes in training.py.
"""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetTrainRequest,
TrainingConfig,
TrainingStatus,
)
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
dataset_id=UUID(TEST_DATASET_UUID),
name="test-dataset",
description="Test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=2,
total_images=4,
total_annotations=10,
dataset_path="/data/datasets/test-dataset",
error_message=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
)
defaults.update(overrides)
ds = MagicMock(spec=TrainingDataset)
for k, v in defaults.items():
setattr(ds, k, v)
return ds
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
doc = MagicMock(spec=DatasetDocument)
doc.document_id = UUID(doc_id)
doc.split = split
doc.page_count = 2
doc.annotation_count = 5
return doc
def _find_endpoint(name: str):
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
def test_router_has_dataset_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 2,
"total_images": 4,
"total_annotations": 10,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 1
assert len(result.datasets) == 1
assert result.datasets[0].name == "test-dataset"
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
def test_train_from_building_dataset_fails(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400

View File

@@ -11,8 +11,8 @@ from fastapi.testclient import TestClient
from PIL import Image
import io
from src.web.app import create_app
from src.web.config import ModelConfig, StorageConfig, AppConfig
from inference.web.app import create_app
from inference.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
@@ -87,8 +87,8 @@ class TestHealthEndpoint:
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
@@ -150,8 +150,8 @@ class TestInferEndpoint:
assert response.status_code == 422 # Unprocessable Entity
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
@@ -210,8 +210,8 @@ class TestInferEndpoint:
# This test documents that it should be added
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
@@ -280,16 +280,16 @@ class TestInferenceServiceImports:
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from src.web.inference (non-existent module)
- from inference.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from src.web.services.inference import InferenceService
from inference.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from src.inference.pipeline import InferencePipeline
from src.inference.yolo_detector import YOLODetector
from src.pdf.renderer import render_pdf_to_images
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
from shared.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly
assert InferencePipeline is not None

View File

@@ -10,8 +10,8 @@ from unittest.mock import Mock, patch
from PIL import Image
import io
from src.web.services.inference import InferenceService
from src.web.config import ModelConfig, StorageConfig
from inference.web.services.inference import InferenceService
from inference.web.config import ModelConfig, StorageConfig
@pytest.fixture
@@ -72,8 +72,8 @@ class TestInferenceServiceInitialization:
gpu_available = inference_service.gpu_available
assert isinstance(gpu_available, bool)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_imports_correctly(
self,
mock_yolo_detector,
@@ -102,8 +102,8 @@ class TestInferenceServiceInitialization:
mock_yolo_detector.assert_called_once()
mock_pipeline.assert_called_once()
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_sets_up_pipeline(
self,
mock_yolo_detector,
@@ -135,8 +135,8 @@ class TestInferenceServiceInitialization:
enable_fallback=True,
)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_idempotent(
self,
mock_yolo_detector,
@@ -161,8 +161,8 @@ class TestInferenceServiceInitialization:
class TestInferenceServiceProcessing:
"""Test inference processing methods."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('ultralytics.YOLO')
def test_process_image_basic_flow(
self,
@@ -197,8 +197,8 @@ class TestInferenceServiceProcessing:
assert result.confidence == {"InvoiceNumber": 0.95}
assert result.processing_time_ms > 0
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_process_image_handles_errors(
self,
mock_yolo_detector,
@@ -228,9 +228,9 @@ class TestInferenceServiceProcessing:
class TestInferenceServicePDFRendering:
"""Test PDF rendering imports."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('src.pdf.renderer.render_pdf_to_images')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('shared.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly(
self,
@@ -245,7 +245,7 @@ class TestInferenceServicePDFRendering:
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
This catches the import error we had with:
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
"""
# Setup mocks
mock_detector_instance = Mock()

View File

@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
import pytest
from src.data.async_request_db import ApiKeyConfig
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
from inference.data.async_request_db import ApiKeyConfig
from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
class TestRateLimiter:

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.training import create_training_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockTrainingTask: