Files
invoice-master/scripts/03_split_dataset.py
Yaojia Wang dafa86c588 Init
2025-10-26 20:41:11 +01:00

124 lines
3.8 KiB
Python

"""
Dataset Split Script - Step 3
Splits images and labels into training and validation sets
"""
import shutil
import random
from pathlib import Path
import os
# Paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
# Configuration
VALIDATION_SPLIT = 0.2 # 20% for validation
RANDOM_SEED = 42
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
"""
Split dataset into training and validation sets
Args:
val_split: Fraction of data to use for validation (0.0 to 1.0)
seed: Random seed for reproducibility
"""
print("="*60)
print("Splitting Dataset into Train/Val Sets")
print("="*60)
print(f"Validation split: {val_split*100:.1f}%")
print(f"Random seed: {seed}\n")
# Check if temp directories exist
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
print(BASE_DIR)
print(YOLO_DATASET_DIR)
print(TEMP_IMAGES_DIR)
print(f"Error: Temporary directories not found")
print(f"Please run 02_create_labels.py first")
return
# Get all image files
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
if not image_files:
print(f"No image files found in {TEMP_IMAGES_DIR}")
return
# Filter images that have corresponding labels
valid_pairs = []
for image_file in image_files:
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
if label_file.exists():
valid_pairs.append({
"image": image_file,
"label": label_file
})
if not valid_pairs:
print("No valid image-label pairs found")
return
print(f"Found {len(valid_pairs)} image-label pair(s)")
# Shuffle and split
random.seed(seed)
random.shuffle(valid_pairs)
split_index = int(len(valid_pairs) * (1 - val_split))
train_pairs = valid_pairs[:split_index]
val_pairs = valid_pairs[split_index:]
print(f"\nSplit results:")
print(f" Training set: {len(train_pairs)} samples")
print(f" Validation set: {len(val_pairs)} samples")
print()
# Clear existing train/val directories
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
if directory.exists():
shutil.rmtree(directory)
directory.mkdir(parents=True, exist_ok=True)
# Copy training files
print("Copying training files...")
for pair in train_pairs:
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
print(f" Copied {len(train_pairs)} image-label pairs to train/")
# Copy validation files
print("Copying validation files...")
for pair in val_pairs:
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
print(f" Copied {len(val_pairs)} image-label pairs to val/")
print("\n" + "="*60)
print("Dataset split complete!")
print(f"\nDataset structure:")
print(f" {TRAIN_IMAGES_DIR}")
print(f" {TRAIN_LABELS_DIR}")
print(f" {VAL_IMAGES_DIR}")
print(f" {VAL_LABELS_DIR}")
print(f"\nNext step: Run 04_train_yolo.py to train the model")
print("="*60)
def main():
"""Main function"""
split_dataset()
if __name__ == "__main__":
main()