""" YOLO Training Script - Step 4 Trains YOLOv8 model on the prepared invoice dataset """ from pathlib import Path from ultralytics import YOLO import torch import os # Paths BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml") MODELS_DIR = Path(BASE_DIR + "/models") # Training configuration MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge) EPOCHS = 100 BATCH_SIZE = 16 IMAGE_SIZE = 640 DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8 # Create models directory MODELS_DIR.mkdir(exist_ok=True) def train_model( model_size=MODEL_SIZE, epochs=EPOCHS, batch_size=BATCH_SIZE, img_size=IMAGE_SIZE, device=DEVICE ): """ Train YOLOv8 model on invoice dataset Args: model_size: Size of YOLO model (n, s, m, l, x) epochs: Number of training epochs batch_size: Batch size for training img_size: Input image size device: Device to use for training (cuda or cpu) """ print("="*60) print("YOLOv8 Invoice Detection Training") print("="*60) # Check if dataset.yaml exists if not DATASET_YAML.exists(): print(f"Error: {DATASET_YAML} not found") print("Please ensure the dataset.yaml file exists") return # Print configuration print(f"\nConfiguration:") print(f" Model: YOLOv8{model_size}") print(f" Epochs: {epochs}") print(f" Batch size: {batch_size}") print(f" Image size: {img_size}") print(f" Device: {device}") print(f" Dataset config: {DATASET_YAML}") print() # Initialize model print(f"Loading YOLOv8{model_size} model...") model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model # Print device info if device == 0: print(f"Using GPU: {torch.cuda.get_device_name(0)}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") else: print("Using CPU (training will be slower)") print("\nStarting training...") print("-" * 60) # Train the model results = model.train( data=str(DATASET_YAML), epochs=epochs, imgsz=img_size, batch=batch_size, device=device, project=str(MODELS_DIR), name="payment_slip_detector_v1", exist_ok=True, patience=20, # Early stopping patience save=True, save_period=10, # Save checkpoint every 10 epochs verbose=True, plots=True # Generate training plots ) print("\n" + "="*60) print("Training complete!") print("="*60) # Print results best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt" print(f"\nTrained models saved to:") print(f" Best model: {best_model_path}") print(f" Last model: {last_model_path}") print(f"\nTraining plots saved to:") print(f" {MODELS_DIR / 'payment_slip_detector_v1'}") print("\n" + "="*60) def validate_model(model_path=None): """ Validate trained model on validation set Args: model_path: Path to model weights (default: best.pt from last training) """ if model_path is None: model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" if not Path(model_path).exists(): print(f"Error: Model not found at {model_path}") print("Please train a model first") return print("="*60) print("Validating Model") print("="*60) print(f"Model: {model_path}\n") # Load model model = YOLO(str(model_path)) # Validate results = model.val(data=str(DATASET_YAML)) print("\n" + "="*60) print("Validation complete!") print("="*60) def predict_sample(model_path=None, image_path=None, conf_threshold=0.25): """ Run prediction on a sample image Args: model_path: Path to model weights image_path: Path to image to predict on conf_threshold: Confidence threshold for detections """ if model_path is None: model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" if not Path(model_path).exists(): print(f"Error: Model not found at {model_path}") return if image_path is None: # Try to get a sample from validation set val_images_dir = Path("data/yolo_dataset/images/val") sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png")) if sample_images: image_path = sample_images[0] else: print("No sample images found") return print("="*60) print("Running Prediction") print("="*60) print(f"Model: {model_path}") print(f"Image: {image_path}") print(f"Confidence threshold: {conf_threshold}\n") # Load model model = YOLO(str(model_path)) # Predict results = model.predict( source=str(image_path), conf=conf_threshold, save=True, project=str(MODELS_DIR / "predictions"), name="sample" ) print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}") print("="*60) def main(): """Main training function""" import argparse parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset") parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"], help="Mode: train, validate, or predict") parser.add_argument("--model-size", type=str, default=MODEL_SIZE, choices=["n", "s", "m", "l", "x"], help="YOLO model size") parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of training epochs") parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size") parser.add_argument("--img-size", type=int, default=IMAGE_SIZE, help="Image size") parser.add_argument("--model-path", type=str, default=None, help="Path to model weights (for validate/predict)") parser.add_argument("--image-path", type=str, default=None, help="Path to image (for predict)") args = parser.parse_args() if args.mode == "train": train_model( model_size=args.model_size, epochs=args.epochs, batch_size=args.batch_size, img_size=args.img_size ) elif args.mode == "validate": validate_model(model_path=args.model_path) elif args.mode == "predict": predict_sample(model_path=args.model_path, image_path=args.image_path) if __name__ == "__main__": main()