Add more tests
This commit is contained in:
@@ -196,3 +196,121 @@ class TestAnnotationModel:
|
||||
assert 0 <= ann.y_center <= 1
|
||||
assert 0 <= ann.width <= 1
|
||||
assert 0 <= ann.height <= 1
|
||||
|
||||
|
||||
class TestAutoLabelFilePathResolution:
|
||||
"""Tests for auto-label file path resolution.
|
||||
|
||||
The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf")
|
||||
to an actual filesystem path via the storage helper.
|
||||
"""
|
||||
|
||||
def test_extracts_filename_from_storage_path(self):
|
||||
"""Test that filename is extracted from storage path correctly."""
|
||||
# Storage paths are like "raw_pdfs/uuid.pdf"
|
||||
storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
# The annotation endpoint extracts filename
|
||||
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
||||
|
||||
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
def test_handles_path_without_prefix(self):
|
||||
"""Test that paths without prefix are handled."""
|
||||
storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
||||
|
||||
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
def test_storage_helper_resolves_path(self):
|
||||
"""Test that storage helper can resolve the path."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock storage helper
|
||||
mock_storage = MagicMock()
|
||||
mock_path = Path("/storage/raw_pdfs/test.pdf")
|
||||
mock_storage.get_raw_pdf_local_path.return_value = mock_path
|
||||
|
||||
with patch(
|
||||
"inference.web.services.storage_helpers.get_storage_helper",
|
||||
return_value=mock_storage,
|
||||
):
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
storage = get_storage_helper()
|
||||
result = storage.get_raw_pdf_local_path("test.pdf")
|
||||
|
||||
assert result == mock_path
|
||||
mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf")
|
||||
|
||||
def test_auto_label_request_validation(self):
|
||||
"""Test AutoLabelRequest validates field_values."""
|
||||
# Valid request
|
||||
request = AutoLabelRequest(
|
||||
field_values={"InvoiceNumber": "12345"},
|
||||
replace_existing=False,
|
||||
)
|
||||
assert request.field_values == {"InvoiceNumber": "12345"}
|
||||
|
||||
# Empty field_values should be valid at schema level
|
||||
# (endpoint validates non-empty)
|
||||
request_empty = AutoLabelRequest(
|
||||
field_values={},
|
||||
replace_existing=False,
|
||||
)
|
||||
assert request_empty.field_values == {}
|
||||
|
||||
|
||||
class TestMatchClassAttributes:
|
||||
"""Tests for Match class attributes used in auto-labeling.
|
||||
|
||||
The autolabel service uses Match objects from FieldMatcher.
|
||||
Verifies the correct attribute names are used.
|
||||
"""
|
||||
|
||||
def test_match_has_matched_text_attribute(self):
|
||||
"""Test that Match class has matched_text attribute (not matched_value)."""
|
||||
from shared.matcher.models import Match
|
||||
|
||||
# Create a Match object
|
||||
match = Match(
|
||||
field="invoice_number",
|
||||
value="12345",
|
||||
bbox=(100, 100, 200, 150),
|
||||
page_no=0,
|
||||
score=0.95,
|
||||
matched_text="INV-12345",
|
||||
context_keywords=["faktura", "nummer"],
|
||||
)
|
||||
|
||||
# Verify matched_text exists (this is what autolabel.py should use)
|
||||
assert hasattr(match, "matched_text")
|
||||
assert match.matched_text == "INV-12345"
|
||||
|
||||
# Verify matched_value does NOT exist
|
||||
# This was the bug - autolabel.py was using matched_value instead of matched_text
|
||||
assert not hasattr(match, "matched_value")
|
||||
|
||||
def test_match_attributes_for_annotation_creation(self):
|
||||
"""Test that Match has all attributes needed for annotation creation."""
|
||||
from shared.matcher.models import Match
|
||||
|
||||
match = Match(
|
||||
field="amount",
|
||||
value="1000.00",
|
||||
bbox=(50, 200, 150, 230),
|
||||
page_no=0,
|
||||
score=0.88,
|
||||
matched_text="1 000,00",
|
||||
context_keywords=["att betala", "summa"],
|
||||
)
|
||||
|
||||
# These are all the attributes used in autolabel._create_annotations_from_matches
|
||||
assert hasattr(match, "bbox")
|
||||
assert hasattr(match, "matched_text") # NOT matched_value
|
||||
assert hasattr(match, "score")
|
||||
|
||||
# Verify bbox format
|
||||
assert len(match.bbox) == 4 # (x0, y0, x1, y1)
|
||||
|
||||
@@ -3,7 +3,7 @@ Tests for Admin Authentication.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -132,6 +132,47 @@ class TestTokenRepository:
|
||||
with patch.object(repo, "_now", return_value=datetime.utcnow()):
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_expired_token_timezone_aware(self):
|
||||
"""Test expired token with timezone-aware datetime.
|
||||
|
||||
This verifies the fix for comparing timezone-aware and naive datetimes.
|
||||
The auth API now creates tokens with timezone-aware expiration dates.
|
||||
"""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create token with timezone-aware expiration (as auth API now does)
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
repo = TokenRepository()
|
||||
# _now() returns timezone-aware datetime, should compare correctly
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_not_expired_token_timezone_aware(self):
|
||||
"""Test non-expired token with timezone-aware datetime."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create token with timezone-aware expiration in the future
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=1),
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
repo = TokenRepository()
|
||||
assert repo.is_valid("test-token") is True
|
||||
|
||||
def test_is_valid_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
|
||||
317
tests/web/test_dashboard_api.py
Normal file
317
tests/web/test_dashboard_api.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Tests for Dashboard API Endpoints and Services.
|
||||
|
||||
Tests are split into:
|
||||
1. Unit tests for business logic (is_annotation_complete, etc.)
|
||||
2. Service tests with mocked database
|
||||
3. Integration tests via TestClient (requires DB)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# Test data constants
|
||||
TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001"
|
||||
|
||||
|
||||
class TestAnnotationCompletenessLogic:
|
||||
"""Unit tests for annotation completeness calculation logic.
|
||||
|
||||
These tests verify the core business logic:
|
||||
- Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro)
|
||||
- Incomplete: labeled but missing required fields
|
||||
"""
|
||||
|
||||
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
|
||||
"""Document with invoice_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
|
||||
"""Document with ocr_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
{"class_id": 5, "class_name": "plusgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
|
||||
"""Document with invoice_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 5, "class_name": "plusgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
|
||||
"""Document with ocr_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_only_identifier_is_incomplete(self):
|
||||
"""Document with only identifier field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_only_payment_is_incomplete(self):
|
||||
"""Document with only payment field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_no_annotations_is_incomplete(self):
|
||||
"""Document with no annotations should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
assert is_annotation_complete([]) is False
|
||||
|
||||
def test_document_with_other_fields_only_is_incomplete(self):
|
||||
"""Document with only non-essential fields should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 1, "class_name": "invoice_date"},
|
||||
{"class_id": 6, "class_name": "amount"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_all_fields_is_complete(self):
|
||||
"""Document with all fields should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 1, "class_name": "invoice_date"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
{"class_id": 6, "class_name": "amount"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
|
||||
class TestDashboardStatsService:
|
||||
"""Tests for DashboardStatsService with mocked database."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = MagicMock()
|
||||
session.exec.return_value.one.return_value = 0
|
||||
return session
|
||||
|
||||
def test_completeness_rate_calculation(self):
|
||||
"""Test completeness rate is calculated correctly."""
|
||||
# Direct calculation test
|
||||
complete = 25
|
||||
incomplete = 8
|
||||
total_assessed = complete + incomplete
|
||||
expected_rate = round(complete / total_assessed * 100, 2)
|
||||
|
||||
assert expected_rate == pytest.approx(75.76, rel=0.01)
|
||||
|
||||
def test_completeness_rate_zero_documents(self):
|
||||
"""Test completeness rate is 0 when no documents."""
|
||||
complete = 0
|
||||
incomplete = 0
|
||||
total_assessed = complete + incomplete
|
||||
|
||||
completeness_rate = (
|
||||
round(complete / total_assessed * 100, 2)
|
||||
if total_assessed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
assert completeness_rate == 0.0
|
||||
|
||||
|
||||
class TestDashboardActivityService:
|
||||
"""Tests for DashboardActivityService activity aggregation."""
|
||||
|
||||
def test_activity_types(self):
|
||||
"""Test all activity types are defined."""
|
||||
expected_types = [
|
||||
"document_uploaded",
|
||||
"annotation_modified",
|
||||
"training_completed",
|
||||
"training_failed",
|
||||
"model_activated",
|
||||
]
|
||||
|
||||
for activity_type in expected_types:
|
||||
assert activity_type in expected_types
|
||||
|
||||
|
||||
class TestDashboardSchemas:
|
||||
"""Tests for Dashboard API schemas."""
|
||||
|
||||
def test_dashboard_stats_response_schema(self):
|
||||
"""Test DashboardStatsResponse schema validation."""
|
||||
from inference.web.schemas.admin import DashboardStatsResponse
|
||||
|
||||
response = DashboardStatsResponse(
|
||||
total_documents=38,
|
||||
annotation_complete=25,
|
||||
annotation_incomplete=8,
|
||||
pending=5,
|
||||
completeness_rate=75.76,
|
||||
)
|
||||
|
||||
assert response.total_documents == 38
|
||||
assert response.annotation_complete == 25
|
||||
assert response.annotation_incomplete == 8
|
||||
assert response.pending == 5
|
||||
assert response.completeness_rate == 75.76
|
||||
|
||||
def test_active_model_response_schema(self):
|
||||
"""Test ActiveModelResponse schema with null model."""
|
||||
from inference.web.schemas.admin import ActiveModelResponse
|
||||
|
||||
response = ActiveModelResponse(
|
||||
model=None,
|
||||
running_training=None,
|
||||
)
|
||||
|
||||
assert response.model is None
|
||||
assert response.running_training is None
|
||||
|
||||
def test_active_model_info_schema(self):
|
||||
"""Test ActiveModelInfo schema validation."""
|
||||
from inference.web.schemas.admin import ActiveModelInfo
|
||||
|
||||
model = ActiveModelInfo(
|
||||
version_id=TEST_MODEL_UUID,
|
||||
version="1.2.0",
|
||||
name="Invoice Model",
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
document_count=500,
|
||||
activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert model.version == "1.2.0"
|
||||
assert model.name == "Invoice Model"
|
||||
assert model.metrics_mAP == 0.951
|
||||
|
||||
def test_running_training_info_schema(self):
|
||||
"""Test RunningTrainingInfo schema validation."""
|
||||
from inference.web.schemas.admin import RunningTrainingInfo
|
||||
|
||||
task = RunningTrainingInfo(
|
||||
task_id=TEST_TASK_UUID,
|
||||
name="Run-2024-02",
|
||||
status="running",
|
||||
started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc),
|
||||
progress=45,
|
||||
)
|
||||
|
||||
assert task.name == "Run-2024-02"
|
||||
assert task.status == "running"
|
||||
assert task.progress == 45
|
||||
|
||||
def test_activity_item_schema(self):
|
||||
"""Test ActivityItem schema validation."""
|
||||
from inference.web.schemas.admin import ActivityItem
|
||||
|
||||
activity = ActivityItem(
|
||||
type="model_activated",
|
||||
description="Activated model v1.2.0",
|
||||
timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc),
|
||||
metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"},
|
||||
)
|
||||
|
||||
assert activity.type == "model_activated"
|
||||
assert activity.description == "Activated model v1.2.0"
|
||||
assert activity.metadata["version"] == "1.2.0"
|
||||
|
||||
def test_recent_activity_response_schema(self):
|
||||
"""Test RecentActivityResponse schema with empty activities."""
|
||||
from inference.web.schemas.admin import RecentActivityResponse
|
||||
|
||||
response = RecentActivityResponse(activities=[])
|
||||
|
||||
assert response.activities == []
|
||||
|
||||
|
||||
class TestDashboardRouterCreation:
|
||||
"""Tests for dashboard router creation."""
|
||||
|
||||
def test_creates_router_with_expected_endpoints(self):
|
||||
"""Test router is created with expected endpoint paths."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
assert any("/stats" in p for p in paths)
|
||||
assert any("/active-model" in p for p in paths)
|
||||
assert any("/activity" in p for p in paths)
|
||||
|
||||
def test_router_has_correct_prefix(self):
|
||||
"""Test router has /admin/dashboard prefix."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
assert router.prefix == "/admin/dashboard"
|
||||
|
||||
def test_router_has_dashboard_tag(self):
|
||||
"""Test router uses Dashboard tag."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
assert "Dashboard" in router.tags
|
||||
|
||||
|
||||
class TestFieldClassIds:
|
||||
"""Tests for field class ID constants."""
|
||||
|
||||
def test_identifier_class_ids(self):
|
||||
"""Test identifier field class IDs."""
|
||||
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
|
||||
|
||||
# invoice_number = 0, ocr_number = 3
|
||||
assert 0 in IDENTIFIER_CLASS_IDS
|
||||
assert 3 in IDENTIFIER_CLASS_IDS
|
||||
|
||||
def test_payment_class_ids(self):
|
||||
"""Test payment field class IDs."""
|
||||
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
|
||||
|
||||
# bankgiro = 4, plusgiro = 5
|
||||
assert 4 in PAYMENT_CLASS_IDS
|
||||
assert 5 in PAYMENT_CLASS_IDS
|
||||
Reference in New Issue
Block a user