This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -9,7 +9,8 @@ from uuid import UUID
from fastapi import HTTPException
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from inference.data.admin_models import AdminAnnotation, AdminDocument
from shared.fields import FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import (
AnnotationCreate,

View File

@@ -31,6 +31,7 @@ class MockAdminDocument:
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.category = kwargs.get('category', 'invoice')
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
@@ -67,12 +68,13 @@ class MockAdminDB:
def get_documents_by_token(
self,
admin_token,
admin_token=None,
status=None,
upload_source=None,
has_annotations=None,
auto_label_status=None,
batch_id=None,
category=None,
limit=20,
offset=0
):
@@ -95,6 +97,8 @@ class MockAdminDB:
docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
if category:
docs = [d for d in docs if d.category == category]
total = len(docs)
return docs[offset:offset+limit], total

View File

@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
def test_cleanup_orphan_files(self, async_service, mock_db):
"""Test cleanup of orphan files."""
# Create an orphan file
# Create the async upload directory
temp_dir = async_service._async_config.temp_upload_dir
temp_dir.mkdir(parents=True, exist_ok=True)
orphan_file = temp_dir / "orphan-request.pdf"
orphan_file.write_bytes(b"orphan content")
@@ -228,7 +230,13 @@ class TestAsyncProcessingService:
# Mock database to say file doesn't exist
mock_db.get_request.return_value = None
count = async_service._cleanup_orphan_files()
# Mock the storage helper to return the same directory as the fixture
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
mock_helper = MagicMock()
mock_helper.get_uploads_base_path.return_value = temp_dir
mock_storage.return_value = mock_helper
count = async_service._cleanup_orphan_files()
assert count == 1
assert not orphan_file.exists()

View File

@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
"""
import pytest
from unittest.mock import MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
TEST_ADMIN_TOKEN = "test-admin-token-12345"
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
@pytest.fixture
def admin_token() -> str:
"""Provide admin token for testing."""
return TEST_ADMIN_TOKEN
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
# Default return values
mock.get_document_by_token.return_value = None
mock.get_dataset.return_value = None
mock.get_augmented_datasets.return_value = ([], 0)
return mock
@pytest.fixture
def admin_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client with admin authentication."""
app = FastAPI()
# Override dependencies
def get_token_override():
return TEST_ADMIN_TOKEN
def get_db_override():
return mock_admin_db
app.dependency_overrides[validate_admin_token] = get_token_override
app.dependency_overrides[get_admin_db] = get_db_override
# Include router - the router already has /augmentation prefix
# so we add /api/v1/admin to get /api/v1/admin/augmentation
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
@pytest.fixture
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client WITHOUT admin authentication override."""
app = FastAPI()
# Only override the database, NOT the token validation
def get_db_override():
return mock_admin_db
app.dependency_overrides[get_admin_db] = get_db_override
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
class TestAugmentationTypesEndpoint:
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
assert "stage" in aug_type
def test_list_augmentation_types_unauthorized(
self, admin_client: TestClient
self, unauthenticated_client: TestClient
) -> None:
"""Test that unauthorized request is rejected."""
response = admin_client.get("/api/v1/admin/augmentation/types")
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401
@@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test previewing augmentation on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
assert response.status_code == 200
data = response.json()
@@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test previewing full config on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
assert response.status_code == 200
data = response.json()
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
mock_admin_db: MagicMock,
) -> None:
"""Test creating augmented dataset."""
# Mock dataset exists
mock_dataset = MagicMock()
mock_dataset.total_images = 100
mock_admin_db.get_dataset.return_value = mock_dataset
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
@pytest.fixture
def sample_document_id() -> str:
"""Provide a sample document ID for testing."""
# This would need to be created in test setup
return "test-document-id"
return TEST_DOCUMENT_UUID
@pytest.fixture
def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing."""
# This would need to be created in test setup
return "test-dataset-id"
return TEST_DATASET_UUID

View File

@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
name="test-dataset",
description="Test dataset",
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
# Mock the active training tasks lookup to return empty dict
mock_db.get_active_training_tasks_for_datasets.return_value = {}
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))

View File

