743 lines
30 KiB
Python
743 lines
30 KiB
Python
"""
|
|
Tests for DatasetBuilder service.
|
|
|
|
TDD: Write tests first, then implement dataset_builder.py.
|
|
"""
|
|
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from inference.data.admin_models import (
|
|
AdminAnnotation,
|
|
AdminDocument,
|
|
TrainingDataset,
|
|
FIELD_CLASSES,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_admin_images(tmp_path):
|
|
"""Create mock admin images directory with sample images."""
|
|
doc_ids = [uuid4() for _ in range(5)]
|
|
for doc_id in doc_ids:
|
|
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
|
doc_dir.mkdir(parents=True)
|
|
# Create 2 pages per doc
|
|
for page in range(1, 3):
|
|
img_path = doc_dir / f"page_{page}.png"
|
|
img_path.write_bytes(b"fake-png-data")
|
|
return tmp_path, doc_ids
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_admin_db():
|
|
"""Mock AdminDB with dataset and document methods."""
|
|
db = MagicMock()
|
|
db.create_dataset.return_value = TrainingDataset(
|
|
dataset_id=uuid4(),
|
|
name="test-dataset",
|
|
status="building",
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
)
|
|
return db
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_documents(tmp_admin_images):
|
|
"""Create sample AdminDocument objects."""
|
|
tmp_path, doc_ids = tmp_admin_images
|
|
docs = []
|
|
for doc_id in doc_ids:
|
|
doc = MagicMock(spec=AdminDocument)
|
|
doc.document_id = doc_id
|
|
doc.filename = f"{doc_id}.pdf"
|
|
doc.page_count = 2
|
|
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_annotations(sample_documents):
|
|
"""Create sample annotations for each document page."""
|
|
annotations = {}
|
|
for doc in sample_documents:
|
|
doc_anns = []
|
|
for page in range(1, 3):
|
|
ann = MagicMock(spec=AdminAnnotation)
|
|
ann.document_id = doc.document_id
|
|
ann.page_number = page
|
|
ann.class_id = 0
|
|
ann.class_name = "invoice_number"
|
|
ann.x_center = 0.5
|
|
ann.y_center = 0.3
|
|
ann.width = 0.2
|
|
ann.height = 0.05
|
|
doc_anns.append(ann)
|
|
annotations[str(doc.document_id)] = doc_anns
|
|
return annotations
|
|
|
|
|
|
class TestDatasetBuilder:
|
|
"""Tests for DatasetBuilder."""
|
|
|
|
def test_build_creates_directory_structure(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
dataset_dir = tmp_path / "datasets" / "test"
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# Mock DB calls
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
|
for split in ["train", "val", "test"]:
|
|
assert (result_dir / "images" / split).exists()
|
|
assert (result_dir / "labels" / split).exists()
|
|
|
|
def test_build_copies_images(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""Images should be copied from admin_images to dataset folder."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
result = builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
# Check total images copied
|
|
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
|
total_images = sum(
|
|
len(list((result_dir / "images" / split).glob("*.png")))
|
|
for split in ["train", "val", "test"]
|
|
)
|
|
assert total_images == 10 # 5 docs * 2 pages
|
|
|
|
def test_build_generates_yolo_labels(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""YOLO label files should be generated with correct format."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
|
total_labels = sum(
|
|
len(list((result_dir / "labels" / split).glob("*.txt")))
|
|
for split in ["train", "val", "test"]
|
|
)
|
|
assert total_labels == 10 # 5 docs * 2 pages
|
|
|
|
# Check label format: "class_id x_center y_center width height"
|
|
label_files = list((result_dir / "labels").rglob("*.txt"))
|
|
content = label_files[0].read_text().strip()
|
|
parts = content.split()
|
|
assert len(parts) == 5
|
|
assert int(parts[0]) == 0 # class_id
|
|
assert 0 <= float(parts[1]) <= 1 # x_center
|
|
assert 0 <= float(parts[2]) <= 1 # y_center
|
|
|
|
def test_build_generates_data_yaml(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""data.yaml should be generated with correct field classes."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml"
|
|
assert yaml_path.exists()
|
|
content = yaml_path.read_text()
|
|
assert "train:" in content
|
|
assert "val:" in content
|
|
assert "nc:" in content
|
|
assert "invoice_number" in content
|
|
|
|
def test_build_splits_documents_correctly(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""Documents should be split into train/val/test according to ratios."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
# Verify add_dataset_documents was called with correct splits
|
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
|
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
|
splits = [d["split"] for d in docs_added]
|
|
assert "train" in splits
|
|
# With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test
|
|
train_count = splits.count("train")
|
|
assert train_count >= 3 # At least 3 of 5 should be train
|
|
|
|
def test_build_updates_status_to_ready(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""After successful build, dataset status should be updated to 'ready'."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
mock_admin_db.update_dataset_status.assert_called_once()
|
|
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
|
assert call_kwargs["status"] == "ready"
|
|
assert call_kwargs["total_documents"] == 5
|
|
assert call_kwargs["total_images"] == 10
|
|
|
|
def test_build_sets_failed_on_error(
|
|
self, tmp_path, mock_admin_db
|
|
):
|
|
"""If build fails, dataset status should be set to 'failed'."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
with pytest.raises(ValueError):
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=["nonexistent-id"],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
mock_admin_db.update_dataset_status.assert_called_once()
|
|
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
|
assert call_kwargs["status"] == "failed"
|
|
|
|
def test_build_with_seed_produces_deterministic_splits(
|
|
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
|
):
|
|
"""Same seed should produce same splits."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
results = []
|
|
for _ in range(2):
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
sample_annotations.get(str(doc_id), [])
|
|
)
|
|
mock_admin_db.add_dataset_documents.reset_mock()
|
|
mock_admin_db.update_dataset_status.reset_mock()
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in sample_documents],
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
|
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
|
results.append([(d["document_id"], d["split"]) for d in docs])
|
|
|
|
assert results[0] == results[1]
|
|
|
|
|
|
class TestAssignSplitsByGroup:
|
|
"""Tests for _assign_splits_by_group method with group_key logic."""
|
|
|
|
def _make_mock_doc(self, doc_id, group_key=None):
|
|
"""Create a mock AdminDocument with document_id and group_key."""
|
|
doc = MagicMock(spec=AdminDocument)
|
|
doc.document_id = doc_id
|
|
doc.group_key = group_key
|
|
doc.page_count = 1
|
|
return doc
|
|
|
|
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
|
|
"""Documents with unique group_key are distributed across splits."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# 3 documents, each with unique group_key
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
|
|
|
# With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
|
|
train_count = sum(1 for s in result.values() if s == "train")
|
|
val_count = sum(1 for s in result.values() if s == "val")
|
|
assert train_count >= 1
|
|
assert val_count >= 1 # Ensure val is not empty
|
|
|
|
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
|
|
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key=None),
|
|
self._make_mock_doc(uuid4(), group_key=""),
|
|
self._make_mock_doc(uuid4(), group_key=None),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
|
|
|
# Each null/empty group_key doc is independent, distributed across splits
|
|
# With 3 docs: ensure at least 1 in train and 1 in val
|
|
train_count = sum(1 for s in result.values() if s == "train")
|
|
val_count = sum(1 for s in result.values() if s == "val")
|
|
assert train_count >= 1
|
|
assert val_count >= 1
|
|
|
|
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
|
|
"""Documents with same group_key should be assigned to the same split."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# 6 documents in 2 groups
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
|
|
|
# All docs in supplier-A should have same split
|
|
splits_a = [result[str(d.document_id)] for d in docs[:3]]
|
|
assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
|
|
|
|
# All docs in supplier-B should have same split
|
|
splits_b = [result[str(d.document_id)] for d in docs[3:]]
|
|
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
|
|
|
|
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
|
|
"""Multi-doc groups should be split according to train/val/test ratios."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# 10 groups with 2 docs each
|
|
docs = []
|
|
for i in range(10):
|
|
group_key = f"group-{i}"
|
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
|
|
|
# Count groups per split
|
|
group_splits = {}
|
|
for doc in docs:
|
|
split = result[str(doc.document_id)]
|
|
if doc.group_key not in group_splits:
|
|
group_splits[doc.group_key] = split
|
|
else:
|
|
# Verify same group has same split
|
|
assert group_splits[doc.group_key] == split
|
|
|
|
split_counts = {"train": 0, "val": 0, "test": 0}
|
|
for split in group_splits.values():
|
|
split_counts[split] += 1
|
|
|
|
# With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
|
|
assert split_counts["train"] >= 6
|
|
assert split_counts["train"] <= 8
|
|
assert split_counts["val"] >= 1
|
|
assert split_counts["val"] <= 3
|
|
|
|
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
|
|
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
docs = [
|
|
# Single-doc groups
|
|
self._make_mock_doc(uuid4(), group_key="single-1"),
|
|
self._make_mock_doc(uuid4(), group_key="single-2"),
|
|
self._make_mock_doc(uuid4(), group_key=None),
|
|
# Multi-doc groups
|
|
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
|
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
|
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
|
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
|
|
|
# All groups are shuffled and distributed
|
|
# Ensure at least 1 in train and 1 in val
|
|
train_count = sum(1 for s in result.values() if s == "train")
|
|
val_count = sum(1 for s in result.values() if s == "val")
|
|
assert train_count >= 1
|
|
assert val_count >= 1
|
|
|
|
# Multi-doc groups stay together
|
|
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
|
|
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
|
|
|
|
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
|
|
"""Same seed should produce same split assignments."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
|
]
|
|
|
|
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
|
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
|
|
|
assert result1 == result2
|
|
|
|
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
|
|
"""Different seeds should potentially produce different split assignments."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# Many groups to increase chance of different results
|
|
docs = []
|
|
for i in range(20):
|
|
group_key = f"group-{i}"
|
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
|
|
|
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
|
|
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
|
|
|
|
# Results should be different (very likely with 20 groups)
|
|
assert result1 != result2
|
|
|
|
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
|
|
"""Every document should be assigned a split."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
|
self._make_mock_doc(uuid4(), group_key=None),
|
|
self._make_mock_doc(uuid4(), group_key="single"),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
|
|
|
assert len(result) == len(docs)
|
|
for doc in docs:
|
|
assert str(doc.document_id) in result
|
|
assert result[str(doc.document_id)] in ["train", "val", "test"]
|
|
|
|
def test_empty_documents_list(self, tmp_path, mock_admin_db):
|
|
"""Empty document list should return empty result."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
|
|
|
|
assert result == {}
|
|
|
|
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
|
|
"""When all groups have multiple docs, splits should follow ratios."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
# 5 groups with 3 docs each
|
|
docs = []
|
|
for i in range(5):
|
|
group_key = f"group-{i}"
|
|
for _ in range(3):
|
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
|
|
|
# Group splits
|
|
group_splits = {}
|
|
for doc in docs:
|
|
if doc.group_key not in group_splits:
|
|
group_splits[doc.group_key] = result[str(doc.document_id)]
|
|
|
|
split_counts = {"train": 0, "val": 0, "test": 0}
|
|
for split in group_splits.values():
|
|
split_counts[split] += 1
|
|
|
|
# With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
|
|
assert split_counts["train"] >= 2
|
|
assert split_counts["train"] <= 4
|
|
|
|
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
|
|
"""When all groups have single doc, they are distributed across splits."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
|
|
docs = [
|
|
self._make_mock_doc(uuid4(), group_key="unique-1"),
|
|
self._make_mock_doc(uuid4(), group_key="unique-2"),
|
|
self._make_mock_doc(uuid4(), group_key="unique-3"),
|
|
self._make_mock_doc(uuid4(), group_key=None),
|
|
self._make_mock_doc(uuid4(), group_key=""),
|
|
]
|
|
|
|
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
|
|
|
# With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
|
|
train_count = sum(1 for s in result.values() if s == "train")
|
|
val_count = sum(1 for s in result.values() if s == "val")
|
|
assert train_count >= 2
|
|
assert val_count >= 1 # Ensure val is not empty
|
|
|
|
|
|
class TestBuildDatasetWithGroupKey:
|
|
"""Integration tests for build_dataset with group_key logic."""
|
|
|
|
@pytest.fixture
|
|
def grouped_documents(self, tmp_path):
|
|
"""Create documents with various group_key configurations."""
|
|
doc_ids = []
|
|
docs = []
|
|
|
|
# Create 3 groups: 2 multi-doc groups + 2 single-doc groups
|
|
group_configs = [
|
|
("supplier-A", 3), # Multi-doc group: 3 docs
|
|
("supplier-B", 2), # Multi-doc group: 2 docs
|
|
("unique-1", 1), # Single-doc group
|
|
(None, 1), # Null group_key
|
|
]
|
|
|
|
for group_key, count in group_configs:
|
|
for _ in range(count):
|
|
doc_id = uuid4()
|
|
doc_ids.append(doc_id)
|
|
|
|
# Create image files
|
|
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
|
doc_dir.mkdir(parents=True)
|
|
for page in range(1, 3):
|
|
(doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
|
|
|
|
# Create mock document
|
|
doc = MagicMock(spec=AdminDocument)
|
|
doc.document_id = doc_id
|
|
doc.filename = f"{doc_id}.pdf"
|
|
doc.page_count = 2
|
|
doc.group_key = group_key
|
|
doc.file_path = str(doc_dir)
|
|
docs.append(doc)
|
|
|
|
return tmp_path, docs
|
|
|
|
@pytest.fixture
|
|
def grouped_annotations(self, grouped_documents):
|
|
"""Create annotations for grouped documents."""
|
|
tmp_path, docs = grouped_documents
|
|
annotations = {}
|
|
for doc in docs:
|
|
doc_anns = []
|
|
for page in range(1, 3):
|
|
ann = MagicMock(spec=AdminAnnotation)
|
|
ann.document_id = doc.document_id
|
|
ann.page_number = page
|
|
ann.class_id = 0
|
|
ann.class_name = "invoice_number"
|
|
ann.x_center = 0.5
|
|
ann.y_center = 0.3
|
|
ann.width = 0.2
|
|
ann.height = 0.05
|
|
doc_anns.append(ann)
|
|
annotations[str(doc.document_id)] = doc_anns
|
|
return annotations
|
|
|
|
def test_build_respects_group_key_splits(
|
|
self, grouped_documents, grouped_annotations, mock_admin_db
|
|
):
|
|
"""build_dataset should use group_key for split assignment."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
tmp_path, docs = grouped_documents
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = docs
|
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
|
grouped_annotations.get(str(doc_id), [])
|
|
)
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in docs],
|
|
train_ratio=0.5,
|
|
val_ratio=0.5,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
# Get the document splits from add_dataset_documents call
|
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
|
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
|
|
|
# Build mapping of doc_id -> split
|
|
doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
|
|
|
|
# Verify all docs are assigned a valid split
|
|
for doc_id in doc_split_map:
|
|
assert doc_split_map[doc_id] in ("train", "val", "test")
|
|
|
|
# Verify multi-doc groups stay together
|
|
supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
|
|
supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
|
|
assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
|
|
|
|
supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
|
|
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
|
|
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
|
|
|
|
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
|
|
"""All docs with same group_key should go to same split."""
|
|
from inference.web.services.dataset_builder import DatasetBuilder
|
|
|
|
# Create 5 docs all with same group_key
|
|
docs = []
|
|
for i in range(5):
|
|
doc_id = uuid4()
|
|
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
|
doc_dir.mkdir(parents=True)
|
|
(doc_dir / "page_1.png").write_bytes(b"fake-png")
|
|
|
|
doc = MagicMock(spec=AdminDocument)
|
|
doc.document_id = doc_id
|
|
doc.filename = f"{doc_id}.pdf"
|
|
doc.page_count = 1
|
|
doc.group_key = "same-group"
|
|
docs.append(doc)
|
|
|
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
|
mock_admin_db.get_documents_by_ids.return_value = docs
|
|
mock_admin_db.get_annotations_for_document.return_value = []
|
|
|
|
dataset = mock_admin_db.create_dataset.return_value
|
|
builder.build_dataset(
|
|
dataset_id=str(dataset.dataset_id),
|
|
document_ids=[str(d.document_id) for d in docs],
|
|
train_ratio=0.6,
|
|
val_ratio=0.2,
|
|
seed=42,
|
|
admin_images_dir=tmp_path / "admin_images",
|
|
)
|
|
|
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
|
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
|
|
|
splits = [d["split"] for d in docs_added]
|
|
# All should be in the same split (one group)
|
|
assert len(set(splits)) == 1, "All docs with same group_key should be in same split"
|