restructure project
This commit is contained in:
100
packages/training/run_training.py
Normal file
100
packages/training/run_training.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Service Entry Point.
|
||||
|
||||
Runs a specific training task by ID (for Azure ACI on-demand mode)
|
||||
or polls the database for pending tasks (for local dev).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from training.data.training_db import TrainingTaskDB
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
|
||||
"""Execute a single training task."""
|
||||
task_id = task["task_id"]
|
||||
config = task.get("config") or {}
|
||||
|
||||
logger.info("Starting training task %s with config: %s", task_id, config)
|
||||
db.update_status(task_id, "running")
|
||||
|
||||
try:
|
||||
from training.cli.train import run_training
|
||||
|
||||
result = run_training(
|
||||
epochs=config.get("epochs", 100),
|
||||
batch=config.get("batch_size", 16),
|
||||
model=config.get("base_model", "yolo11n.pt"),
|
||||
imgsz=config.get("imgsz", 1280),
|
||||
name=config.get("name", f"training_{task_id[:8]}"),
|
||||
)
|
||||
|
||||
db.complete_task(
|
||||
task_id,
|
||||
model_path=result.get("model_path", ""),
|
||||
metrics=result.get("metrics", {}),
|
||||
)
|
||||
logger.info("Training task %s completed successfully.", task_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Training task %s failed", task_id)
|
||||
db.fail_task(task_id, str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Invoice Training Service")
|
||||
parser.add_argument(
|
||||
"--task-id",
|
||||
help="Specific task ID to run (ACI on-demand mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--poll",
|
||||
action="store_true",
|
||||
help="Poll database for pending tasks (local dev mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--poll-interval",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Seconds between polls (default: 60)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db = TrainingTaskDB()
|
||||
|
||||
if args.task_id:
|
||||
task = db.get_task(args.task_id)
|
||||
if not task:
|
||||
logger.error("Task %s not found", args.task_id)
|
||||
sys.exit(1)
|
||||
execute_training_task(db, task)
|
||||
|
||||
elif args.poll:
|
||||
logger.info(
|
||||
"Starting training service in poll mode (interval=%ds)",
|
||||
args.poll_interval,
|
||||
)
|
||||
while True:
|
||||
tasks = db.get_pending_tasks(limit=1)
|
||||
for task in tasks:
|
||||
execute_training_task(db, task)
|
||||
time.sleep(args.poll_interval)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user