WIP
This commit is contained in:
@@ -11,7 +11,11 @@ from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_dataset_repository,
|
||||
)
|
||||
|
||||
|
||||
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
||||
@@ -26,18 +30,27 @@ def admin_token() -> str:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
def mock_document_repo() -> MagicMock:
|
||||
"""Create a mock DocumentRepository for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get_document_by_token.return_value = None
|
||||
mock.get_dataset.return_value = None
|
||||
mock.get_augmented_datasets.return_value = ([], 0)
|
||||
mock.get.return_value = None
|
||||
mock.get_by_token.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def mock_dataset_repo() -> MagicMock:
|
||||
"""Create a mock DatasetRepository for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get.return_value = None
|
||||
mock.get_paginated.return_value = ([], 0)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
||||
"""Create test client with admin authentication."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -45,11 +58,15 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def get_token_override():
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
def get_document_repo_override():
|
||||
return mock_document_repo
|
||||
|
||||
def get_dataset_repo_override():
|
||||
return mock_dataset_repo
|
||||
|
||||
app.dependency_overrides[validate_admin_token] = get_token_override
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
||||
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
||||
|
||||
# Include router - the router already has /augmentation prefix
|
||||
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
||||
@@ -60,15 +77,19 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def unauthenticated_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
||||
"""Create test client WITHOUT admin authentication override."""
|
||||
app = FastAPI()
|
||||
|
||||
# Only override the database, NOT the token validation
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
# Only override the repositories, NOT the token validation
|
||||
def get_document_repo_override():
|
||||
return mock_document_repo
|
||||
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
def get_dataset_repo_override():
|
||||
return mock_dataset_repo
|
||||
|
||||
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
||||
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
||||
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
@@ -142,13 +163,13 @@ class TestAugmentationPreviewEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_document_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing augmentation on a document."""
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
mock_document_repo.get.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
@@ -218,13 +239,13 @@ class TestAugmentationPreviewConfigEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_document_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing full config on a document."""
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
mock_document_repo.get.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
@@ -260,13 +281,13 @@ class TestAugmentationBatchEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_dataset_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating augmented dataset."""
|
||||
# Mock dataset exists
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.total_images = 100
|
||||
mock_admin_db.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_repo.get.return_value = mock_dataset
|
||||
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
|
||||
Reference in New Issue
Block a user