Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -196,3 +196,121 @@ class TestAnnotationModel:
assert 0 <= ann.y_center <= 1
assert 0 <= ann.width <= 1
assert 0 <= ann.height <= 1
class TestAutoLabelFilePathResolution:
"""Tests for auto-label file path resolution.
The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf")
to an actual filesystem path via the storage helper.
"""
def test_extracts_filename_from_storage_path(self):
"""Test that filename is extracted from storage path correctly."""
# Storage paths are like "raw_pdfs/uuid.pdf"
storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf"
# The annotation endpoint extracts filename
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
def test_handles_path_without_prefix(self):
"""Test that paths without prefix are handled."""
storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf"
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
def test_storage_helper_resolves_path(self):
"""Test that storage helper can resolve the path."""
from pathlib import Path
from unittest.mock import MagicMock, patch
# Mock storage helper
mock_storage = MagicMock()
mock_path = Path("/storage/raw_pdfs/test.pdf")
mock_storage.get_raw_pdf_local_path.return_value = mock_path
with patch(
"inference.web.services.storage_helpers.get_storage_helper",
return_value=mock_storage,
):
from inference.web.services.storage_helpers import get_storage_helper
storage = get_storage_helper()
result = storage.get_raw_pdf_local_path("test.pdf")
assert result == mock_path
mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf")
def test_auto_label_request_validation(self):
"""Test AutoLabelRequest validates field_values."""
# Valid request
request = AutoLabelRequest(
field_values={"InvoiceNumber": "12345"},
replace_existing=False,
)
assert request.field_values == {"InvoiceNumber": "12345"}
# Empty field_values should be valid at schema level
# (endpoint validates non-empty)
request_empty = AutoLabelRequest(
field_values={},
replace_existing=False,
)
assert request_empty.field_values == {}
class TestMatchClassAttributes:
"""Tests for Match class attributes used in auto-labeling.
The autolabel service uses Match objects from FieldMatcher.
Verifies the correct attribute names are used.
"""
def test_match_has_matched_text_attribute(self):
"""Test that Match class has matched_text attribute (not matched_value)."""
from shared.matcher.models import Match
# Create a Match object
match = Match(
field="invoice_number",
value="12345",
bbox=(100, 100, 200, 150),
page_no=0,
score=0.95,
matched_text="INV-12345",
context_keywords=["faktura", "nummer"],
)
# Verify matched_text exists (this is what autolabel.py should use)
assert hasattr(match, "matched_text")
assert match.matched_text == "INV-12345"
# Verify matched_value does NOT exist
# This was the bug - autolabel.py was using matched_value instead of matched_text
assert not hasattr(match, "matched_value")
def test_match_attributes_for_annotation_creation(self):
"""Test that Match has all attributes needed for annotation creation."""
from shared.matcher.models import Match
match = Match(
field="amount",
value="1000.00",
bbox=(50, 200, 150, 230),
page_no=0,
score=0.88,
matched_text="1 000,00",
context_keywords=["att betala", "summa"],
)
# These are all the attributes used in autolabel._create_annotations_from_matches
assert hasattr(match, "bbox")
assert hasattr(match, "matched_text") # NOT matched_value
assert hasattr(match, "score")
# Verify bbox format
assert len(match.bbox) == 4 # (x0, y0, x1, y1)

View File

