This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -199,67 +199,63 @@ def main():
db.close()
return
# Start training
# Start training using shared trainer
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
from ultralytics import YOLO
from shared.training import YOLOTrainer, TrainingConfig
# Load model
# Determine resume checkpoint
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
if args.resume and last_checkpoint.exists():
print(f"Resuming from: {last_checkpoint}")
model = YOLO(str(last_checkpoint))
else:
model = YOLO(args.model)
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
# Training arguments
# Create training config
data_yaml = dataset_dir / 'dataset.yaml'
train_args = {
'data': str(data_yaml.absolute()),
'epochs': args.epochs,
'batch': args.batch,
'imgsz': args.imgsz,
'project': args.project,
'name': args.name,
'device': args.device,
'exist_ok': True,
'pretrained': True,
'verbose': True,
'workers': args.workers,
'cache': args.cache,
'resume': args.resume and last_checkpoint.exists(),
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,
'scale': 0.2,
'shear': 0.0,
'perspective': 0.0,
'flipud': 0.0,
'fliplr': 0.0,
'mosaic': 0.0,
'mixup': 0.0,
'hsv_h': 0.0,
'hsv_s': 0.1,
'hsv_v': 0.2,
}
config = TrainingConfig(
model_path=args.model,
data_yaml=str(data_yaml),
epochs=args.epochs,
batch_size=args.batch,
image_size=args.imgsz,
device=args.device,
project=args.project,
name=args.name,
workers=args.workers,
cache=args.cache,
resume=args.resume,
resume_from=resume_from,
)
# Train
results = model.train(**train_args)
# Run training
trainer = YOLOTrainer(config=config)
result = trainer.train()
if not result.success:
print(f"\nError: Training failed - {result.error}")
db.close()
sys.exit(1)
# Print results
print("\n" + "=" * 60)
print("Training Complete")
print("=" * 60)
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
print(f"Best model: {result.model_path}")
print(f"Save directory: {result.save_dir}")
if result.metrics:
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
# Validate on test set
print("\nRunning validation on test set...")
metrics = model.val(split='test')
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
if result.model_path:
config.model_path = result.model_path
config.data_yaml = str(data_yaml)
test_trainer = YOLOTrainer(config=config)
test_metrics = test_trainer.validate(split='test')
if test_metrics:
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
# Close database
db.close()