Init
This commit is contained in:
123
scripts/03_split_dataset.py
Normal file
123
scripts/03_split_dataset.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Dataset Split Script - Step 3
|
||||
Splits images and labels into training and validation sets
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
|
||||
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
||||
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
||||
|
||||
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
||||
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
||||
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
||||
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
||||
|
||||
# Configuration
|
||||
VALIDATION_SPLIT = 0.2 # 20% for validation
|
||||
RANDOM_SEED = 42
|
||||
|
||||
|
||||
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
|
||||
"""
|
||||
Split dataset into training and validation sets
|
||||
|
||||
Args:
|
||||
val_split: Fraction of data to use for validation (0.0 to 1.0)
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
print("="*60)
|
||||
print("Splitting Dataset into Train/Val Sets")
|
||||
print("="*60)
|
||||
print(f"Validation split: {val_split*100:.1f}%")
|
||||
print(f"Random seed: {seed}\n")
|
||||
|
||||
# Check if temp directories exist
|
||||
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
|
||||
print(BASE_DIR)
|
||||
print(YOLO_DATASET_DIR)
|
||||
print(TEMP_IMAGES_DIR)
|
||||
print(f"Error: Temporary directories not found")
|
||||
print(f"Please run 02_create_labels.py first")
|
||||
return
|
||||
|
||||
# Get all image files
|
||||
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
|
||||
|
||||
if not image_files:
|
||||
print(f"No image files found in {TEMP_IMAGES_DIR}")
|
||||
return
|
||||
|
||||
# Filter images that have corresponding labels
|
||||
valid_pairs = []
|
||||
for image_file in image_files:
|
||||
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
|
||||
if label_file.exists():
|
||||
valid_pairs.append({
|
||||
"image": image_file,
|
||||
"label": label_file
|
||||
})
|
||||
|
||||
if not valid_pairs:
|
||||
print("No valid image-label pairs found")
|
||||
return
|
||||
|
||||
print(f"Found {len(valid_pairs)} image-label pair(s)")
|
||||
|
||||
# Shuffle and split
|
||||
random.seed(seed)
|
||||
random.shuffle(valid_pairs)
|
||||
|
||||
split_index = int(len(valid_pairs) * (1 - val_split))
|
||||
train_pairs = valid_pairs[:split_index]
|
||||
val_pairs = valid_pairs[split_index:]
|
||||
|
||||
print(f"\nSplit results:")
|
||||
print(f" Training set: {len(train_pairs)} samples")
|
||||
print(f" Validation set: {len(val_pairs)} samples")
|
||||
print()
|
||||
|
||||
# Clear existing train/val directories
|
||||
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy training files
|
||||
print("Copying training files...")
|
||||
for pair in train_pairs:
|
||||
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
|
||||
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
|
||||
print(f" Copied {len(train_pairs)} image-label pairs to train/")
|
||||
|
||||
# Copy validation files
|
||||
print("Copying validation files...")
|
||||
for pair in val_pairs:
|
||||
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
|
||||
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
|
||||
print(f" Copied {len(val_pairs)} image-label pairs to val/")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Dataset split complete!")
|
||||
print(f"\nDataset structure:")
|
||||
print(f" {TRAIN_IMAGES_DIR}")
|
||||
print(f" {TRAIN_LABELS_DIR}")
|
||||
print(f" {VAL_IMAGES_DIR}")
|
||||
print(f" {VAL_LABELS_DIR}")
|
||||
print(f"\nNext step: Run 04_train_yolo.py to train the model")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
split_dataset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user