106 lines
3.0 KiB
Python
106 lines
3.0 KiB
Python
"""Trigger training jobs on Azure Container Instances."""
|
|
|
|
import logging
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Azure SDK is optional; only needed if using ACI trigger
|
|
try:
|
|
from azure.identity import DefaultAzureCredential
|
|
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
|
|
from azure.mgmt.containerinstance.models import (
|
|
Container,
|
|
ContainerGroup,
|
|
EnvironmentVariable,
|
|
GpuResource,
|
|
ResourceRequests,
|
|
ResourceRequirements,
|
|
)
|
|
|
|
_AZURE_SDK_AVAILABLE = True
|
|
except ImportError:
|
|
_AZURE_SDK_AVAILABLE = False
|
|
|
|
|
|
def start_training_container(task_id: str) -> str | None:
|
|
"""
|
|
Start an Azure Container Instance for a training task.
|
|
|
|
Returns the container group name if successful, None otherwise.
|
|
Requires environment variables:
|
|
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
|
|
"""
|
|
if not _AZURE_SDK_AVAILABLE:
|
|
logger.warning(
|
|
"Azure SDK not installed. Install azure-mgmt-containerinstance "
|
|
"and azure-identity to use ACI trigger."
|
|
)
|
|
return None
|
|
|
|
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
|
|
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
|
|
image = os.environ.get(
|
|
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
|
|
)
|
|
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
|
|
location = os.environ.get("AZURE_LOCATION", "eastus")
|
|
|
|
if not subscription_id:
|
|
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
|
|
return None
|
|
|
|
credential = DefaultAzureCredential()
|
|
client = ContainerInstanceManagementClient(credential, subscription_id)
|
|
|
|
container_name = f"training-{task_id[:8]}"
|
|
|
|
env_vars = [
|
|
EnvironmentVariable(name="TASK_ID", value=task_id),
|
|
]
|
|
|
|
# Pass DB connection securely
|
|
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
|
|
val = os.environ.get(var, "")
|
|
if val:
|
|
env_vars.append(EnvironmentVariable(name=var, value=val))
|
|
|
|
db_password = os.environ.get("DB_PASSWORD", "")
|
|
if db_password:
|
|
env_vars.append(
|
|
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
|
|
)
|
|
|
|
container = Container(
|
|
name=container_name,
|
|
image=image,
|
|
resources=ResourceRequirements(
|
|
requests=ResourceRequests(
|
|
cpu=4,
|
|
memory_in_gb=16,
|
|
gpu=GpuResource(count=1, sku=gpu_sku),
|
|
)
|
|
),
|
|
environment_variables=env_vars,
|
|
command=[
|
|
"python",
|
|
"run_training.py",
|
|
"--task-id",
|
|
task_id,
|
|
],
|
|
)
|
|
|
|
group = ContainerGroup(
|
|
location=location,
|
|
containers=[container],
|
|
os_type="Linux",
|
|
restart_policy="Never",
|
|
)
|
|
|
|
logger.info("Creating ACI container group: %s", container_name)
|
|
client.container_groups.begin_create_or_update(
|
|
resource_group, container_name, group
|
|
)
|
|
|
|
return container_name
|