Files
invoice-master-poc-v2/packages/backend/backend/web/services/dataset_builder.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

266 lines
8.7 KiB
Python

"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from shared.fields import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(
self,
datasets_repo,
documents_repo,
annotations_repo,
base_dir: Path,
):
self._datasets_repo = datasets_repo
self._documents_repo = documents_repo
self._annotations_repo = annotations_repo
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._documents_repo.get_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Group documents by group_key and assign splits
doc_list = list(documents)
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._datasets_repo.add_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._datasets_repo.update_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _assign_splits_by_group(
self,
documents: list,
train_ratio: float,
val_ratio: float,
seed: int,
) -> dict[str, str]:
"""Assign splits based on group_key.
Logic:
- Documents with same group_key stay together in the same split
- Groups with only 1 document go directly to train
- Groups with 2+ documents participate in shuffle & split
Args:
documents: List of AdminDocument objects
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
seed: Random seed for reproducibility
Returns:
Dict mapping document_id (str) -> split ("train"/"val"/"test")
"""
# Group documents by group_key
# None/empty group_key treated as unique (each doc is its own group)
groups: dict[str | None, list] = {}
for doc in documents:
key = doc.group_key if doc.group_key else None
if key is None:
# Treat each ungrouped doc as its own unique group
# Use document_id as pseudo-key
key = f"__ungrouped_{doc.document_id}"
groups.setdefault(key, []).append(doc)
# Separate single-doc groups from multi-doc groups
single_doc_groups: list[tuple[str | None, list]] = []
multi_doc_groups: list[tuple[str | None, list]] = []
for key, docs in groups.items():
if len(docs) == 1:
single_doc_groups.append((key, docs))
else:
multi_doc_groups.append((key, docs))
# Initialize result mapping
doc_splits: dict[str, str] = {}
# Combine all groups for splitting
all_groups = single_doc_groups + multi_doc_groups
# Shuffle all groups and assign splits
if all_groups:
rng = random.Random(seed)
rng.shuffle(all_groups)
n_groups = len(all_groups)
n_train = max(1, round(n_groups * train_ratio))
# Ensure at least 1 in val if we have more than 1 group
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
for i, (_key, docs) in enumerate(all_groups):
if i < n_train:
split = "train"
elif i < n_train + n_val:
split = "val"
else:
split = "test"
for doc in docs:
doc_splits[str(doc.document_id)] = split
logger.info(
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
len(all_groups),
sum(1 for s in doc_splits.values() if s == "train"),
sum(1 for s in doc_splits.values() if s == "val"),
)
return doc_splits
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))