498 lines
17 KiB
Python
498 lines
17 KiB
Python
"""
|
|
Dashboard Service Integration Tests
|
|
|
|
Tests DashboardStatsService and DashboardActivityService with real database operations.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from inference.data.admin_models import (
|
|
AdminAnnotation,
|
|
AdminDocument,
|
|
AnnotationHistory,
|
|
ModelVersion,
|
|
TrainingDataset,
|
|
TrainingTask,
|
|
)
|
|
from inference.web.services.dashboard_service import (
|
|
DashboardStatsService,
|
|
DashboardActivityService,
|
|
is_annotation_complete,
|
|
IDENTIFIER_CLASS_IDS,
|
|
PAYMENT_CLASS_IDS,
|
|
)
|
|
|
|
|
|
class TestIsAnnotationComplete:
|
|
"""Tests for is_annotation_complete function."""
|
|
|
|
def test_complete_with_invoice_number_and_bankgiro(self):
|
|
"""Test complete with invoice_number (0) and bankgiro (4)."""
|
|
annotations = [
|
|
{"class_id": 0}, # invoice_number
|
|
{"class_id": 4}, # bankgiro
|
|
]
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_complete_with_ocr_number_and_plusgiro(self):
|
|
"""Test complete with ocr_number (3) and plusgiro (5)."""
|
|
annotations = [
|
|
{"class_id": 3}, # ocr_number
|
|
{"class_id": 5}, # plusgiro
|
|
]
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
def test_incomplete_missing_identifier(self):
|
|
"""Test incomplete when missing identifier."""
|
|
annotations = [
|
|
{"class_id": 4}, # bankgiro only
|
|
]
|
|
assert is_annotation_complete(annotations) is False
|
|
|
|
def test_incomplete_missing_payment(self):
|
|
"""Test incomplete when missing payment."""
|
|
annotations = [
|
|
{"class_id": 0}, # invoice_number only
|
|
]
|
|
assert is_annotation_complete(annotations) is False
|
|
|
|
def test_incomplete_empty_annotations(self):
|
|
"""Test incomplete with empty annotations."""
|
|
assert is_annotation_complete([]) is False
|
|
|
|
def test_complete_with_multiple_fields(self):
|
|
"""Test complete with multiple fields."""
|
|
annotations = [
|
|
{"class_id": 0}, # invoice_number
|
|
{"class_id": 1}, # invoice_date
|
|
{"class_id": 3}, # ocr_number
|
|
{"class_id": 4}, # bankgiro
|
|
{"class_id": 5}, # plusgiro
|
|
{"class_id": 6}, # amount
|
|
]
|
|
assert is_annotation_complete(annotations) is True
|
|
|
|
|
|
class TestDashboardStatsService:
|
|
"""Tests for DashboardStatsService."""
|
|
|
|
def test_get_stats_empty_database(self, patched_session):
|
|
"""Test stats with empty database."""
|
|
service = DashboardStatsService()
|
|
|
|
stats = service.get_stats()
|
|
|
|
assert stats["total_documents"] == 0
|
|
assert stats["annotation_complete"] == 0
|
|
assert stats["annotation_incomplete"] == 0
|
|
assert stats["pending"] == 0
|
|
assert stats["completeness_rate"] == 0.0
|
|
|
|
def test_get_stats_with_documents(self, patched_session, admin_token):
|
|
"""Test stats with various document states."""
|
|
service = DashboardStatsService()
|
|
session = patched_session
|
|
|
|
# Create documents with different statuses
|
|
docs = []
|
|
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename=f"doc_{i}.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path=f"/uploads/doc_{i}.pdf",
|
|
page_count=1,
|
|
status=status,
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
docs.append(doc)
|
|
session.commit()
|
|
|
|
stats = service.get_stats()
|
|
|
|
assert stats["total_documents"] == 5
|
|
assert stats["pending"] == 2 # pending + auto_labeling
|
|
|
|
def test_get_stats_complete_annotations(self, patched_session, admin_token):
|
|
"""Test completeness calculation with proper annotations."""
|
|
service = DashboardStatsService()
|
|
session = patched_session
|
|
|
|
# Create a labeled document with complete annotations
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename="complete_doc.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/complete_doc.pdf",
|
|
page_count=1,
|
|
status="labeled",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
session.commit()
|
|
|
|
# Add identifier annotation (invoice_number = class_id 0)
|
|
ann1 = AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=doc.document_id,
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.1,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=80,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
text_value="INV-001",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(ann1)
|
|
|
|
# Add payment annotation (bankgiro = class_id 4)
|
|
ann2 = AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=doc.document_id,
|
|
page_number=1,
|
|
class_id=4,
|
|
class_name="bankgiro",
|
|
x_center=0.5,
|
|
y_center=0.2,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=160,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
text_value="123-4567",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(ann2)
|
|
session.commit()
|
|
|
|
stats = service.get_stats()
|
|
|
|
assert stats["annotation_complete"] == 1
|
|
assert stats["annotation_incomplete"] == 0
|
|
assert stats["completeness_rate"] == 100.0
|
|
|
|
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
|
|
"""Test completeness with incomplete annotations."""
|
|
service = DashboardStatsService()
|
|
session = patched_session
|
|
|
|
# Create a labeled document missing payment annotation
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename="incomplete_doc.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/incomplete_doc.pdf",
|
|
page_count=1,
|
|
status="labeled",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
session.commit()
|
|
|
|
# Add only identifier annotation (missing payment)
|
|
ann = AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=doc.document_id,
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.1,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=80,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
text_value="INV-001",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(ann)
|
|
session.commit()
|
|
|
|
stats = service.get_stats()
|
|
|
|
assert stats["annotation_complete"] == 0
|
|
assert stats["annotation_incomplete"] == 1
|
|
assert stats["completeness_rate"] == 0.0
|
|
|
|
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
|
|
"""Test stats with mix of complete and incomplete documents."""
|
|
service = DashboardStatsService()
|
|
session = patched_session
|
|
|
|
# Create 2 labeled documents
|
|
docs = []
|
|
for i in range(2):
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename=f"mixed_doc_{i}.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path=f"/uploads/mixed_doc_{i}.pdf",
|
|
page_count=1,
|
|
status="labeled",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
docs.append(doc)
|
|
session.commit()
|
|
|
|
# First document: complete (has identifier + payment)
|
|
session.add(AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=docs[0].document_id,
|
|
page_number=1,
|
|
class_id=0, # invoice_number
|
|
class_name="invoice_number",
|
|
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
|
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
))
|
|
session.add(AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=docs[0].document_id,
|
|
page_number=1,
|
|
class_id=4, # bankgiro
|
|
class_name="bankgiro",
|
|
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
|
|
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
))
|
|
|
|
# Second document: incomplete (missing payment)
|
|
session.add(AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=docs[1].document_id,
|
|
page_number=1,
|
|
class_id=0, # invoice_number only
|
|
class_name="invoice_number",
|
|
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
|
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
))
|
|
session.commit()
|
|
|
|
stats = service.get_stats()
|
|
|
|
assert stats["annotation_complete"] == 1
|
|
assert stats["annotation_incomplete"] == 1
|
|
assert stats["completeness_rate"] == 50.0
|
|
|
|
|
|
class TestDashboardActivityService:
|
|
"""Tests for DashboardActivityService."""
|
|
|
|
def test_get_recent_activities_empty(self, patched_session):
|
|
"""Test activities with empty database."""
|
|
service = DashboardActivityService()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
assert activities == []
|
|
|
|
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
|
|
"""Test activities include document uploads."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Create documents
|
|
for i in range(3):
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename=f"activity_doc_{i}.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path=f"/uploads/activity_doc_{i}.pdf",
|
|
page_count=1,
|
|
status="pending",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
|
|
assert len(upload_activities) == 3
|
|
|
|
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
|
|
"""Test activities include annotation overrides."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Create annotation history with override
|
|
history = AnnotationHistory(
|
|
history_id=uuid4(),
|
|
annotation_id=sample_annotation.annotation_id,
|
|
document_id=sample_document.document_id,
|
|
action="override",
|
|
previous_value={"text_value": "OLD-001"},
|
|
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
|
|
changed_by="test-admin",
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(history)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
|
|
assert len(override_activities) >= 1
|
|
|
|
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
|
|
"""Test activities include training completions."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Update training task to completed
|
|
sample_training_task.status = "completed"
|
|
sample_training_task.metrics_mAP = 0.85
|
|
sample_training_task.updated_at = datetime.now(timezone.utc)
|
|
session.add(sample_training_task)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
training_activities = [a for a in activities if a["type"] == "training_completed"]
|
|
assert len(training_activities) >= 1
|
|
assert "mAP" in training_activities[0]["metadata"]
|
|
|
|
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
|
|
"""Test activities include training failures."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Update training task to failed
|
|
sample_training_task.status = "failed"
|
|
sample_training_task.error_message = "CUDA out of memory"
|
|
sample_training_task.updated_at = datetime.now(timezone.utc)
|
|
session.add(sample_training_task)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
failed_activities = [a for a in activities if a["type"] == "training_failed"]
|
|
assert len(failed_activities) >= 1
|
|
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
|
|
|
|
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
|
|
"""Test activities include model activations."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Activate model
|
|
sample_model_version.is_active = True
|
|
sample_model_version.activated_at = datetime.now(timezone.utc)
|
|
session.add(sample_model_version)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
activation_activities = [a for a in activities if a["type"] == "model_activated"]
|
|
assert len(activation_activities) >= 1
|
|
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
|
|
|
|
def test_get_recent_activities_limit(self, patched_session, admin_token):
|
|
"""Test activity limit parameter."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Create many documents
|
|
for i in range(20):
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename=f"limit_doc_{i}.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path=f"/uploads/limit_doc_{i}.pdf",
|
|
page_count=1,
|
|
status="pending",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities(limit=5)
|
|
|
|
assert len(activities) <= 5
|
|
|
|
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
|
|
"""Test activities are sorted by timestamp descending."""
|
|
service = DashboardActivityService()
|
|
session = patched_session
|
|
|
|
# Create document
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename="sorted_doc.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/sorted_doc.pdf",
|
|
page_count=1,
|
|
status="pending",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(doc)
|
|
|
|
# Complete training task
|
|
sample_training_task.status = "completed"
|
|
sample_training_task.metrics_mAP = 0.90
|
|
sample_training_task.updated_at = datetime.now(timezone.utc)
|
|
session.add(sample_training_task)
|
|
session.commit()
|
|
|
|
activities = service.get_recent_activities()
|
|
|
|
# Verify sorted by timestamp DESC
|
|
timestamps = [a["timestamp"] for a in activities]
|
|
assert timestamps == sorted(timestamps, reverse=True)
|