124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
"""
|
|
Dataset Split Script - Step 3
|
|
Splits images and labels into training and validation sets
|
|
"""
|
|
|
|
import shutil
|
|
import random
|
|
from pathlib import Path
|
|
import os
|
|
|
|
# Paths
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
|
|
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
|
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
|
|
|
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
|
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
|
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
|
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
|
|
|
# Configuration
|
|
VALIDATION_SPLIT = 0.2 # 20% for validation
|
|
RANDOM_SEED = 42
|
|
|
|
|
|
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
|
|
"""
|
|
Split dataset into training and validation sets
|
|
|
|
Args:
|
|
val_split: Fraction of data to use for validation (0.0 to 1.0)
|
|
seed: Random seed for reproducibility
|
|
"""
|
|
print("="*60)
|
|
print("Splitting Dataset into Train/Val Sets")
|
|
print("="*60)
|
|
print(f"Validation split: {val_split*100:.1f}%")
|
|
print(f"Random seed: {seed}\n")
|
|
|
|
# Check if temp directories exist
|
|
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
|
|
print(BASE_DIR)
|
|
print(YOLO_DATASET_DIR)
|
|
print(TEMP_IMAGES_DIR)
|
|
print(f"Error: Temporary directories not found")
|
|
print(f"Please run 02_create_labels.py first")
|
|
return
|
|
|
|
# Get all image files
|
|
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
|
|
|
|
if not image_files:
|
|
print(f"No image files found in {TEMP_IMAGES_DIR}")
|
|
return
|
|
|
|
# Filter images that have corresponding labels
|
|
valid_pairs = []
|
|
for image_file in image_files:
|
|
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
|
|
if label_file.exists():
|
|
valid_pairs.append({
|
|
"image": image_file,
|
|
"label": label_file
|
|
})
|
|
|
|
if not valid_pairs:
|
|
print("No valid image-label pairs found")
|
|
return
|
|
|
|
print(f"Found {len(valid_pairs)} image-label pair(s)")
|
|
|
|
# Shuffle and split
|
|
random.seed(seed)
|
|
random.shuffle(valid_pairs)
|
|
|
|
split_index = int(len(valid_pairs) * (1 - val_split))
|
|
train_pairs = valid_pairs[:split_index]
|
|
val_pairs = valid_pairs[split_index:]
|
|
|
|
print(f"\nSplit results:")
|
|
print(f" Training set: {len(train_pairs)} samples")
|
|
print(f" Validation set: {len(val_pairs)} samples")
|
|
print()
|
|
|
|
# Clear existing train/val directories
|
|
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
|
|
if directory.exists():
|
|
shutil.rmtree(directory)
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Copy training files
|
|
print("Copying training files...")
|
|
for pair in train_pairs:
|
|
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
|
|
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
|
|
print(f" Copied {len(train_pairs)} image-label pairs to train/")
|
|
|
|
# Copy validation files
|
|
print("Copying validation files...")
|
|
for pair in val_pairs:
|
|
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
|
|
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
|
|
print(f" Copied {len(val_pairs)} image-label pairs to val/")
|
|
|
|
print("\n" + "="*60)
|
|
print("Dataset split complete!")
|
|
print(f"\nDataset structure:")
|
|
print(f" {TRAIN_IMAGES_DIR}")
|
|
print(f" {TRAIN_LABELS_DIR}")
|
|
print(f" {VAL_IMAGES_DIR}")
|
|
print(f" {VAL_LABELS_DIR}")
|
|
print(f"\nNext step: Run 04_train_yolo.py to train the model")
|
|
print("="*60)
|
|
|
|
|
|
def main():
|
|
"""Main function"""
|
|
split_dataset()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|