Files
invoice-master/scripts/04_train_yolo.py
Yaojia Wang dafa86c588 Init
2025-10-26 20:41:11 +01:00

231 lines
6.8 KiB
Python

"""
YOLO Training Script - Step 4
Trains YOLOv8 model on the prepared invoice dataset
"""
from pathlib import Path
from ultralytics import YOLO
import torch
import os
# Paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
MODELS_DIR = Path(BASE_DIR + "/models")
# Training configuration
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
EPOCHS = 100
BATCH_SIZE = 16
IMAGE_SIZE = 640
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
# Create models directory
MODELS_DIR.mkdir(exist_ok=True)
def train_model(
model_size=MODEL_SIZE,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
img_size=IMAGE_SIZE,
device=DEVICE
):
"""
Train YOLOv8 model on invoice dataset
Args:
model_size: Size of YOLO model (n, s, m, l, x)
epochs: Number of training epochs
batch_size: Batch size for training
img_size: Input image size
device: Device to use for training (cuda or cpu)
"""
print("="*60)
print("YOLOv8 Invoice Detection Training")
print("="*60)
# Check if dataset.yaml exists
if not DATASET_YAML.exists():
print(f"Error: {DATASET_YAML} not found")
print("Please ensure the dataset.yaml file exists")
return
# Print configuration
print(f"\nConfiguration:")
print(f" Model: YOLOv8{model_size}")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size}")
print(f" Image size: {img_size}")
print(f" Device: {device}")
print(f" Dataset config: {DATASET_YAML}")
print()
# Initialize model
print(f"Loading YOLOv8{model_size} model...")
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
# Print device info
if device == 0:
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
print("Using CPU (training will be slower)")
print("\nStarting training...")
print("-" * 60)
# Train the model
results = model.train(
data=str(DATASET_YAML),
epochs=epochs,
imgsz=img_size,
batch=batch_size,
device=device,
project=str(MODELS_DIR),
name="payment_slip_detector_v1",
exist_ok=True,
patience=20, # Early stopping patience
save=True,
save_period=10, # Save checkpoint every 10 epochs
verbose=True,
plots=True # Generate training plots
)
print("\n" + "="*60)
print("Training complete!")
print("="*60)
# Print results
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
print(f"\nTrained models saved to:")
print(f" Best model: {best_model_path}")
print(f" Last model: {last_model_path}")
print(f"\nTraining plots saved to:")
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
print("\n" + "="*60)
def validate_model(model_path=None):
"""
Validate trained model on validation set
Args:
model_path: Path to model weights (default: best.pt from last training)
"""
if model_path is None:
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
if not Path(model_path).exists():
print(f"Error: Model not found at {model_path}")
print("Please train a model first")
return
print("="*60)
print("Validating Model")
print("="*60)
print(f"Model: {model_path}\n")
# Load model
model = YOLO(str(model_path))
# Validate
results = model.val(data=str(DATASET_YAML))
print("\n" + "="*60)
print("Validation complete!")
print("="*60)
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
"""
Run prediction on a sample image
Args:
model_path: Path to model weights
image_path: Path to image to predict on
conf_threshold: Confidence threshold for detections
"""
if model_path is None:
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
if not Path(model_path).exists():
print(f"Error: Model not found at {model_path}")
return
if image_path is None:
# Try to get a sample from validation set
val_images_dir = Path("data/yolo_dataset/images/val")
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
if sample_images:
image_path = sample_images[0]
else:
print("No sample images found")
return
print("="*60)
print("Running Prediction")
print("="*60)
print(f"Model: {model_path}")
print(f"Image: {image_path}")
print(f"Confidence threshold: {conf_threshold}\n")
# Load model
model = YOLO(str(model_path))
# Predict
results = model.predict(
source=str(image_path),
conf=conf_threshold,
save=True,
project=str(MODELS_DIR / "predictions"),
name="sample"
)
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
print("="*60)
def main():
"""Main training function"""
import argparse
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
help="Mode: train, validate, or predict")
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
choices=["n", "s", "m", "l", "x"],
help="YOLO model size")
parser.add_argument("--epochs", type=int, default=EPOCHS,
help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help="Batch size")
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
help="Image size")
parser.add_argument("--model-path", type=str, default=None,
help="Path to model weights (for validate/predict)")
parser.add_argument("--image-path", type=str, default=None,
help="Path to image (for predict)")
args = parser.parse_args()
if args.mode == "train":
train_model(
model_size=args.model_size,
epochs=args.epochs,
batch_size=args.batch_size,
img_size=args.img_size
)
elif args.mode == "validate":
validate_model(model_path=args.model_path)
elif args.mode == "predict":
predict_sample(model_path=args.model_path, image_path=args.image_path)
if __name__ == "__main__":
main()