251 lines
7.8 KiB
Python
251 lines
7.8 KiB
Python
"""
|
|
YOLO Dataset Builder
|
|
|
|
Builds training dataset from PDFs and structured CSV data.
|
|
"""
|
|
|
|
import csv
|
|
import shutil
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
import random
|
|
|
|
|
|
@dataclass
|
|
class DatasetStats:
|
|
"""Statistics about the generated dataset."""
|
|
total_documents: int
|
|
successful: int
|
|
failed: int
|
|
total_annotations: int
|
|
annotations_per_class: dict[str, int]
|
|
train_count: int
|
|
val_count: int
|
|
test_count: int
|
|
|
|
|
|
class DatasetBuilder:
|
|
"""Builds YOLO training dataset from PDFs and CSV data."""
|
|
|
|
def __init__(
|
|
self,
|
|
pdf_dir: str | Path,
|
|
csv_path: str | Path,
|
|
output_dir: str | Path,
|
|
document_id_column: str = 'DocumentId',
|
|
dpi: int = 300,
|
|
train_ratio: float = 0.8,
|
|
val_ratio: float = 0.1,
|
|
test_ratio: float = 0.1
|
|
):
|
|
"""
|
|
Initialize dataset builder.
|
|
|
|
Args:
|
|
pdf_dir: Directory containing PDF files
|
|
csv_path: Path to structured data CSV
|
|
output_dir: Output directory for dataset
|
|
document_id_column: Column name for document ID
|
|
dpi: Resolution for rendering
|
|
train_ratio: Fraction for training set
|
|
val_ratio: Fraction for validation set
|
|
test_ratio: Fraction for test set
|
|
"""
|
|
self.pdf_dir = Path(pdf_dir)
|
|
self.csv_path = Path(csv_path)
|
|
self.output_dir = Path(output_dir)
|
|
self.document_id_column = document_id_column
|
|
self.dpi = dpi
|
|
self.train_ratio = train_ratio
|
|
self.val_ratio = val_ratio
|
|
self.test_ratio = test_ratio
|
|
|
|
def load_structured_data(self) -> dict[str, dict]:
|
|
"""Load structured data from CSV."""
|
|
data = {}
|
|
|
|
with open(self.csv_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
doc_id = row.get(self.document_id_column)
|
|
if doc_id:
|
|
data[doc_id] = row
|
|
|
|
return data
|
|
|
|
def find_pdf_for_document(self, doc_id: str) -> Path | None:
|
|
"""Find PDF file for a document ID."""
|
|
# Try common naming patterns
|
|
patterns = [
|
|
f"{doc_id}.pdf",
|
|
f"{doc_id.lower()}.pdf",
|
|
f"{doc_id.upper()}.pdf",
|
|
f"*{doc_id}*.pdf",
|
|
]
|
|
|
|
for pattern in patterns:
|
|
matches = list(self.pdf_dir.glob(pattern))
|
|
if matches:
|
|
return matches[0]
|
|
|
|
return None
|
|
|
|
def build(self, seed: int = 42) -> DatasetStats:
|
|
"""
|
|
Build the complete dataset.
|
|
|
|
Args:
|
|
seed: Random seed for train/val/test split
|
|
|
|
Returns:
|
|
DatasetStats with build results
|
|
"""
|
|
from shared.fields import CLASS_NAMES
|
|
from .annotation_generator import AnnotationGenerator
|
|
|
|
random.seed(seed)
|
|
|
|
# Setup output directories
|
|
for split in ['train', 'val', 'test']:
|
|
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
|
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate config files
|
|
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
|
|
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
|
|
|
|
# Load structured data
|
|
structured_data = self.load_structured_data()
|
|
|
|
# Stats tracking
|
|
stats = {
|
|
'total': 0,
|
|
'successful': 0,
|
|
'failed': 0,
|
|
'annotations': 0,
|
|
'per_class': {name: 0 for name in CLASS_NAMES},
|
|
'splits': {'train': 0, 'val': 0, 'test': 0}
|
|
}
|
|
|
|
# Process each document
|
|
processed_items = []
|
|
|
|
for doc_id, data in structured_data.items():
|
|
stats['total'] += 1
|
|
|
|
pdf_path = self.find_pdf_for_document(doc_id)
|
|
if not pdf_path:
|
|
print(f"Warning: PDF not found for document {doc_id}")
|
|
stats['failed'] += 1
|
|
continue
|
|
|
|
try:
|
|
# Generate to temp dir first
|
|
temp_dir = self.output_dir / 'temp' / doc_id
|
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
from .annotation_generator import generate_annotations
|
|
annotation_files = generate_annotations(
|
|
pdf_path, data, temp_dir, self.dpi
|
|
)
|
|
|
|
if annotation_files:
|
|
processed_items.append({
|
|
'doc_id': doc_id,
|
|
'temp_dir': temp_dir,
|
|
'annotation_files': annotation_files
|
|
})
|
|
stats['successful'] += 1
|
|
else:
|
|
print(f"Warning: No annotations generated for {doc_id}")
|
|
stats['failed'] += 1
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {doc_id}: {e}")
|
|
stats['failed'] += 1
|
|
|
|
# Shuffle and split
|
|
random.shuffle(processed_items)
|
|
|
|
n_train = int(len(processed_items) * self.train_ratio)
|
|
n_val = int(len(processed_items) * self.val_ratio)
|
|
|
|
splits = {
|
|
'train': processed_items[:n_train],
|
|
'val': processed_items[n_train:n_train + n_val],
|
|
'test': processed_items[n_train + n_val:]
|
|
}
|
|
|
|
# Move files to final locations
|
|
for split_name, items in splits.items():
|
|
for item in items:
|
|
temp_dir = item['temp_dir']
|
|
|
|
# Move images
|
|
for img in (temp_dir / 'images').glob('*'):
|
|
dest = self.output_dir / split_name / 'images' / img.name
|
|
shutil.move(str(img), str(dest))
|
|
|
|
# Move labels and count annotations
|
|
for label in (temp_dir / 'labels').glob('*.txt'):
|
|
dest = self.output_dir / split_name / 'labels' / label.name
|
|
shutil.move(str(label), str(dest))
|
|
|
|
# Count annotations per class
|
|
with open(dest, 'r') as f:
|
|
for line in f:
|
|
class_id = int(line.strip().split()[0])
|
|
if 0 <= class_id < len(CLASS_NAMES):
|
|
stats['per_class'][CLASS_NAMES[class_id]] += 1
|
|
stats['annotations'] += 1
|
|
|
|
stats['splits'][split_name] += 1
|
|
|
|
# Cleanup temp dir
|
|
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
|
|
|
|
return DatasetStats(
|
|
total_documents=stats['total'],
|
|
successful=stats['successful'],
|
|
failed=stats['failed'],
|
|
total_annotations=stats['annotations'],
|
|
annotations_per_class=stats['per_class'],
|
|
train_count=stats['splits']['train'],
|
|
val_count=stats['splits']['val'],
|
|
test_count=stats['splits']['test']
|
|
)
|
|
|
|
def process_single_document(
|
|
self,
|
|
doc_id: str,
|
|
data: dict,
|
|
split: str = 'train'
|
|
) -> bool:
|
|
"""
|
|
Process a single document (for incremental building).
|
|
|
|
Args:
|
|
doc_id: Document ID
|
|
data: Structured data dict
|
|
split: Which split to add to
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from .annotation_generator import generate_annotations
|
|
|
|
pdf_path = self.find_pdf_for_document(doc_id)
|
|
if not pdf_path:
|
|
return False
|
|
|
|
try:
|
|
output_subdir = self.output_dir / split
|
|
annotation_files = generate_annotations(
|
|
pdf_path, data, output_subdir, self.dpi
|
|
)
|
|
return len(annotation_files) > 0
|
|
except Exception as e:
|
|
print(f"Error processing {doc_id}: {e}")
|
|
return False
|