WIP
This commit is contained in:
@@ -199,67 +199,63 @@ def main():
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training
|
||||
# Start training using shared trainer
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
|
||||
from ultralytics import YOLO
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
# Load model
|
||||
# Determine resume checkpoint
|
||||
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
||||
if args.resume and last_checkpoint.exists():
|
||||
print(f"Resuming from: {last_checkpoint}")
|
||||
model = YOLO(str(last_checkpoint))
|
||||
else:
|
||||
model = YOLO(args.model)
|
||||
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
|
||||
|
||||
# Training arguments
|
||||
# Create training config
|
||||
data_yaml = dataset_dir / 'dataset.yaml'
|
||||
train_args = {
|
||||
'data': str(data_yaml.absolute()),
|
||||
'epochs': args.epochs,
|
||||
'batch': args.batch,
|
||||
'imgsz': args.imgsz,
|
||||
'project': args.project,
|
||||
'name': args.name,
|
||||
'device': args.device,
|
||||
'exist_ok': True,
|
||||
'pretrained': True,
|
||||
'verbose': True,
|
||||
'workers': args.workers,
|
||||
'cache': args.cache,
|
||||
'resume': args.resume and last_checkpoint.exists(),
|
||||
# Document-specific augmentation settings
|
||||
'degrees': 5.0,
|
||||
'translate': 0.05,
|
||||
'scale': 0.2,
|
||||
'shear': 0.0,
|
||||
'perspective': 0.0,
|
||||
'flipud': 0.0,
|
||||
'fliplr': 0.0,
|
||||
'mosaic': 0.0,
|
||||
'mixup': 0.0,
|
||||
'hsv_h': 0.0,
|
||||
'hsv_s': 0.1,
|
||||
'hsv_v': 0.2,
|
||||
}
|
||||
config = TrainingConfig(
|
||||
model_path=args.model,
|
||||
data_yaml=str(data_yaml),
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch,
|
||||
image_size=args.imgsz,
|
||||
device=args.device,
|
||||
project=args.project,
|
||||
name=args.name,
|
||||
workers=args.workers,
|
||||
cache=args.cache,
|
||||
resume=args.resume,
|
||||
resume_from=resume_from,
|
||||
)
|
||||
|
||||
# Train
|
||||
results = model.train(**train_args)
|
||||
# Run training
|
||||
trainer = YOLOTrainer(config=config)
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
print(f"\nError: Training failed - {result.error}")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
||||
print(f"Best model: {result.model_path}")
|
||||
print(f"Save directory: {result.save_dir}")
|
||||
if result.metrics:
|
||||
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
metrics = model.val(split='test')
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
if result.model_path:
|
||||
config.model_path = result.model_path
|
||||
config.data_yaml = str(data_yaml)
|
||||
test_trainer = YOLOTrainer(config=config)
|
||||
test_metrics = test_trainer.validate(split='test')
|
||||
if test_metrics:
|
||||
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user