318 lines
11 KiB
Python
318 lines
11 KiB
Python
"""
|
|
Tests for Dashboard API Endpoints and Services.
|
|
|
|
Tests are split into:
|
|
1. Unit tests for business logic (is_annotation_complete, etc.)
|
|
2. Service tests with mocked database
|
|
3. Integration tests via TestClient (requires DB)
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
# Test data constants
|
|
TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001"
|
|
TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
|
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001"
|
|
|
|
|
|
class TestAnnotationCompletenessLogic:
|
|
"""Unit tests for annotation completeness calculation logic.
|
|
|
|
These tests verify the core business logic:
|
|
- Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro)
|
|
- Incomplete: labeled but missing required fields
|
|
"""
|
|
|
|
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
|
|
"""Document with invoice_number + bankgiro should be complete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 0, "class_name": "invoice_number"},
|
|
{"class_id": 4, "class_name": "bankgiro"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
|
|
"""Document with ocr_number + plusgiro should be complete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 3, "class_name": "ocr_number"},
|
|
{"class_id": 5, "class_name": "plusgiro"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
|
|
"""Document with invoice_number + plusgiro should be complete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 0, "class_name": "invoice_number"},
|
|
{"class_id": 5, "class_name": "plusgiro"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
|
|
"""Document with ocr_number + bankgiro should be complete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 3, "class_name": "ocr_number"},
|
|
{"class_id": 4, "class_name": "bankgiro"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_document_with_only_identifier_is_incomplete(self):
|
|
"""Document with only identifier field should be incomplete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 0, "class_name": "invoice_number"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is False
|
|
|
|
def test_document_with_only_payment_is_incomplete(self):
|
|
"""Document with only payment field should be incomplete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 4, "class_name": "bankgiro"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is False
|
|
|
|
def test_document_with_no_annotations_is_incomplete(self):
|
|
"""Document with no annotations should be incomplete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
assert is_annotation_complete([]) is False
|
|
|
|
def test_document_with_other_fields_only_is_incomplete(self):
|
|
"""Document with only non-essential fields should be incomplete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 1, "class_name": "invoice_date"},
|
|
{"class_id": 6, "class_name": "amount"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is False
|
|
|
|
def test_document_with_all_fields_is_complete(self):
|
|
"""Document with all fields should be complete."""
|
|
from inference.web.services.dashboard_service import is_annotation_complete
|
|
|
|
annotations = [
|
|
{"class_id": 0, "class_name": "invoice_number"},
|
|
{"class_id": 1, "class_name": "invoice_date"},
|
|
{"class_id": 4, "class_name": "bankgiro"},
|
|
{"class_id": 6, "class_name": "amount"},
|
|
]
|
|
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
|
|
class TestDashboardStatsService:
|
|
"""Tests for DashboardStatsService with mocked database."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create a mock database session."""
|
|
session = MagicMock()
|
|
session.exec.return_value.one.return_value = 0
|
|
return session
|
|
|
|
def test_completeness_rate_calculation(self):
|
|
"""Test completeness rate is calculated correctly."""
|
|
# Direct calculation test
|
|
complete = 25
|
|
incomplete = 8
|
|
total_assessed = complete + incomplete
|
|
expected_rate = round(complete / total_assessed * 100, 2)
|
|
|
|
assert expected_rate == pytest.approx(75.76, rel=0.01)
|
|
|
|
def test_completeness_rate_zero_documents(self):
|
|
"""Test completeness rate is 0 when no documents."""
|
|
complete = 0
|
|
incomplete = 0
|
|
total_assessed = complete + incomplete
|
|
|
|
completeness_rate = (
|
|
round(complete / total_assessed * 100, 2)
|
|
if total_assessed > 0
|
|
else 0.0
|
|
)
|
|
|
|
assert completeness_rate == 0.0
|
|
|
|
|
|
class TestDashboardActivityService:
|
|
"""Tests for DashboardActivityService activity aggregation."""
|
|
|
|
def test_activity_types(self):
|
|
"""Test all activity types are defined."""
|
|
expected_types = [
|
|
"document_uploaded",
|
|
"annotation_modified",
|
|
"training_completed",
|
|
"training_failed",
|
|
"model_activated",
|
|
]
|
|
|
|
for activity_type in expected_types:
|
|
assert activity_type in expected_types
|
|
|
|
|
|
class TestDashboardSchemas:
|
|
"""Tests for Dashboard API schemas."""
|
|
|
|
def test_dashboard_stats_response_schema(self):
|
|
"""Test DashboardStatsResponse schema validation."""
|
|
from inference.web.schemas.admin import DashboardStatsResponse
|
|
|
|
response = DashboardStatsResponse(
|
|
total_documents=38,
|
|
annotation_complete=25,
|
|
annotation_incomplete=8,
|
|
pending=5,
|
|
completeness_rate=75.76,
|
|
)
|
|
|
|
assert response.total_documents == 38
|
|
assert response.annotation_complete == 25
|
|
assert response.annotation_incomplete == 8
|
|
assert response.pending == 5
|
|
assert response.completeness_rate == 75.76
|
|
|
|
def test_active_model_response_schema(self):
|
|
"""Test ActiveModelResponse schema with null model."""
|
|
from inference.web.schemas.admin import ActiveModelResponse
|
|
|
|
response = ActiveModelResponse(
|
|
model=None,
|
|
running_training=None,
|
|
)
|
|
|
|
assert response.model is None
|
|
assert response.running_training is None
|
|
|
|
def test_active_model_info_schema(self):
|
|
"""Test ActiveModelInfo schema validation."""
|
|
from inference.web.schemas.admin import ActiveModelInfo
|
|
|
|
model = ActiveModelInfo(
|
|
version_id=TEST_MODEL_UUID,
|
|
version="1.2.0",
|
|
name="Invoice Model",
|
|
metrics_mAP=0.951,
|
|
metrics_precision=0.94,
|
|
metrics_recall=0.92,
|
|
document_count=500,
|
|
activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
|
|
)
|
|
|
|
assert model.version == "1.2.0"
|
|
assert model.name == "Invoice Model"
|
|
assert model.metrics_mAP == 0.951
|
|
|
|
def test_running_training_info_schema(self):
|
|
"""Test RunningTrainingInfo schema validation."""
|
|
from inference.web.schemas.admin import RunningTrainingInfo
|
|
|
|
task = RunningTrainingInfo(
|
|
task_id=TEST_TASK_UUID,
|
|
name="Run-2024-02",
|
|
status="running",
|
|
started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc),
|
|
progress=45,
|
|
)
|
|
|
|
assert task.name == "Run-2024-02"
|
|
assert task.status == "running"
|
|
assert task.progress == 45
|
|
|
|
def test_activity_item_schema(self):
|
|
"""Test ActivityItem schema validation."""
|
|
from inference.web.schemas.admin import ActivityItem
|
|
|
|
activity = ActivityItem(
|
|
type="model_activated",
|
|
description="Activated model v1.2.0",
|
|
timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc),
|
|
metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"},
|
|
)
|
|
|
|
assert activity.type == "model_activated"
|
|
assert activity.description == "Activated model v1.2.0"
|
|
assert activity.metadata["version"] == "1.2.0"
|
|
|
|
def test_recent_activity_response_schema(self):
|
|
"""Test RecentActivityResponse schema with empty activities."""
|
|
from inference.web.schemas.admin import RecentActivityResponse
|
|
|
|
response = RecentActivityResponse(activities=[])
|
|
|
|
assert response.activities == []
|
|
|
|
|
|
class TestDashboardRouterCreation:
|
|
"""Tests for dashboard router creation."""
|
|
|
|
def test_creates_router_with_expected_endpoints(self):
|
|
"""Test router is created with expected endpoint paths."""
|
|
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
|
|
|
router = create_dashboard_router()
|
|
|
|
paths = [route.path for route in router.routes]
|
|
|
|
assert any("/stats" in p for p in paths)
|
|
assert any("/active-model" in p for p in paths)
|
|
assert any("/activity" in p for p in paths)
|
|
|
|
def test_router_has_correct_prefix(self):
|
|
"""Test router has /admin/dashboard prefix."""
|
|
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
|
|
|
router = create_dashboard_router()
|
|
|
|
assert router.prefix == "/admin/dashboard"
|
|
|
|
def test_router_has_dashboard_tag(self):
|
|
"""Test router uses Dashboard tag."""
|
|
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
|
|
|
router = create_dashboard_router()
|
|
|
|
assert "Dashboard" in router.tags
|
|
|
|
|
|
class TestFieldClassIds:
|
|
"""Tests for field class ID constants."""
|
|
|
|
def test_identifier_class_ids(self):
|
|
"""Test identifier field class IDs."""
|
|
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
|
|
|
|
# invoice_number = 0, ocr_number = 3
|
|
assert 0 in IDENTIFIER_CLASS_IDS
|
|
assert 3 in IDENTIFIER_CLASS_IDS
|
|
|
|
def test_payment_class_ids(self):
|
|
"""Test payment field class IDs."""
|
|
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
|
|
|
|
# bankgiro = 4, plusgiro = 5
|
|
assert 4 in PAYMENT_CLASS_IDS
|
|
assert 5 in PAYMENT_CLASS_IDS
|