101 lines
2.6 KiB
Python
101 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Service Entry Point.
|
|
|
|
Runs a specific training task by ID (for Azure ACI on-demand mode)
|
|
or polls the database for pending tasks (for local dev).
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
import time
|
|
|
|
from training.data.training_db import TrainingTaskDB
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
|
|
"""Execute a single training task."""
|
|
task_id = task["task_id"]
|
|
config = task.get("config") or {}
|
|
|
|
logger.info("Starting training task %s with config: %s", task_id, config)
|
|
db.update_status(task_id, "running")
|
|
|
|
try:
|
|
from training.cli.train import run_training
|
|
|
|
result = run_training(
|
|
epochs=config.get("epochs", 100),
|
|
batch=config.get("batch_size", 16),
|
|
model=config.get("base_model", "yolo11n.pt"),
|
|
imgsz=config.get("imgsz", 1280),
|
|
name=config.get("name", f"training_{task_id[:8]}"),
|
|
)
|
|
|
|
db.complete_task(
|
|
task_id,
|
|
model_path=result.get("model_path", ""),
|
|
metrics=result.get("metrics", {}),
|
|
)
|
|
logger.info("Training task %s completed successfully.", task_id)
|
|
|
|
except Exception as e:
|
|
logger.exception("Training task %s failed", task_id)
|
|
db.fail_task(task_id, str(e))
|
|
sys.exit(1)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Invoice Training Service")
|
|
parser.add_argument(
|
|
"--task-id",
|
|
help="Specific task ID to run (ACI on-demand mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--poll",
|
|
action="store_true",
|
|
help="Poll database for pending tasks (local dev mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--poll-interval",
|
|
type=int,
|
|
default=60,
|
|
help="Seconds between polls (default: 60)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
db = TrainingTaskDB()
|
|
|
|
if args.task_id:
|
|
task = db.get_task(args.task_id)
|
|
if not task:
|
|
logger.error("Task %s not found", args.task_id)
|
|
sys.exit(1)
|
|
execute_training_task(db, task)
|
|
|
|
elif args.poll:
|
|
logger.info(
|
|
"Starting training service in poll mode (interval=%ds)",
|
|
args.poll_interval,
|
|
)
|
|
while True:
|
|
tasks = db.get_pending_tasks(limit=1)
|
|
for task in tasks:
|
|
execute_training_task(db, task)
|
|
time.sleep(args.poll_interval)
|
|
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|