restructure project
This commit is contained in:
@@ -10,12 +10,12 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB
|
||||
from src.data.models import AsyncRequest
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService
|
||||
from src.web.config import AsyncConfig, StorageConfig
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB
|
||||
from inference.data.models import AsyncRequest
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -9,9 +9,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
||||
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from src.web.schemas.admin import (
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
||||
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
AutoLabelRequest,
|
||||
|
||||
@@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import AdminToken
|
||||
from src.web.core.auth import (
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.web.core.auth import (
|
||||
get_admin_db,
|
||||
reset_admin_db,
|
||||
validate_admin_token,
|
||||
@@ -81,7 +81,7 @@ class TestAdminDB:
|
||||
|
||||
def test_is_valid_admin_token_active(self):
|
||||
"""Test valid active token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestAdminDB:
|
||||
|
||||
def test_is_valid_admin_token_inactive(self):
|
||||
"""Test inactive token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestAdminDB:
|
||||
|
||||
def test_is_valid_admin_token_expired(self):
|
||||
"""Test expired token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -132,7 +132,7 @@ class TestAdminDB:
|
||||
|
||||
def test_is_valid_admin_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = None
|
||||
|
||||
@@ -12,8 +12,9 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.data.admin_models import AdminDocument, AdminToken
|
||||
from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router
|
||||
from inference.data.admin_models import AdminDocument, AdminToken
|
||||
from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
|
||||
# Test UUID
|
||||
@@ -42,13 +43,12 @@ class TestAdminRouter:
|
||||
|
||||
def test_creates_router_with_endpoints(self):
|
||||
"""Test router is created with expected endpoints."""
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
router = create_documents_router(StorageConfig())
|
||||
|
||||
# Get route paths (include prefix from router)
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
# Paths include the /admin prefix
|
||||
assert any("/auth/token" in p for p in paths)
|
||||
# Paths include the /admin/documents prefix
|
||||
assert any("/documents" in p for p in paths)
|
||||
assert any("/documents/stats" in p for p in paths)
|
||||
assert any("{document_id}" in p for p in paths)
|
||||
@@ -66,7 +66,7 @@ class TestCreateTokenEndpoint:
|
||||
|
||||
def test_create_token_success(self, mock_db):
|
||||
"""Test successful token creation."""
|
||||
from src.web.schemas.admin import AdminTokenCreate
|
||||
from inference.web.schemas.admin import AdminTokenCreate
|
||||
|
||||
request = AdminTokenCreate(name="Test Token", expires_in_days=30)
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -189,7 +190,7 @@ def app():
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
router = create_documents_router(StorageConfig())
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
245
tests/web/test_admin_schemas_split.py
Normal file
245
tests/web/test_admin_schemas_split.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Tests to verify admin schemas split maintains backward compatibility.
|
||||
|
||||
All existing imports from inference.web.schemas.admin must continue to work.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEnumImports:
|
||||
"""All enums importable from inference.web.schemas.admin."""
|
||||
|
||||
def test_document_status(self):
|
||||
from inference.web.schemas.admin import DocumentStatus
|
||||
assert DocumentStatus.PENDING == "pending"
|
||||
|
||||
def test_auto_label_status(self):
|
||||
from inference.web.schemas.admin import AutoLabelStatus
|
||||
assert AutoLabelStatus.RUNNING == "running"
|
||||
|
||||
def test_training_status(self):
|
||||
from inference.web.schemas.admin import TrainingStatus
|
||||
assert TrainingStatus.PENDING == "pending"
|
||||
|
||||
def test_training_type(self):
|
||||
from inference.web.schemas.admin import TrainingType
|
||||
assert TrainingType.TRAIN == "train"
|
||||
|
||||
def test_annotation_source(self):
|
||||
from inference.web.schemas.admin import AnnotationSource
|
||||
assert AnnotationSource.MANUAL == "manual"
|
||||
|
||||
|
||||
class TestAuthImports:
|
||||
"""Auth schemas importable."""
|
||||
|
||||
def test_admin_token_create(self):
|
||||
from inference.web.schemas.admin import AdminTokenCreate
|
||||
token = AdminTokenCreate(name="test")
|
||||
assert token.name == "test"
|
||||
|
||||
def test_admin_token_response(self):
|
||||
from inference.web.schemas.admin import AdminTokenResponse
|
||||
assert AdminTokenResponse is not None
|
||||
|
||||
|
||||
class TestDocumentImports:
|
||||
"""Document schemas importable."""
|
||||
|
||||
def test_document_upload_response(self):
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
assert DocumentUploadResponse is not None
|
||||
|
||||
def test_document_item(self):
|
||||
from inference.web.schemas.admin import DocumentItem
|
||||
assert DocumentItem is not None
|
||||
|
||||
def test_document_list_response(self):
|
||||
from inference.web.schemas.admin import DocumentListResponse
|
||||
assert DocumentListResponse is not None
|
||||
|
||||
def test_document_detail_response(self):
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
assert DocumentDetailResponse is not None
|
||||
|
||||
def test_document_stats_response(self):
|
||||
from inference.web.schemas.admin import DocumentStatsResponse
|
||||
assert DocumentStatsResponse is not None
|
||||
|
||||
|
||||
class TestAnnotationImports:
|
||||
"""Annotation schemas importable."""
|
||||
|
||||
def test_bounding_box(self):
|
||||
from inference.web.schemas.admin import BoundingBox
|
||||
bb = BoundingBox(x=0, y=0, width=100, height=50)
|
||||
assert bb.width == 100
|
||||
|
||||
def test_annotation_create(self):
|
||||
from inference.web.schemas.admin import AnnotationCreate
|
||||
assert AnnotationCreate is not None
|
||||
|
||||
def test_annotation_update(self):
|
||||
from inference.web.schemas.admin import AnnotationUpdate
|
||||
assert AnnotationUpdate is not None
|
||||
|
||||
def test_annotation_item(self):
|
||||
from inference.web.schemas.admin import AnnotationItem
|
||||
assert AnnotationItem is not None
|
||||
|
||||
def test_annotation_response(self):
|
||||
from inference.web.schemas.admin import AnnotationResponse
|
||||
assert AnnotationResponse is not None
|
||||
|
||||
def test_annotation_list_response(self):
|
||||
from inference.web.schemas.admin import AnnotationListResponse
|
||||
assert AnnotationListResponse is not None
|
||||
|
||||
def test_annotation_lock_request(self):
|
||||
from inference.web.schemas.admin import AnnotationLockRequest
|
||||
assert AnnotationLockRequest is not None
|
||||
|
||||
def test_annotation_lock_response(self):
|
||||
from inference.web.schemas.admin import AnnotationLockResponse
|
||||
assert AnnotationLockResponse is not None
|
||||
|
||||
def test_auto_label_request(self):
|
||||
from inference.web.schemas.admin import AutoLabelRequest
|
||||
assert AutoLabelRequest is not None
|
||||
|
||||
def test_auto_label_response(self):
|
||||
from inference.web.schemas.admin import AutoLabelResponse
|
||||
assert AutoLabelResponse is not None
|
||||
|
||||
def test_annotation_verify_request(self):
|
||||
from inference.web.schemas.admin import AnnotationVerifyRequest
|
||||
assert AnnotationVerifyRequest is not None
|
||||
|
||||
def test_annotation_verify_response(self):
|
||||
from inference.web.schemas.admin import AnnotationVerifyResponse
|
||||
assert AnnotationVerifyResponse is not None
|
||||
|
||||
def test_annotation_override_request(self):
|
||||
from inference.web.schemas.admin import AnnotationOverrideRequest
|
||||
assert AnnotationOverrideRequest is not None
|
||||
|
||||
def test_annotation_override_response(self):
|
||||
from inference.web.schemas.admin import AnnotationOverrideResponse
|
||||
assert AnnotationOverrideResponse is not None
|
||||
|
||||
|
||||
class TestTrainingImports:
|
||||
"""Training schemas importable."""
|
||||
|
||||
def test_training_config(self):
|
||||
from inference.web.schemas.admin import TrainingConfig
|
||||
config = TrainingConfig()
|
||||
assert config.epochs == 100
|
||||
|
||||
def test_training_task_create(self):
|
||||
from inference.web.schemas.admin import TrainingTaskCreate
|
||||
assert TrainingTaskCreate is not None
|
||||
|
||||
def test_training_task_item(self):
|
||||
from inference.web.schemas.admin import TrainingTaskItem
|
||||
assert TrainingTaskItem is not None
|
||||
|
||||
def test_training_task_list_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskListResponse
|
||||
assert TrainingTaskListResponse is not None
|
||||
|
||||
def test_training_task_detail_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskDetailResponse
|
||||
assert TrainingTaskDetailResponse is not None
|
||||
|
||||
def test_training_task_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskResponse
|
||||
assert TrainingTaskResponse is not None
|
||||
|
||||
def test_training_log_item(self):
|
||||
from inference.web.schemas.admin import TrainingLogItem
|
||||
assert TrainingLogItem is not None
|
||||
|
||||
def test_training_logs_response(self):
|
||||
from inference.web.schemas.admin import TrainingLogsResponse
|
||||
assert TrainingLogsResponse is not None
|
||||
|
||||
def test_export_request(self):
|
||||
from inference.web.schemas.admin import ExportRequest
|
||||
assert ExportRequest is not None
|
||||
|
||||
def test_export_response(self):
|
||||
from inference.web.schemas.admin import ExportResponse
|
||||
assert ExportResponse is not None
|
||||
|
||||
def test_training_document_item(self):
|
||||
from inference.web.schemas.admin import TrainingDocumentItem
|
||||
assert TrainingDocumentItem is not None
|
||||
|
||||
def test_training_documents_response(self):
|
||||
from inference.web.schemas.admin import TrainingDocumentsResponse
|
||||
assert TrainingDocumentsResponse is not None
|
||||
|
||||
def test_model_metrics(self):
|
||||
from inference.web.schemas.admin import ModelMetrics
|
||||
assert ModelMetrics is not None
|
||||
|
||||
def test_training_model_item(self):
|
||||
from inference.web.schemas.admin import TrainingModelItem
|
||||
assert TrainingModelItem is not None
|
||||
|
||||
def test_training_models_response(self):
|
||||
from inference.web.schemas.admin import TrainingModelsResponse
|
||||
assert TrainingModelsResponse is not None
|
||||
|
||||
def test_training_history_item(self):
|
||||
from inference.web.schemas.admin import TrainingHistoryItem
|
||||
assert TrainingHistoryItem is not None
|
||||
|
||||
|
||||
class TestDatasetImports:
|
||||
"""Dataset schemas importable."""
|
||||
|
||||
def test_dataset_create_request(self):
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
assert DatasetCreateRequest is not None
|
||||
|
||||
def test_dataset_document_item(self):
|
||||
from inference.web.schemas.admin import DatasetDocumentItem
|
||||
assert DatasetDocumentItem is not None
|
||||
|
||||
def test_dataset_response(self):
|
||||
from inference.web.schemas.admin import DatasetResponse
|
||||
assert DatasetResponse is not None
|
||||
|
||||
def test_dataset_detail_response(self):
|
||||
from inference.web.schemas.admin import DatasetDetailResponse
|
||||
assert DatasetDetailResponse is not None
|
||||
|
||||
def test_dataset_list_item(self):
|
||||
from inference.web.schemas.admin import DatasetListItem
|
||||
assert DatasetListItem is not None
|
||||
|
||||
def test_dataset_list_response(self):
|
||||
from inference.web.schemas.admin import DatasetListResponse
|
||||
assert DatasetListResponse is not None
|
||||
|
||||
def test_dataset_train_request(self):
|
||||
from inference.web.schemas.admin import DatasetTrainRequest
|
||||
assert DatasetTrainRequest is not None
|
||||
|
||||
|
||||
class TestForwardReferences:
|
||||
"""Forward references resolve correctly."""
|
||||
|
||||
def test_document_detail_has_annotation_items(self):
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
fields = DocumentDetailResponse.model_fields
|
||||
assert "annotations" in fields
|
||||
assert "training_history" in fields
|
||||
|
||||
def test_dataset_train_request_has_config(self):
|
||||
from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig
|
||||
req = DatasetTrainRequest(name="test", config=TrainingConfig())
|
||||
assert req.config.epochs == 100
|
||||
@@ -7,15 +7,15 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from src.data.admin_models import TrainingTask, TrainingLog
|
||||
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from src.web.core.scheduler import (
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog
|
||||
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from inference.web.core.scheduler import (
|
||||
TrainingScheduler,
|
||||
get_training_scheduler,
|
||||
start_scheduler,
|
||||
stop_scheduler,
|
||||
)
|
||||
from src.web.schemas.admin import (
|
||||
from inference.web.schemas.admin import (
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -110,7 +110,7 @@ def app():
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
router = create_locks_router()
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.annotations import create_annotation_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
|
||||
|
||||
class TestAsyncTask:
|
||||
|
||||
@@ -11,12 +11,12 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
||||
from src.web.api.v1.async_api.routes import create_async_router, set_async_service
|
||||
from src.web.services.async_processing import AsyncSubmitResult
|
||||
from src.web.dependencies import init_dependencies
|
||||
from src.web.rate_limiter import RateLimiter, RateLimitStatus
|
||||
from src.web.schemas.inference import AsyncStatus
|
||||
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
||||
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from inference.web.services.async_processing import AsyncSubmitResult
|
||||
from inference.web.dependencies import init_dependencies
|
||||
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
|
||||
from inference.web.schemas.inference import AsyncStatus
|
||||
|
||||
# Valid UUID for testing
|
||||
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
@@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import AsyncRequest
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
|
||||
from src.web.config import AsyncConfig, StorageConfig
|
||||
from src.web.rate_limiter import RateLimiter
|
||||
from inference.data.async_request_db import AsyncRequest
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -8,8 +8,8 @@ from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from src.web.services.autolabel import AutoLabelService
|
||||
from src.data.admin_db import AdminDB
|
||||
from inference.web.services.autolabel import AutoLabelService
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
|
||||
class MockDocument:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from src.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
||||
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
||||
|
||||
|
||||
class MockBatchService:
|
||||
|
||||
@@ -11,10 +11,10 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.batch.routes import router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
from inference.web.api.v1.batch.routes import router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
331
tests/web/test_dataset_builder.py
Normal file
331
tests/web/test_dataset_builder.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Tests for DatasetBuilder service.
|
||||
|
||||
TDD: Write tests first, then implement dataset_builder.py.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
TrainingDataset,
|
||||
FIELD_CLASSES,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_admin_images(tmp_path):
|
||||
"""Create mock admin images directory with sample images."""
|
||||
doc_ids = [uuid4() for _ in range(5)]
|
||||
for doc_id in doc_ids:
|
||||
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||
doc_dir.mkdir(parents=True)
|
||||
# Create 2 pages per doc
|
||||
for page in range(1, 3):
|
||||
img_path = doc_dir / f"page_{page}.png"
|
||||
img_path.write_bytes(b"fake-png-data")
|
||||
return tmp_path, doc_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Mock AdminDB with dataset and document methods."""
|
||||
db = MagicMock()
|
||||
db.create_dataset.return_value = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
status="building",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(tmp_admin_images):
|
||||
"""Create sample AdminDocument objects."""
|
||||
tmp_path, doc_ids = tmp_admin_images
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = MagicMock(spec=AdminDocument)
|
||||
doc.document_id = doc_id
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 2
|
||||
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotations(sample_documents):
|
||||
"""Create sample annotations for each document page."""
|
||||
annotations = {}
|
||||
for doc in sample_documents:
|
||||
doc_anns = []
|
||||
for page in range(1, 3):
|
||||
ann = MagicMock(spec=AdminAnnotation)
|
||||
ann.document_id = doc.document_id
|
||||
ann.page_number = page
|
||||
ann.class_id = 0
|
||||
ann.class_name = "invoice_number"
|
||||
ann.x_center = 0.5
|
||||
ann.y_center = 0.3
|
||||
ann.width = 0.2
|
||||
ann.height = 0.05
|
||||
doc_anns.append(ann)
|
||||
annotations[str(doc.document_id)] = doc_anns
|
||||
return annotations
|
||||
|
||||
|
||||
class TestDatasetBuilder:
|
||||
"""Tests for DatasetBuilder."""
|
||||
|
||||
def test_build_creates_directory_structure(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset_dir = tmp_path / "datasets" / "test"
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# Mock DB calls
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
for split in ["train", "val", "test"]:
|
||||
assert (result_dir / "images" / split).exists()
|
||||
assert (result_dir / "labels" / split).exists()
|
||||
|
||||
def test_build_copies_images(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Images should be copied from admin_images to dataset folder."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
result = builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Check total images copied
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
total_images = sum(
|
||||
len(list((result_dir / "images" / split).glob("*.png")))
|
||||
for split in ["train", "val", "test"]
|
||||
)
|
||||
assert total_images == 10 # 5 docs * 2 pages
|
||||
|
||||
def test_build_generates_yolo_labels(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""YOLO label files should be generated with correct format."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
total_labels = sum(
|
||||
len(list((result_dir / "labels" / split).glob("*.txt")))
|
||||
for split in ["train", "val", "test"]
|
||||
)
|
||||
assert total_labels == 10 # 5 docs * 2 pages
|
||||
|
||||
# Check label format: "class_id x_center y_center width height"
|
||||
label_files = list((result_dir / "labels").rglob("*.txt"))
|
||||
content = label_files[0].read_text().strip()
|
||||
parts = content.split()
|
||||
assert len(parts) == 5
|
||||
assert int(parts[0]) == 0 # class_id
|
||||
assert 0 <= float(parts[1]) <= 1 # x_center
|
||||
assert 0 <= float(parts[2]) <= 1 # y_center
|
||||
|
||||
def test_build_generates_data_yaml(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""data.yaml should be generated with correct field classes."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml"
|
||||
assert yaml_path.exists()
|
||||
content = yaml_path.read_text()
|
||||
assert "train:" in content
|
||||
assert "val:" in content
|
||||
assert "nc:" in content
|
||||
assert "invoice_number" in content
|
||||
|
||||
def test_build_splits_documents_correctly(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Documents should be split into train/val/test according to ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Verify add_dataset_documents was called with correct splits
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
splits = [d["split"] for d in docs_added]
|
||||
assert "train" in splits
|
||||
# With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test
|
||||
train_count = splits.count("train")
|
||||
assert train_count >= 3 # At least 3 of 5 should be train
|
||||
|
||||
def test_build_updates_status_to_ready(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""After successful build, dataset status should be updated to 'ready'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
assert call_kwargs["status"] == "ready"
|
||||
assert call_kwargs["total_documents"] == 5
|
||||
assert call_kwargs["total_images"] == 10
|
||||
|
||||
def test_build_sets_failed_on_error(
|
||||
self, tmp_path, mock_admin_db
|
||||
):
|
||||
"""If build fails, dataset status should be set to 'failed'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=["nonexistent-id"],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
assert call_kwargs["status"] == "failed"
|
||||
|
||||
def test_build_with_seed_produces_deterministic_splits(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Same seed should produce same splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
results = []
|
||||
for _ in range(2):
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
mock_admin_db.add_dataset_documents.reset_mock()
|
||||
mock_admin_db.update_dataset_status.reset_mock()
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||
|
||||
assert results[0] == results[1]
|
||||
200
tests/web/test_dataset_routes.py
Normal file
200
tests/web/test_dataset_routes.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for Dataset API routes in training.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetTrainRequest,
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
)
|
||||
|
||||
|
||||
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
||||
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
|
||||
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
|
||||
|
||||
def _make_dataset(**overrides) -> MagicMock:
|
||||
defaults = dict(
|
||||
dataset_id=UUID(TEST_DATASET_UUID),
|
||||
name="test-dataset",
|
||||
description="Test dataset",
|
||||
status="ready",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=2,
|
||||
total_images=4,
|
||||
total_annotations=10,
|
||||
dataset_path="/data/datasets/test-dataset",
|
||||
error_message=None,
|
||||
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
ds = MagicMock(spec=TrainingDataset)
|
||||
for k, v in defaults.items():
|
||||
setattr(ds, k, v)
|
||||
return ds
|
||||
|
||||
|
||||
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
|
||||
doc = MagicMock(spec=DatasetDocument)
|
||||
doc.document_id = UUID(doc_id)
|
||||
doc.split = split
|
||||
doc.page_count = 2
|
||||
doc.annotation_count = 5
|
||||
return doc
|
||||
|
||||
|
||||
def _find_endpoint(name: str):
|
||||
router = create_training_router()
|
||||
for route in router.routes:
|
||||
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
class TestCreateDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets."""
|
||||
|
||||
def test_router_has_dataset_endpoints(self):
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
assert any("datasets" in p for p in paths)
|
||||
|
||||
def test_create_dataset_calls_builder(self):
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_dataset.return_value = {
|
||||
"total_documents": 2,
|
||||
"total_images": 4,
|
||||
"total_annotations": 10,
|
||||
}
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
) as mock_cls:
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
mock_builder.build_dataset.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert result.name == "test-dataset"
|
||||
|
||||
|
||||
class TestListDatasetsRoute:
|
||||
"""Tests for GET /admin/training/datasets."""
|
||||
|
||||
def test_list_datasets(self):
|
||||
fn = _find_endpoint("list_datasets")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.datasets) == 1
|
||||
assert result.datasets[0].name == "test-dataset"
|
||||
|
||||
|
||||
class TestGetDatasetRoute:
|
||||
"""Tests for GET /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_get_dataset_returns_detail(self):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset()
|
||||
mock_db.get_dataset_documents.return_value = [
|
||||
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
|
||||
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
|
||||
]
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert len(result.documents) == 2
|
||||
|
||||
def test_get_dataset_not_found(self):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeleteDatasetRoute:
|
||||
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_delete_dataset(self):
|
||||
fn = _find_endpoint("delete_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
|
||||
assert result["message"] == "Dataset deleted"
|
||||
|
||||
|
||||
class TestTrainFromDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
|
||||
|
||||
def test_train_from_ready_dataset(self):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert result.status == TrainingStatus.PENDING
|
||||
mock_db.create_training_task.assert_called_once()
|
||||
|
||||
def test_train_from_building_dataset_fails(self):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
@@ -11,8 +11,8 @@ from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.app import create_app
|
||||
from src.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
from inference.web.app import create_app
|
||||
from inference.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -87,8 +87,8 @@ class TestHealthEndpoint:
|
||||
class TestInferEndpoint:
|
||||
"""Test /api/v1/infer endpoint."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_accepts_png_file(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -150,8 +150,8 @@ class TestInferEndpoint:
|
||||
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_returns_cross_validation_if_available(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -210,8 +210,8 @@ class TestInferEndpoint:
|
||||
# This test documents that it should be added
|
||||
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_handles_processing_errors_gracefully(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -280,16 +280,16 @@ class TestInferenceServiceImports:
|
||||
|
||||
This test will fail if there are ImportError issues like:
|
||||
- from ..inference.pipeline (wrong relative import)
|
||||
- from src.web.inference (non-existent module)
|
||||
- from inference.web.inference (non-existent module)
|
||||
|
||||
It ensures the imports are correct before runtime.
|
||||
"""
|
||||
from src.web.services.inference import InferenceService
|
||||
from inference.web.services.inference import InferenceService
|
||||
|
||||
# Import the modules that InferenceService tries to import
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
from src.inference.yolo_detector import YOLODetector
|
||||
from src.pdf.renderer import render_pdf_to_images
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
|
||||
# If we got here, all imports work correctly
|
||||
assert InferencePipeline is not None
|
||||
|
||||
@@ -10,8 +10,8 @@ from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.services.inference import InferenceService
|
||||
from src.web.config import ModelConfig, StorageConfig
|
||||
from inference.web.services.inference import InferenceService
|
||||
from inference.web.config import ModelConfig, StorageConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -72,8 +72,8 @@ class TestInferenceServiceInitialization:
|
||||
gpu_available = inference_service.gpu_available
|
||||
assert isinstance(gpu_available, bool)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_imports_correctly(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -102,8 +102,8 @@ class TestInferenceServiceInitialization:
|
||||
mock_yolo_detector.assert_called_once()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_sets_up_pipeline(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -135,8 +135,8 @@ class TestInferenceServiceInitialization:
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_idempotent(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -161,8 +161,8 @@ class TestInferenceServiceInitialization:
|
||||
class TestInferenceServiceProcessing:
|
||||
"""Test inference processing methods."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_process_image_basic_flow(
|
||||
self,
|
||||
@@ -197,8 +197,8 @@ class TestInferenceServiceProcessing:
|
||||
assert result.confidence == {"InvoiceNumber": 0.95}
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
def test_process_image_handles_errors(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -228,9 +228,9 @@ class TestInferenceServiceProcessing:
|
||||
class TestInferenceServicePDFRendering:
|
||||
"""Test PDF rendering imports."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('src.pdf.renderer.render_pdf_to_images')
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
self,
|
||||
@@ -245,7 +245,7 @@ class TestInferenceServicePDFRendering:
|
||||
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
|
||||
|
||||
This catches the import error we had with:
|
||||
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
|
||||
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
|
||||
@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig
|
||||
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
from inference.data.async_request_db import ApiKeyConfig
|
||||
from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.training import create_training_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockTrainingTask:
|
||||
|
||||
Reference in New Issue
Block a user