Files
invoice-master-poc-v2/packages/inference/inference/web/services/dataset_builder.py
2026-01-27 23:58:17 +01:00

189 lines
5.9 KiB
Python

"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from inference.data.admin_models import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(self, db, base_dir: Path):
self._db = db
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._db.update_dataset_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._db.get_documents_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
doc_id = str(doc.document_id)
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._db.add_dataset_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._db.update_dataset_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))