restructure project
This commit is contained in:
249
packages/training/training/yolo/dataset_builder.py
Normal file
249
packages/training/training/yolo/dataset_builder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
YOLO Dataset Builder
|
||||
|
||||
Builds training dataset from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
import random
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetStats:
|
||||
"""Statistics about the generated dataset."""
|
||||
total_documents: int
|
||||
successful: int
|
||||
failed: int
|
||||
total_annotations: int
|
||||
annotations_per_class: dict[str, int]
|
||||
train_count: int
|
||||
val_count: int
|
||||
test_count: int
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training dataset from PDFs and CSV data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pdf_dir: str | Path,
|
||||
csv_path: str | Path,
|
||||
output_dir: str | Path,
|
||||
document_id_column: str = 'DocumentId',
|
||||
dpi: int = 300,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
test_ratio: float = 0.1
|
||||
):
|
||||
"""
|
||||
Initialize dataset builder.
|
||||
|
||||
Args:
|
||||
pdf_dir: Directory containing PDF files
|
||||
csv_path: Path to structured data CSV
|
||||
output_dir: Output directory for dataset
|
||||
document_id_column: Column name for document ID
|
||||
dpi: Resolution for rendering
|
||||
train_ratio: Fraction for training set
|
||||
val_ratio: Fraction for validation set
|
||||
test_ratio: Fraction for test set
|
||||
"""
|
||||
self.pdf_dir = Path(pdf_dir)
|
||||
self.csv_path = Path(csv_path)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.document_id_column = document_id_column
|
||||
self.dpi = dpi
|
||||
self.train_ratio = train_ratio
|
||||
self.val_ratio = val_ratio
|
||||
self.test_ratio = test_ratio
|
||||
|
||||
def load_structured_data(self) -> dict[str, dict]:
|
||||
"""Load structured data from CSV."""
|
||||
data = {}
|
||||
|
||||
with open(self.csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
doc_id = row.get(self.document_id_column)
|
||||
if doc_id:
|
||||
data[doc_id] = row
|
||||
|
||||
return data
|
||||
|
||||
def find_pdf_for_document(self, doc_id: str) -> Path | None:
|
||||
"""Find PDF file for a document ID."""
|
||||
# Try common naming patterns
|
||||
patterns = [
|
||||
f"{doc_id}.pdf",
|
||||
f"{doc_id.lower()}.pdf",
|
||||
f"{doc_id.upper()}.pdf",
|
||||
f"*{doc_id}*.pdf",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = list(self.pdf_dir.glob(pattern))
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
return None
|
||||
|
||||
def build(self, seed: int = 42) -> DatasetStats:
|
||||
"""
|
||||
Build the complete dataset.
|
||||
|
||||
Args:
|
||||
seed: Random seed for train/val/test split
|
||||
|
||||
Returns:
|
||||
DatasetStats with build results
|
||||
"""
|
||||
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
# Setup output directories
|
||||
for split in ['train', 'val', 'test']:
|
||||
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate config files
|
||||
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
|
||||
|
||||
# Load structured data
|
||||
structured_data = self.load_structured_data()
|
||||
|
||||
# Stats tracking
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'annotations': 0,
|
||||
'per_class': {name: 0 for name in CLASS_NAMES},
|
||||
'splits': {'train': 0, 'val': 0, 'test': 0}
|
||||
}
|
||||
|
||||
# Process each document
|
||||
processed_items = []
|
||||
|
||||
for doc_id, data in structured_data.items():
|
||||
stats['total'] += 1
|
||||
|
||||
pdf_path = self.find_pdf_for_document(doc_id)
|
||||
if not pdf_path:
|
||||
print(f"Warning: PDF not found for document {doc_id}")
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Generate to temp dir first
|
||||
temp_dir = self.output_dir / 'temp' / doc_id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from .annotation_generator import generate_annotations
|
||||
annotation_files = generate_annotations(
|
||||
pdf_path, data, temp_dir, self.dpi
|
||||
)
|
||||
|
||||
if annotation_files:
|
||||
processed_items.append({
|
||||
'doc_id': doc_id,
|
||||
'temp_dir': temp_dir,
|
||||
'annotation_files': annotation_files
|
||||
})
|
||||
stats['successful'] += 1
|
||||
else:
|
||||
print(f"Warning: No annotations generated for {doc_id}")
|
||||
stats['failed'] += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
stats['failed'] += 1
|
||||
|
||||
# Shuffle and split
|
||||
random.shuffle(processed_items)
|
||||
|
||||
n_train = int(len(processed_items) * self.train_ratio)
|
||||
n_val = int(len(processed_items) * self.val_ratio)
|
||||
|
||||
splits = {
|
||||
'train': processed_items[:n_train],
|
||||
'val': processed_items[n_train:n_train + n_val],
|
||||
'test': processed_items[n_train + n_val:]
|
||||
}
|
||||
|
||||
# Move files to final locations
|
||||
for split_name, items in splits.items():
|
||||
for item in items:
|
||||
temp_dir = item['temp_dir']
|
||||
|
||||
# Move images
|
||||
for img in (temp_dir / 'images').glob('*'):
|
||||
dest = self.output_dir / split_name / 'images' / img.name
|
||||
shutil.move(str(img), str(dest))
|
||||
|
||||
# Move labels and count annotations
|
||||
for label in (temp_dir / 'labels').glob('*.txt'):
|
||||
dest = self.output_dir / split_name / 'labels' / label.name
|
||||
shutil.move(str(label), str(dest))
|
||||
|
||||
# Count annotations per class
|
||||
with open(dest, 'r') as f:
|
||||
for line in f:
|
||||
class_id = int(line.strip().split()[0])
|
||||
if 0 <= class_id < len(CLASS_NAMES):
|
||||
stats['per_class'][CLASS_NAMES[class_id]] += 1
|
||||
stats['annotations'] += 1
|
||||
|
||||
stats['splits'][split_name] += 1
|
||||
|
||||
# Cleanup temp dir
|
||||
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
|
||||
|
||||
return DatasetStats(
|
||||
total_documents=stats['total'],
|
||||
successful=stats['successful'],
|
||||
failed=stats['failed'],
|
||||
total_annotations=stats['annotations'],
|
||||
annotations_per_class=stats['per_class'],
|
||||
train_count=stats['splits']['train'],
|
||||
val_count=stats['splits']['val'],
|
||||
test_count=stats['splits']['test']
|
||||
)
|
||||
|
||||
def process_single_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
data: dict,
|
||||
split: str = 'train'
|
||||
) -> bool:
|
||||
"""
|
||||
Process a single document (for incremental building).
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
data: Structured data dict
|
||||
split: Which split to add to
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from .annotation_generator import generate_annotations
|
||||
|
||||
pdf_path = self.find_pdf_for_document(doc_id)
|
||||
if not pdf_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
output_subdir = self.output_dir / split
|
||||
annotation_files = generate_annotations(
|
||||
pdf_path, data, output_subdir, self.dpi
|
||||
)
|
||||
return len(annotation_files) > 0
|
||||
except Exception as e:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user