#!/usr/bin/env python3 """ Training Service Entry Point. Runs a specific training task by ID (for Azure ACI on-demand mode) or polls the database for pending tasks (for local dev). """ import argparse import logging import sys import time from training.data.training_db import TrainingTaskDB logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) logger = logging.getLogger(__name__) def execute_training_task(db: TrainingTaskDB, task: dict) -> None: """Execute a single training task.""" task_id = task["task_id"] config = task.get("config") or {} logger.info("Starting training task %s with config: %s", task_id, config) db.update_status(task_id, "running") try: from training.cli.train import run_training result = run_training( epochs=config.get("epochs", 100), batch=config.get("batch_size", 16), model=config.get("base_model", "yolo11n.pt"), imgsz=config.get("imgsz", 1280), name=config.get("name", f"training_{task_id[:8]}"), ) db.complete_task( task_id, model_path=result.get("model_path", ""), metrics=result.get("metrics", {}), ) logger.info("Training task %s completed successfully.", task_id) except Exception as e: logger.exception("Training task %s failed", task_id) db.fail_task(task_id, str(e)) sys.exit(1) def main() -> None: parser = argparse.ArgumentParser(description="Invoice Training Service") parser.add_argument( "--task-id", help="Specific task ID to run (ACI on-demand mode)", ) parser.add_argument( "--poll", action="store_true", help="Poll database for pending tasks (local dev mode)", ) parser.add_argument( "--poll-interval", type=int, default=60, help="Seconds between polls (default: 60)", ) args = parser.parse_args() db = TrainingTaskDB() if args.task_id: task = db.get_task(args.task_id) if not task: logger.error("Task %s not found", args.task_id) sys.exit(1) execute_training_task(db, task) elif args.poll: logger.info( "Starting training service in poll mode (interval=%ds)", args.poll_interval, ) while True: tasks = db.get_pending_tasks(limit=1) for task in tasks: execute_training_task(db, task) time.sleep(args.poll_interval) else: parser.print_help() sys.exit(1) if __name__ == "__main__": main()