266 lines
8.7 KiB
Python
266 lines
8.7 KiB
Python
"""
|
|
Dataset Builder Service
|
|
|
|
Creates training datasets by copying images from admin storage,
|
|
generating YOLO label files, and splitting into train/val/test sets.
|
|
"""
|
|
|
|
import logging
|
|
import random
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from shared.fields import FIELD_CLASSES
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatasetBuilder:
|
|
"""Builds YOLO training datasets from admin documents."""
|
|
|
|
def __init__(
|
|
self,
|
|
datasets_repo,
|
|
documents_repo,
|
|
annotations_repo,
|
|
base_dir: Path,
|
|
):
|
|
self._datasets_repo = datasets_repo
|
|
self._documents_repo = documents_repo
|
|
self._annotations_repo = annotations_repo
|
|
self._base_dir = Path(base_dir)
|
|
|
|
def build_dataset(
|
|
self,
|
|
dataset_id: str,
|
|
document_ids: list[str],
|
|
train_ratio: float,
|
|
val_ratio: float,
|
|
seed: int,
|
|
admin_images_dir: Path,
|
|
) -> dict:
|
|
"""Build a complete YOLO dataset from document IDs.
|
|
|
|
Args:
|
|
dataset_id: UUID of the dataset record.
|
|
document_ids: List of document UUIDs to include.
|
|
train_ratio: Fraction for training set.
|
|
val_ratio: Fraction for validation set.
|
|
seed: Random seed for reproducible splits.
|
|
admin_images_dir: Root directory of admin images.
|
|
|
|
Returns:
|
|
Summary dict with total_documents, total_images, total_annotations.
|
|
|
|
Raises:
|
|
ValueError: If no valid documents found.
|
|
"""
|
|
try:
|
|
return self._do_build(
|
|
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
|
)
|
|
except Exception as e:
|
|
self._datasets_repo.update_status(
|
|
dataset_id=dataset_id,
|
|
status="failed",
|
|
error_message=str(e),
|
|
)
|
|
raise
|
|
|
|
def _do_build(
|
|
self,
|
|
dataset_id: str,
|
|
document_ids: list[str],
|
|
train_ratio: float,
|
|
val_ratio: float,
|
|
seed: int,
|
|
admin_images_dir: Path,
|
|
) -> dict:
|
|
# 1. Fetch documents
|
|
documents = self._documents_repo.get_by_ids(document_ids)
|
|
if not documents:
|
|
raise ValueError("No valid documents found for the given IDs")
|
|
|
|
# 2. Create directory structure
|
|
dataset_dir = self._base_dir / dataset_id
|
|
for split in ["train", "val", "test"]:
|
|
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
|
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
|
|
|
# 3. Group documents by group_key and assign splits
|
|
doc_list = list(documents)
|
|
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
|
|
|
|
# 4. Process each document
|
|
total_images = 0
|
|
total_annotations = 0
|
|
dataset_docs = []
|
|
|
|
for doc in doc_list:
|
|
doc_id = str(doc.document_id)
|
|
split = doc_splits[doc_id]
|
|
annotations = self._annotations_repo.get_for_document(str(doc.document_id))
|
|
|
|
# Group annotations by page
|
|
page_annotations: dict[int, list] = {}
|
|
for ann in annotations:
|
|
page_annotations.setdefault(ann.page_number, []).append(ann)
|
|
|
|
doc_image_count = 0
|
|
doc_ann_count = 0
|
|
|
|
# Copy images and write labels for each page
|
|
for page_num in range(1, doc.page_count + 1):
|
|
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
|
|
if not src_image.exists():
|
|
logger.warning("Image not found: %s", src_image)
|
|
continue
|
|
|
|
dst_name = f"{doc_id}_page{page_num}"
|
|
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
|
|
shutil.copy2(src_image, dst_image)
|
|
doc_image_count += 1
|
|
|
|
# Write YOLO label file
|
|
page_anns = page_annotations.get(page_num, [])
|
|
label_lines = []
|
|
for ann in page_anns:
|
|
label_lines.append(
|
|
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
|
|
f"{ann.width:.6f} {ann.height:.6f}"
|
|
)
|
|
doc_ann_count += 1
|
|
|
|
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
|
|
label_path.write_text("\n".join(label_lines))
|
|
|
|
total_images += doc_image_count
|
|
total_annotations += doc_ann_count
|
|
|
|
dataset_docs.append({
|
|
"document_id": doc_id,
|
|
"split": split,
|
|
"page_count": doc_image_count,
|
|
"annotation_count": doc_ann_count,
|
|
})
|
|
|
|
# 5. Record document-split assignments in DB
|
|
self._datasets_repo.add_documents(
|
|
dataset_id=dataset_id,
|
|
documents=dataset_docs,
|
|
)
|
|
|
|
# 6. Generate data.yaml
|
|
self._generate_data_yaml(dataset_dir)
|
|
|
|
# 7. Update dataset status
|
|
self._datasets_repo.update_status(
|
|
dataset_id=dataset_id,
|
|
status="ready",
|
|
total_documents=len(doc_list),
|
|
total_images=total_images,
|
|
total_annotations=total_annotations,
|
|
dataset_path=str(dataset_dir),
|
|
)
|
|
|
|
return {
|
|
"total_documents": len(doc_list),
|
|
"total_images": total_images,
|
|
"total_annotations": total_annotations,
|
|
}
|
|
|
|
def _assign_splits_by_group(
|
|
self,
|
|
documents: list,
|
|
train_ratio: float,
|
|
val_ratio: float,
|
|
seed: int,
|
|
) -> dict[str, str]:
|
|
"""Assign splits based on group_key.
|
|
|
|
Logic:
|
|
- Documents with same group_key stay together in the same split
|
|
- Groups with only 1 document go directly to train
|
|
- Groups with 2+ documents participate in shuffle & split
|
|
|
|
Args:
|
|
documents: List of AdminDocument objects
|
|
train_ratio: Fraction for training set
|
|
val_ratio: Fraction for validation set
|
|
seed: Random seed for reproducibility
|
|
|
|
Returns:
|
|
Dict mapping document_id (str) -> split ("train"/"val"/"test")
|
|
"""
|
|
# Group documents by group_key
|
|
# None/empty group_key treated as unique (each doc is its own group)
|
|
groups: dict[str | None, list] = {}
|
|
for doc in documents:
|
|
key = doc.group_key if doc.group_key else None
|
|
if key is None:
|
|
# Treat each ungrouped doc as its own unique group
|
|
# Use document_id as pseudo-key
|
|
key = f"__ungrouped_{doc.document_id}"
|
|
groups.setdefault(key, []).append(doc)
|
|
|
|
# Separate single-doc groups from multi-doc groups
|
|
single_doc_groups: list[tuple[str | None, list]] = []
|
|
multi_doc_groups: list[tuple[str | None, list]] = []
|
|
|
|
for key, docs in groups.items():
|
|
if len(docs) == 1:
|
|
single_doc_groups.append((key, docs))
|
|
else:
|
|
multi_doc_groups.append((key, docs))
|
|
|
|
# Initialize result mapping
|
|
doc_splits: dict[str, str] = {}
|
|
|
|
# Combine all groups for splitting
|
|
all_groups = single_doc_groups + multi_doc_groups
|
|
|
|
# Shuffle all groups and assign splits
|
|
if all_groups:
|
|
rng = random.Random(seed)
|
|
rng.shuffle(all_groups)
|
|
|
|
n_groups = len(all_groups)
|
|
n_train = max(1, round(n_groups * train_ratio))
|
|
# Ensure at least 1 in val if we have more than 1 group
|
|
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
|
|
|
|
for i, (_key, docs) in enumerate(all_groups):
|
|
if i < n_train:
|
|
split = "train"
|
|
elif i < n_train + n_val:
|
|
split = "val"
|
|
else:
|
|
split = "test"
|
|
|
|
for doc in docs:
|
|
doc_splits[str(doc.document_id)] = split
|
|
|
|
logger.info(
|
|
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
|
|
len(all_groups),
|
|
sum(1 for s in doc_splits.values() if s == "train"),
|
|
sum(1 for s in doc_splits.values() if s == "val"),
|
|
)
|
|
|
|
return doc_splits
|
|
|
|
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
|
"""Generate YOLO data.yaml configuration file."""
|
|
data = {
|
|
"path": str(dataset_dir.absolute()),
|
|
"train": "images/train",
|
|
"val": "images/val",
|
|
"test": "images/test",
|
|
"nc": len(FIELD_CLASSES),
|
|
"names": FIELD_CLASSES,
|
|
}
|
|
yaml_path = dataset_dir / "data.yaml"
|
|
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))
|