WIP
This commit is contained in:
@@ -9,7 +9,8 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
|
||||
@@ -31,6 +31,7 @@ class MockAdminDocument:
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.category = kwargs.get('category', 'invoice')
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
@@ -67,12 +68,13 @@ class MockAdminDB:
|
||||
|
||||
def get_documents_by_token(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
upload_source=None,
|
||||
has_annotations=None,
|
||||
auto_label_status=None,
|
||||
batch_id=None,
|
||||
category=None,
|
||||
limit=20,
|
||||
offset=0
|
||||
):
|
||||
@@ -95,6 +97,8 @@ class MockAdminDB:
|
||||
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
||||
if batch_id:
|
||||
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
||||
if category:
|
||||
docs = [d for d in docs if d.category == category]
|
||||
|
||||
total = len(docs)
|
||||
return docs[offset:offset+limit], total
|
||||
|
||||
@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
|
||||
|
||||
def test_cleanup_orphan_files(self, async_service, mock_db):
|
||||
"""Test cleanup of orphan files."""
|
||||
# Create an orphan file
|
||||
# Create the async upload directory
|
||||
temp_dir = async_service._async_config.temp_upload_dir
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
orphan_file = temp_dir / "orphan-request.pdf"
|
||||
orphan_file.write_bytes(b"orphan content")
|
||||
|
||||
@@ -228,7 +230,13 @@ class TestAsyncProcessingService:
|
||||
# Mock database to say file doesn't exist
|
||||
mock_db.get_request.return_value = None
|
||||
|
||||
count = async_service._cleanup_orphan_files()
|
||||
# Mock the storage helper to return the same directory as the fixture
|
||||
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
|
||||
mock_helper = MagicMock()
|
||||
mock_helper.get_uploads_base_path.return_value = temp_dir
|
||||
mock_storage.return_value = mock_helper
|
||||
|
||||
count = async_service._cleanup_orphan_files()
|
||||
|
||||
assert count == 1
|
||||
assert not orphan_file.exists()
|
||||
|
||||
@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
||||
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_token() -> str:
|
||||
"""Provide admin token for testing."""
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get_document_by_token.return_value = None
|
||||
mock.get_dataset.return_value = None
|
||||
mock.get_augmented_datasets.return_value = ([], 0)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
"""Create test client with admin authentication."""
|
||||
app = FastAPI()
|
||||
|
||||
# Override dependencies
|
||||
def get_token_override():
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
|
||||
app.dependency_overrides[validate_admin_token] = get_token_override
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
|
||||
# Include router - the router already has /augmentation prefix
|
||||
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
"""Create test client WITHOUT admin authentication override."""
|
||||
app = FastAPI()
|
||||
|
||||
# Only override the database, NOT the token validation
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestAugmentationTypesEndpoint:
|
||||
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
|
||||
assert "stage" in aug_type
|
||||
|
||||
def test_list_augmentation_types_unauthorized(
|
||||
self, admin_client: TestClient
|
||||
self, unauthenticated_client: TestClient
|
||||
) -> None:
|
||||
"""Test that unauthorized request is rejected."""
|
||||
response = admin_client.get("/api/v1/admin/augmentation/types")
|
||||
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing augmentation on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {"std": 15},
|
||||
},
|
||||
)
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {"std": 15},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing full config on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||
"preserve_bboxes": True,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||
"preserve_bboxes": True,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating augmented dataset."""
|
||||
# Mock dataset exists
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.total_images = 100
|
||||
mock_admin_db.get_dataset.return_value = mock_dataset
|
||||
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
|
||||
@pytest.fixture
|
||||
def sample_document_id() -> str:
|
||||
"""Provide a sample document ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-document-id"
|
||||
return TEST_DOCUMENT_UUID
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset_id() -> str:
|
||||
"""Provide a sample dataset ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-dataset-id"
|
||||
return TEST_DATASET_UUID
|
||||
|
||||
@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
|
||||
name="test-dataset",
|
||||
description="Test dataset",
|
||||
status="ready",
|
||||
training_status=None,
|
||||
active_training_task_id=None,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||
# Mock the active training tasks lookup to return empty dict
|
||||
mock_db.get_active_training_tasks_for_datasets.return_value = {}
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
|
||||
363
tests/web/test_dataset_training_status.py
Normal file
363
tests/web/test_dataset_training_status.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Tests for dataset training status feature.
|
||||
|
||||
Tests cover:
|
||||
1. Database model fields (training_status, active_training_task_id)
|
||||
2. AdminDB update_dataset_training_status method
|
||||
3. API response includes training status fields
|
||||
4. Scheduler updates dataset status during training lifecycle
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Database Model
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTrainingDatasetModel:
|
||||
"""Tests for TrainingDataset model fields."""
|
||||
|
||||
def test_training_dataset_has_training_status_field(self):
|
||||
"""TrainingDataset model should have training_status field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(
|
||||
name="test-dataset",
|
||||
training_status="running",
|
||||
)
|
||||
assert dataset.training_status == "running"
|
||||
|
||||
def test_training_dataset_has_active_training_task_id_field(self):
|
||||
"""TrainingDataset model should have active_training_task_id field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
name="test-dataset",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
assert dataset.active_training_task_id == task_id
|
||||
|
||||
def test_training_dataset_defaults(self):
|
||||
"""TrainingDataset should have correct defaults for new fields."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test-dataset")
|
||||
assert dataset.training_status is None
|
||||
assert dataset.active_training_task_id is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test AdminDB Methods
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdminDBDatasetTrainingStatus:
|
||||
"""Tests for AdminDB.update_dataset_training_status method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
def test_update_dataset_training_status_sets_status(self, mock_session):
|
||||
"""update_dataset_training_status should set training_status."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
)
|
||||
|
||||
assert dataset.training_status == "running"
|
||||
mock_session.add.assert_called_once_with(dataset)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_dataset_training_status_sets_task_id(self, mock_session):
|
||||
"""update_dataset_training_status should set active_training_task_id."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(task_id),
|
||||
)
|
||||
|
||||
assert dataset.active_training_task_id == task_id
|
||||
|
||||
def test_update_dataset_training_status_updates_main_status_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should update main status to 'trained' when completed."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
assert dataset.status == "trained"
|
||||
assert dataset.training_status == "completed"
|
||||
|
||||
def test_update_dataset_training_status_clears_task_id_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should clear task_id when training completes."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=dataset_id,
|
||||
name="test-dataset",
|
||||
status="ready",
|
||||
training_status="running",
|
||||
active_training_task_id=task_id,
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
)
|
||||
|
||||
assert dataset.active_training_task_id is None
|
||||
|
||||
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
|
||||
"""update_dataset_training_status should handle missing dataset gracefully."""
|
||||
mock_session.get.return_value = None
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
# Should not raise
|
||||
db.update_dataset_training_status(
|
||||
dataset_id=str(uuid4()),
|
||||
training_status="running",
|
||||
)
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test API Response
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatasetDetailResponseTrainingStatus:
|
||||
"""Tests for DatasetDetailResponse including training status fields."""
|
||||
|
||||
def test_dataset_detail_response_includes_training_status(self):
|
||||
"""DatasetDetailResponse schema should include training_status field."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
name="test-dataset",
|
||||
description=None,
|
||||
status="ready",
|
||||
training_status="running",
|
||||
active_training_task_id=str(uuid4()),
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=10,
|
||||
total_images=15,
|
||||
total_annotations=100,
|
||||
dataset_path="/path/to/dataset",
|
||||
error_message=None,
|
||||
documents=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert response.training_status == "running"
|
||||
assert response.active_training_task_id is not None
|
||||
|
||||
def test_dataset_detail_response_allows_null_training_status(self):
|
||||
"""DatasetDetailResponse should allow null training_status."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
name="test-dataset",
|
||||
description=None,
|
||||
status="ready",
|
||||
training_status=None,
|
||||
active_training_task_id=None,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=10,
|
||||
total_images=15,
|
||||
total_annotations=100,
|
||||
dataset_path=None,
|
||||
error_message=None,
|
||||
documents=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert response.training_status is None
|
||||
assert response.active_training_task_id is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Scheduler Training Status Updates
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSchedulerDatasetStatusUpdates:
|
||||
"""Tests for scheduler updating dataset status during training."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
mock = MagicMock()
|
||||
mock.get_dataset.return_value = MagicMock(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
dataset_path="/path/to/dataset",
|
||||
total_images=100,
|
||||
)
|
||||
mock.get_pending_training_tasks.return_value = []
|
||||
return mock
|
||||
|
||||
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
|
||||
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
|
||||
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
|
||||
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
scheduler._db = mock_db
|
||||
|
||||
task_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Execute task (will fail but we check the status update call)
|
||||
try:
|
||||
scheduler._execute_task(
|
||||
task_id=task_id,
|
||||
config={"model_name": "yolo11n.pt"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to fail in test environment
|
||||
|
||||
# Check that training status was updated to running
|
||||
mock_db.update_dataset_training_status.assert_called()
|
||||
first_call = mock_db.update_dataset_training_status.call_args_list[0]
|
||||
assert first_call.kwargs["training_status"] == "running"
|
||||
assert first_call.kwargs["active_training_task_id"] == task_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Dataset Status Values
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatasetStatusValues:
|
||||
"""Tests for valid dataset status values."""
|
||||
|
||||
def test_dataset_status_building(self):
|
||||
"""Dataset can have status 'building'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="building")
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_dataset_status_ready(self):
|
||||
"""Dataset can have status 'ready'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="ready")
|
||||
assert dataset.status == "ready"
|
||||
|
||||
def test_dataset_status_trained(self):
|
||||
"""Dataset can have status 'trained'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="trained")
|
||||
assert dataset.status == "trained"
|
||||
|
||||
def test_dataset_status_failed(self):
|
||||
"""Dataset can have status 'failed'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="failed")
|
||||
assert dataset.status == "failed"
|
||||
|
||||
def test_training_status_values(self):
|
||||
"""Training status can have various values."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
|
||||
for status in valid_statuses:
|
||||
dataset = TrainingDataset(name="test", training_status=status)
|
||||
assert dataset.training_status == status
|
||||
207
tests/web/test_document_category.py
Normal file
207
tests/web/test_document_category.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Tests for Document Category Feature.
|
||||
|
||||
TDD tests for adding category field to admin_documents table.
|
||||
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument
|
||||
|
||||
|
||||
# Test constants
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestAdminDocumentCategoryField:
|
||||
"""Tests for AdminDocument category field."""
|
||||
|
||||
def test_document_has_category_field(self):
|
||||
"""Test AdminDocument model has category field."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
)
|
||||
assert hasattr(doc, "category")
|
||||
|
||||
def test_document_category_defaults_to_invoice(self):
|
||||
"""Test category defaults to 'invoice' when not specified."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
)
|
||||
assert doc.category == "invoice"
|
||||
|
||||
def test_document_accepts_custom_category(self):
|
||||
"""Test document accepts custom category values."""
|
||||
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
|
||||
|
||||
for cat in categories:
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
category=cat,
|
||||
)
|
||||
assert doc.category == cat
|
||||
|
||||
def test_document_category_is_string_type(self):
|
||||
"""Test category field is a string type."""
|
||||
doc = AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/path/to/file.pdf",
|
||||
category="letter",
|
||||
)
|
||||
assert isinstance(doc.category, str)
|
||||
|
||||
|
||||
class TestDocumentCategoryInReadModel:
|
||||
"""Tests for category in response models."""
|
||||
|
||||
def test_admin_document_read_has_category(self):
|
||||
"""Test AdminDocumentRead includes category field."""
|
||||
from inference.data.admin_models import AdminDocumentRead
|
||||
|
||||
# Check the model has category field in its schema
|
||||
assert "category" in AdminDocumentRead.model_fields
|
||||
|
||||
|
||||
class TestDocumentCategoryAPI:
|
||||
"""Tests for document category in API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
return db
|
||||
|
||||
def test_upload_document_with_category(self, mock_admin_db):
|
||||
"""Test uploading document with category parameter."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
# Verify response schema supports category
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
message="Upload successful",
|
||||
category="letter",
|
||||
)
|
||||
assert response.category == "letter"
|
||||
|
||||
def test_list_documents_returns_category(self, mock_admin_db):
|
||||
"""Test list documents endpoint returns category."""
|
||||
from inference.web.schemas.admin import DocumentItem
|
||||
|
||||
item = DocumentItem(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
annotation_count=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
category="invoice",
|
||||
)
|
||||
assert item.category == "invoice"
|
||||
|
||||
def test_document_detail_includes_category(self, mock_admin_db):
|
||||
"""Test document detail response includes category."""
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
|
||||
# Check schema has category
|
||||
assert "category" in DocumentDetailResponse.model_fields
|
||||
|
||||
|
||||
class TestDocumentCategoryFiltering:
|
||||
"""Tests for filtering documents by category."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB with category filtering support."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
|
||||
# Mock documents with different categories
|
||||
invoice_doc = MagicMock()
|
||||
invoice_doc.document_id = uuid4()
|
||||
invoice_doc.category = "invoice"
|
||||
|
||||
letter_doc = MagicMock()
|
||||
letter_doc.document_id = uuid4()
|
||||
letter_doc.category = "letter"
|
||||
|
||||
db.get_documents_by_category.return_value = [invoice_doc]
|
||||
return db
|
||||
|
||||
def test_filter_documents_by_category(self, mock_admin_db):
|
||||
"""Test filtering documents by category."""
|
||||
# This tests the DB method signature
|
||||
result = mock_admin_db.get_documents_by_category("invoice")
|
||||
assert len(result) == 1
|
||||
assert result[0].category == "invoice"
|
||||
|
||||
|
||||
class TestDocumentCategoryUpdate:
|
||||
"""Tests for updating document category."""
|
||||
|
||||
def test_update_document_category_schema(self):
|
||||
"""Test update document request supports category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="letter")
|
||||
assert request.category == "letter"
|
||||
|
||||
def test_update_document_category_optional(self):
|
||||
"""Test category is optional in update request."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
# Should not raise - category is optional
|
||||
request = DocumentUpdateRequest()
|
||||
assert request.category is None
|
||||
|
||||
|
||||
class TestDatasetWithCategory:
|
||||
"""Tests for dataset creation with category filtering."""
|
||||
|
||||
def test_dataset_create_with_category_filter(self):
|
||||
"""Test creating dataset can filter by document category."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Invoice Training Set",
|
||||
document_ids=[TEST_DOC_UUID],
|
||||
category="invoice", # Optional filter
|
||||
)
|
||||
assert request.category == "invoice"
|
||||
|
||||
def test_dataset_create_category_is_optional(self):
|
||||
"""Test category filter is optional when creating dataset."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Mixed Training Set",
|
||||
document_ids=[TEST_DOC_UUID],
|
||||
)
|
||||
# category should be optional
|
||||
assert not hasattr(request, "category") or request.category is None
|
||||
165
tests/web/test_document_category_api.py
Normal file
165
tests/web/test_document_category_api.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Tests for Document Category API Endpoints.
|
||||
|
||||
TDD tests for category filtering and management in document endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Test constants
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestGetCategoriesEndpoint:
|
||||
"""Tests for GET /admin/documents/categories endpoint."""
|
||||
|
||||
def test_categories_endpoint_returns_list(self):
|
||||
"""Test categories endpoint returns list of available categories."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
# Test schema exists and works
|
||||
response = DocumentCategoriesResponse(
|
||||
categories=["invoice", "letter", "receipt"],
|
||||
total=3,
|
||||
)
|
||||
assert response.categories == ["invoice", "letter", "receipt"]
|
||||
assert response.total == 3
|
||||
|
||||
def test_categories_response_schema(self):
|
||||
"""Test DocumentCategoriesResponse schema structure."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
assert "categories" in DocumentCategoriesResponse.model_fields
|
||||
assert "total" in DocumentCategoriesResponse.model_fields
|
||||
|
||||
|
||||
class TestDocumentListFilterByCategory:
|
||||
"""Tests for filtering documents by category."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
|
||||
# Mock documents with different categories
|
||||
invoice_doc = MagicMock()
|
||||
invoice_doc.document_id = uuid4()
|
||||
invoice_doc.category = "invoice"
|
||||
invoice_doc.filename = "invoice1.pdf"
|
||||
|
||||
letter_doc = MagicMock()
|
||||
letter_doc.document_id = uuid4()
|
||||
letter_doc.category = "letter"
|
||||
letter_doc.filename = "letter1.pdf"
|
||||
|
||||
db.get_documents.return_value = ([invoice_doc], 1)
|
||||
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
|
||||
return db
|
||||
|
||||
def test_list_documents_accepts_category_filter(self, mock_admin_db):
|
||||
"""Test list documents endpoint accepts category query parameter."""
|
||||
# The endpoint should accept ?category=invoice parameter
|
||||
# This test verifies the schema/query parameter exists
|
||||
from inference.web.schemas.admin import DocumentListResponse
|
||||
|
||||
# Schema should work with category filter applied
|
||||
assert DocumentListResponse is not None
|
||||
|
||||
def test_get_document_categories_from_db(self, mock_admin_db):
|
||||
"""Test fetching unique categories from database."""
|
||||
categories = mock_admin_db.get_document_categories()
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
assert len(categories) == 3
|
||||
|
||||
|
||||
class TestDocumentUploadWithCategory:
|
||||
"""Tests for uploading documents with category."""
|
||||
|
||||
def test_upload_request_accepts_category(self):
|
||||
"""Test upload request can include category field."""
|
||||
# When uploading via form data, category should be accepted
|
||||
# This is typically a form field, not a schema
|
||||
pass
|
||||
|
||||
def test_upload_response_includes_category(self):
|
||||
"""Test upload response includes the category that was set."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
category="letter", # Custom category
|
||||
message="Upload successful",
|
||||
)
|
||||
assert response.category == "letter"
|
||||
|
||||
def test_upload_defaults_to_invoice_category(self):
|
||||
"""Test upload defaults to 'invoice' if no category specified."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
page_count=1,
|
||||
status="pending",
|
||||
message="Upload successful",
|
||||
# No category specified - should default to "invoice"
|
||||
)
|
||||
assert response.category == "invoice"
|
||||
|
||||
|
||||
class TestAdminDBCategoryMethods:
|
||||
"""Tests for AdminDB category-related methods."""
|
||||
|
||||
def test_get_document_categories_method_exists(self):
|
||||
"""Test AdminDB has get_document_categories method."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "get_document_categories")
|
||||
|
||||
def test_get_documents_accepts_category_filter(self):
|
||||
"""Test get_documents_by_token method accepts category parameter."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
import inspect
|
||||
|
||||
db = AdminDB()
|
||||
# Check the method exists and accepts category parameter
|
||||
method = getattr(db, "get_documents_by_token", None)
|
||||
assert callable(method)
|
||||
|
||||
# Check category is in the method signature
|
||||
sig = inspect.signature(method)
|
||||
assert "category" in sig.parameters
|
||||
|
||||
|
||||
class TestUpdateDocumentCategory:
|
||||
"""Tests for updating document category."""
|
||||
|
||||
def test_update_document_category_method_exists(self):
|
||||
"""Test AdminDB has method to update document category."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "update_document_category")
|
||||
|
||||
def test_update_request_schema(self):
|
||||
"""Test DocumentUpdateRequest can update category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="receipt")
|
||||
assert request.category == "receipt"
|
||||
@@ -32,10 +32,10 @@ def test_app(tmp_path):
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
),
|
||||
storage=StorageConfig(
|
||||
file=StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
|
||||
max_file_size_mb=50,
|
||||
),
|
||||
)
|
||||
@@ -252,20 +252,25 @@ class TestResultsEndpoint:
|
||||
response = client.get("/api/v1/results/nonexistent.png")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
||||
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
|
||||
"""Test that existing result file is returned."""
|
||||
# Get storage config from app
|
||||
storage_config = test_app.extra.get("storage_config")
|
||||
if not storage_config:
|
||||
pytest.skip("Storage config not available in test app")
|
||||
|
||||
# Create a test result file
|
||||
result_file = storage_config.result_dir / "test_result.png"
|
||||
# Create a test result file in temp directory
|
||||
result_dir = tmp_path / "results"
|
||||
result_dir.mkdir(exist_ok=True)
|
||||
result_file = result_dir / "test_result.png"
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img.save(result_file)
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
# Mock the storage helper to return our test file path
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_helper = Mock()
|
||||
mock_helper.get_result_local_path.return_value = result_file
|
||||
mock_storage.return_value = mock_helper
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.inference_service = None
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result.status == "active"
|
||||
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = None
|
||||
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.inference_service = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
|
||||
828
tests/web/test_storage_helpers.py
Normal file
828
tests/web/test_storage_helpers.py
Normal file
@@ -0,0 +1,828 @@
|
||||
"""Tests for storage helpers module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
|
||||
from shared.storage import PREFIXES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage() -> MagicMock:
|
||||
"""Create a mock storage backend."""
|
||||
storage = MagicMock()
|
||||
storage.upload_bytes = MagicMock()
|
||||
storage.download_bytes = MagicMock(return_value=b"test content")
|
||||
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
|
||||
storage.exists = MagicMock(return_value=True)
|
||||
storage.delete = MagicMock(return_value=True)
|
||||
storage.list_files = MagicMock(return_value=[])
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper(mock_storage: MagicMock) -> StorageHelper:
|
||||
"""Create a storage helper with mock backend."""
|
||||
return StorageHelper(storage=mock_storage)
|
||||
|
||||
|
||||
class TestStorageHelperInit:
|
||||
"""Tests for StorageHelper initialization."""
|
||||
|
||||
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
|
||||
"""Should use provided storage backend."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
assert helper.storage is mock_storage
|
||||
|
||||
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Storage property should return the backend."""
|
||||
assert helper.storage is mock_storage
|
||||
|
||||
|
||||
class TestDocumentOperations:
|
||||
"""Tests for document storage operations."""
|
||||
|
||||
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should upload document with correct path."""
|
||||
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
|
||||
|
||||
assert doc_id == "doc123"
|
||||
assert path == "documents/doc123.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"pdf content", "documents/doc123.pdf", overwrite=True
|
||||
)
|
||||
|
||||
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
|
||||
"""Should generate document ID if not provided."""
|
||||
doc_id, path = helper.upload_document(b"content", "file.pdf")
|
||||
|
||||
assert doc_id is not None
|
||||
assert len(doc_id) > 0
|
||||
assert path.startswith("documents/")
|
||||
|
||||
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should download document from correct path."""
|
||||
content = helper.download_document("doc123")
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for document."""
|
||||
url = helper.get_document_url("doc123", expires_in_seconds=7200)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"documents/doc123.pdf", 7200
|
||||
)
|
||||
|
||||
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check document existence."""
|
||||
exists = helper.document_exists("doc123")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete document."""
|
||||
result = helper.delete_document("doc123")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
|
||||
|
||||
|
||||
class TestImageOperations:
|
||||
"""Tests for image storage operations."""
|
||||
|
||||
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save page image with correct path."""
|
||||
path = helper.save_page_image("doc123", 1, b"image data")
|
||||
|
||||
assert path == "images/doc123/page_1.png"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"image data", "images/doc123/page_1.png", overwrite=True
|
||||
)
|
||||
|
||||
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get page image from correct path."""
|
||||
content = helper.get_page_image("doc123", 2)
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
|
||||
|
||||
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for page image."""
|
||||
url = helper.get_page_image_url("doc123", 3)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"images/doc123/page_3.png", 3600
|
||||
)
|
||||
|
||||
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete all images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"images/doc123/page_1.png",
|
||||
"images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
deleted = helper.delete_document_images("doc123")
|
||||
|
||||
assert deleted == 2
|
||||
mock_storage.list_files.assert_called_once_with("images/doc123/")
|
||||
|
||||
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should list all images for a document."""
|
||||
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
|
||||
|
||||
images = helper.list_document_images("doc123")
|
||||
|
||||
assert images == ["images/doc123/page_1.png"]
|
||||
|
||||
|
||||
class TestUploadOperations:
|
||||
"""Tests for upload staging operations."""
|
||||
|
||||
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save upload to correct path."""
|
||||
path = helper.save_upload(b"content", "file.pdf")
|
||||
|
||||
assert path == "uploads/file.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save upload with subfolder."""
|
||||
path = helper.save_upload(b"content", "file.pdf", "async")
|
||||
|
||||
assert path == "uploads/async/file.pdf"
|
||||
|
||||
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get upload from correct path."""
|
||||
content = helper.get_upload("file.pdf", "async")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
|
||||
|
||||
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete upload."""
|
||||
result = helper.delete_upload("file.pdf")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
|
||||
|
||||
|
||||
class TestResultOperations:
|
||||
"""Tests for result file operations."""
|
||||
|
||||
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save result to correct path."""
|
||||
path = helper.save_result(b"result data", "output.json")
|
||||
|
||||
assert path == "results/output.json"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get result from correct path."""
|
||||
content = helper.get_result("output.json")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("results/output.json")
|
||||
|
||||
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for result."""
|
||||
url = helper.get_result_url("output.json")
|
||||
|
||||
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
|
||||
|
||||
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check result existence."""
|
||||
exists = helper.result_exists("output.json")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("results/output.json")
|
||||
|
||||
|
||||
class TestExportOperations:
|
||||
"""Tests for export file operations."""
|
||||
|
||||
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save export to correct path."""
|
||||
path = helper.save_export(b"export data", "exp123", "dataset.zip")
|
||||
|
||||
assert path == "exports/exp123/dataset.zip"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for export."""
|
||||
url = helper.get_export_url("exp123", "dataset.zip")
|
||||
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"exports/exp123/dataset.zip", 3600
|
||||
)
|
||||
|
||||
|
||||
class TestRawPdfOperations:
|
||||
"""Tests for raw PDF operations (legacy compatibility)."""
|
||||
|
||||
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save raw PDF to correct path."""
|
||||
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
|
||||
|
||||
assert path == "raw_pdfs/invoice.pdf"
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
|
||||
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get raw PDF from correct path."""
|
||||
content = helper.get_raw_pdf("invoice.pdf")
|
||||
|
||||
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||
|
||||
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check raw PDF existence."""
|
||||
exists = helper.raw_pdf_exists("invoice.pdf")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||
|
||||
|
||||
class TestAdminImageOperations:
|
||||
"""Tests for admin image storage operations."""
|
||||
|
||||
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should save admin image with correct path."""
|
||||
path = helper.save_admin_image("doc123", 1, b"image data")
|
||||
|
||||
assert path == "admin_images/doc123/page_1.png"
|
||||
mock_storage.upload_bytes.assert_called_once_with(
|
||||
b"image data", "admin_images/doc123/page_1.png", overwrite=True
|
||||
)
|
||||
|
||||
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get admin image from correct path."""
|
||||
content = helper.get_admin_image("doc123", 2)
|
||||
|
||||
assert content == b"test content"
|
||||
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
|
||||
|
||||
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should get presigned URL for admin image."""
|
||||
url = helper.get_admin_image_url("doc123", 3)
|
||||
|
||||
assert url == "https://example.com/file"
|
||||
mock_storage.get_presigned_url.assert_called_once_with(
|
||||
"admin_images/doc123/page_3.png", 3600
|
||||
)
|
||||
|
||||
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check admin image existence."""
|
||||
exists = helper.admin_image_exists("doc123", 1)
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
|
||||
|
||||
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
|
||||
"""Should return correct admin image path."""
|
||||
path = helper.get_admin_image_path("doc123", 2)
|
||||
|
||||
assert path == "admin_images/doc123/page_2.png"
|
||||
|
||||
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should list all admin images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"admin_images/doc123/page_1.png",
|
||||
"admin_images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
images = helper.list_admin_images("doc123")
|
||||
|
||||
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
|
||||
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||
|
||||
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete all admin images for a document."""
|
||||
mock_storage.list_files.return_value = [
|
||||
"admin_images/doc123/page_1.png",
|
||||
"admin_images/doc123/page_2.png",
|
||||
]
|
||||
|
||||
deleted = helper.delete_admin_images("doc123")
|
||||
|
||||
assert deleted == 2
|
||||
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||
|
||||
|
||||
class TestGetLocalPath:
|
||||
"""Tests for get_local_path method."""
|
||||
|
||||
def test_get_admin_image_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test image
|
||||
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "page_1.png").write_bytes(b"test image")
|
||||
|
||||
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "page_1.png"
|
||||
|
||||
def test_get_admin_image_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when storage doesn't support local paths."""
|
||||
# Mock storage without get_local_path method (simulating cloud storage)
|
||||
mock_storage.get_local_path = MagicMock(return_value=None)
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
|
||||
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||
|
||||
assert local_path is None
|
||||
|
||||
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
local_path = helper.get_admin_image_local_path("nonexistent", 1)
|
||||
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestGetAdminImageDimensions:
|
||||
"""Tests for get_admin_image_dimensions method."""
|
||||
|
||||
def test_get_dimensions_with_local_storage(self) -> None:
|
||||
"""Should return image dimensions when using local storage."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test image with known dimensions
|
||||
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
img.save(test_path / "page_1.png")
|
||||
|
||||
dimensions = helper.get_admin_image_dimensions("doc123", 1)
|
||||
|
||||
assert dimensions == (800, 600)
|
||||
|
||||
def test_get_dimensions_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
|
||||
|
||||
assert dimensions is None
|
||||
|
||||
|
||||
class TestGetStorageHelper:
|
||||
"""Tests for get_storage_helper function."""
|
||||
|
||||
def test_returns_helper_instance(self) -> None:
|
||||
"""Should return a StorageHelper instance."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
# Reset the global helper
|
||||
import inference.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper = get_storage_helper()
|
||||
|
||||
assert isinstance(helper, StorageHelper)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
import inference.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper1 = get_storage_helper()
|
||||
helper2 = get_storage_helper()
|
||||
|
||||
assert helper1 is helper2
|
||||
|
||||
|
||||
class TestDeleteResult:
|
||||
"""Tests for delete_result method."""
|
||||
|
||||
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should delete result file."""
|
||||
result = helper.delete_result("output.json")
|
||||
|
||||
assert result is True
|
||||
mock_storage.delete.assert_called_once_with("results/output.json")
|
||||
|
||||
|
||||
class TestResultLocalPath:
|
||||
"""Tests for get_result_local_path method."""
|
||||
|
||||
def test_get_result_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test result file
|
||||
test_path = Path(temp_dir) / "results"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "output.json").write_bytes(b"test result")
|
||||
|
||||
local_path = helper.get_result_local_path("output.json")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "output.json"
|
||||
|
||||
def test_get_result_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when storage doesn't support local paths."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_result_local_path("output.json")
|
||||
assert local_path is None
|
||||
|
||||
def test_get_result_local_path_nonexistent_file(self) -> None:
|
||||
"""Should return None when file doesn't exist."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
local_path = helper.get_result_local_path("nonexistent.json")
|
||||
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestResultsBasePath:
|
||||
"""Tests for get_results_base_path method."""
|
||||
|
||||
def test_get_results_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_results_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "results"
|
||||
|
||||
def test_get_results_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_results_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestUploadLocalPath:
|
||||
"""Tests for get_upload_local_path method."""
|
||||
|
||||
def test_get_upload_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test upload file
|
||||
test_path = Path(temp_dir) / "uploads"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||
|
||||
local_path = helper.get_upload_local_path("file.pdf")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "file.pdf"
|
||||
|
||||
def test_get_upload_local_path_with_subfolder(self) -> None:
|
||||
"""Should return local path with subfolder."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test upload file with subfolder
|
||||
test_path = Path(temp_dir) / "uploads" / "async"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||
|
||||
local_path = helper.get_upload_local_path("file.pdf", "async")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
|
||||
def test_get_upload_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_upload_local_path("file.pdf")
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestUploadsBasePath:
|
||||
"""Tests for get_uploads_base_path method."""
|
||||
|
||||
def test_get_uploads_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_uploads_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "uploads"
|
||||
|
||||
def test_get_uploads_base_path_with_subfolder(self) -> None:
|
||||
"""Should return base path with subfolder."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_uploads_base_path("async")
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "async"
|
||||
|
||||
def test_get_uploads_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_uploads_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestUploadExists:
|
||||
"""Tests for upload_exists method."""
|
||||
|
||||
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||
"""Should check upload existence."""
|
||||
exists = helper.upload_exists("file.pdf")
|
||||
|
||||
assert exists is True
|
||||
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
|
||||
|
||||
def test_upload_exists_with_subfolder(
|
||||
self, helper: StorageHelper, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should check upload existence with subfolder."""
|
||||
helper.upload_exists("file.pdf", "async")
|
||||
|
||||
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
|
||||
|
||||
|
||||
class TestDatasetsBasePath:
|
||||
"""Tests for get_datasets_base_path method."""
|
||||
|
||||
def test_get_datasets_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_datasets_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "datasets"
|
||||
|
||||
def test_get_datasets_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_datasets_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestAdminImagesBasePath:
|
||||
"""Tests for get_admin_images_base_path method."""
|
||||
|
||||
def test_get_admin_images_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_admin_images_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "admin_images"
|
||||
|
||||
def test_get_admin_images_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_admin_images_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestRawPdfsBasePath:
|
||||
"""Tests for get_raw_pdfs_base_path method."""
|
||||
|
||||
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_raw_pdfs_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "raw_pdfs"
|
||||
|
||||
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_raw_pdfs_base_path()
|
||||
assert base_path is None
|
||||
|
||||
|
||||
class TestRawPdfLocalPath:
|
||||
"""Tests for get_raw_pdf_local_path method."""
|
||||
|
||||
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
|
||||
"""Should return local path when using local storage backend."""
|
||||
from pathlib import Path
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
# Create a test raw PDF
|
||||
test_path = Path(temp_dir) / "raw_pdfs"
|
||||
test_path.mkdir(parents=True, exist_ok=True)
|
||||
(test_path / "invoice.pdf").write_bytes(b"test pdf")
|
||||
|
||||
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||
|
||||
assert local_path is not None
|
||||
assert local_path.exists()
|
||||
assert local_path.name == "invoice.pdf"
|
||||
|
||||
def test_get_raw_pdf_local_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||
assert local_path is None
|
||||
|
||||
|
||||
class TestRawPdfPath:
|
||||
"""Tests for get_raw_pdf_path method."""
|
||||
|
||||
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
|
||||
"""Should return correct storage path."""
|
||||
path = helper.get_raw_pdf_path("invoice.pdf")
|
||||
assert path == "raw_pdfs/invoice.pdf"
|
||||
|
||||
|
||||
class TestAutolabelOutputPath:
|
||||
"""Tests for get_autolabel_output_path method."""
|
||||
|
||||
def test_get_autolabel_output_path_with_local_storage(self) -> None:
|
||||
"""Should return output path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
output_path = helper.get_autolabel_output_path()
|
||||
|
||||
assert output_path is not None
|
||||
assert output_path.exists()
|
||||
assert output_path.name == "autolabel_output"
|
||||
|
||||
def test_get_autolabel_output_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
output_path = helper.get_autolabel_output_path()
|
||||
assert output_path is None
|
||||
|
||||
|
||||
class TestTrainingDataPath:
|
||||
"""Tests for get_training_data_path method."""
|
||||
|
||||
def test_get_training_data_path_with_local_storage(self) -> None:
|
||||
"""Should return training path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
training_path = helper.get_training_data_path()
|
||||
|
||||
assert training_path is not None
|
||||
assert training_path.exists()
|
||||
assert training_path.name == "training"
|
||||
|
||||
def test_get_training_data_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
training_path = helper.get_training_data_path()
|
||||
assert training_path is None
|
||||
|
||||
|
||||
class TestExportsBasePath:
|
||||
"""Tests for get_exports_base_path method."""
|
||||
|
||||
def test_get_exports_base_path_with_local_storage(self) -> None:
|
||||
"""Should return base path when using local storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = LocalStorageBackend(temp_dir)
|
||||
helper = StorageHelper(storage=storage)
|
||||
|
||||
base_path = helper.get_exports_base_path()
|
||||
|
||||
assert base_path is not None
|
||||
assert base_path.exists()
|
||||
assert base_path.name == "exports"
|
||||
|
||||
def test_get_exports_base_path_returns_none_for_cloud(
|
||||
self, mock_storage: MagicMock
|
||||
) -> None:
|
||||
"""Should return None when not using local storage."""
|
||||
helper = StorageHelper(storage=mock_storage)
|
||||
base_path = helper.get_exports_base_path()
|
||||
assert base_path is None
|
||||
306
tests/web/test_storage_integration.py
Normal file
306
tests/web/test_storage_integration.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Tests for storage backend integration in web application.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStorageBackendInitialization:
|
||||
"""Tests for storage backend initialization in web config."""
|
||||
|
||||
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
|
||||
"""Test that get_storage_backend returns a StorageBackend instance."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = get_storage_backend()
|
||||
|
||||
assert isinstance(backend, StorageBackend)
|
||||
|
||||
def test_get_storage_backend_uses_config_file_if_exists(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that storage config file is used when present."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
config_file.write_text(f"""
|
||||
backend: local
|
||||
|
||||
local:
|
||||
base_path: {storage_path}
|
||||
""")
|
||||
|
||||
backend = get_storage_backend(config_path=config_file)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||
"""Test fallback to environment variables when no config file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
backend = get_storage_backend(config_path=None)
|
||||
|
||||
assert isinstance(backend, LocalStorageBackend)
|
||||
|
||||
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
|
||||
"""Test that AppConfig can be created with storage backend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import AppConfig, create_app_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = create_app_config()
|
||||
|
||||
assert hasattr(config, "storage_backend")
|
||||
assert isinstance(config.storage_backend, StorageBackend)
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentUpload:
|
||||
"""Tests for storage backend usage in document upload."""
|
||||
|
||||
def test_upload_document_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document upload uses storage backend."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a mock upload file
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
# Upload should use storage backend
|
||||
result = service.upload_document(
|
||||
content=pdf_content,
|
||||
filename="test.pdf",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
# Verify file was stored via storage backend
|
||||
assert backend.exists(f"documents/{result.id}.pdf")
|
||||
|
||||
def test_upload_document_stores_logical_path(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document stores logical path, not absolute path."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
result = service.upload_document(
|
||||
content=pdf_content,
|
||||
filename="test.pdf",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
|
||||
# Path should be logical (relative), not absolute
|
||||
assert not result.file_path.startswith("/")
|
||||
assert not result.file_path.startswith("C:")
|
||||
assert result.file_path.startswith("documents/")
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentDownload:
|
||||
"""Tests for storage backend usage in document download/serving."""
|
||||
|
||||
def test_get_document_url_returns_presigned_url(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document URL uses presigned URL from storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test file
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
url = service.get_document_url(doc_path)
|
||||
|
||||
# Should return a URL (file:// for local, https:// for cloud)
|
||||
assert url is not None
|
||||
assert "test-doc.pdf" in url
|
||||
|
||||
def test_download_document_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document download uses storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test file
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
original_content = b"%PDF-1.4 test content"
|
||||
backend.upload_bytes(original_content, doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
content = service.download_document(doc_path)
|
||||
|
||||
assert content == original_content
|
||||
|
||||
|
||||
class TestStorageBackendInImageServing:
|
||||
"""Tests for storage backend usage in image serving."""
|
||||
|
||||
def test_get_page_image_url_returns_presigned_url(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that page image URL uses presigned URL."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create a test image
|
||||
image_path = "images/doc-123/page_1.png"
|
||||
backend.upload_bytes(b"fake png content", image_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
url = service.get_page_image_url("doc-123", 1)
|
||||
|
||||
assert url is not None
|
||||
assert "page_1.png" in url
|
||||
|
||||
def test_save_page_image_uses_storage_backend(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that page image saving uses storage backend."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
image_content = b"fake png content"
|
||||
service.save_page_image("doc-123", 1, image_content)
|
||||
|
||||
# Verify image was stored
|
||||
assert backend.exists("images/doc-123/page_1.png")
|
||||
|
||||
|
||||
class TestStorageBackendInDocumentDeletion:
|
||||
"""Tests for storage backend usage in document deletion."""
|
||||
|
||||
def test_delete_document_removes_from_storage(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document deletion removes file from storage."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create test files
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
service.delete_document_files(doc_path)
|
||||
|
||||
assert not backend.exists(doc_path)
|
||||
|
||||
def test_delete_document_removes_images(
|
||||
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||
) -> None:
|
||||
"""Test that document deletion removes associated images."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
storage_path = tmp_path / "storage"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
# Create test files
|
||||
doc_id = "test-doc-123"
|
||||
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
|
||||
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
service.delete_document_images(doc_id)
|
||||
|
||||
assert not backend.exists(f"images/{doc_id}/page_1.png")
|
||||
assert not backend.exists(f"images/{doc_id}/page_2.png")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
mock = MagicMock()
|
||||
mock.get_document.return_value = None
|
||||
mock.create_document.return_value = MagicMock(
|
||||
id="test-doc-id",
|
||||
file_path="documents/test-doc-id.pdf",
|
||||
)
|
||||
return mock
|
||||
@@ -103,6 +103,31 @@ class MockAnnotation:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockModelVersion:
|
||||
"""Mock ModelVersion for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.version_id = kwargs.get('version_id', uuid4())
|
||||
self.version = kwargs.get('version', '1.0.0')
|
||||
self.name = kwargs.get('name', 'Test Model')
|
||||
self.description = kwargs.get('description', None)
|
||||
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||
self.status = kwargs.get('status', 'inactive')
|
||||
self.is_active = kwargs.get('is_active', False)
|
||||
self.task_id = kwargs.get('task_id', None)
|
||||
self.dataset_id = kwargs.get('dataset_id', None)
|
||||
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||
self.document_count = kwargs.get('document_count', 100)
|
||||
self.training_config = kwargs.get('training_config', {})
|
||||
self.file_size = kwargs.get('file_size', 52428800)
|
||||
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
|
||||
self.activated_at = kwargs.get('activated_at', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
|
||||
@@ -111,6 +136,7 @@ class MockAdminDB:
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
self.model_versions = {}
|
||||
|
||||
def get_documents_for_training(
|
||||
self,
|
||||
@@ -174,6 +200,14 @@ class MockAdminDB:
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||
"""Get model versions with optional filtering."""
|
||||
models = list(self.model_versions.values())
|
||||
if status:
|
||||
models = [m for m in models if m.status == status]
|
||||
total = len(models)
|
||||
return models[offset:offset+limit], total
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
@@ -241,6 +275,30 @@ def app():
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Add model versions
|
||||
model1 = MockModelVersion(
|
||||
version="1.0.0",
|
||||
name="Model v1.0.0",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
)
|
||||
model2 = MockModelVersion(
|
||||
version="1.1.0",
|
||||
name="Model v1.1.0",
|
||||
status="active",
|
||||
is_active=True,
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
document_count=600,
|
||||
)
|
||||
mock_db.model_versions[str(model1.version_id)] = model1
|
||||
mock_db.model_versions[str(model2.version_id)] = model2
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
|
||||
|
||||
|
||||
class TestTrainingModels:
|
||||
"""Tests for GET /admin/training/models endpoint."""
|
||||
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
|
||||
|
||||
def test_get_training_models_success(self, client):
|
||||
"""Test getting trained models list."""
|
||||
"""Test getting model versions list."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -338,43 +396,44 @@ class TestTrainingModels:
|
||||
assert len(data["models"]) == 2
|
||||
|
||||
def test_get_training_models_includes_metrics(self, client):
|
||||
"""Test that models include metrics."""
|
||||
"""Test that model versions include metrics."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check first model has metrics
|
||||
# Check first model has metrics fields
|
||||
model = data["models"][0]
|
||||
assert "metrics" in model
|
||||
assert "mAP" in model["metrics"]
|
||||
assert model["metrics"]["mAP"] is not None
|
||||
assert "precision" in model["metrics"]
|
||||
assert "recall" in model["metrics"]
|
||||
assert "metrics_mAP" in model
|
||||
assert model["metrics_mAP"] is not None
|
||||
|
||||
def test_get_training_models_includes_download_url(self, client):
|
||||
"""Test that completed models have download URLs."""
|
||||
def test_get_training_models_includes_version_fields(self, client):
|
||||
"""Test that model versions include version fields."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check completed models have download URLs
|
||||
for model in data["models"]:
|
||||
if model["status"] == "completed":
|
||||
assert "download_url" in model
|
||||
assert model["download_url"] is not None
|
||||
# Check model has expected fields
|
||||
model = data["models"][0]
|
||||
assert "version_id" in model
|
||||
assert "version" in model
|
||||
assert "name" in model
|
||||
assert "status" in model
|
||||
assert "is_active" in model
|
||||
assert "document_count" in model
|
||||
|
||||
def test_get_training_models_filter_by_status(self, client):
|
||||
"""Test filtering models by status."""
|
||||
response = client.get("/admin/training/models?status=completed")
|
||||
"""Test filtering model versions by status."""
|
||||
response = client.get("/admin/training/models?status=active")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned models should be completed
|
||||
assert data["total"] == 1
|
||||
# All returned models should be active
|
||||
for model in data["models"]:
|
||||
assert model["status"] == "completed"
|
||||
assert model["status"] == "active"
|
||||
|
||||
def test_get_training_models_pagination(self, client):
|
||||
"""Test pagination for models."""
|
||||
"""Test pagination for model versions."""
|
||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user