restructure project
This commit is contained in:
331
tests/web/test_dataset_builder.py
Normal file
331
tests/web/test_dataset_builder.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Tests for DatasetBuilder service.
|
||||
|
||||
TDD: Write tests first, then implement dataset_builder.py.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
TrainingDataset,
|
||||
FIELD_CLASSES,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_admin_images(tmp_path):
|
||||
"""Create mock admin images directory with sample images."""
|
||||
doc_ids = [uuid4() for _ in range(5)]
|
||||
for doc_id in doc_ids:
|
||||
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||
doc_dir.mkdir(parents=True)
|
||||
# Create 2 pages per doc
|
||||
for page in range(1, 3):
|
||||
img_path = doc_dir / f"page_{page}.png"
|
||||
img_path.write_bytes(b"fake-png-data")
|
||||
return tmp_path, doc_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Mock AdminDB with dataset and document methods."""
|
||||
db = MagicMock()
|
||||
db.create_dataset.return_value = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
status="building",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(tmp_admin_images):
|
||||
"""Create sample AdminDocument objects."""
|
||||
tmp_path, doc_ids = tmp_admin_images
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = MagicMock(spec=AdminDocument)
|
||||
doc.document_id = doc_id
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 2
|
||||
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotations(sample_documents):
|
||||
"""Create sample annotations for each document page."""
|
||||
annotations = {}
|
||||
for doc in sample_documents:
|
||||
doc_anns = []
|
||||
for page in range(1, 3):
|
||||
ann = MagicMock(spec=AdminAnnotation)
|
||||
ann.document_id = doc.document_id
|
||||
ann.page_number = page
|
||||
ann.class_id = 0
|
||||
ann.class_name = "invoice_number"
|
||||
ann.x_center = 0.5
|
||||
ann.y_center = 0.3
|
||||
ann.width = 0.2
|
||||
ann.height = 0.05
|
||||
doc_anns.append(ann)
|
||||
annotations[str(doc.document_id)] = doc_anns
|
||||
return annotations
|
||||
|
||||
|
||||
class TestDatasetBuilder:
|
||||
"""Tests for DatasetBuilder."""
|
||||
|
||||
def test_build_creates_directory_structure(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset_dir = tmp_path / "datasets" / "test"
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# Mock DB calls
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
for split in ["train", "val", "test"]:
|
||||
assert (result_dir / "images" / split).exists()
|
||||
assert (result_dir / "labels" / split).exists()
|
||||
|
||||
def test_build_copies_images(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Images should be copied from admin_images to dataset folder."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
result = builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Check total images copied
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
total_images = sum(
|
||||
len(list((result_dir / "images" / split).glob("*.png")))
|
||||
for split in ["train", "val", "test"]
|
||||
)
|
||||
assert total_images == 10 # 5 docs * 2 pages
|
||||
|
||||
def test_build_generates_yolo_labels(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""YOLO label files should be generated with correct format."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
result_dir = tmp_path / "datasets" / str(dataset.dataset_id)
|
||||
total_labels = sum(
|
||||
len(list((result_dir / "labels" / split).glob("*.txt")))
|
||||
for split in ["train", "val", "test"]
|
||||
)
|
||||
assert total_labels == 10 # 5 docs * 2 pages
|
||||
|
||||
# Check label format: "class_id x_center y_center width height"
|
||||
label_files = list((result_dir / "labels").rglob("*.txt"))
|
||||
content = label_files[0].read_text().strip()
|
||||
parts = content.split()
|
||||
assert len(parts) == 5
|
||||
assert int(parts[0]) == 0 # class_id
|
||||
assert 0 <= float(parts[1]) <= 1 # x_center
|
||||
assert 0 <= float(parts[2]) <= 1 # y_center
|
||||
|
||||
def test_build_generates_data_yaml(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""data.yaml should be generated with correct field classes."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml"
|
||||
assert yaml_path.exists()
|
||||
content = yaml_path.read_text()
|
||||
assert "train:" in content
|
||||
assert "val:" in content
|
||||
assert "nc:" in content
|
||||
assert "invoice_number" in content
|
||||
|
||||
def test_build_splits_documents_correctly(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Documents should be split into train/val/test according to ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Verify add_dataset_documents was called with correct splits
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
splits = [d["split"] for d in docs_added]
|
||||
assert "train" in splits
|
||||
# With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test
|
||||
train_count = splits.count("train")
|
||||
assert train_count >= 3 # At least 3 of 5 should be train
|
||||
|
||||
def test_build_updates_status_to_ready(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""After successful build, dataset status should be updated to 'ready'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
assert call_kwargs["status"] == "ready"
|
||||
assert call_kwargs["total_documents"] == 5
|
||||
assert call_kwargs["total_images"] == 10
|
||||
|
||||
def test_build_sets_failed_on_error(
|
||||
self, tmp_path, mock_admin_db
|
||||
):
|
||||
"""If build fails, dataset status should be set to 'failed'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=["nonexistent-id"],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
assert call_kwargs["status"] == "failed"
|
||||
|
||||
def test_build_with_seed_produces_deterministic_splits(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
):
|
||||
"""Same seed should produce same splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
results = []
|
||||
for _ in range(2):
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
mock_admin_db.add_dataset_documents.reset_mock()
|
||||
mock_admin_db.update_dataset_status.reset_mock()
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||
|
||||
assert results[0] == results[1]
|
||||
Reference in New Issue
Block a user