@@ -3,7 +3,7 @@ Tests for Admin Authentication.
"""
import pytest
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
@@ -132,6 +132,47 @@ class TestTokenRepository:
with patch.object(repo, "_now", return_value=datetime.utcnow()):
assert repo.is_valid("test-token") is False
def test_is_valid_expired_token_timezone_aware(self):
"""Test expired token with timezone-aware datetime.
This verifies the fix for comparing timezone-aware and naive datetimes.
The auth API now creates tokens with timezone-aware expiration dates.
"""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
# Create token with timezone-aware expiration (as auth API now does)
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
)
mock_session.get.return_value = mock_token
repo = TokenRepository()
# _now() returns timezone-aware datetime, should compare correctly
assert repo.is_valid("test-token") is False
def test_is_valid_not_expired_token_timezone_aware(self):
"""Test non-expired token with timezone-aware datetime."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
# Create token with timezone-aware expiration in the future
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=1),
)
mock_session.get.return_value = mock_token
repo = TokenRepository()
assert repo.is_valid("test-token") is True
def test_is_valid_token_not_found(self):
"""Test token not found."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:

View File

@@ -0,0 +1,317 @@
"""
Tests for Dashboard API Endpoints and Services.
Tests are split into:
1. Unit tests for business logic (is_annotation_complete, etc.)
2. Service tests with mocked database
3. Integration tests via TestClient (requires DB)
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
# Test data constants
TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001"
TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001"
class TestAnnotationCompletenessLogic:
"""Unit tests for annotation completeness calculation logic.
These tests verify the core business logic:
- Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro)
- Incomplete: labeled but missing required fields
"""
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
"""Document with invoice_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
"""Document with ocr_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
{"class_id": 5, "class_name": "plusgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
"""Document with invoice_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 5, "class_name": "plusgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
"""Document with ocr_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_only_identifier_is_incomplete(self):
"""Document with only identifier field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_only_payment_is_incomplete(self):
"""Document with only payment field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_no_annotations_is_incomplete(self):
"""Document with no annotations should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
assert is_annotation_complete([]) is False
def test_document_with_other_fields_only_is_incomplete(self):
"""Document with only non-essential fields should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 1, "class_name": "invoice_date"},
{"class_id": 6, "class_name": "amount"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_all_fields_is_complete(self):
"""Document with all fields should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 1, "class_name": "invoice_date"},
{"class_id": 4, "class_name": "bankgiro"},
{"class_id": 6, "class_name": "amount"},
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService with mocked database."""
@pytest.fixture
def mock_session(self):
"""Create a mock database session."""
session = MagicMock()
session.exec.return_value.one.return_value = 0
return session
def test_completeness_rate_calculation(self):
"""Test completeness rate is calculated correctly."""
# Direct calculation test
complete = 25
incomplete = 8
total_assessed = complete + incomplete
expected_rate = round(complete / total_assessed * 100, 2)
assert expected_rate == pytest.approx(75.76, rel=0.01)
def test_completeness_rate_zero_documents(self):
"""Test completeness rate is 0 when no documents."""
complete = 0
incomplete = 0
total_assessed = complete + incomplete
completeness_rate = (
round(complete / total_assessed * 100, 2)
if total_assessed > 0
else 0.0
)
assert completeness_rate == 0.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService activity aggregation."""
def test_activity_types(self):
"""Test all activity types are defined."""
expected_types = [
"document_uploaded",
"annotation_modified",
"training_completed",
"training_failed",
"model_activated",
]
for activity_type in expected_types:
assert activity_type in expected_types
class TestDashboardSchemas:
"""Tests for Dashboard API schemas."""
def test_dashboard_stats_response_schema(self):
"""Test DashboardStatsResponse schema validation."""
from inference.web.schemas.admin import DashboardStatsResponse
response = DashboardStatsResponse(
total_documents=38,
annotation_complete=25,
annotation_incomplete=8,
pending=5,
completeness_rate=75.76,
)
assert response.total_documents == 38
assert response.annotation_complete == 25
assert response.annotation_incomplete == 8
assert response.pending == 5
assert response.completeness_rate == 75.76
def test_active_model_response_schema(self):
"""Test ActiveModelResponse schema with null model."""
from inference.web.schemas.admin import ActiveModelResponse
response = ActiveModelResponse(
model=None,
running_training=None,
)
assert response.model is None
assert response.running_training is None
def test_active_model_info_schema(self):
"""Test ActiveModelInfo schema validation."""
from inference.web.schemas.admin import ActiveModelInfo
model = ActiveModelInfo(
version_id=TEST_MODEL_UUID,
version="1.2.0",
name="Invoice Model",
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=500,
activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
)
assert model.version == "1.2.0"
assert model.name == "Invoice Model"
assert model.metrics_mAP == 0.951
def test_running_training_info_schema(self):
"""Test RunningTrainingInfo schema validation."""
from inference.web.schemas.admin import RunningTrainingInfo
task = RunningTrainingInfo(
task_id=TEST_TASK_UUID,
name="Run-2024-02",
status="running",
started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc),
progress=45,
)
assert task.name == "Run-2024-02"
assert task.status == "running"
assert task.progress == 45
def test_activity_item_schema(self):
"""Test ActivityItem schema validation."""
from inference.web.schemas.admin import ActivityItem
activity = ActivityItem(
type="model_activated",
description="Activated model v1.2.0",
timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc),
metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"},
)
assert activity.type == "model_activated"
assert activity.description == "Activated model v1.2.0"
assert activity.metadata["version"] == "1.2.0"
def test_recent_activity_response_schema(self):
"""Test RecentActivityResponse schema with empty activities."""
from inference.web.schemas.admin import RecentActivityResponse
response = RecentActivityResponse(activities=[])
assert response.activities == []
class TestDashboardRouterCreation:
"""Tests for dashboard router creation."""
def test_creates_router_with_expected_endpoints(self):
"""Test router is created with expected endpoint paths."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
paths = [route.path for route in router.routes]
assert any("/stats" in p for p in paths)
assert any("/active-model" in p for p in paths)
assert any("/activity" in p for p in paths)
def test_router_has_correct_prefix(self):
"""Test router has /admin/dashboard prefix."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
assert router.prefix == "/admin/dashboard"
def test_router_has_dashboard_tag(self):
"""Test router uses Dashboard tag."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
assert "Dashboard" in router.tags
class TestFieldClassIds:
"""Tests for field class ID constants."""
def test_identifier_class_ids(self):
"""Test identifier field class IDs."""
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
# invoice_number = 0, ocr_number = 3
assert 0 in IDENTIFIER_CLASS_IDS
assert 3 in IDENTIFIER_CLASS_IDS
def test_payment_class_ids(self):
"""Test payment field class IDs."""
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
# bankgiro = 4, plusgiro = 5
assert 4 in PAYMENT_CLASS_IDS
assert 5 in PAYMENT_CLASS_IDS