Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1 @@
"""Service integration tests."""

View File

@@ -0,0 +1,497 @@
"""
Dashboard Service Integration Tests
Tests DashboardStatsService and DashboardActivityService with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
is_annotation_complete,
IDENTIFIER_CLASS_IDS,
PAYMENT_CLASS_IDS,
)
class TestIsAnnotationComplete:
"""Tests for is_annotation_complete function."""
def test_complete_with_invoice_number_and_bankgiro(self):
"""Test complete with invoice_number (0) and bankgiro (4)."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 4}, # bankgiro
]
assert is_annotation_complete(annotations) is True
def test_complete_with_ocr_number_and_plusgiro(self):
"""Test complete with ocr_number (3) and plusgiro (5)."""
annotations = [
{"class_id": 3}, # ocr_number
{"class_id": 5}, # plusgiro
]
assert is_annotation_complete(annotations) is True
def test_incomplete_missing_identifier(self):
"""Test incomplete when missing identifier."""
annotations = [
{"class_id": 4}, # bankgiro only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_missing_payment(self):
"""Test incomplete when missing payment."""
annotations = [
{"class_id": 0}, # invoice_number only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_empty_annotations(self):
"""Test incomplete with empty annotations."""
assert is_annotation_complete([]) is False
def test_complete_with_multiple_fields(self):
"""Test complete with multiple fields."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 1}, # invoice_date
{"class_id": 3}, # ocr_number
{"class_id": 4}, # bankgiro
{"class_id": 5}, # plusgiro
{"class_id": 6}, # amount
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService."""
def test_get_stats_empty_database(self, patched_session):
"""Test stats with empty database."""
service = DashboardStatsService()
stats = service.get_stats()
assert stats["total_documents"] == 0
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 0
assert stats["pending"] == 0
assert stats["completeness_rate"] == 0.0
def test_get_stats_with_documents(self, patched_session, admin_token):
"""Test stats with various document states."""
service = DashboardStatsService()
session = patched_session
# Create documents with different statuses
docs = []
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
stats = service.get_stats()
assert stats["total_documents"] == 5
assert stats["pending"] == 2 # pending + auto_labeling
def test_get_stats_complete_annotations(self, patched_session, admin_token):
"""Test completeness calculation with proper annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier annotation (invoice_number = class_id 0)
ann1 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann1)
# Add payment annotation (bankgiro = class_id 4)
ann2 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4,
class_name="bankgiro",
x_center=0.5,
y_center=0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=160,
bbox_width=160,
bbox_height=40,
text_value="123-4567",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann2)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 0
assert stats["completeness_rate"] == 100.0
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
"""Test completeness with incomplete annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document missing payment annotation
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="incomplete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/incomplete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add only identifier annotation (missing payment)
ann = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 0.0
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
"""Test stats with mix of complete and incomplete documents."""
service = DashboardStatsService()
session = patched_session
# Create 2 labeled documents
docs = []
for i in range(2):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"mixed_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/mixed_doc_{i}.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
# First document: complete (has identifier + payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
# Second document: incomplete (missing payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[1].document_id,
page_number=1,
class_id=0, # invoice_number only
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 50.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService."""
def test_get_recent_activities_empty(self, patched_session):
"""Test activities with empty database."""
service = DashboardActivityService()
activities = service.get_recent_activities()
assert activities == []
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
"""Test activities include document uploads."""
service = DashboardActivityService()
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities()
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
"""Test activities include annotation overrides."""
service = DashboardActivityService()
session = patched_session
# Create annotation history with override
history = AnnotationHistory(
history_id=uuid4(),
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="override",
previous_value={"text_value": "OLD-001"},
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
changed_by="test-admin",
created_at=datetime.now(timezone.utc),
)
session.add(history)
session.commit()
activities = service.get_recent_activities()
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
assert len(override_activities) >= 1
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
"""Test activities include training completions."""
service = DashboardActivityService()
session = patched_session
# Update training task to completed
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.85
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
training_activities = [a for a in activities if a["type"] == "training_completed"]
assert len(training_activities) >= 1
assert "mAP" in training_activities[0]["metadata"]
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
"""Test activities include training failures."""
service = DashboardActivityService()
session = patched_session
# Update training task to failed
sample_training_task.status = "failed"
sample_training_task.error_message = "CUDA out of memory"
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
failed_activities = [a for a in activities if a["type"] == "training_failed"]
assert len(failed_activities) >= 1
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
"""Test activities include model activations."""
service = DashboardActivityService()
session = patched_session
# Activate model
sample_model_version.is_active = True
sample_model_version.activated_at = datetime.now(timezone.utc)
session.add(sample_model_version)
session.commit()
activities = service.get_recent_activities()
activation_activities = [a for a in activities if a["type"] == "model_activated"]
assert len(activation_activities) >= 1
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
def test_get_recent_activities_limit(self, patched_session, admin_token):
"""Test activity limit parameter."""
service = DashboardActivityService()
session = patched_session
# Create many documents
for i in range(20):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities(limit=5)
assert len(activities) <= 5
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
"""Test activities are sorted by timestamp descending."""
service = DashboardActivityService()
session = patched_session
# Create document
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="sorted_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/sorted_doc.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
# Complete training task
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.90
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
# Verify sorted by timestamp DESC
timestamps = [a["timestamp"] for a in activities]
assert timestamps == sorted(timestamps, reverse=True)

View File

@@ -0,0 +1,453 @@
"""
Dataset Builder Service Integration Tests
Tests DatasetBuilder with real file operations and repository interactions.
"""
import shutil
from datetime import datetime, timezone
from pathlib import Path
from uuid import uuid4
import pytest
import yaml
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.annotation_repository import AnnotationRepository
from inference.data.repositories.dataset_repository import DatasetRepository
from inference.data.repositories.document_repository import DocumentRepository
from inference.web.services.dataset_builder import DatasetBuilder
@pytest.fixture
def dataset_builder(patched_session, temp_dataset_dir):
"""Create a DatasetBuilder with real repositories."""
return DatasetBuilder(
datasets_repo=DatasetRepository(),
documents_repo=DocumentRepository(),
annotations_repo=AnnotationRepository(),
base_dir=temp_dataset_dir,
)
@pytest.fixture
def admin_images_dir(temp_upload_dir):
"""Create a directory for admin images."""
images_dir = temp_upload_dir / "admin_images"
images_dir.mkdir(parents=True, exist_ok=True)
return images_dir
@pytest.fixture
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
"""Create documents with annotations and corresponding image files."""
documents = []
doc_repo = DocumentRepository()
ann_repo = AnnotationRepository()
for i in range(5):
# Create document
doc_id = doc_repo.create(
filename=f"invoice_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/invoice_{i}.pdf",
page_count=2,
category="invoice",
group_key=f"group_{i % 2}", # Two groups
)
# Create image files for each page
doc_dir = admin_images_dir / doc_id
doc_dir.mkdir(parents=True, exist_ok=True)
for page in range(1, 3):
image_path = doc_dir / f"page_{page}.png"
# Create a minimal fake PNG
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
# Create annotations
for j in range(3):
ann_repo.create(
document_id=doc_id,
page_number=1,
class_id=j,
class_name=f"field_{j}",
x_center=0.5,
y_center=0.1 + j * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + j * 160,
bbox_width=160,
bbox_height=40,
text_value=f"value_{j}",
confidence=0.95,
source="auto",
)
doc = doc_repo.get(doc_id)
documents.append(doc)
return documents
class TestDatasetBuilderBasic:
"""Tests for basic dataset building operations."""
def test_build_dataset_creates_directory_structure(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that building creates proper directory structure."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Test Dataset")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Check directory structure
assert (dataset_dir / "images" / "train").exists()
assert (dataset_dir / "images" / "val").exists()
assert (dataset_dir / "images" / "test").exists()
assert (dataset_dir / "labels" / "train").exists()
assert (dataset_dir / "labels" / "val").exists()
assert (dataset_dir / "labels" / "test").exists()
def test_build_dataset_copies_images(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that images are copied to dataset directory."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Image Copy Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
result = dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total images across all splits
total_images = 0
for split in ["train", "val", "test"]:
images = list((dataset_dir / "images" / split).glob("*.png"))
total_images += len(images)
# 5 docs * 2 pages = 10 images
assert total_images == 10
assert result["total_images"] == 10
def test_build_dataset_generates_labels(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that YOLO label files are generated."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total label files
total_labels = 0
for split in ["train", "val", "test"]:
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
total_labels += len(labels)
# Same count as images
assert total_labels == 10
def test_build_dataset_generates_data_yaml(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that data.yaml is generated correctly."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="YAML Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
yaml_path = dataset_dir / "data.yaml"
assert yaml_path.exists()
with open(yaml_path) as f:
data = yaml.safe_load(f)
assert data["train"] == "images/train"
assert data["val"] == "images/val"
assert data["test"] == "images/test"
assert "nc" in data
assert "names" in data
class TestDatasetBuilderSplits:
"""Tests for train/val/test split assignment."""
def test_split_ratio_respected(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that split ratios are approximately respected."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Split Ratio Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=admin_images_dir,
)
# Check document assignments in database
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
splits = {"train": 0, "val": 0, "test": 0}
for doc in dataset_docs:
splits[doc.split] += 1
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
# Due to rounding and group constraints, allow some variation
assert splits["train"] >= 2
assert splits["val"] >= 1 or splits["test"] >= 1
def test_same_seed_same_split(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that same seed produces same split."""
dataset_repo = DatasetRepository()
doc_ids = [str(d.document_id) for d in documents_with_annotations]
# Build first dataset
dataset1 = dataset_repo.create(name="Seed Test 1")
dataset_builder.build_dataset(
dataset_id=str(dataset1.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Build second dataset with same seed
dataset2 = dataset_repo.create(name="Seed Test 2")
dataset_builder.build_dataset(
dataset_id=str(dataset2.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Compare splits
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
assert docs1 == docs2
class TestDatasetBuilderDatabase:
"""Tests for database interactions."""
def test_updates_dataset_status(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that dataset status is updated after build."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Status Update Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "ready"
assert updated.total_documents == 5
assert updated.total_images == 10
assert updated.total_annotations > 0
assert updated.dataset_path is not None
def test_records_document_assignments(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that document assignments are recorded in database."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Assignment Recording Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
assert len(dataset_docs) == 5
for doc in dataset_docs:
assert doc.split in ["train", "val", "test"]
assert doc.page_count > 0
class TestDatasetBuilderErrors:
"""Tests for error handling."""
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with empty document list."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Empty Docs Test")
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with nonexistent document IDs."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Invalid IDs Test")
fake_ids = [str(uuid4()) for _ in range(3)]
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=fake_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
"""Test that dataset status is set to failed on error."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Failure Status Test")
try:
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
except ValueError:
pass
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "failed"
assert updated.error_message is not None
class TestLabelFileFormat:
"""Tests for YOLO label file format."""
def test_label_file_format(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that label files are in correct YOLO format."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Format Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Find a label file with content
label_files = []
for split in ["train", "val", "test"]:
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
# Check at least one label file has correct format
found_valid_label = False
for label_file in label_files:
content = label_file.read_text().strip()
if content:
lines = content.split("\n")
for line in lines:
parts = line.split()
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
assert 0 <= class_id < 10
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 <= width <= 1
assert 0 <= height <= 1
found_valid_label = True
break
assert found_valid_label, "No valid label files found"

View File

@@ -0,0 +1,283 @@
"""
Document Service Integration Tests
Tests DocumentService with real storage operations.
"""
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from inference.web.services.document_service import DocumentService, DocumentResult
class MockStorageBackend:
"""Simple in-memory storage backend for testing."""
def __init__(self):
self._files: dict[str, bytes] = {}
def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None:
if not overwrite and remote_path in self._files:
raise FileExistsError(f"File already exists: {remote_path}")
self._files[remote_path] = content
def download_bytes(self, remote_path: str) -> bytes:
if remote_path not in self._files:
raise FileNotFoundError(f"File not found: {remote_path}")
return self._files[remote_path]
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}"
def exists(self, remote_path: str) -> bool:
return remote_path in self._files
def delete(self, remote_path: str) -> bool:
if remote_path in self._files:
del self._files[remote_path]
return True
return False
def list_files(self, prefix: str) -> list[str]:
return [path for path in self._files.keys() if path.startswith(prefix)]
@pytest.fixture
def mock_storage():
"""Create a mock storage backend."""
return MockStorageBackend()
@pytest.fixture
def document_service(mock_storage):
"""Create a DocumentService with mock storage."""
return DocumentService(storage_backend=mock_storage)
class TestDocumentUpload:
"""Tests for document upload operations."""
def test_upload_document(self, document_service):
"""Test uploading a document."""
content = b"%PDF-1.4 test content"
filename = "test_invoice.pdf"
result = document_service.upload_document(content, filename)
assert result is not None
assert result.id is not None
assert result.filename == filename
assert result.file_path.startswith("documents/")
assert result.file_path.endswith(".pdf")
def test_upload_document_with_custom_id(self, document_service):
"""Test uploading with custom document ID."""
content = b"%PDF-1.4 test content"
filename = "invoice.pdf"
custom_id = "custom-doc-12345"
result = document_service.upload_document(
content, filename, document_id=custom_id
)
assert result.id == custom_id
assert custom_id in result.file_path
def test_upload_preserves_extension(self, document_service):
"""Test that file extension is preserved."""
cases = [
("document.pdf", ".pdf"),
("image.PNG", ".png"),
("file.JPEG", ".jpeg"),
("noextension", ""),
]
for filename, expected_ext in cases:
result = document_service.upload_document(b"content", filename)
if expected_ext:
assert result.file_path.endswith(expected_ext)
def test_upload_document_overwrite(self, document_service, mock_storage):
"""Test that upload overwrites existing file."""
content1 = b"original content"
content2 = b"new content"
doc_id = "overwrite-test"
document_service.upload_document(content1, "doc.pdf", document_id=doc_id)
document_service.upload_document(content2, "doc.pdf", document_id=doc_id)
# Should have new content
remote_path = f"documents/{doc_id}.pdf"
stored_content = mock_storage.download_bytes(remote_path)
assert stored_content == content2
class TestDocumentDownload:
"""Tests for document download operations."""
def test_download_document(self, document_service, mock_storage):
"""Test downloading a document."""
content = b"test document content"
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(content, remote_path)
downloaded = document_service.download_document(remote_path)
assert downloaded == content
def test_download_nonexistent_document(self, document_service):
"""Test downloading document that doesn't exist."""
with pytest.raises(FileNotFoundError):
document_service.download_document("documents/nonexistent.pdf")
class TestDocumentUrl:
"""Tests for document URL generation."""
def test_get_document_url(self, document_service, mock_storage):
"""Test getting presigned URL for document."""
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(b"content", remote_path)
url = document_service.get_document_url(remote_path, expires_in_seconds=7200)
assert url.startswith("https://")
assert remote_path in url
assert "7200" in url
def test_get_document_url_default_expiry(self, document_service):
"""Test default URL expiry."""
url = document_service.get_document_url("documents/doc.pdf")
assert "3600" in url
class TestDocumentExists:
"""Tests for document existence check."""
def test_document_exists(self, document_service, mock_storage):
"""Test checking if document exists."""
remote_path = "documents/existing.pdf"
mock_storage.upload_bytes(b"content", remote_path)
assert document_service.document_exists(remote_path) is True
def test_document_not_exists(self, document_service):
"""Test checking if nonexistent document exists."""
assert document_service.document_exists("documents/nonexistent.pdf") is False
class TestDocumentDelete:
"""Tests for document deletion."""
def test_delete_document(self, document_service, mock_storage):
"""Test deleting a document."""
remote_path = "documents/to-delete.pdf"
mock_storage.upload_bytes(b"content", remote_path)
result = document_service.delete_document_files(remote_path)
assert result is True
assert document_service.document_exists(remote_path) is False
def test_delete_nonexistent_document(self, document_service):
"""Test deleting document that doesn't exist."""
result = document_service.delete_document_files("documents/nonexistent.pdf")
assert result is False
class TestPageImages:
"""Tests for page image operations."""
def test_save_page_image(self, document_service, mock_storage):
"""Test saving a page image."""
doc_id = "test-doc-123"
page_num = 1
image_content = b"\x89PNG\r\n\x1a\n fake png"
remote_path = document_service.save_page_image(doc_id, page_num, image_content)
assert remote_path == f"images/{doc_id}/page_{page_num}.png"
assert mock_storage.exists(remote_path)
def test_save_multiple_page_images(self, document_service, mock_storage):
"""Test saving images for multiple pages."""
doc_id = "multi-page-doc"
for page_num in range(1, 4):
content = f"page {page_num} content".encode()
document_service.save_page_image(doc_id, page_num, content)
images = document_service.list_document_images(doc_id)
assert len(images) == 3
def test_get_page_image(self, document_service, mock_storage):
"""Test downloading a page image."""
doc_id = "test-doc"
page_num = 2
image_content = b"image data"
document_service.save_page_image(doc_id, page_num, image_content)
downloaded = document_service.get_page_image(doc_id, page_num)
assert downloaded == image_content
def test_get_page_image_url(self, document_service):
"""Test getting URL for page image."""
doc_id = "test-doc"
page_num = 1
url = document_service.get_page_image_url(doc_id, page_num)
assert f"images/{doc_id}/page_{page_num}.png" in url
def test_list_document_images(self, document_service, mock_storage):
"""Test listing all images for a document."""
doc_id = "list-test-doc"
for i in range(5):
document_service.save_page_image(doc_id, i + 1, f"page {i}".encode())
images = document_service.list_document_images(doc_id)
assert len(images) == 5
def test_delete_document_images(self, document_service, mock_storage):
"""Test deleting all images for a document."""
doc_id = "delete-images-doc"
for i in range(3):
document_service.save_page_image(doc_id, i + 1, b"content")
deleted_count = document_service.delete_document_images(doc_id)
assert deleted_count == 3
assert len(document_service.list_document_images(doc_id)) == 0
class TestRoundTrip:
"""Tests for complete upload-download cycles."""
def test_document_round_trip(self, document_service):
"""Test uploading and downloading document."""
original_content = b"%PDF-1.4 complete document content here"
filename = "roundtrip.pdf"
result = document_service.upload_document(original_content, filename)
downloaded = document_service.download_document(result.file_path)
assert downloaded == original_content
def test_image_round_trip(self, document_service):
"""Test saving and retrieving page image."""
doc_id = "roundtrip-doc"
page_num = 1
original_image = b"\x89PNG fake image data"
document_service.save_page_image(doc_id, page_num, original_image)
retrieved = document_service.get_page_image(doc_id, page_num)
assert retrieved == original_image