189 lines
5.9 KiB
Python
189 lines
5.9 KiB
Python
"""
|
|
Dataset Builder Service
|
|
|
|
Creates training datasets by copying images from admin storage,
|
|
generating YOLO label files, and splitting into train/val/test sets.
|
|
"""
|
|
|
|
import logging
|
|
import random
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from inference.data.admin_models import FIELD_CLASSES
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatasetBuilder:
|
|
"""Builds YOLO training datasets from admin documents."""
|
|
|
|
def __init__(self, db, base_dir: Path):
|
|
self._db = db
|
|
self._base_dir = Path(base_dir)
|
|
|
|
def build_dataset(
|
|
self,
|
|
dataset_id: str,
|
|
document_ids: list[str],
|
|
train_ratio: float,
|
|
val_ratio: float,
|
|
seed: int,
|
|
admin_images_dir: Path,
|
|
) -> dict:
|
|
"""Build a complete YOLO dataset from document IDs.
|
|
|
|
Args:
|
|
dataset_id: UUID of the dataset record.
|
|
document_ids: List of document UUIDs to include.
|
|
train_ratio: Fraction for training set.
|
|
val_ratio: Fraction for validation set.
|
|
seed: Random seed for reproducible splits.
|
|
admin_images_dir: Root directory of admin images.
|
|
|
|
Returns:
|
|
Summary dict with total_documents, total_images, total_annotations.
|
|
|
|
Raises:
|
|
ValueError: If no valid documents found.
|
|
"""
|
|
try:
|
|
return self._do_build(
|
|
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
|
)
|
|
except Exception as e:
|
|
self._db.update_dataset_status(
|
|
dataset_id=dataset_id,
|
|
status="failed",
|
|
error_message=str(e),
|
|
)
|
|
raise
|
|
|
|
def _do_build(
|
|
self,
|
|
dataset_id: str,
|
|
document_ids: list[str],
|
|
train_ratio: float,
|
|
val_ratio: float,
|
|
seed: int,
|
|
admin_images_dir: Path,
|
|
) -> dict:
|
|
# 1. Fetch documents
|
|
documents = self._db.get_documents_by_ids(document_ids)
|
|
if not documents:
|
|
raise ValueError("No valid documents found for the given IDs")
|
|
|
|
# 2. Create directory structure
|
|
dataset_dir = self._base_dir / dataset_id
|
|
for split in ["train", "val", "test"]:
|
|
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
|
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
|
|
|
# 3. Shuffle and split documents
|
|
doc_list = list(documents)
|
|
rng = random.Random(seed)
|
|
rng.shuffle(doc_list)
|
|
|
|
n = len(doc_list)
|
|
n_train = max(1, round(n * train_ratio))
|
|
n_val = max(0, round(n * val_ratio))
|
|
n_test = n - n_train - n_val
|
|
|
|
splits = (
|
|
["train"] * n_train
|
|
+ ["val"] * n_val
|
|
+ ["test"] * n_test
|
|
)
|
|
|
|
# 4. Process each document
|
|
total_images = 0
|
|
total_annotations = 0
|
|
dataset_docs = []
|
|
|
|
for doc, split in zip(doc_list, splits):
|
|
doc_id = str(doc.document_id)
|
|
annotations = self._db.get_annotations_for_document(doc.document_id)
|
|
|
|
# Group annotations by page
|
|
page_annotations: dict[int, list] = {}
|
|
for ann in annotations:
|
|
page_annotations.setdefault(ann.page_number, []).append(ann)
|
|
|
|
doc_image_count = 0
|
|
doc_ann_count = 0
|
|
|
|
# Copy images and write labels for each page
|
|
for page_num in range(1, doc.page_count + 1):
|
|
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
|
|
if not src_image.exists():
|
|
logger.warning("Image not found: %s", src_image)
|
|
continue
|
|
|
|
dst_name = f"{doc_id}_page{page_num}"
|
|
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
|
|
shutil.copy2(src_image, dst_image)
|
|
doc_image_count += 1
|
|
|
|
# Write YOLO label file
|
|
page_anns = page_annotations.get(page_num, [])
|
|
label_lines = []
|
|
for ann in page_anns:
|
|
label_lines.append(
|
|
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
|
|
f"{ann.width:.6f} {ann.height:.6f}"
|
|
)
|
|
doc_ann_count += 1
|
|
|
|
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
|
|
label_path.write_text("\n".join(label_lines))
|
|
|
|
total_images += doc_image_count
|
|
total_annotations += doc_ann_count
|
|
|
|
dataset_docs.append({
|
|
"document_id": doc_id,
|
|
"split": split,
|
|
"page_count": doc_image_count,
|
|
"annotation_count": doc_ann_count,
|
|
})
|
|
|
|
# 5. Record document-split assignments in DB
|
|
self._db.add_dataset_documents(
|
|
dataset_id=dataset_id,
|
|
documents=dataset_docs,
|
|
)
|
|
|
|
# 6. Generate data.yaml
|
|
self._generate_data_yaml(dataset_dir)
|
|
|
|
# 7. Update dataset status
|
|
self._db.update_dataset_status(
|
|
dataset_id=dataset_id,
|
|
status="ready",
|
|
total_documents=len(doc_list),
|
|
total_images=total_images,
|
|
total_annotations=total_annotations,
|
|
dataset_path=str(dataset_dir),
|
|
)
|
|
|
|
return {
|
|
"total_documents": len(doc_list),
|
|
"total_images": total_images,
|
|
"total_annotations": total_annotations,
|
|
}
|
|
|
|
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
|
"""Generate YOLO data.yaml configuration file."""
|
|
data = {
|
|
"path": str(dataset_dir.absolute()),
|
|
"train": "images/train",
|
|
"val": "images/val",
|
|
"test": "images/test",
|
|
"nc": len(FIELD_CLASSES),
|
|
"names": FIELD_CLASSES,
|
|
}
|
|
yaml_path = dataset_dir / "data.yaml"
|
|
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))
|