Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/madengine/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def run(
# Input validation
if timeout < -1:
console.print(
"❌ [red]Timeout must be -1 (default) or a positive integer[/red]"
"❌ [red]Timeout must be -1 (default), 0 (no timeout), or a positive integer[/red]"
)
raise typer.Exit(ExitCode.INVALID_ARGS)

Expand All @@ -192,8 +192,11 @@ def run(
effective_additional_context_file = None

# Convert -1 (default) to actual default timeout value (7200 seconds = 2 hours)
# Convert 0 to None sentinel meaning "no timeout"
if timeout == -1:
timeout = 7200
elif timeout == 0:
timeout = None

try:
# Check if we're doing execution-only or full workflow
Expand All @@ -211,7 +214,7 @@ def run(
f"🚀 [bold cyan]Running Models (Execution Only)[/bold cyan]\n"
f"Manifest: [yellow]{manifest_file}[/yellow]\n"
f"Registry: [yellow]{registry or 'Auto-detected'}[/yellow]\n"
f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s",
f"Timeout: [yellow]{f'{timeout}s' if timeout is not None else 'Disabled'}[/yellow]",
title="Execution Configuration",
border_style="green",
)
Expand Down Expand Up @@ -314,7 +317,7 @@ def run(
f"🔨🚀 [bold cyan]Complete Workflow (Build + Run)[/bold cyan]\n"
f"Tags: [yellow]{', '.join(processed_tags) if processed_tags else 'All models'}[/yellow]\n"
f"Registry: [yellow]{registry or 'Local only'}[/yellow]\n"
f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s"
f"Timeout: [yellow]{f'{timeout}s' if timeout is not None else 'Disabled'}[/yellow]"
f"{skip_note}",
title="Workflow Configuration",
border_style="magenta",
Expand Down Expand Up @@ -476,4 +479,3 @@ def run(
)
handle_error(e, context=context)
raise typer.Exit(ExitCode.FAILURE)

100 changes: 100 additions & 0 deletions src/madengine/core/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
Shared authentication utilities for madengine.

Centralises credential loading logic used by both BuildOrchestrator and
RunOrchestrator so that fixes and improvements only need to be made once.

Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
"""

import json
import os
import shlex
from typing import Dict, Optional

from madengine.core.errors import (
ConfigurationError,
create_error_context,
handle_error,
)


def load_credentials() -> Optional[Dict]:
"""Load credentials from credential.json and environment variables."""
credentials: Optional[Dict] = None
credential_file = "credential.json"
if os.path.exists(credential_file):
try:
with open(credential_file) as f:
credentials = json.load(f)
except Exception as e:
context = create_error_context(
operation="load_credentials",
component="auth",
file_path=credential_file,
)
handle_error(
ConfigurationError(
f"Could not load credentials: {e}",
context=context,
suggestions=["Check if credential.json exists and has valid JSON format"],
)
)
docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER")
docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD")
docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO")
if docker_hub_user and docker_hub_password:
if credentials is None:
credentials = {}
credentials["dockerhub"] = {
"username": docker_hub_user,
"password": docker_hub_password,
}
if docker_hub_repo:
credentials["dockerhub"]["repository"] = docker_hub_repo
return credentials


def login_to_registry(
registry: Optional[str],
credentials: Optional[Dict],
console,
rich_console,
raise_on_failure: bool = True,
) -> None:
"""Login to a Docker registry (shared implementation for all orchestrators)."""
if not credentials:
rich_console.print("[yellow]No credentials provided for registry login[/yellow]")
return
registry_key = registry if registry else "dockerhub"
if registry and registry.lower() == "docker.io":
registry_key = "dockerhub"
if registry_key not in credentials:
error_msg = f"No credentials found for registry: {registry_key}"
rich_console.print(f"[red]{error_msg}[/red]")
if raise_on_failure:
raise RuntimeError(error_msg)
return
creds = credentials[registry_key]
if "username" not in creds or "password" not in creds:
error_msg = f"Invalid credentials format for registry: {registry_key}"
rich_console.print(f"[red]{error_msg}[/red]")
if raise_on_failure:
raise RuntimeError(error_msg)
return
username = str(creds["username"])
password = str(creds["password"])
quoted_username = shlex.quote(username)
login_command = "printf %s \"$MAD_REGISTRY_PASSWORD\" | docker login"
if registry and registry.lower() not in ["docker.io", "dockerhub"]:
login_command += f" {shlex.quote(str(registry))}"
login_command += f" --username {quoted_username} --password-stdin"
login_env = {**os.environ, "MAD_REGISTRY_PASSWORD": password}
try:
console.sh(login_command, secret=True, env=login_env)
rich_console.print(f"[green]Successfully logged in to registry: {registry or 'DockerHub'}[/green]")
except Exception as e:
rich_console.print(f"[red]Failed to login to registry {registry}: {e}[/red]")
if raise_on_failure:
raise
18 changes: 11 additions & 7 deletions src/madengine/core/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""
# built-in modules
import os
import re
import shlex
import typing

# user-defined modules
Expand Down Expand Up @@ -57,21 +59,23 @@ def __init__(
self.groupid = self.console.sh("id -g")

# check if container name exists
container_name_quoted = shlex.quote(container_name) # shell safety for stop/rm
container_name_re = re.escape(container_name) # regex safety for --filter
container_name_exists = self.console.sh(
"docker container ps -a | grep " + container_name + " | wc -l"
"docker container ps -aq --filter " + shlex.quote(f"name=^/{container_name_re}$")
)
# if container name exists, clean it up automatically
if container_name_exists != "0":
if container_name_exists:
print(
f"⚠️ Container '{container_name}' already exists. Cleaning up..."
)
# Stop the container (with timeout)
self.console.sh(
f"docker stop -t 1 {container_name} 2>/dev/null || true"
f"docker stop -t 1 {container_name_quoted} 2>/dev/null || true"
)
# Remove the container
self.console.sh(
f"docker rm -f {container_name} 2>/dev/null || true"
f"docker rm -f {container_name_quoted} 2>/dev/null || true"
)
print(f"✓ Cleaned up existing container '{container_name}'")

Expand All @@ -93,10 +97,10 @@ def __init__(
# add envVars
if envVars is not None:
for evar in envVars.keys():
command += "-e " + evar + "=" + envVars[evar] + " "
command += "-e " + evar + "=" + shlex.quote(str(envVars[evar])) + " "

command += "--workdir /myworkspace/ "
command += "--name " + container_name + " "
command += "--name " + container_name_quoted + " "
command += image + " "

# Use 'cat' to keep container alive (blocks waiting for stdin)
Expand All @@ -123,7 +127,7 @@ def sh(self, command: str, timeout: int = 60, secret: bool = False) -> str:
"""
# run as root!
return self.console.sh(
"docker exec " + self.docker_sha + ' bash -c "' + command + '"',
"docker exec " + self.docker_sha + " bash -c " + shlex.quote(command),
timeout=timeout,
secret=secret,
)
Expand Down
16 changes: 7 additions & 9 deletions src/madengine/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import logging
import traceback
from dataclasses import dataclass
from typing import Optional, Any, Dict, List
from enum import Enum
Expand All @@ -16,7 +15,6 @@
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.table import Table
except ImportError:
raise ImportError("Rich is required for error handling. Install with: pip install rich")

Expand Down Expand Up @@ -83,7 +81,7 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg
)


class ConnectionError(MADEngineError):
class NetworkError(MADEngineError):
"""Connection and network errors."""

def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs):
Expand All @@ -96,6 +94,10 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg
)


# Deprecated: use NetworkError instead; will be removed in a future release
ConnectionError = NetworkError


class AuthenticationError(MADEngineError):
"""Authentication and credential errors."""

Expand All @@ -122,10 +124,6 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg
)


# Backward compatibility alias
RuntimeError = ExecutionError


class BuildError(MADEngineError):
"""Build and compilation errors."""

Expand Down Expand Up @@ -191,7 +189,7 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg
)


class TimeoutError(MADEngineError):
class DeploymentTimeoutError(MADEngineError):
"""Timeout and duration errors."""

def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs):
Expand Down Expand Up @@ -387,4 +385,4 @@ def create_error_context(
phase=phase,
component=component,
**kwargs
)
)
5 changes: 4 additions & 1 deletion src/madengine/core/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
# built-in modules
import signal
import typing


class Timeout:
Expand Down Expand Up @@ -42,9 +41,13 @@ def handle_timeout(self, signum, frame) -> None:

def __enter__(self) -> None:
"""Enter the context manager."""
if not self.seconds:
return
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)

def __exit__(self, type, value, traceback) -> None:
"""Exit the context manager."""
if not self.seconds:
return
signal.alarm(0)