""" Tests for DatasetBuilder service. TDD: Write tests first, then implement dataset_builder.py. """ import shutil import tempfile from pathlib import Path from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest from inference.data.admin_models import ( AdminAnnotation, AdminDocument, TrainingDataset, FIELD_CLASSES, ) @pytest.fixture def tmp_admin_images(tmp_path): """Create mock admin images directory with sample images.""" doc_ids = [uuid4() for _ in range(5)] for doc_id in doc_ids: doc_dir = tmp_path / "admin_images" / str(doc_id) doc_dir.mkdir(parents=True) # Create 2 pages per doc for page in range(1, 3): img_path = doc_dir / f"page_{page}.png" img_path.write_bytes(b"fake-png-data") return tmp_path, doc_ids @pytest.fixture def mock_admin_db(): """Mock AdminDB with dataset and document methods.""" db = MagicMock() db.create_dataset.return_value = TrainingDataset( dataset_id=uuid4(), name="test-dataset", status="building", train_ratio=0.8, val_ratio=0.1, seed=42, ) return db @pytest.fixture def sample_documents(tmp_admin_images): """Create sample AdminDocument objects.""" tmp_path, doc_ids = tmp_admin_images docs = [] for doc_id in doc_ids: doc = MagicMock(spec=AdminDocument) doc.document_id = doc_id doc.filename = f"{doc_id}.pdf" doc.page_count = 2 doc.file_path = str(tmp_path / "admin_images" / str(doc_id)) docs.append(doc) return docs @pytest.fixture def sample_annotations(sample_documents): """Create sample annotations for each document page.""" annotations = {} for doc in sample_documents: doc_anns = [] for page in range(1, 3): ann = MagicMock(spec=AdminAnnotation) ann.document_id = doc.document_id ann.page_number = page ann.class_id = 0 ann.class_name = "invoice_number" ann.x_center = 0.5 ann.y_center = 0.3 ann.width = 0.2 ann.height = 0.05 doc_anns.append(ann) annotations[str(doc.document_id)] = doc_anns return annotations class TestDatasetBuilder: """Tests for DatasetBuilder.""" def test_build_creates_directory_structure( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """Dataset builder should create images/ and labels/ with train/val/test subdirs.""" from inference.web.services.dataset_builder import DatasetBuilder dataset_dir = tmp_path / "datasets" / "test" builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") # Mock DB calls mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) result_dir = tmp_path / "datasets" / str(dataset.dataset_id) for split in ["train", "val", "test"]: assert (result_dir / "images" / split).exists() assert (result_dir / "labels" / split).exists() def test_build_copies_images( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """Images should be copied from admin_images to dataset folder.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value result = builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) # Check total images copied result_dir = tmp_path / "datasets" / str(dataset.dataset_id) total_images = sum( len(list((result_dir / "images" / split).glob("*.png"))) for split in ["train", "val", "test"] ) assert total_images == 10 # 5 docs * 2 pages def test_build_generates_yolo_labels( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """YOLO label files should be generated with correct format.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) result_dir = tmp_path / "datasets" / str(dataset.dataset_id) total_labels = sum( len(list((result_dir / "labels" / split).glob("*.txt"))) for split in ["train", "val", "test"] ) assert total_labels == 10 # 5 docs * 2 pages # Check label format: "class_id x_center y_center width height" label_files = list((result_dir / "labels").rglob("*.txt")) content = label_files[0].read_text().strip() parts = content.split() assert len(parts) == 5 assert int(parts[0]) == 0 # class_id assert 0 <= float(parts[1]) <= 1 # x_center assert 0 <= float(parts[2]) <= 1 # y_center def test_build_generates_data_yaml( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """data.yaml should be generated with correct field classes.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml" assert yaml_path.exists() content = yaml_path.read_text() assert "train:" in content assert "val:" in content assert "nc:" in content assert "invoice_number" in content def test_build_splits_documents_correctly( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """Documents should be split into train/val/test according to ratios.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) # Verify add_dataset_documents was called with correct splits call_args = mock_admin_db.add_dataset_documents.call_args docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1] splits = [d["split"] for d in docs_added] assert "train" in splits # With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test train_count = splits.count("train") assert train_count >= 3 # At least 3 of 5 should be train def test_build_updates_status_to_ready( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """After successful build, dataset status should be updated to 'ready'.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) mock_admin_db.update_dataset_status.assert_called_once() call_kwargs = mock_admin_db.update_dataset_status.call_args[1] assert call_kwargs["status"] == "ready" assert call_kwargs["total_documents"] == 5 assert call_kwargs["total_images"] == 10 def test_build_sets_failed_on_error( self, tmp_path, mock_admin_db ): """If build fails, dataset status should be set to 'failed'.""" from inference.web.services.dataset_builder import DatasetBuilder builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = [] # No docs found dataset = mock_admin_db.create_dataset.return_value with pytest.raises(ValueError): builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=["nonexistent-id"], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) mock_admin_db.update_dataset_status.assert_called_once() call_kwargs = mock_admin_db.update_dataset_status.call_args[1] assert call_kwargs["status"] == "failed" def test_build_with_seed_produces_deterministic_splits( self, tmp_path, mock_admin_db, sample_documents, sample_annotations ): """Same seed should produce same splits.""" from inference.web.services.dataset_builder import DatasetBuilder results = [] for _ in range(2): builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") mock_admin_db.get_documents_by_ids.return_value = sample_documents mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( sample_annotations.get(str(doc_id), []) ) mock_admin_db.add_dataset_documents.reset_mock() mock_admin_db.update_dataset_status.reset_mock() dataset = mock_admin_db.create_dataset.return_value builder.build_dataset( dataset_id=str(dataset.dataset_id), document_ids=[str(d.document_id) for d in sample_documents], train_ratio=0.8, val_ratio=0.1, seed=42, admin_images_dir=tmp_path / "admin_images", ) call_args = mock_admin_db.add_dataset_documents.call_args docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1] results.append([(d["document_id"], d["split"]) for d in docs]) assert results[0] == results[1]