restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
Training Service Entry Point.
Runs a specific training task by ID (for Azure ACI on-demand mode)
or polls the database for pending tasks (for local dev).
"""
import argparse
import logging
import sys
import time
from training.data.training_db import TrainingTaskDB
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
"""Execute a single training task."""
task_id = task["task_id"]
config = task.get("config") or {}
logger.info("Starting training task %s with config: %s", task_id, config)
db.update_status(task_id, "running")
try:
from training.cli.train import run_training
result = run_training(
epochs=config.get("epochs", 100),
batch=config.get("batch_size", 16),
model=config.get("base_model", "yolo11n.pt"),
imgsz=config.get("imgsz", 1280),
name=config.get("name", f"training_{task_id[:8]}"),
)
db.complete_task(
task_id,
model_path=result.get("model_path", ""),
metrics=result.get("metrics", {}),
)
logger.info("Training task %s completed successfully.", task_id)
except Exception as e:
logger.exception("Training task %s failed", task_id)
db.fail_task(task_id, str(e))
sys.exit(1)
def main() -> None:
parser = argparse.ArgumentParser(description="Invoice Training Service")
parser.add_argument(
"--task-id",
help="Specific task ID to run (ACI on-demand mode)",
)
parser.add_argument(
"--poll",
action="store_true",
help="Poll database for pending tasks (local dev mode)",
)
parser.add_argument(
"--poll-interval",
type=int,
default=60,
help="Seconds between polls (default: 60)",
)
args = parser.parse_args()
db = TrainingTaskDB()
if args.task_id:
task = db.get_task(args.task_id)
if not task:
logger.error("Task %s not found", args.task_id)
sys.exit(1)
execute_training_task(db, task)
elif args.poll:
logger.info(
"Starting training service in poll mode (interval=%ds)",
args.poll_interval,
)
while True:
tasks = db.get_pending_tasks(limit=1)
for task in tasks:
execute_training_task(db, task)
time.sleep(args.poll_interval)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()