Files
invoice-master-poc-v2/tests/integration/services/test_dataset_builder_integration.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

454 lines
16 KiB
Python

"""
Dataset Builder Service Integration Tests
Tests DatasetBuilder with real file operations and repository interactions.
"""
import shutil
from datetime import datetime, timezone
from pathlib import Path
from uuid import uuid4
import pytest
import yaml
from backend.data.admin_models import AdminAnnotation, AdminDocument
from backend.data.repositories.annotation_repository import AnnotationRepository
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.document_repository import DocumentRepository
from backend.web.services.dataset_builder import DatasetBuilder
@pytest.fixture
def dataset_builder(patched_session, temp_dataset_dir):
"""Create a DatasetBuilder with real repositories."""
return DatasetBuilder(
datasets_repo=DatasetRepository(),
documents_repo=DocumentRepository(),
annotations_repo=AnnotationRepository(),
base_dir=temp_dataset_dir,
)
@pytest.fixture
def admin_images_dir(temp_upload_dir):
"""Create a directory for admin images."""
images_dir = temp_upload_dir / "admin_images"
images_dir.mkdir(parents=True, exist_ok=True)
return images_dir
@pytest.fixture
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
"""Create documents with annotations and corresponding image files."""
documents = []
doc_repo = DocumentRepository()
ann_repo = AnnotationRepository()
for i in range(5):
# Create document
doc_id = doc_repo.create(
filename=f"invoice_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/invoice_{i}.pdf",
page_count=2,
category="invoice",
group_key=f"group_{i % 2}", # Two groups
)
# Create image files for each page
doc_dir = admin_images_dir / doc_id
doc_dir.mkdir(parents=True, exist_ok=True)
for page in range(1, 3):
image_path = doc_dir / f"page_{page}.png"
# Create a minimal fake PNG
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
# Create annotations
for j in range(3):
ann_repo.create(
document_id=doc_id,
page_number=1,
class_id=j,
class_name=f"field_{j}",
x_center=0.5,
y_center=0.1 + j * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + j * 160,
bbox_width=160,
bbox_height=40,
text_value=f"value_{j}",
confidence=0.95,
source="auto",
)
doc = doc_repo.get(doc_id)
documents.append(doc)
return documents
class TestDatasetBuilderBasic:
"""Tests for basic dataset building operations."""
def test_build_dataset_creates_directory_structure(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that building creates proper directory structure."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Test Dataset")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Check directory structure
assert (dataset_dir / "images" / "train").exists()
assert (dataset_dir / "images" / "val").exists()
assert (dataset_dir / "images" / "test").exists()
assert (dataset_dir / "labels" / "train").exists()
assert (dataset_dir / "labels" / "val").exists()
assert (dataset_dir / "labels" / "test").exists()
def test_build_dataset_copies_images(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that images are copied to dataset directory."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Image Copy Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
result = dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total images across all splits
total_images = 0
for split in ["train", "val", "test"]:
images = list((dataset_dir / "images" / split).glob("*.png"))
total_images += len(images)
# 5 docs * 2 pages = 10 images
assert total_images == 10
assert result["total_images"] == 10
def test_build_dataset_generates_labels(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that YOLO label files are generated."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total label files
total_labels = 0
for split in ["train", "val", "test"]:
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
total_labels += len(labels)
# Same count as images
assert total_labels == 10
def test_build_dataset_generates_data_yaml(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that data.yaml is generated correctly."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="YAML Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
yaml_path = dataset_dir / "data.yaml"
assert yaml_path.exists()
with open(yaml_path) as f:
data = yaml.safe_load(f)
assert data["train"] == "images/train"
assert data["val"] == "images/val"
assert data["test"] == "images/test"
assert "nc" in data
assert "names" in data
class TestDatasetBuilderSplits:
"""Tests for train/val/test split assignment."""
def test_split_ratio_respected(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that split ratios are approximately respected."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Split Ratio Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=admin_images_dir,
)
# Check document assignments in database
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
splits = {"train": 0, "val": 0, "test": 0}
for doc in dataset_docs:
splits[doc.split] += 1
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
# Due to rounding and group constraints, allow some variation
assert splits["train"] >= 2
assert splits["val"] >= 1 or splits["test"] >= 1
def test_same_seed_same_split(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that same seed produces same split."""
dataset_repo = DatasetRepository()
doc_ids = [str(d.document_id) for d in documents_with_annotations]
# Build first dataset
dataset1 = dataset_repo.create(name="Seed Test 1")
dataset_builder.build_dataset(
dataset_id=str(dataset1.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Build second dataset with same seed
dataset2 = dataset_repo.create(name="Seed Test 2")
dataset_builder.build_dataset(
dataset_id=str(dataset2.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Compare splits
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
assert docs1 == docs2
class TestDatasetBuilderDatabase:
"""Tests for database interactions."""
def test_updates_dataset_status(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that dataset status is updated after build."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Status Update Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "ready"
assert updated.total_documents == 5
assert updated.total_images == 10
assert updated.total_annotations > 0
assert updated.dataset_path is not None
def test_records_document_assignments(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that document assignments are recorded in database."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Assignment Recording Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
assert len(dataset_docs) == 5
for doc in dataset_docs:
assert doc.split in ["train", "val", "test"]
assert doc.page_count > 0
class TestDatasetBuilderErrors:
"""Tests for error handling."""
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with empty document list."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Empty Docs Test")
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with nonexistent document IDs."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Invalid IDs Test")
fake_ids = [str(uuid4()) for _ in range(3)]
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=fake_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
"""Test that dataset status is set to failed on error."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Failure Status Test")
try:
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
except ValueError:
pass
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "failed"
assert updated.error_message is not None
class TestLabelFileFormat:
"""Tests for YOLO label file format."""
def test_label_file_format(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that label files are in correct YOLO format."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Format Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Find a label file with content
label_files = []
for split in ["train", "val", "test"]:
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
# Check at least one label file has correct format
found_valid_label = False
for label_file in label_files:
content = label_file.read_text().strip()
if content:
lines = content.split("\n")
for line in lines:
parts = line.split()
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
assert 0 <= class_id < 10
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 <= width <= 1
assert 0 <= height <= 1
found_valid_label = True
break
assert found_valid_label, "No valid label files found"