@@ -0,0 +1,363 @@
"""
Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# =============================================================================
# Test Database Model
# =============================================================================
class TestTrainingDatasetModel:
"""Tests for TrainingDataset model fields."""
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
training_status="running",
)
assert dataset.training_status == "running"
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
name="test-dataset",
active_training_task_id=task_id,
)
assert dataset.active_training_task_id == task_id
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
assert dataset.active_training_task_id is None
# =============================================================================
# Test AdminDB Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
assert dataset.training_status == "running"
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
)
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
)
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
training_status="running",
active_training_task_id=task_id,
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
)
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
# Should not raise
db.update_dataset_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
# =============================================================================
# Test API Response
# =============================================================================
class TestDatasetDetailResponseTrainingStatus:
"""Tests for DatasetDetailResponse including training status fields."""
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status="running",
active_training_task_id=str(uuid4()),
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path="/path/to/dataset",
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status == "running"
assert response.active_training_task_id is not None
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path=None,
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status is None
assert response.active_training_task_id is None
# =============================================================================
# Test Scheduler Training Status Updates
# =============================================================================
class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
task_id = str(uuid4())
dataset_id = str(uuid4())
# Execute task (will fail but we check the status update call)
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
dataset_id=dataset_id,
)
except Exception:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id
# =============================================================================
# Test Dataset Status Values
# =============================================================================
class TestDatasetStatusValues:
"""Tests for valid dataset status values."""
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:
dataset = TrainingDataset(name="test", training_status=status)
assert dataset.training_status == status

View File

@@ -0,0 +1,207 @@
"""
Tests for Document Category Feature.
TDD tests for adding category field to admin_documents table.
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock
from uuid import UUID, uuid4
from inference.data.admin_models import AdminDocument
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestAdminDocumentCategoryField:
"""Tests for AdminDocument category field."""
def test_document_has_category_field(self):
"""Test AdminDocument model has category field."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert hasattr(doc, "category")
def test_document_category_defaults_to_invoice(self):
"""Test category defaults to 'invoice' when not specified."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert doc.category == "invoice"
def test_document_accepts_custom_category(self):
"""Test document accepts custom category values."""
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
for cat in categories:
doc = AdminDocument(
document_id=uuid4(),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category=cat,
)
assert doc.category == cat
def test_document_category_is_string_type(self):
"""Test category field is a string type."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category="letter",
)
assert isinstance(doc.category, str)
class TestDocumentCategoryInReadModel:
"""Tests for category in response models."""
def test_admin_document_read_has_category(self):
"""Test AdminDocumentRead includes category field."""
from inference.data.admin_models import AdminDocumentRead
# Check the model has category field in its schema
assert "category" in AdminDocumentRead.model_fields
class TestDocumentCategoryAPI:
"""Tests for document category in API endpoints."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
return db
def test_upload_document_with_category(self, mock_admin_db):
"""Test uploading document with category parameter."""
from inference.web.schemas.admin import DocumentUploadResponse
# Verify response schema supports category
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
category="letter",
)
assert response.category == "letter"
def test_list_documents_returns_category(self, mock_admin_db):
"""Test list documents endpoint returns category."""
from inference.web.schemas.admin import DocumentItem
item = DocumentItem(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
annotation_count=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
category="invoice",
)
assert item.category == "invoice"
def test_document_detail_includes_category(self, mock_admin_db):
"""Test document detail response includes category."""
from inference.web.schemas.admin import DocumentDetailResponse
# Check schema has category
assert "category" in DocumentDetailResponse.model_fields
class TestDocumentCategoryFiltering:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB with category filtering support."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
db.get_documents_by_category.return_value = [invoice_doc]
return db
def test_filter_documents_by_category(self, mock_admin_db):
"""Test filtering documents by category."""
# This tests the DB method signature
result = mock_admin_db.get_documents_by_category("invoice")
assert len(result) == 1
assert result[0].category == "invoice"
class TestDocumentCategoryUpdate:
"""Tests for updating document category."""
def test_update_document_category_schema(self):
"""Test update document request supports category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="letter")
assert request.category == "letter"
def test_update_document_category_optional(self):
"""Test category is optional in update request."""
from inference.web.schemas.admin import DocumentUpdateRequest
# Should not raise - category is optional
request = DocumentUpdateRequest()
assert request.category is None
class TestDatasetWithCategory:
"""Tests for dataset creation with category filtering."""
def test_dataset_create_with_category_filter(self):
"""Test creating dataset can filter by document category."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Invoice Training Set",
document_ids=[TEST_DOC_UUID],
category="invoice", # Optional filter
)
assert request.category == "invoice"
def test_dataset_create_category_is_optional(self):
"""Test category filter is optional when creating dataset."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Mixed Training Set",
document_ids=[TEST_DOC_UUID],
)
# category should be optional
assert not hasattr(request, "category") or request.category is None

View File

@@ -0,0 +1,165 @@
"""
Tests for Document Category API Endpoints.
TDD tests for category filtering and management in document endpoints.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestGetCategoriesEndpoint:
"""Tests for GET /admin/documents/categories endpoint."""
def test_categories_endpoint_returns_list(self):
"""Test categories endpoint returns list of available categories."""
from inference.web.schemas.admin import DocumentCategoriesResponse
# Test schema exists and works
response = DocumentCategoriesResponse(
categories=["invoice", "letter", "receipt"],
total=3,
)
assert response.categories == ["invoice", "letter", "receipt"]
assert response.total == 3
def test_categories_response_schema(self):
"""Test DocumentCategoriesResponse schema structure."""
from inference.web.schemas.admin import DocumentCategoriesResponse
assert "categories" in DocumentCategoriesResponse.model_fields
assert "total" in DocumentCategoriesResponse.model_fields
class TestDocumentListFilterByCategory:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
invoice_doc.filename = "invoice1.pdf"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
letter_doc.filename = "letter1.pdf"
db.get_documents.return_value = ([invoice_doc], 1)
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
return db
def test_list_documents_accepts_category_filter(self, mock_admin_db):
"""Test list documents endpoint accepts category query parameter."""
# The endpoint should accept ?category=invoice parameter
# This test verifies the schema/query parameter exists
from inference.web.schemas.admin import DocumentListResponse
# Schema should work with category filter applied
assert DocumentListResponse is not None
def test_get_document_categories_from_db(self, mock_admin_db):
"""Test fetching unique categories from database."""
categories = mock_admin_db.get_document_categories()
assert "invoice" in categories
assert "letter" in categories
assert len(categories) == 3
class TestDocumentUploadWithCategory:
"""Tests for uploading documents with category."""
def test_upload_request_accepts_category(self):
"""Test upload request can include category field."""
# When uploading via form data, category should be accepted
# This is typically a form field, not a schema
pass
def test_upload_response_includes_category(self):
"""Test upload response includes the category that was set."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
category="letter", # Custom category
message="Upload successful",
)
assert response.category == "letter"
def test_upload_defaults_to_invoice_category(self):
"""Test upload defaults to 'invoice' if no category specified."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
# No category specified - should default to "invoice"
)
assert response.category == "invoice"
class TestAdminDBCategoryMethods:
"""Tests for AdminDB category-related methods."""
def test_get_document_categories_method_exists(self):
"""Test AdminDB has get_document_categories method."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "get_document_categories")
def test_get_documents_accepts_category_filter(self):
"""Test get_documents_by_token method accepts category parameter."""
from inference.data.admin_db import AdminDB
import inspect
db = AdminDB()
# Check the method exists and accepts category parameter
method = getattr(db, "get_documents_by_token", None)
assert callable(method)
# Check category is in the method signature
sig = inspect.signature(method)
assert "category" in sig.parameters
class TestUpdateDocumentCategory:
"""Tests for updating document category."""
def test_update_document_category_method_exists(self):
"""Test AdminDB has method to update document category."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "update_document_category")
def test_update_request_schema(self):
"""Test DocumentUpdateRequest can update category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="receipt")
assert request.category == "receipt"

View File

@@ -32,10 +32,10 @@ def test_app(tmp_path):
use_gpu=False,
dpi=150,
),
storage=StorageConfig(
file=StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
max_file_size_mb=50,
),
)
@@ -252,20 +252,25 @@ class TestResultsEndpoint:
response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
"""Test that existing result file is returned."""
# Get storage config from app
storage_config = test_app.extra.get("storage_config")
if not storage_config:
pytest.skip("Storage config not available in test app")
# Create a test result file
result_file = storage_config.result_dir / "test_result.png"
# Create a test result file in temp directory
result_dir = tmp_path / "results"
result_dir.mkdir(exist_ok=True)
result_file = result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red')
img.save(result_file)
# Request the file
response = client.get("/api/v1/results/test_result.png")
# Mock the storage helper to return our test file path
with patch(
"inference.web.api.v1.public.inference.get_storage_helper"
) as mock_storage:
mock_helper = Mock()
mock_helper.get_result_local_path.return_value = result_file
mock_storage.return_value = mock_helper
# Request the file
response = client.get("/api/v1/results/test_result.png")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"

View File

@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404

View File

@@ -0,0 +1,828 @@
"""Tests for storage helpers module."""
import pytest
from unittest.mock import MagicMock, patch
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
from shared.storage import PREFIXES
@pytest.fixture
def mock_storage() -> MagicMock:
"""Create a mock storage backend."""
storage = MagicMock()
storage.upload_bytes = MagicMock()
storage.download_bytes = MagicMock(return_value=b"test content")
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
storage.exists = MagicMock(return_value=True)
storage.delete = MagicMock(return_value=True)
storage.list_files = MagicMock(return_value=[])
return storage
@pytest.fixture
def helper(mock_storage: MagicMock) -> StorageHelper:
"""Create a storage helper with mock backend."""
return StorageHelper(storage=mock_storage)
class TestStorageHelperInit:
"""Tests for StorageHelper initialization."""
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
"""Should use provided storage backend."""
helper = StorageHelper(storage=mock_storage)
assert helper.storage is mock_storage
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Storage property should return the backend."""
assert helper.storage is mock_storage
class TestDocumentOperations:
"""Tests for document storage operations."""
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should upload document with correct path."""
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
assert doc_id == "doc123"
assert path == "documents/doc123.pdf"
mock_storage.upload_bytes.assert_called_once_with(
b"pdf content", "documents/doc123.pdf", overwrite=True
)
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
"""Should generate document ID if not provided."""
doc_id, path = helper.upload_document(b"content", "file.pdf")
assert doc_id is not None
assert len(doc_id) > 0
assert path.startswith("documents/")
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should download document from correct path."""
content = helper.download_document("doc123")
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for document."""
url = helper.get_document_url("doc123", expires_in_seconds=7200)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"documents/doc123.pdf", 7200
)
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check document existence."""
exists = helper.document_exists("doc123")
assert exists is True
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete document."""
result = helper.delete_document("doc123")
assert result is True
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
class TestImageOperations:
"""Tests for image storage operations."""
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save page image with correct path."""
path = helper.save_page_image("doc123", 1, b"image data")
assert path == "images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "images/doc123/page_1.png", overwrite=True
)
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get page image from correct path."""
content = helper.get_page_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for page image."""
url = helper.get_page_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"images/doc123/page_3.png", 3600
)
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all images for a document."""
mock_storage.list_files.return_value = [
"images/doc123/page_1.png",
"images/doc123/page_2.png",
]
deleted = helper.delete_document_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("images/doc123/")
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all images for a document."""
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
images = helper.list_document_images("doc123")
assert images == ["images/doc123/page_1.png"]
class TestUploadOperations:
"""Tests for upload staging operations."""
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload to correct path."""
path = helper.save_upload(b"content", "file.pdf")
assert path == "uploads/file.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload with subfolder."""
path = helper.save_upload(b"content", "file.pdf", "async")
assert path == "uploads/async/file.pdf"
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get upload from correct path."""
content = helper.get_upload("file.pdf", "async")
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete upload."""
result = helper.delete_upload("file.pdf")
assert result is True
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
class TestResultOperations:
"""Tests for result file operations."""
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save result to correct path."""
path = helper.save_result(b"result data", "output.json")
assert path == "results/output.json"
mock_storage.upload_bytes.assert_called_once()
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get result from correct path."""
content = helper.get_result("output.json")
mock_storage.download_bytes.assert_called_once_with("results/output.json")
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for result."""
url = helper.get_result_url("output.json")
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check result existence."""
exists = helper.result_exists("output.json")
assert exists is True
mock_storage.exists.assert_called_once_with("results/output.json")
class TestExportOperations:
"""Tests for export file operations."""
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save export to correct path."""
path = helper.save_export(b"export data", "exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
mock_storage.upload_bytes.assert_called_once()
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for export."""
url = helper.get_export_url("exp123", "dataset.zip")
mock_storage.get_presigned_url.assert_called_once_with(
"exports/exp123/dataset.zip", 3600
)
class TestRawPdfOperations:
"""Tests for raw PDF operations (legacy compatibility)."""
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save raw PDF to correct path."""
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get raw PDF from correct path."""
content = helper.get_raw_pdf("invoice.pdf")
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check raw PDF existence."""
exists = helper.raw_pdf_exists("invoice.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
class TestAdminImageOperations:
"""Tests for admin image storage operations."""
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save admin image with correct path."""
path = helper.save_admin_image("doc123", 1, b"image data")
assert path == "admin_images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "admin_images/doc123/page_1.png", overwrite=True
)
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get admin image from correct path."""
content = helper.get_admin_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for admin image."""
url = helper.get_admin_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"admin_images/doc123/page_3.png", 3600
)
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check admin image existence."""
exists = helper.admin_image_exists("doc123", 1)
assert exists is True
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
"""Should return correct admin image path."""
path = helper.get_admin_image_path("doc123", 2)
assert path == "admin_images/doc123/page_2.png"
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
images = helper.list_admin_images("doc123")
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
deleted = helper.delete_admin_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
class TestGetLocalPath:
"""Tests for get_local_path method."""
def test_get_admin_image_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "page_1.png").write_bytes(b"test image")
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is not None
assert local_path.exists()
assert local_path.name == "page_1.png"
def test_get_admin_image_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
# Mock storage without get_local_path method (simulating cloud storage)
mock_storage.get_local_path = MagicMock(return_value=None)
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is None
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_admin_image_local_path("nonexistent", 1)
assert local_path is None
class TestGetAdminImageDimensions:
"""Tests for get_admin_image_dimensions method."""
def test_get_dimensions_with_local_storage(self) -> None:
"""Should return image dimensions when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
from PIL import Image
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image with known dimensions
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
img = Image.new("RGB", (800, 600), color="white")
img.save(test_path / "page_1.png")
dimensions = helper.get_admin_image_dimensions("doc123", 1)
assert dimensions == (800, 600)
def test_get_dimensions_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
assert dimensions is None
class TestGetStorageHelper:
"""Tests for get_storage_helper function."""
def test_returns_helper_instance(self) -> None:
"""Should return a StorageHelper instance."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
# Reset the global helper
import inference.web.services.storage_helpers as module
module._default_helper = None
helper = get_storage_helper()
assert isinstance(helper, StorageHelper)
def test_returns_same_instance(self) -> None:
"""Should return the same instance on subsequent calls."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
import inference.web.services.storage_helpers as module
module._default_helper = None
helper1 = get_storage_helper()
helper2 = get_storage_helper()
assert helper1 is helper2
class TestDeleteResult:
"""Tests for delete_result method."""
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete result file."""
result = helper.delete_result("output.json")
assert result is True
mock_storage.delete.assert_called_once_with("results/output.json")
class TestResultLocalPath:
"""Tests for get_result_local_path method."""
def test_get_result_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test result file
test_path = Path(temp_dir) / "results"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "output.json").write_bytes(b"test result")
local_path = helper.get_result_local_path("output.json")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "output.json"
def test_get_result_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_result_local_path("output.json")
assert local_path is None
def test_get_result_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_result_local_path("nonexistent.json")
assert local_path is None
class TestResultsBasePath:
"""Tests for get_results_base_path method."""
def test_get_results_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_results_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "results"
def test_get_results_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_results_base_path()
assert base_path is None
class TestUploadLocalPath:
"""Tests for get_upload_local_path method."""
def test_get_upload_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file
test_path = Path(temp_dir) / "uploads"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "file.pdf"
def test_get_upload_local_path_with_subfolder(self) -> None:
"""Should return local path with subfolder."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file with subfolder
test_path = Path(temp_dir) / "uploads" / "async"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf", "async")
assert local_path is not None
assert local_path.exists()
def test_get_upload_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is None
class TestUploadsBasePath:
"""Tests for get_uploads_base_path method."""
def test_get_uploads_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "uploads"
def test_get_uploads_base_path_with_subfolder(self) -> None:
"""Should return base path with subfolder."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path("async")
assert base_path is not None
assert base_path.exists()
assert base_path.name == "async"
def test_get_uploads_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_uploads_base_path()
assert base_path is None
class TestUploadExists:
"""Tests for upload_exists method."""
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check upload existence."""
exists = helper.upload_exists("file.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
def test_upload_exists_with_subfolder(
self, helper: StorageHelper, mock_storage: MagicMock
) -> None:
"""Should check upload existence with subfolder."""
helper.upload_exists("file.pdf", "async")
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
class TestDatasetsBasePath:
"""Tests for get_datasets_base_path method."""
def test_get_datasets_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_datasets_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "datasets"
def test_get_datasets_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_datasets_base_path()
assert base_path is None
class TestAdminImagesBasePath:
"""Tests for get_admin_images_base_path method."""
def test_get_admin_images_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_admin_images_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "admin_images"
def test_get_admin_images_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_admin_images_base_path()
assert base_path is None
class TestRawPdfsBasePath:
"""Tests for get_raw_pdfs_base_path method."""
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "raw_pdfs"
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is None
class TestRawPdfLocalPath:
"""Tests for get_raw_pdf_local_path method."""
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test raw PDF
test_path = Path(temp_dir) / "raw_pdfs"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "invoice.pdf").write_bytes(b"test pdf")
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "invoice.pdf"
def test_get_raw_pdf_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is None
class TestRawPdfPath:
"""Tests for get_raw_pdf_path method."""
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
"""Should return correct storage path."""
path = helper.get_raw_pdf_path("invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
class TestAutolabelOutputPath:
"""Tests for get_autolabel_output_path method."""
def test_get_autolabel_output_path_with_local_storage(self) -> None:
"""Should return output path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
output_path = helper.get_autolabel_output_path()
assert output_path is not None
assert output_path.exists()
assert output_path.name == "autolabel_output"
def test_get_autolabel_output_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
output_path = helper.get_autolabel_output_path()
assert output_path is None
class TestTrainingDataPath:
"""Tests for get_training_data_path method."""
def test_get_training_data_path_with_local_storage(self) -> None:
"""Should return training path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
training_path = helper.get_training_data_path()
assert training_path is not None
assert training_path.exists()
assert training_path.name == "training"
def test_get_training_data_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
training_path = helper.get_training_data_path()
assert training_path is None
class TestExportsBasePath:
"""Tests for get_exports_base_path method."""
def test_get_exports_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_exports_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "exports"
def test_get_exports_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_exports_base_path()
assert base_path is None

View File

@@ -0,0 +1,306 @@
"""
Tests for storage backend integration in web application.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInitialization:
"""Tests for storage backend initialization in web config."""
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
"""Test that get_storage_backend returns a StorageBackend instance."""
from shared.storage.base import StorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend()
assert isinstance(backend, StorageBackend)
def test_get_storage_backend_uses_config_file_if_exists(
self, tmp_path: Path
) -> None:
"""Test that storage config file is used when present."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test fallback to environment variables when no config file."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
"""Test that AppConfig can be created with storage backend."""
from shared.storage.base import StorageBackend
from inference.web.config import AppConfig, create_app_config
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
config = create_app_config()
assert hasattr(config, "storage_backend")
assert isinstance(config.storage_backend, StorageBackend)
class TestStorageBackendInDocumentUpload:
"""Tests for storage backend usage in document upload."""
def test_upload_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document upload uses storage backend."""
from unittest.mock import AsyncMock
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a mock upload file
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
# Upload should use storage backend
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
assert result is not None
# Verify file was stored via storage backend
assert backend.exists(f"documents/{result.id}.pdf")
def test_upload_document_stores_logical_path(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document stores logical path, not absolute path."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
# Path should be logical (relative), not absolute
assert not result.file_path.startswith("/")
assert not result.file_path.startswith("C:")
assert result.file_path.startswith("documents/")
class TestStorageBackendInDocumentDownload:
"""Tests for storage backend usage in document download/serving."""
def test_get_document_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document URL uses presigned URL from storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_document_url(doc_path)
# Should return a URL (file:// for local, https:// for cloud)
assert url is not None
assert "test-doc.pdf" in url
def test_download_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document download uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
original_content = b"%PDF-1.4 test content"
backend.upload_bytes(original_content, doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
content = service.download_document(doc_path)
assert content == original_content
class TestStorageBackendInImageServing:
"""Tests for storage backend usage in image serving."""
def test_get_page_image_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image URL uses presigned URL."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test image
image_path = "images/doc-123/page_1.png"
backend.upload_bytes(b"fake png content", image_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_page_image_url("doc-123", 1)
assert url is not None
assert "page_1.png" in url
def test_save_page_image_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image saving uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
image_content = b"fake png content"
service.save_page_image("doc-123", 1, image_content)
# Verify image was stored
assert backend.exists("images/doc-123/page_1.png")
class TestStorageBackendInDocumentDeletion:
"""Tests for storage backend usage in document deletion."""
def test_delete_document_removes_from_storage(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes file from storage."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_files(doc_path)
assert not backend.exists(doc_path)
def test_delete_document_removes_images(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes associated images."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_id = "test-doc-123"
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_images(doc_id)
assert not backend.exists(f"images/{doc_id}/page_1.png")
assert not backend.exists(f"images/{doc_id}/page_2.png")
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
mock.get_document.return_value = None
mock.create_document.return_value = MagicMock(
id="test-doc-id",
file_path="documents/test-doc-id.pdf",
)
return mock

View File

@@ -103,6 +103,31 @@ class MockAnnotation:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockModelVersion:
"""Mock ModelVersion for testing."""
def __init__(self, **kwargs):
self.version_id = kwargs.get('version_id', uuid4())
self.version = kwargs.get('version', '1.0.0')
self.name = kwargs.get('name', 'Test Model')
self.description = kwargs.get('description', None)
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.status = kwargs.get('status', 'inactive')
self.is_active = kwargs.get('is_active', False)
self.task_id = kwargs.get('task_id', None)
self.dataset_id = kwargs.get('dataset_id', None)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.document_count = kwargs.get('document_count', 100)
self.training_config = kwargs.get('training_config', {})
self.file_size = kwargs.get('file_size', 52428800)
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
self.activated_at = kwargs.get('activated_at', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
@@ -111,6 +136,7 @@ class MockAdminDB:
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
self.model_versions = {}
def get_documents_for_training(
self,
@@ -174,6 +200,14 @@ class MockAdminDB:
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
models = [m for m in models if m.status == status]
total = len(models)
return models[offset:offset+limit], total
@pytest.fixture
def app():
@@ -241,6 +275,30 @@ def app():
)
mock_db.training_links[str(doc1.document_id)] = [link1]
# Add model versions
model1 = MockModelVersion(
version="1.0.0",
name="Model v1.0.0",
status="inactive",
is_active=False,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
model2 = MockModelVersion(
version="1.1.0",
name="Model v1.1.0",
status="active",
is_active=True,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint."""
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
def test_get_training_models_success(self, client):
"""Test getting trained models list."""
"""Test getting model versions list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
@@ -338,43 +396,44 @@ class TestTrainingModels:
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics."""
"""Test that model versions include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics
# Check first model has metrics fields
model = data["models"][0]
assert "metrics" in model
assert "mAP" in model["metrics"]
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
assert "metrics_mAP" in model
assert model["metrics_mAP"] is not None
def test_get_training_models_includes_download_url(self, client):
"""Test that completed models have download URLs."""
def test_get_training_models_includes_version_fields(self, client):
"""Test that model versions include version fields."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check completed models have download URLs
for model in data["models"]:
if model["status"] == "completed":
assert "download_url" in model
assert model["download_url"] is not None
# Check model has expected fields
model = data["models"][0]
assert "version_id" in model
assert "version" in model
assert "name" in model
assert "status" in model
assert "is_active" in model
assert "document_count" in model
def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status."""
response = client.get("/admin/training/models?status=completed")
"""Test filtering model versions by status."""
response = client.get("/admin/training/models?status=active")
assert response.status_code == 200
data = response.json()
# All returned models should be completed
assert data["total"] == 1
# All returned models should be active
for model in data["models"]:
assert model["status"] == "completed"
assert model["status"] == "active"
def test_get_training_models_pagination(self, client):
"""Test pagination for models."""
"""Test pagination for model versions."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200