Init
This commit is contained in:
230
scripts/04_train_yolo.py
Normal file
230
scripts/04_train_yolo.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
YOLO Training Script - Step 4
|
||||
Trains YOLOv8 model on the prepared invoice dataset
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
import os
|
||||
|
||||
# Paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
|
||||
MODELS_DIR = Path(BASE_DIR + "/models")
|
||||
|
||||
# Training configuration
|
||||
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
|
||||
EPOCHS = 100
|
||||
BATCH_SIZE = 16
|
||||
IMAGE_SIZE = 640
|
||||
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
|
||||
|
||||
# Create models directory
|
||||
MODELS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def train_model(
|
||||
model_size=MODEL_SIZE,
|
||||
epochs=EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
img_size=IMAGE_SIZE,
|
||||
device=DEVICE
|
||||
):
|
||||
"""
|
||||
Train YOLOv8 model on invoice dataset
|
||||
|
||||
Args:
|
||||
model_size: Size of YOLO model (n, s, m, l, x)
|
||||
epochs: Number of training epochs
|
||||
batch_size: Batch size for training
|
||||
img_size: Input image size
|
||||
device: Device to use for training (cuda or cpu)
|
||||
"""
|
||||
print("="*60)
|
||||
print("YOLOv8 Invoice Detection Training")
|
||||
print("="*60)
|
||||
|
||||
# Check if dataset.yaml exists
|
||||
if not DATASET_YAML.exists():
|
||||
print(f"Error: {DATASET_YAML} not found")
|
||||
print("Please ensure the dataset.yaml file exists")
|
||||
return
|
||||
|
||||
# Print configuration
|
||||
print(f"\nConfiguration:")
|
||||
print(f" Model: YOLOv8{model_size}")
|
||||
print(f" Epochs: {epochs}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Image size: {img_size}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Dataset config: {DATASET_YAML}")
|
||||
print()
|
||||
|
||||
# Initialize model
|
||||
print(f"Loading YOLOv8{model_size} model...")
|
||||
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
|
||||
|
||||
# Print device info
|
||||
if device == 0:
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("Using CPU (training will be slower)")
|
||||
|
||||
print("\nStarting training...")
|
||||
print("-" * 60)
|
||||
|
||||
# Train the model
|
||||
results = model.train(
|
||||
data=str(DATASET_YAML),
|
||||
epochs=epochs,
|
||||
imgsz=img_size,
|
||||
batch=batch_size,
|
||||
device=device,
|
||||
project=str(MODELS_DIR),
|
||||
name="payment_slip_detector_v1",
|
||||
exist_ok=True,
|
||||
patience=20, # Early stopping patience
|
||||
save=True,
|
||||
save_period=10, # Save checkpoint every 10 epochs
|
||||
verbose=True,
|
||||
plots=True # Generate training plots
|
||||
)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Training complete!")
|
||||
print("="*60)
|
||||
|
||||
# Print results
|
||||
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
|
||||
|
||||
print(f"\nTrained models saved to:")
|
||||
print(f" Best model: {best_model_path}")
|
||||
print(f" Last model: {last_model_path}")
|
||||
|
||||
print(f"\nTraining plots saved to:")
|
||||
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
|
||||
|
||||
def validate_model(model_path=None):
|
||||
"""
|
||||
Validate trained model on validation set
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (default: best.pt from last training)
|
||||
"""
|
||||
if model_path is None:
|
||||
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
|
||||
if not Path(model_path).exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
print("Please train a model first")
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Validating Model")
|
||||
print("="*60)
|
||||
print(f"Model: {model_path}\n")
|
||||
|
||||
# Load model
|
||||
model = YOLO(str(model_path))
|
||||
|
||||
# Validate
|
||||
results = model.val(data=str(DATASET_YAML))
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Validation complete!")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
|
||||
"""
|
||||
Run prediction on a sample image
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights
|
||||
image_path: Path to image to predict on
|
||||
conf_threshold: Confidence threshold for detections
|
||||
"""
|
||||
if model_path is None:
|
||||
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
|
||||
if not Path(model_path).exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
return
|
||||
|
||||
if image_path is None:
|
||||
# Try to get a sample from validation set
|
||||
val_images_dir = Path("data/yolo_dataset/images/val")
|
||||
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
|
||||
if sample_images:
|
||||
image_path = sample_images[0]
|
||||
else:
|
||||
print("No sample images found")
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Running Prediction")
|
||||
print("="*60)
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Image: {image_path}")
|
||||
print(f"Confidence threshold: {conf_threshold}\n")
|
||||
|
||||
# Load model
|
||||
model = YOLO(str(model_path))
|
||||
|
||||
# Predict
|
||||
results = model.predict(
|
||||
source=str(image_path),
|
||||
conf=conf_threshold,
|
||||
save=True,
|
||||
project=str(MODELS_DIR / "predictions"),
|
||||
name="sample"
|
||||
)
|
||||
|
||||
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
|
||||
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
|
||||
help="Mode: train, validate, or predict")
|
||||
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
|
||||
choices=["n", "s", "m", "l", "x"],
|
||||
help="YOLO model size")
|
||||
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
|
||||
help="Batch size")
|
||||
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
|
||||
help="Image size")
|
||||
parser.add_argument("--model-path", type=str, default=None,
|
||||
help="Path to model weights (for validate/predict)")
|
||||
parser.add_argument("--image-path", type=str, default=None,
|
||||
help="Path to image (for predict)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "train":
|
||||
train_model(
|
||||
model_size=args.model_size,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
img_size=args.img_size
|
||||
)
|
||||
elif args.mode == "validate":
|
||||
validate_model(model_path=args.model_path)
|
||||
elif args.mode == "predict":
|
||||
predict_sample(model_path=args.model_path, image_path=args.image_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user