restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -9,7 +9,7 @@ import pytest
from datetime import datetime
from uuid import UUID, uuid4
from src.data.admin_models import (
from inference.data.admin_models import (
BatchUpload,
BatchUploadFile,
TrainingDocumentLink,

View File

@@ -12,7 +12,7 @@ import tempfile
from pathlib import Path
from datetime import date
from decimal import Decimal
from src.data.csv_loader import (
from shared.data.csv_loader import (
InvoiceRow,
CSVLoader,
load_invoice_csv,

View File

@@ -11,7 +11,7 @@ Tests field normalization functions:
"""
import pytest
from src.inference.field_extractor import FieldExtractor
from inference.pipeline.field_extractor import FieldExtractor
class TestFieldExtractorInit:

View File

@@ -10,7 +10,7 @@ Tests the cross-validation logic between payment_line and detected fields:
import pytest
from unittest.mock import MagicMock, patch
from src.inference.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
from inference.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
class TestCrossValidationResult:

View File

@@ -7,7 +7,7 @@ Usage:
import pytest
from dataclasses import dataclass
from src.matcher.strategies.exact_matcher import ExactMatcher
from shared.matcher.strategies.exact_matcher import ExactMatcher
@dataclass

View File

@@ -9,13 +9,13 @@ Usage:
import pytest
from dataclasses import dataclass
from src.matcher.field_matcher import FieldMatcher, find_field_matches
from src.matcher.models import Match
from src.matcher.token_index import TokenIndex
from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords
from src.matcher import utils as matcher_utils
from src.matcher.utils import normalize_dashes as _normalize_dashes
from src.matcher.strategies import (
from shared.matcher.field_matcher import FieldMatcher, find_field_matches
from shared.matcher.models import Match
from shared.matcher.token_index import TokenIndex
from shared.matcher.context import CONTEXT_KEYWORDS, find_context_keywords
from shared.matcher import utils as matcher_utils
from shared.matcher.utils import normalize_dashes as _normalize_dashes
from shared.matcher.strategies import (
SubstringMatcher,
FlexibleDateMatcher,
FuzzyMatcher,

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.amount_normalizer import AmountNormalizer
from shared.normalize.normalizers.amount_normalizer import AmountNormalizer
class TestAmountNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.bankgiro_normalizer import BankgiroNormalizer
from shared.normalize.normalizers.bankgiro_normalizer import BankgiroNormalizer
class TestBankgiroNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.customer_number_normalizer import CustomerNumberNormalizer
from shared.normalize.normalizers.customer_number_normalizer import CustomerNumberNormalizer
class TestCustomerNumberNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.date_normalizer import DateNormalizer
from shared.normalize.normalizers.date_normalizer import DateNormalizer
class TestDateNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.invoice_number_normalizer import InvoiceNumberNormalizer
from shared.normalize.normalizers.invoice_number_normalizer import InvoiceNumberNormalizer
class TestInvoiceNumberNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.ocr_normalizer import OCRNormalizer
from shared.normalize.normalizers.ocr_normalizer import OCRNormalizer
class TestOCRNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.organisation_number_normalizer import OrganisationNumberNormalizer
from shared.normalize.normalizers.organisation_number_normalizer import OrganisationNumberNormalizer
class TestOrganisationNumberNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.plusgiro_normalizer import PlusgiroNormalizer
from shared.normalize.normalizers.plusgiro_normalizer import PlusgiroNormalizer
class TestPlusgiroNormalizer:

View File

@@ -6,7 +6,7 @@ Usage:
"""
import pytest
from src.normalize.normalizers.supplier_accounts_normalizer import SupplierAccountsNormalizer
from shared.normalize.normalizers.supplier_accounts_normalizer import SupplierAccountsNormalizer
class TestSupplierAccountsNormalizer:

View File

@@ -8,7 +8,7 @@ Usage:
"""
import pytest
from src.normalize.normalizer import (
from shared.normalize.normalizer import (
FieldNormalizer,
NormalizedValue,
normalize_field,

View File

@@ -9,8 +9,8 @@ Tests the parsing of Swedish invoice payment lines including:
"""
import pytest
from src.ocr.machine_code_parser import MachineCodeParser, MachineCodeResult
from src.pdf.extractor import Token as TextToken
from shared.ocr.machine_code_parser import MachineCodeParser, MachineCodeResult
from shared.pdf.extractor import Token as TextToken
class TestParseStandardPaymentLine:

View File

@@ -13,7 +13,7 @@ Usage:
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from src.pdf.detector import (
from shared.pdf.detector import (
extract_text_first_page,
is_text_pdf,
get_pdf_type,
@@ -54,12 +54,12 @@ class TestIsTextPDF:
def test_empty_pdf_returns_false(self):
"""Should return False for PDF with no text."""
with patch("src.pdf.detector.extract_text_first_page", return_value=""):
with patch("shared.pdf.detector.extract_text_first_page", return_value=""):
assert is_text_pdf("test.pdf") is False
def test_short_text_returns_false(self):
"""Should return False for PDF with very short text."""
with patch("src.pdf.detector.extract_text_first_page", return_value="Hello"):
with patch("shared.pdf.detector.extract_text_first_page", return_value="Hello"):
assert is_text_pdf("test.pdf") is False
def test_readable_text_with_keywords_returns_true(self):
@@ -72,7 +72,7 @@ class TestIsTextPDF:
Moms: 25%
""" + "a" * 200 # Ensure > 200 chars
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
assert is_text_pdf("test.pdf") is True
def test_garbled_text_returns_false(self):
@@ -80,7 +80,7 @@ class TestIsTextPDF:
# Simulate garbled text (lots of non-printable characters)
garbled = "\x00\x01\x02" * 100 + "abc" * 20 # Low readable ratio
with patch("src.pdf.detector.extract_text_first_page", return_value=garbled):
with patch("shared.pdf.detector.extract_text_first_page", return_value=garbled):
assert is_text_pdf("test.pdf") is False
def test_text_without_keywords_needs_high_readability(self):
@@ -88,7 +88,7 @@ class TestIsTextPDF:
# Text without invoice keywords
text = "The quick brown fox jumps over the lazy dog. " * 10
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
# Should pass if readable ratio is high enough
result = is_text_pdf("test.pdf")
# Result depends on character ratio - ASCII text should pass
@@ -98,7 +98,7 @@ class TestIsTextPDF:
"""Should respect custom min_chars parameter."""
text = "Short text here" # 15 chars
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
# Default min_chars=30 - should fail
assert is_text_pdf("test.pdf", min_chars=30) is False
# Custom min_chars=10 - should pass basic length check
@@ -273,7 +273,7 @@ class TestIsTextPDFKeywordDetection:
# Create text with keyword and enough content
text = f"Document with {keyword} keyword here" + " more text" * 50
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
# Need at least 2 keywords for is_text_pdf to return True
# So this tests if keyword is recognized when combined with others
pass
@@ -282,7 +282,7 @@ class TestIsTextPDFKeywordDetection:
"""Should detect English invoice keywords."""
text = "Invoice document with date and amount information" + " x" * 100
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
# invoice + date = 2 keywords
result = is_text_pdf("test.pdf")
assert result is True
@@ -292,7 +292,7 @@ class TestIsTextPDFKeywordDetection:
# Only one keyword
text = "This is a faktura document" + " x" * 200
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
# With only 1 keyword, falls back to other checks
# Should still pass if readability is high
pass
@@ -306,7 +306,7 @@ class TestReadabilityChecks:
# Pure ASCII text
text = "This is a normal document with only ASCII characters. " * 10
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
result = is_text_pdf("test.pdf")
assert result is True
@@ -314,7 +314,7 @@ class TestReadabilityChecks:
"""Should accept Swedish characters as readable."""
text = "Fakturadatum för årets moms på öre belopp" + " normal" * 50
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
result = is_text_pdf("test.pdf")
assert result is True
@@ -326,7 +326,7 @@ class TestReadabilityChecks:
unreadable = "\x80\x81\x82" * 50 # 150 unreadable chars
text = readable + unreadable
with patch("src.pdf.detector.extract_text_first_page", return_value=text):
with patch("shared.pdf.detector.extract_text_first_page", return_value=text):
result = is_text_pdf("test.pdf")
assert result is False

View File

@@ -12,7 +12,7 @@ Usage:
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from src.pdf.extractor import (
from shared.pdf.extractor import (
Token,
PDFDocument,
extract_text_tokens,
@@ -509,7 +509,7 @@ class TestPDFDocumentIsTextPDF:
mock_doc = MagicMock()
with patch("fitz.open", return_value=mock_doc):
with patch("src.pdf.extractor._is_text_pdf_standalone", return_value=True) as mock_check:
with patch("shared.pdf.extractor._is_text_pdf_standalone", return_value=True) as mock_check:
with PDFDocument("test.pdf") as pdf:
result = pdf.is_text_pdf(min_chars=50)

View File

@@ -18,7 +18,7 @@ class TestDatabaseConfig:
def test_config_loads_from_env(self):
"""Test that config loads successfully from .env file."""
# Import config (should load .env automatically)
from src import config
from shared import config
# Verify database config is loaded
assert config.DATABASE is not None
@@ -30,7 +30,7 @@ class TestDatabaseConfig:
def test_database_password_loaded(self):
"""Test that database password is loaded from environment."""
from src import config
from shared import config
# Password should be loaded from .env
assert config.DATABASE['password'] is not None
@@ -38,7 +38,7 @@ class TestDatabaseConfig:
def test_database_connection_string(self):
"""Test database connection string generation."""
from src import config
from shared import config
conn_str = config.get_db_connection_string()
@@ -71,7 +71,7 @@ class TestPathsConfig:
def test_paths_config_exists(self):
"""Test that PATHS configuration exists."""
from src import config
from shared import config
assert config.PATHS is not None
assert 'csv_dir' in config.PATHS
@@ -85,7 +85,7 @@ class TestAutolabelConfig:
def test_autolabel_config_exists(self):
"""Test that AUTOLABEL configuration exists."""
from src import config
from shared import config
assert config.AUTOLABEL is not None
assert 'workers' in config.AUTOLABEL
@@ -95,7 +95,7 @@ class TestAutolabelConfig:
def test_autolabel_ratios_sum_to_one(self):
"""Test that train/val/test ratios sum to 1.0."""
from src import config
from shared import config
total = (
config.AUTOLABEL['train_ratio'] +

View File

@@ -10,7 +10,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.inference.customer_number_parser import (
from inference.pipeline.customer_number_parser import (
CustomerNumberParser,
DashFormatPattern,
NoDashFormatPattern,

View File

@@ -11,7 +11,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.data.db import DocumentDB
from shared.data.db import DocumentDB
class TestSQLInjectionPrevention:

View File

@@ -10,7 +10,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.exceptions import (
from shared.exceptions import (
InvoiceExtractionError,
PDFProcessingError,
OCRError,

48
tests/test_imports.py Normal file
View File

@@ -0,0 +1,48 @@
"""Import validation tests.
Ensures all lazy imports across packages resolve correctly,
catching cross-package import errors that mocks would hide.
"""
import importlib
import pkgutil
import pytest
def _collect_modules(package_name: str) -> list[str]:
"""Recursively collect all module names under a package."""
try:
package = importlib.import_module(package_name)
except Exception:
return [package_name]
modules = [package_name]
if hasattr(package, "__path__"):
for _importer, modname, _ispkg in pkgutil.walk_packages(
package.__path__, prefix=package_name + "."
):
modules.append(modname)
return modules
SHARED_MODULES = _collect_modules("shared")
INFERENCE_MODULES = _collect_modules("inference")
TRAINING_MODULES = _collect_modules("training")
@pytest.mark.parametrize("module_name", SHARED_MODULES)
def test_shared_module_imports(module_name: str) -> None:
"""Every module in the shared package should import without error."""
importlib.import_module(module_name)
@pytest.mark.parametrize("module_name", INFERENCE_MODULES)
def test_inference_module_imports(module_name: str) -> None:
"""Every module in the inference package should import without error."""
importlib.import_module(module_name)
@pytest.mark.parametrize("module_name", TRAINING_MODULES)
def test_training_module_imports(module_name: str) -> None:
"""Every module in the training package should import without error."""
importlib.import_module(module_name)

View File

@@ -10,7 +10,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.inference.payment_line_parser import PaymentLineParser, PaymentLineData
from inference.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData
class TestPaymentLineParser:

View File

@@ -6,9 +6,9 @@ Tests for advanced utility modules:
"""
import pytest
from src.utils.fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
from src.utils.ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants
from src.utils.context_extractor import ContextExtractor, extract_field_with_context
from shared.utils.fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
from shared.utils.ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants
from shared.utils.context_extractor import ContextExtractor, extract_field_with_context
class TestFuzzyMatcher:

View File

@@ -3,9 +3,9 @@ Tests for shared utility modules.
"""
import pytest
from src.utils.text_cleaner import TextCleaner
from src.utils.format_variants import FormatVariants
from src.utils.validators import FieldValidators
from shared.utils.text_cleaner import TextCleaner
from shared.utils.format_variants import FormatVariants
from shared.utils.validators import FieldValidators
class TestTextCleaner:

View File

@@ -10,12 +10,12 @@ from uuid import UUID
import pytest
from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from src.data.models import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService
from src.web.config import AsyncConfig, StorageConfig
from src.web.core.rate_limiter import RateLimiter
from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from inference.data.models import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.core.rate_limiter import RateLimiter
@pytest.fixture

View File

@@ -9,9 +9,9 @@ from uuid import UUID
from fastapi import HTTPException
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from src.web.schemas.admin import (
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import (
AnnotationCreate,
AnnotationUpdate,
AutoLabelRequest,

View File

@@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from src.data.admin_db import AdminDB
from src.data.admin_models import AdminToken
from src.web.core.auth import (
from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminToken
from inference.web.core.auth import (
get_admin_db,
reset_admin_db,
validate_admin_token,
@@ -81,7 +81,7 @@ class TestAdminDB:
def test_is_valid_admin_token_active(self):
"""Test valid active token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -98,7 +98,7 @@ class TestAdminDB:
def test_is_valid_admin_token_inactive(self):
"""Test inactive token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -115,7 +115,7 @@ class TestAdminDB:
def test_is_valid_admin_token_expired(self):
"""Test expired token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -132,7 +132,7 @@ class TestAdminDB:
def test_is_valid_admin_token_not_found(self):
"""Test token not found."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None

View File

@@ -12,8 +12,9 @@ from uuid import UUID
from fastapi import HTTPException
from fastapi.testclient import TestClient
from src.data.admin_models import AdminDocument, AdminToken
from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router
from inference.data.admin_models import AdminDocument, AdminToken
from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router
from inference.web.config import StorageConfig
# Test UUID
@@ -42,13 +43,12 @@ class TestAdminRouter:
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_documents_router(StorageConfig())
# Get route paths (include prefix from router)
paths = [route.path for route in router.routes]
# Paths include the /admin prefix
assert any("/auth/token" in p for p in paths)
# Paths include the /admin/documents prefix
assert any("/documents" in p for p in paths)
assert any("/documents/stats" in p for p in paths)
assert any("{document_id}" in p for p in paths)
@@ -66,7 +66,7 @@ class TestCreateTokenEndpoint:
def test_create_token_success(self, mock_db):
"""Test successful token creation."""
from src.web.schemas.admin import AdminTokenCreate
from inference.web.schemas.admin import AdminTokenCreate
request = AdminTokenCreate(name="Test Token", expires_in_days=30)

View File

@@ -9,8 +9,9 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
@@ -189,7 +190,7 @@ def app():
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_documents_router(StorageConfig())
app.include_router(router)
return app

View File

@@ -0,0 +1,245 @@
"""
Tests to verify admin schemas split maintains backward compatibility.
All existing imports from inference.web.schemas.admin must continue to work.
"""
import pytest
class TestEnumImports:
"""All enums importable from inference.web.schemas.admin."""
def test_document_status(self):
from inference.web.schemas.admin import DocumentStatus
assert DocumentStatus.PENDING == "pending"
def test_auto_label_status(self):
from inference.web.schemas.admin import AutoLabelStatus
assert AutoLabelStatus.RUNNING == "running"
def test_training_status(self):
from inference.web.schemas.admin import TrainingStatus
assert TrainingStatus.PENDING == "pending"
def test_training_type(self):
from inference.web.schemas.admin import TrainingType
assert TrainingType.TRAIN == "train"
def test_annotation_source(self):
from inference.web.schemas.admin import AnnotationSource
assert AnnotationSource.MANUAL == "manual"
class TestAuthImports:
"""Auth schemas importable."""
def test_admin_token_create(self):
from inference.web.schemas.admin import AdminTokenCreate
token = AdminTokenCreate(name="test")
assert token.name == "test"
def test_admin_token_response(self):
from inference.web.schemas.admin import AdminTokenResponse
assert AdminTokenResponse is not None
class TestDocumentImports:
"""Document schemas importable."""
def test_document_upload_response(self):
from inference.web.schemas.admin import DocumentUploadResponse
assert DocumentUploadResponse is not None
def test_document_item(self):
from inference.web.schemas.admin import DocumentItem
assert DocumentItem is not None
def test_document_list_response(self):
from inference.web.schemas.admin import DocumentListResponse
assert DocumentListResponse is not None
def test_document_detail_response(self):
from inference.web.schemas.admin import DocumentDetailResponse
assert DocumentDetailResponse is not None
def test_document_stats_response(self):
from inference.web.schemas.admin import DocumentStatsResponse
assert DocumentStatsResponse is not None
class TestAnnotationImports:
"""Annotation schemas importable."""
def test_bounding_box(self):
from inference.web.schemas.admin import BoundingBox
bb = BoundingBox(x=0, y=0, width=100, height=50)
assert bb.width == 100
def test_annotation_create(self):
from inference.web.schemas.admin import AnnotationCreate
assert AnnotationCreate is not None
def test_annotation_update(self):
from inference.web.schemas.admin import AnnotationUpdate
assert AnnotationUpdate is not None
def test_annotation_item(self):
from inference.web.schemas.admin import AnnotationItem
assert AnnotationItem is not None
def test_annotation_response(self):
from inference.web.schemas.admin import AnnotationResponse
assert AnnotationResponse is not None
def test_annotation_list_response(self):
from inference.web.schemas.admin import AnnotationListResponse
assert AnnotationListResponse is not None
def test_annotation_lock_request(self):
from inference.web.schemas.admin import AnnotationLockRequest
assert AnnotationLockRequest is not None
def test_annotation_lock_response(self):
from inference.web.schemas.admin import AnnotationLockResponse
assert AnnotationLockResponse is not None
def test_auto_label_request(self):
from inference.web.schemas.admin import AutoLabelRequest
assert AutoLabelRequest is not None
def test_auto_label_response(self):
from inference.web.schemas.admin import AutoLabelResponse
assert AutoLabelResponse is not None
def test_annotation_verify_request(self):
from inference.web.schemas.admin import AnnotationVerifyRequest
assert AnnotationVerifyRequest is not None
def test_annotation_verify_response(self):
from inference.web.schemas.admin import AnnotationVerifyResponse
assert AnnotationVerifyResponse is not None
def test_annotation_override_request(self):
from inference.web.schemas.admin import AnnotationOverrideRequest
assert AnnotationOverrideRequest is not None
def test_annotation_override_response(self):
from inference.web.schemas.admin import AnnotationOverrideResponse
assert AnnotationOverrideResponse is not None
class TestTrainingImports:
"""Training schemas importable."""
def test_training_config(self):
from inference.web.schemas.admin import TrainingConfig
config = TrainingConfig()
assert config.epochs == 100
def test_training_task_create(self):
from inference.web.schemas.admin import TrainingTaskCreate
assert TrainingTaskCreate is not None
def test_training_task_item(self):
from inference.web.schemas.admin import TrainingTaskItem
assert TrainingTaskItem is not None
def test_training_task_list_response(self):
from inference.web.schemas.admin import TrainingTaskListResponse
assert TrainingTaskListResponse is not None
def test_training_task_detail_response(self):
from inference.web.schemas.admin import TrainingTaskDetailResponse
assert TrainingTaskDetailResponse is not None
def test_training_task_response(self):
from inference.web.schemas.admin import TrainingTaskResponse
assert TrainingTaskResponse is not None
def test_training_log_item(self):
from inference.web.schemas.admin import TrainingLogItem
assert TrainingLogItem is not None
def test_training_logs_response(self):
from inference.web.schemas.admin import TrainingLogsResponse
assert TrainingLogsResponse is not None
def test_export_request(self):
from inference.web.schemas.admin import ExportRequest
assert ExportRequest is not None
def test_export_response(self):
from inference.web.schemas.admin import ExportResponse
assert ExportResponse is not None
def test_training_document_item(self):
from inference.web.schemas.admin import TrainingDocumentItem
assert TrainingDocumentItem is not None
def test_training_documents_response(self):
from inference.web.schemas.admin import TrainingDocumentsResponse
assert TrainingDocumentsResponse is not None
def test_model_metrics(self):
from inference.web.schemas.admin import ModelMetrics
assert ModelMetrics is not None
def test_training_model_item(self):
from inference.web.schemas.admin import TrainingModelItem
assert TrainingModelItem is not None
def test_training_models_response(self):
from inference.web.schemas.admin import TrainingModelsResponse
assert TrainingModelsResponse is not None
def test_training_history_item(self):
from inference.web.schemas.admin import TrainingHistoryItem
assert TrainingHistoryItem is not None
class TestDatasetImports:
"""Dataset schemas importable."""
def test_dataset_create_request(self):
from inference.web.schemas.admin import DatasetCreateRequest
assert DatasetCreateRequest is not None
def test_dataset_document_item(self):
from inference.web.schemas.admin import DatasetDocumentItem
assert DatasetDocumentItem is not None
def test_dataset_response(self):
from inference.web.schemas.admin import DatasetResponse
assert DatasetResponse is not None
def test_dataset_detail_response(self):
from inference.web.schemas.admin import DatasetDetailResponse
assert DatasetDetailResponse is not None
def test_dataset_list_item(self):
from inference.web.schemas.admin import DatasetListItem
assert DatasetListItem is not None
def test_dataset_list_response(self):
from inference.web.schemas.admin import DatasetListResponse
assert DatasetListResponse is not None
def test_dataset_train_request(self):
from inference.web.schemas.admin import DatasetTrainRequest
assert DatasetTrainRequest is not None
class TestForwardReferences:
"""Forward references resolve correctly."""
def test_document_detail_has_annotation_items(self):
from inference.web.schemas.admin import DocumentDetailResponse
fields = DocumentDetailResponse.model_fields
assert "annotations" in fields
assert "training_history" in fields
def test_dataset_train_request_has_config(self):
from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig
req = DatasetTrainRequest(name="test", config=TrainingConfig())
assert req.config.epochs == 100

View File

@@ -7,15 +7,15 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import UUID
from src.data.admin_models import TrainingTask, TrainingLog
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
from src.web.core.scheduler import (
from inference.data.admin_models import TrainingTask, TrainingLog
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
from inference.web.core.scheduler import (
TrainingScheduler,
get_training_scheduler,
start_scheduler,
stop_scheduler,
)
from src.web.schemas.admin import (
from inference.web.schemas.admin import (
TrainingConfig,
TrainingStatus,
TrainingTaskCreate,

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
@@ -110,7 +110,7 @@ def app():
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
router = create_locks_router()
app.include_router(router)
return app

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.annotations import create_annotation_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:

View File

@@ -11,7 +11,7 @@ from unittest.mock import MagicMock
import pytest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
class TestAsyncTask:

View File

@@ -11,12 +11,12 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from src.web.api.v1.async_api.routes import create_async_router, set_async_service
from src.web.services.async_processing import AsyncSubmitResult
from src.web.dependencies import init_dependencies
from src.web.rate_limiter import RateLimiter, RateLimitStatus
from src.web.schemas.inference import AsyncStatus
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.services.async_processing import AsyncSubmitResult
from inference.web.dependencies import init_dependencies
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
from inference.web.schemas.inference import AsyncStatus
# Valid UUID for testing
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"

View File

@@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch
import pytest
from src.data.async_request_db import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from src.web.config import AsyncConfig, StorageConfig
from src.web.rate_limiter import RateLimiter
from inference.data.async_request_db import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.rate_limiter import RateLimiter
@pytest.fixture

View File

@@ -8,8 +8,8 @@ from pathlib import Path
from unittest.mock import Mock, MagicMock
from uuid import uuid4
from src.web.services.autolabel import AutoLabelService
from src.data.admin_db import AdminDB
from inference.web.services.autolabel import AutoLabelService
from inference.data.admin_db import AdminDB
class MockDocument:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from src.web.workers.batch_queue import BatchTask, BatchTaskQueue
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
class MockBatchService:

View File

@@ -11,10 +11,10 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.batch.routes import router
from src.web.core.auth import validate_admin_token, get_admin_db
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from src.web.services.batch_upload import BatchUploadService
from inference.web.api.v1.batch.routes import router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
class MockAdminDB:

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
import pytest
from src.data.admin_db import AdminDB
from src.web.services.batch_upload import BatchUploadService
from inference.data.admin_db import AdminDB
from inference.web.services.batch_upload import BatchUploadService
@pytest.fixture

View File

@@ -0,0 +1,331 @@
"""
Tests for DatasetBuilder service.
TDD: Write tests first, then implement dataset_builder.py.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
TrainingDataset,
FIELD_CLASSES,
)
@pytest.fixture
def tmp_admin_images(tmp_path):
"""Create mock admin images directory with sample images."""
doc_ids = [uuid4() for _ in range(5)]
for doc_id in doc_ids:
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
# Create 2 pages per doc
for page in range(1, 3):
img_path = doc_dir / f"page_{page}.png"
img_path.write_bytes(b"fake-png-data")
return tmp_path, doc_ids
@pytest.fixture
def mock_admin_db():
"""Mock AdminDB with dataset and document methods."""
db = MagicMock()
db.create_dataset.return_value = TrainingDataset(
dataset_id=uuid4(),
name="test-dataset",
status="building",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
return db
@pytest.fixture
def sample_documents(tmp_admin_images):
"""Create sample AdminDocument objects."""
tmp_path, doc_ids = tmp_admin_images
docs = []
for doc_id in doc_ids:
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 2
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
docs.append(doc)
return docs
@pytest.fixture
def sample_annotations(sample_documents):
"""Create sample annotations for each document page."""
annotations = {}
for doc in sample_documents:
doc_anns = []
for page in range(1, 3):
ann = MagicMock(spec=AdminAnnotation)
ann.document_id = doc.document_id
ann.page_number = page
ann.class_id = 0
ann.class_name = "invoice_number"
ann.x_center = 0.5
ann.y_center = 0.3
ann.width = 0.2
ann.height = 0.05
doc_anns.append(ann)
annotations[str(doc.document_id)] = doc_anns
return annotations
class TestDatasetBuilder:
"""Tests for DatasetBuilder."""
def test_build_creates_directory_structure(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
from inference.web.services.dataset_builder import DatasetBuilder
dataset_dir = tmp_path / "datasets" / "test"
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# Mock DB calls
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
for split in ["train", "val", "test"]:
assert (result_dir / "images" / split).exists()
assert (result_dir / "labels" / split).exists()
def test_build_copies_images(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Images should be copied from admin_images to dataset folder."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
result = builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Check total images copied
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
total_images = sum(
len(list((result_dir / "images" / split).glob("*.png")))
for split in ["train", "val", "test"]
)
assert total_images == 10 # 5 docs * 2 pages
def test_build_generates_yolo_labels(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""YOLO label files should be generated with correct format."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
total_labels = sum(
len(list((result_dir / "labels" / split).glob("*.txt")))
for split in ["train", "val", "test"]
)
assert total_labels == 10 # 5 docs * 2 pages
# Check label format: "class_id x_center y_center width height"
label_files = list((result_dir / "labels").rglob("*.txt"))
content = label_files[0].read_text().strip()
parts = content.split()
assert len(parts) == 5
assert int(parts[0]) == 0 # class_id
assert 0 <= float(parts[1]) <= 1 # x_center
assert 0 <= float(parts[2]) <= 1 # y_center
def test_build_generates_data_yaml(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""data.yaml should be generated with correct field classes."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml"
assert yaml_path.exists()
content = yaml_path.read_text()
assert "train:" in content
assert "val:" in content
assert "nc:" in content
assert "invoice_number" in content
def test_build_splits_documents_correctly(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Documents should be split into train/val/test according to ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Verify add_dataset_documents was called with correct splits
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]
assert "train" in splits
# With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test
train_count = splits.count("train")
assert train_count >= 3 # At least 3 of 5 should be train
def test_build_updates_status_to_ready(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""After successful build, dataset status should be updated to 'ready'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
assert call_kwargs["status"] == "ready"
assert call_kwargs["total_documents"] == 5
assert call_kwargs["total_images"] == 10
def test_build_sets_failed_on_error(
self, tmp_path, mock_admin_db
):
"""If build fails, dataset status should be set to 'failed'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
dataset = mock_admin_db.create_dataset.return_value
with pytest.raises(ValueError):
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=["nonexistent-id"],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
assert call_kwargs["status"] == "failed"
def test_build_with_seed_produces_deterministic_splits(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
):
"""Same seed should produce same splits."""
from inference.web.services.dataset_builder import DatasetBuilder
results = []
for _ in range(2):
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
sample_annotations.get(str(doc_id), [])
)
mock_admin_db.add_dataset_documents.reset_mock()
mock_admin_db.update_dataset_status.reset_mock()
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
results.append([(d["document_id"], d["split"]) for d in docs])
assert results[0] == results[1]

View File

@@ -0,0 +1,200 @@
"""
Tests for Dataset API routes in training.py.
"""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetTrainRequest,
TrainingConfig,
TrainingStatus,
)
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011"
TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
dataset_id=UUID(TEST_DATASET_UUID),
name="test-dataset",
description="Test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=2,
total_images=4,
total_annotations=10,
dataset_path="/data/datasets/test-dataset",
error_message=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
)
defaults.update(overrides)
ds = MagicMock(spec=TrainingDataset)
for k, v in defaults.items():
setattr(ds, k, v)
return ds
def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock:
doc = MagicMock(spec=DatasetDocument)
doc.document_id = UUID(doc_id)
doc.split = split
doc.page_count = 2
doc.annotation_count = 5
return doc
def _find_endpoint(name: str):
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
def test_router_has_dataset_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 2,
"total_images": 4,
"total_annotations": 10,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 1
assert len(result.datasets) == 1
assert result.datasets[0].name == "test-dataset"
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
def test_train_from_building_dataset_fails(self):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400

View File

@@ -11,8 +11,8 @@ from fastapi.testclient import TestClient
from PIL import Image
import io
from src.web.app import create_app
from src.web.config import ModelConfig, StorageConfig, AppConfig
from inference.web.app import create_app
from inference.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
@@ -87,8 +87,8 @@ class TestHealthEndpoint:
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
@@ -150,8 +150,8 @@ class TestInferEndpoint:
assert response.status_code == 422 # Unprocessable Entity
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
@@ -210,8 +210,8 @@ class TestInferEndpoint:
# This test documents that it should be added
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
@@ -280,16 +280,16 @@ class TestInferenceServiceImports:
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from src.web.inference (non-existent module)
- from inference.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from src.web.services.inference import InferenceService
from inference.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from src.inference.pipeline import InferencePipeline
from src.inference.yolo_detector import YOLODetector
from src.pdf.renderer import render_pdf_to_images
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
from shared.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly
assert InferencePipeline is not None

View File

@@ -10,8 +10,8 @@ from unittest.mock import Mock, patch
from PIL import Image
import io
from src.web.services.inference import InferenceService
from src.web.config import ModelConfig, StorageConfig
from inference.web.services.inference import InferenceService
from inference.web.config import ModelConfig, StorageConfig
@pytest.fixture
@@ -72,8 +72,8 @@ class TestInferenceServiceInitialization:
gpu_available = inference_service.gpu_available
assert isinstance(gpu_available, bool)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_imports_correctly(
self,
mock_yolo_detector,
@@ -102,8 +102,8 @@ class TestInferenceServiceInitialization:
mock_yolo_detector.assert_called_once()
mock_pipeline.assert_called_once()
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_sets_up_pipeline(
self,
mock_yolo_detector,
@@ -135,8 +135,8 @@ class TestInferenceServiceInitialization:
enable_fallback=True,
)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_initialize_idempotent(
self,
mock_yolo_detector,
@@ -161,8 +161,8 @@ class TestInferenceServiceInitialization:
class TestInferenceServiceProcessing:
"""Test inference processing methods."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('ultralytics.YOLO')
def test_process_image_basic_flow(
self,
@@ -197,8 +197,8 @@ class TestInferenceServiceProcessing:
assert result.confidence == {"InvoiceNumber": 0.95}
assert result.processing_time_ms > 0
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_process_image_handles_errors(
self,
mock_yolo_detector,
@@ -228,9 +228,9 @@ class TestInferenceServiceProcessing:
class TestInferenceServicePDFRendering:
"""Test PDF rendering imports."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('src.pdf.renderer.render_pdf_to_images')
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('shared.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly(
self,
@@ -245,7 +245,7 @@ class TestInferenceServicePDFRendering:
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
This catches the import error we had with:
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
"""
# Setup mocks
mock_detector_instance = Mock()

View File

@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
import pytest
from src.data.async_request_db import ApiKeyConfig
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
from inference.data.async_request_db import ApiKeyConfig
from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
class TestRateLimiter:

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.training import create_training_router
from src.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockTrainingTask: