WIP
This commit is contained in:
@@ -16,7 +16,6 @@ from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
TrainingDataset,
|
||||
FIELD_CLASSES,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,10 +34,10 @@ def tmp_admin_images(tmp_path):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Mock AdminDB with dataset and document methods."""
|
||||
db = MagicMock()
|
||||
db.create_dataset.return_value = TrainingDataset(
|
||||
def mock_datasets_repo():
|
||||
"""Mock DatasetRepository."""
|
||||
repo = MagicMock()
|
||||
repo.create.return_value = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
status="building",
|
||||
@@ -46,7 +45,19 @@ def mock_admin_db():
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
return db
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents_repo():
|
||||
"""Mock DocumentRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_annotations_repo():
|
||||
"""Mock AnnotationRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -60,6 +71,7 @@ def sample_documents(tmp_admin_images):
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 2
|
||||
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
|
||||
doc.group_key = None # Default to no group
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@@ -89,21 +101,27 @@ class TestDatasetBuilder:
|
||||
"""Tests for DatasetBuilder."""
|
||||
|
||||
def test_build_creates_directory_structure(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset_dir = tmp_path / "datasets" / "test"
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# Mock DB calls
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
# Mock repo calls
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -119,18 +137,24 @@ class TestDatasetBuilder:
|
||||
assert (result_dir / "labels" / split).exists()
|
||||
|
||||
def test_build_copies_images(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Images should be copied from admin_images to dataset folder."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
result = builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -149,18 +173,24 @@ class TestDatasetBuilder:
|
||||
assert total_images == 10 # 5 docs * 2 pages
|
||||
|
||||
def test_build_generates_yolo_labels(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""YOLO label files should be generated with correct format."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -187,18 +217,24 @@ class TestDatasetBuilder:
|
||||
assert 0 <= float(parts[2]) <= 1 # y_center
|
||||
|
||||
def test_build_generates_data_yaml(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""data.yaml should be generated with correct field classes."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -217,18 +253,24 @@ class TestDatasetBuilder:
|
||||
assert "invoice_number" in content
|
||||
|
||||
def test_build_splits_documents_correctly(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Documents should be split into train/val/test according to ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -238,8 +280,8 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Verify add_dataset_documents was called with correct splits
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
# Verify add_documents was called with correct splits
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
splits = [d["split"] for d in docs_added]
|
||||
assert "train" in splits
|
||||
@@ -248,18 +290,24 @@ class TestDatasetBuilder:
|
||||
assert train_count >= 3 # At least 3 of 5 should be train
|
||||
|
||||
def test_build_updates_status_to_ready(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""After successful build, dataset status should be updated to 'ready'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -269,22 +317,27 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
mock_datasets_repo.update_status.assert_called_once()
|
||||
call_kwargs = mock_datasets_repo.update_status.call_args[1]
|
||||
assert call_kwargs["status"] == "ready"
|
||||
assert call_kwargs["total_documents"] == 5
|
||||
assert call_kwargs["total_images"] == 10
|
||||
|
||||
def test_build_sets_failed_on_error(
|
||||
self, tmp_path, mock_admin_db
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""If build fails, dataset status should be set to 'failed'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = [] # No docs found
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
@@ -295,27 +348,33 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
mock_datasets_repo.update_status.assert_called_once()
|
||||
call_kwargs = mock_datasets_repo.update_status.call_args[1]
|
||||
assert call_kwargs["status"] == "failed"
|
||||
|
||||
def test_build_with_seed_produces_deterministic_splits(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Same seed should produce same splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
results = []
|
||||
for _ in range(2):
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
mock_admin_db.add_dataset_documents.reset_mock()
|
||||
mock_admin_db.update_dataset_status.reset_mock()
|
||||
mock_datasets_repo.add_documents.reset_mock()
|
||||
mock_datasets_repo.update_status.reset_mock()
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -324,7 +383,7 @@ class TestDatasetBuilder:
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||
|
||||
@@ -342,11 +401,18 @@ class TestAssignSplitsByGroup:
|
||||
doc.page_count = 1
|
||||
return doc
|
||||
|
||||
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
|
||||
def test_single_doc_groups_are_distributed(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with unique group_key are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 3 documents, each with unique group_key
|
||||
docs = [
|
||||
@@ -363,11 +429,18 @@ class TestAssignSplitsByGroup:
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1 # Ensure val is not empty
|
||||
|
||||
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
|
||||
def test_null_group_key_treated_as_single_doc_group(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
@@ -384,11 +457,18 @@ class TestAssignSplitsByGroup:
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1
|
||||
|
||||
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
|
||||
def test_multi_doc_groups_stay_together(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with same group_key should be assigned to the same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 6 documents in 2 groups
|
||||
docs = [
|
||||
@@ -410,11 +490,18 @@ class TestAssignSplitsByGroup:
|
||||
splits_b = [result[str(d.document_id)] for d in docs[3:]]
|
||||
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
|
||||
|
||||
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
|
||||
def test_multi_doc_groups_split_by_ratio(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Multi-doc groups should be split according to train/val/test ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 10 groups with 2 docs each
|
||||
docs = []
|
||||
@@ -445,11 +532,18 @@ class TestAssignSplitsByGroup:
|
||||
assert split_counts["val"] >= 1
|
||||
assert split_counts["val"] <= 3
|
||||
|
||||
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_mixed_single_and_multi_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
# Single-doc groups
|
||||
@@ -476,11 +570,18 @@ class TestAssignSplitsByGroup:
|
||||
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
|
||||
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
|
||||
|
||||
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
|
||||
def test_deterministic_with_seed(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Same seed should produce same split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
@@ -496,11 +597,18 @@ class TestAssignSplitsByGroup:
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
|
||||
def test_different_seed_may_produce_different_splits(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Different seeds should potentially produce different split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# Many groups to increase chance of different results
|
||||
docs = []
|
||||
@@ -515,11 +623,18 @@ class TestAssignSplitsByGroup:
|
||||
# Results should be different (very likely with 20 groups)
|
||||
assert result1 != result2
|
||||
|
||||
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
|
||||
def test_all_docs_assigned(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Every document should be assigned a split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
@@ -535,21 +650,35 @@ class TestAssignSplitsByGroup:
|
||||
assert str(doc.document_id) in result
|
||||
assert result[str(doc.document_id)] in ["train", "val", "test"]
|
||||
|
||||
def test_empty_documents_list(self, tmp_path, mock_admin_db):
|
||||
def test_empty_documents_list(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Empty document list should return empty result."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_only_multi_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have multiple docs, splits should follow ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 5 groups with 3 docs each
|
||||
docs = []
|
||||
@@ -574,11 +703,18 @@ class TestAssignSplitsByGroup:
|
||||
assert split_counts["train"] >= 2
|
||||
assert split_counts["train"] <= 4
|
||||
|
||||
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_only_single_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have single doc, they are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="unique-1"),
|
||||
@@ -658,20 +794,26 @@ class TestBuildDatasetWithGroupKey:
|
||||
return annotations
|
||||
|
||||
def test_build_respects_group_key_splits(
|
||||
self, grouped_documents, grouped_annotations, mock_admin_db
|
||||
self, grouped_documents, grouped_annotations,
|
||||
mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""build_dataset should use group_key for split assignment."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
tmp_path, docs = grouped_documents
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = docs
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
grouped_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
@@ -681,8 +823,8 @@ class TestBuildDatasetWithGroupKey:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Get the document splits from add_dataset_documents call
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
# Get the document splits from add_documents call
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
# Build mapping of doc_id -> split
|
||||
@@ -701,7 +843,9 @@ class TestBuildDatasetWithGroupKey:
|
||||
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
|
||||
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
|
||||
|
||||
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
|
||||
def test_build_with_all_same_group_key(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""All docs with same group_key should go to same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
@@ -720,11 +864,16 @@ class TestBuildDatasetWithGroupKey:
|
||||
doc.group_key = "same-group"
|
||||
docs.append(doc)
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.return_value = []
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = docs
|
||||
mock_annotations_repo.get_for_document.return_value = []
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
@@ -734,7 +883,7 @@ class TestBuildDatasetWithGroupKey:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
splits = [d["split"] for d in docs_added]
|
||||
|
||||
Reference in New Issue
Block a user