Add more tests
This commit is contained in:
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Dataset Builder Service Integration Tests
|
||||
|
||||
Tests DatasetBuilder with real file operations and repository interactions.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_builder(patched_session, temp_dataset_dir):
|
||||
"""Create a DatasetBuilder with real repositories."""
|
||||
return DatasetBuilder(
|
||||
datasets_repo=DatasetRepository(),
|
||||
documents_repo=DocumentRepository(),
|
||||
annotations_repo=AnnotationRepository(),
|
||||
base_dir=temp_dataset_dir,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_images_dir(temp_upload_dir):
|
||||
"""Create a directory for admin images."""
|
||||
images_dir = temp_upload_dir / "admin_images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
|
||||
"""Create documents with annotations and corresponding image files."""
|
||||
documents = []
|
||||
doc_repo = DocumentRepository()
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
for i in range(5):
|
||||
# Create document
|
||||
doc_id = doc_repo.create(
|
||||
filename=f"invoice_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/invoice_{i}.pdf",
|
||||
page_count=2,
|
||||
category="invoice",
|
||||
group_key=f"group_{i % 2}", # Two groups
|
||||
)
|
||||
|
||||
# Create image files for each page
|
||||
doc_dir = admin_images_dir / doc_id
|
||||
doc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for page in range(1, 3):
|
||||
image_path = doc_dir / f"page_{page}.png"
|
||||
# Create a minimal fake PNG
|
||||
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
# Create annotations
|
||||
for j in range(3):
|
||||
ann_repo.create(
|
||||
document_id=doc_id,
|
||||
page_number=1,
|
||||
class_id=j,
|
||||
class_name=f"field_{j}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + j * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + j * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value=f"value_{j}",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
doc = doc_repo.get(doc_id)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class TestDatasetBuilderBasic:
|
||||
"""Tests for basic dataset building operations."""
|
||||
|
||||
def test_build_dataset_creates_directory_structure(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that building creates proper directory structure."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Test Dataset")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Check directory structure
|
||||
assert (dataset_dir / "images" / "train").exists()
|
||||
assert (dataset_dir / "images" / "val").exists()
|
||||
assert (dataset_dir / "images" / "test").exists()
|
||||
assert (dataset_dir / "labels" / "train").exists()
|
||||
assert (dataset_dir / "labels" / "val").exists()
|
||||
assert (dataset_dir / "labels" / "test").exists()
|
||||
|
||||
def test_build_dataset_copies_images(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that images are copied to dataset directory."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Image Copy Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
result = dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total images across all splits
|
||||
total_images = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
images = list((dataset_dir / "images" / split).glob("*.png"))
|
||||
total_images += len(images)
|
||||
|
||||
# 5 docs * 2 pages = 10 images
|
||||
assert total_images == 10
|
||||
assert result["total_images"] == 10
|
||||
|
||||
def test_build_dataset_generates_labels(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that YOLO label files are generated."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total label files
|
||||
total_labels = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
|
||||
total_labels += len(labels)
|
||||
|
||||
# Same count as images
|
||||
assert total_labels == 10
|
||||
|
||||
def test_build_dataset_generates_data_yaml(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that data.yaml is generated correctly."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="YAML Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
|
||||
assert yaml_path.exists()
|
||||
|
||||
with open(yaml_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
assert data["train"] == "images/train"
|
||||
assert data["val"] == "images/val"
|
||||
assert data["test"] == "images/test"
|
||||
assert "nc" in data
|
||||
assert "names" in data
|
||||
|
||||
|
||||
class TestDatasetBuilderSplits:
|
||||
"""Tests for train/val/test split assignment."""
|
||||
|
||||
def test_split_ratio_respected(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that split ratios are approximately respected."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Split Ratio Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.6,
|
||||
val_ratio=0.2,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Check document assignments in database
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
splits = {"train": 0, "val": 0, "test": 0}
|
||||
for doc in dataset_docs:
|
||||
splits[doc.split] += 1
|
||||
|
||||
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
|
||||
# Due to rounding and group constraints, allow some variation
|
||||
assert splits["train"] >= 2
|
||||
assert splits["val"] >= 1 or splits["test"] >= 1
|
||||
|
||||
def test_same_seed_same_split(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that same seed produces same split."""
|
||||
dataset_repo = DatasetRepository()
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
# Build first dataset
|
||||
dataset1 = dataset_repo.create(name="Seed Test 1")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset1.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Build second dataset with same seed
|
||||
dataset2 = dataset_repo.create(name="Seed Test 2")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset2.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Compare splits
|
||||
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
|
||||
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
|
||||
|
||||
assert docs1 == docs2
|
||||
|
||||
|
||||
class TestDatasetBuilderDatabase:
|
||||
"""Tests for database interactions."""
|
||||
|
||||
def test_updates_dataset_status(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that dataset status is updated after build."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Status Update Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
|
||||
assert updated.status == "ready"
|
||||
assert updated.total_documents == 5
|
||||
assert updated.total_images == 10
|
||||
assert updated.total_annotations > 0
|
||||
assert updated.dataset_path is not None
|
||||
|
||||
def test_records_document_assignments(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that document assignments are recorded in database."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Assignment Recording Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
assert len(dataset_docs) == 5
|
||||
|
||||
for doc in dataset_docs:
|
||||
assert doc.split in ["train", "val", "test"]
|
||||
assert doc.page_count > 0
|
||||
|
||||
|
||||
class TestDatasetBuilderErrors:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with empty document list."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Empty Docs Test")
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with nonexistent document IDs."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Invalid IDs Test")
|
||||
|
||||
fake_ids = [str(uuid4()) for _ in range(3)]
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=fake_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that dataset status is set to failed on error."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Failure Status Test")
|
||||
|
||||
try:
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
assert updated.status == "failed"
|
||||
assert updated.error_message is not None
|
||||
|
||||
|
||||
class TestLabelFileFormat:
|
||||
"""Tests for YOLO label file format."""
|
||||
|
||||
def test_label_file_format(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that label files are in correct YOLO format."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Format Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Find a label file with content
|
||||
label_files = []
|
||||
for split in ["train", "val", "test"]:
|
||||
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
|
||||
|
||||
# Check at least one label file has correct format
|
||||
found_valid_label = False
|
||||
for label_file in label_files:
|
||||
content = label_file.read_text().strip()
|
||||
if content:
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
|
||||
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
assert 0 <= class_id < 10
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 <= width <= 1
|
||||
assert 0 <= height <= 1
|
||||
|
||||
found_valid_label = True
|
||||
break
|
||||
|
||||
assert found_valid_label, "No valid label files found"
|
||||
Reference in New Issue
Block a user