Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,497 @@
"""
Dashboard Service Integration Tests
Tests DashboardStatsService and DashboardActivityService with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
is_annotation_complete,
IDENTIFIER_CLASS_IDS,
PAYMENT_CLASS_IDS,
)
class TestIsAnnotationComplete:
"""Tests for is_annotation_complete function."""
def test_complete_with_invoice_number_and_bankgiro(self):
"""Test complete with invoice_number (0) and bankgiro (4)."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 4}, # bankgiro
]
assert is_annotation_complete(annotations) is True
def test_complete_with_ocr_number_and_plusgiro(self):
"""Test complete with ocr_number (3) and plusgiro (5)."""
annotations = [
{"class_id": 3}, # ocr_number
{"class_id": 5}, # plusgiro
]
assert is_annotation_complete(annotations) is True
def test_incomplete_missing_identifier(self):
"""Test incomplete when missing identifier."""
annotations = [
{"class_id": 4}, # bankgiro only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_missing_payment(self):
"""Test incomplete when missing payment."""
annotations = [
{"class_id": 0}, # invoice_number only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_empty_annotations(self):
"""Test incomplete with empty annotations."""
assert is_annotation_complete([]) is False
def test_complete_with_multiple_fields(self):
"""Test complete with multiple fields."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 1}, # invoice_date
{"class_id": 3}, # ocr_number
{"class_id": 4}, # bankgiro
{"class_id": 5}, # plusgiro
{"class_id": 6}, # amount
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService."""
def test_get_stats_empty_database(self, patched_session):
"""Test stats with empty database."""
service = DashboardStatsService()
stats = service.get_stats()
assert stats["total_documents"] == 0
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 0
assert stats["pending"] == 0
assert stats["completeness_rate"] == 0.0
def test_get_stats_with_documents(self, patched_session, admin_token):
"""Test stats with various document states."""
service = DashboardStatsService()
session = patched_session
# Create documents with different statuses
docs = []
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
stats = service.get_stats()
assert stats["total_documents"] == 5
assert stats["pending"] == 2 # pending + auto_labeling
def test_get_stats_complete_annotations(self, patched_session, admin_token):
"""Test completeness calculation with proper annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier annotation (invoice_number = class_id 0)
ann1 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann1)
# Add payment annotation (bankgiro = class_id 4)
ann2 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4,
class_name="bankgiro",
x_center=0.5,
y_center=0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=160,
bbox_width=160,
bbox_height=40,
text_value="123-4567",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann2)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 0
assert stats["completeness_rate"] == 100.0
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
"""Test completeness with incomplete annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document missing payment annotation
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="incomplete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/incomplete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add only identifier annotation (missing payment)
ann = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 0.0
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
"""Test stats with mix of complete and incomplete documents."""
service = DashboardStatsService()
session = patched_session
# Create 2 labeled documents
docs = []
for i in range(2):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"mixed_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/mixed_doc_{i}.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
# First document: complete (has identifier + payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
# Second document: incomplete (missing payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[1].document_id,
page_number=1,
class_id=0, # invoice_number only
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 50.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService."""
def test_get_recent_activities_empty(self, patched_session):
"""Test activities with empty database."""
service = DashboardActivityService()
activities = service.get_recent_activities()
assert activities == []
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
"""Test activities include document uploads."""
service = DashboardActivityService()
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities()
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
"""Test activities include annotation overrides."""
service = DashboardActivityService()
session = patched_session
# Create annotation history with override
history = AnnotationHistory(
history_id=uuid4(),
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="override",
previous_value={"text_value": "OLD-001"},
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
changed_by="test-admin",
created_at=datetime.now(timezone.utc),
)
session.add(history)
session.commit()
activities = service.get_recent_activities()
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
assert len(override_activities) >= 1
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
"""Test activities include training completions."""
service = DashboardActivityService()
session = patched_session
# Update training task to completed
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.85
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
training_activities = [a for a in activities if a["type"] == "training_completed"]
assert len(training_activities) >= 1
assert "mAP" in training_activities[0]["metadata"]
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
"""Test activities include training failures."""
service = DashboardActivityService()
session = patched_session
# Update training task to failed
sample_training_task.status = "failed"
sample_training_task.error_message = "CUDA out of memory"
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
failed_activities = [a for a in activities if a["type"] == "training_failed"]
assert len(failed_activities) >= 1
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
"""Test activities include model activations."""
service = DashboardActivityService()
session = patched_session
# Activate model
sample_model_version.is_active = True
sample_model_version.activated_at = datetime.now(timezone.utc)
session.add(sample_model_version)
session.commit()
activities = service.get_recent_activities()
activation_activities = [a for a in activities if a["type"] == "model_activated"]
assert len(activation_activities) >= 1
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
def test_get_recent_activities_limit(self, patched_session, admin_token):
"""Test activity limit parameter."""
service = DashboardActivityService()
session = patched_session
# Create many documents
for i in range(20):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities(limit=5)
assert len(activities) <= 5
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
"""Test activities are sorted by timestamp descending."""
service = DashboardActivityService()
session = patched_session
# Create document
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="sorted_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/sorted_doc.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
# Complete training task
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.90
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
# Verify sorted by timestamp DESC
timestamps = [a["timestamp"] for a in activities]
assert timestamps == sorted(timestamps, reverse=True)