231 lines
6.8 KiB
Python
231 lines
6.8 KiB
Python
"""
|
|
YOLO Training Script - Step 4
|
|
Trains YOLOv8 model on the prepared invoice dataset
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
import torch
|
|
import os
|
|
|
|
# Paths
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
|
|
MODELS_DIR = Path(BASE_DIR + "/models")
|
|
|
|
# Training configuration
|
|
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
|
|
EPOCHS = 100
|
|
BATCH_SIZE = 16
|
|
IMAGE_SIZE = 640
|
|
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
|
|
|
|
# Create models directory
|
|
MODELS_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
def train_model(
|
|
model_size=MODEL_SIZE,
|
|
epochs=EPOCHS,
|
|
batch_size=BATCH_SIZE,
|
|
img_size=IMAGE_SIZE,
|
|
device=DEVICE
|
|
):
|
|
"""
|
|
Train YOLOv8 model on invoice dataset
|
|
|
|
Args:
|
|
model_size: Size of YOLO model (n, s, m, l, x)
|
|
epochs: Number of training epochs
|
|
batch_size: Batch size for training
|
|
img_size: Input image size
|
|
device: Device to use for training (cuda or cpu)
|
|
"""
|
|
print("="*60)
|
|
print("YOLOv8 Invoice Detection Training")
|
|
print("="*60)
|
|
|
|
# Check if dataset.yaml exists
|
|
if not DATASET_YAML.exists():
|
|
print(f"Error: {DATASET_YAML} not found")
|
|
print("Please ensure the dataset.yaml file exists")
|
|
return
|
|
|
|
# Print configuration
|
|
print(f"\nConfiguration:")
|
|
print(f" Model: YOLOv8{model_size}")
|
|
print(f" Epochs: {epochs}")
|
|
print(f" Batch size: {batch_size}")
|
|
print(f" Image size: {img_size}")
|
|
print(f" Device: {device}")
|
|
print(f" Dataset config: {DATASET_YAML}")
|
|
print()
|
|
|
|
# Initialize model
|
|
print(f"Loading YOLOv8{model_size} model...")
|
|
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
|
|
|
|
# Print device info
|
|
if device == 0:
|
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
|
else:
|
|
print("Using CPU (training will be slower)")
|
|
|
|
print("\nStarting training...")
|
|
print("-" * 60)
|
|
|
|
# Train the model
|
|
results = model.train(
|
|
data=str(DATASET_YAML),
|
|
epochs=epochs,
|
|
imgsz=img_size,
|
|
batch=batch_size,
|
|
device=device,
|
|
project=str(MODELS_DIR),
|
|
name="payment_slip_detector_v1",
|
|
exist_ok=True,
|
|
patience=20, # Early stopping patience
|
|
save=True,
|
|
save_period=10, # Save checkpoint every 10 epochs
|
|
verbose=True,
|
|
plots=True # Generate training plots
|
|
)
|
|
|
|
print("\n" + "="*60)
|
|
print("Training complete!")
|
|
print("="*60)
|
|
|
|
# Print results
|
|
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
|
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
|
|
|
|
print(f"\nTrained models saved to:")
|
|
print(f" Best model: {best_model_path}")
|
|
print(f" Last model: {last_model_path}")
|
|
|
|
print(f"\nTraining plots saved to:")
|
|
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
def validate_model(model_path=None):
|
|
"""
|
|
Validate trained model on validation set
|
|
|
|
Args:
|
|
model_path: Path to model weights (default: best.pt from last training)
|
|
"""
|
|
if model_path is None:
|
|
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
|
|
|
if not Path(model_path).exists():
|
|
print(f"Error: Model not found at {model_path}")
|
|
print("Please train a model first")
|
|
return
|
|
|
|
print("="*60)
|
|
print("Validating Model")
|
|
print("="*60)
|
|
print(f"Model: {model_path}\n")
|
|
|
|
# Load model
|
|
model = YOLO(str(model_path))
|
|
|
|
# Validate
|
|
results = model.val(data=str(DATASET_YAML))
|
|
|
|
print("\n" + "="*60)
|
|
print("Validation complete!")
|
|
print("="*60)
|
|
|
|
|
|
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
|
|
"""
|
|
Run prediction on a sample image
|
|
|
|
Args:
|
|
model_path: Path to model weights
|
|
image_path: Path to image to predict on
|
|
conf_threshold: Confidence threshold for detections
|
|
"""
|
|
if model_path is None:
|
|
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
|
|
|
if not Path(model_path).exists():
|
|
print(f"Error: Model not found at {model_path}")
|
|
return
|
|
|
|
if image_path is None:
|
|
# Try to get a sample from validation set
|
|
val_images_dir = Path("data/yolo_dataset/images/val")
|
|
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
|
|
if sample_images:
|
|
image_path = sample_images[0]
|
|
else:
|
|
print("No sample images found")
|
|
return
|
|
|
|
print("="*60)
|
|
print("Running Prediction")
|
|
print("="*60)
|
|
print(f"Model: {model_path}")
|
|
print(f"Image: {image_path}")
|
|
print(f"Confidence threshold: {conf_threshold}\n")
|
|
|
|
# Load model
|
|
model = YOLO(str(model_path))
|
|
|
|
# Predict
|
|
results = model.predict(
|
|
source=str(image_path),
|
|
conf=conf_threshold,
|
|
save=True,
|
|
project=str(MODELS_DIR / "predictions"),
|
|
name="sample"
|
|
)
|
|
|
|
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
|
|
print("="*60)
|
|
|
|
|
|
def main():
|
|
"""Main training function"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
|
|
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
|
|
help="Mode: train, validate, or predict")
|
|
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
|
|
choices=["n", "s", "m", "l", "x"],
|
|
help="YOLO model size")
|
|
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
|
help="Number of training epochs")
|
|
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
|
|
help="Batch size")
|
|
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
|
|
help="Image size")
|
|
parser.add_argument("--model-path", type=str, default=None,
|
|
help="Path to model weights (for validate/predict)")
|
|
parser.add_argument("--image-path", type=str, default=None,
|
|
help="Path to image (for predict)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "train":
|
|
train_model(
|
|
model_size=args.model_size,
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
img_size=args.img_size
|
|
)
|
|
elif args.mode == "validate":
|
|
validate_model(model_path=args.model_path)
|
|
elif args.mode == "predict":
|
|
predict_sample(model_path=args.model_path, image_path=args.image_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|