Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/phoenix/submit-bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sbatch_common_opts="\
#SBATCH -J shb-${sbatch_script%%.sh}-$device # job name
#SBATCH --account=gts-sbryngelson3 # account
#SBATCH -N1 # nodes
#SBATCH -t 02:00:00 # walltime
#SBATCH -t 03:00:00 # walltime
#SBATCH -q embers # QOS
#SBATCH -o $job_slug.out # stdout+stderr
#SBATCH --mem-per-cpu=2G # default mem (overridden below)
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/phoenix/submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sbatch_common_opts="\
#SBATCH -J shb-${sbatch_script%%.sh}-$device # job name
#SBATCH --account=gts-sbryngelson3 # account
#SBATCH -N1 # nodes
#SBATCH -t 03:00:00 # walltime
#SBATCH -t 04:00:00 # walltime
#SBATCH -q embers # QOS
#SBATCH -o $job_slug.out # stdout+stderr
#SBATCH --mem-per-cpu=2G # default mem (overridden below)
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/phoenix/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ if [ "$job_device" = "gpu" ]; then
n_test_threads=`expr $gpu_count \* 2`
fi

./mfc.sh test --max-attempts 3 -a -j $n_test_threads $device_opts -- -c phoenix
./mfc.sh test --max-attempts 3 -a --schedul-debug -j $n_test_threads $device_opts -- -c phoenix

sleep 10
rm -rf "$currentdir" || true
Expand Down
1 change: 1 addition & 0 deletions toolchain/mfc/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def add_common_arguments(p, mask = None):
test.add_argument( "--no-examples", action="store_true", default=False, help="Do not test example cases." )
test.add_argument("--case-optimization", action="store_true", default=False, help="(GPU Optimization) Compile MFC targets with some case parameters hard-coded.")
test.add_argument( "--dry-run", action="store_true", default=False, help="Build and generate case files but do not run tests.")
test.add_argument( "--sched-debug", action="store_true", default=False, help="Enable detailed scheduler debug logging.")

test_meg = test.add_mutually_exclusive_group()
test_meg.add_argument("--generate", action="store_true", default=False, help="(Test Generation) Generate golden files.")
Expand Down
78 changes: 72 additions & 6 deletions toolchain/mfc/sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import traceback

from .printer import cons
from .state import ARG

class WorkerThread(threading.Thread):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -44,65 +45,120 @@ def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = No

sched.LOAD = { id: 0.0 for id in devices or [] }

# Debug logging setup
gpu_mode = devices is not None and len(devices) > 0
debug_enabled = ARG("sched_debug", False) # Check for --sched-debug flag

def debug_log(msg):
if debug_enabled:
cons.print(msg)

if debug_enabled:
debug_log(f"[SCHED DEBUG] Starting scheduler: {len(tasks)} tasks, {nThreads} threads, GPU mode: {gpu_mode}")
if gpu_mode:
debug_log(f"[SCHED DEBUG] GPU devices: {devices}")

def join_first_dead_thread(progress, complete_tracker) -> None:
nonlocal threads, nAvailable

debug_log(f"[SCHED DEBUG] Checking {len(threads)} active threads for completion")

for threadID, threadHolder in enumerate(threads):
# Check if thread is not alive OR if it's been running for too long
thread_not_alive = not threadHolder.thread.is_alive()

debug_log(f"[SCHED DEBUG] Thread {threadID}: alive={threadHolder.thread.is_alive()}, devices={threadHolder.devices}")

if thread_not_alive:
debug_log(f"[SCHED DEBUG] Thread {threadID} detected as dead, attempting to join...")

# Properly join the thread with timeout to prevent infinite hangs
join_start_time = time.time()
timeout_duration = 120.0 if gpu_mode else 30.0 # Longer timeout for GPU

debug_log(f"[SCHED DEBUG] Joining thread {threadID} with {timeout_duration}s timeout...")

try:
threadHolder.thread.join(timeout=30.0) # 30 second timeout
threadHolder.thread.join(timeout=timeout_duration)
join_end_time = time.time()
join_duration = join_end_time - join_start_time

debug_log(f"[SCHED DEBUG] Thread {threadID} join completed in {join_duration:.2f}s")

# Double-check that thread actually finished joining
if threadHolder.thread.is_alive():
# Thread didn't finish within timeout - this is a serious issue
raise RuntimeError(f"Thread {threadID} failed to join within 30 seconds timeout. "
f"Thread may be hung or in an inconsistent state.")
debug_log(f"[SCHED DEBUG] ERROR: Thread {threadID} still alive after {timeout_duration}s timeout!")
debug_log(f"[SCHED DEBUG] Thread {threadID} devices: {threadHolder.devices}")
debug_log(f"[SCHED DEBUG] Thread {threadID} exception: {threadHolder.thread.exc}")
raise RuntimeError(f"Thread {threadID} failed to join within {timeout_duration} seconds timeout. "
f"Thread may be hung or in an inconsistent state. "
f"GPU devices: {threadHolder.devices}")

except Exception as join_exc:
# Handle join-specific exceptions with more context
debug_log(f"[SCHED DEBUG] Exception during thread {threadID} join: {join_exc}")
raise RuntimeError(f"Failed to join thread {threadID}: {join_exc}. "
f"This may indicate a system threading issue or hung test case.") from join_exc
f"This may indicate a system threading issue or hung test case. "
f"GPU devices: {threadHolder.devices}") from join_exc

debug_log(f"[SCHED DEBUG] Thread {threadID} successfully joined")

# Check for and propagate any exceptions that occurred in the worker thread
# But only if the worker function didn't complete successfully
# (This allows test failures to be handled gracefully by handle_case)
if threadHolder.thread.exc is not None:
debug_log(f"[SCHED DEBUG] Thread {threadID} had exception: {threadHolder.thread.exc}")
debug_log(f"[SCHED DEBUG] Thread {threadID} completed successfully: {threadHolder.thread.completed_successfully}")

if threadHolder.thread.completed_successfully:
# Test framework handled the exception gracefully (e.g., test failure)
# Don't re-raise - this is expected behavior
debug_log(f"[SCHED DEBUG] Thread {threadID} exception was handled gracefully by test framework")
pass
# Unhandled exception - this indicates a real problem
elif hasattr(threadHolder.thread, 'exc_info') and threadHolder.thread.exc_info:
error_msg = f"Worker thread {threadID} failed with unhandled exception:\n{threadHolder.thread.exc_info}"
debug_log(f"[SCHED DEBUG] Thread {threadID} had unhandled exception!")
raise RuntimeError(error_msg) from threadHolder.thread.exc
else:
debug_log(f"[SCHED DEBUG] Thread {threadID} had unhandled exception without details")
raise threadHolder.thread.exc

# Update scheduler state
nAvailable += threadHolder.ppn
for device in threadHolder.devices or set():
old_load = sched.LOAD[device]
sched.LOAD[device] -= threadHolder.load / threadHolder.ppn
debug_log(f"[SCHED DEBUG] Device {device} load: {old_load:.3f} -> {sched.LOAD[device]:.3f}")

progress.advance(complete_tracker)

debug_log(f"[SCHED DEBUG] Thread {threadID} cleanup complete, removing from active threads")
del threads[threadID]

break

debug_log(f"[SCHED DEBUG] join_first_dead_thread completed, {len(threads)} threads remaining")

with rich.progress.Progress(console=cons.raw, transient=True) as progress:
queue_tracker = progress.add_task("Queued ", total=len(tasks))
complete_tracker = progress.add_task("Completed", total=len(tasks))

debug_log(f"[SCHED DEBUG] Starting task queue processing...")

# Queue Tests
for task in tasks:
for task_idx, task in enumerate(tasks):
debug_log(f"[SCHED DEBUG] Processing task {task_idx+1}/{len(tasks)}: ppn={task.ppn}, load={task.load}")

# Wait until there are threads available
while nAvailable < task.ppn:
debug_log(f"[SCHED DEBUG] Waiting for resources: need {task.ppn}, have {nAvailable}")

# This is important if "-j 1" is used (the default) since there
# are test cases that require test.ppn=2
if task.ppn > nThreads and nAvailable > 0:
debug_log(f"[SCHED DEBUG] Task requires more threads ({task.ppn}) than available ({nThreads}), but some are free ({nAvailable})")
break

# Keep track of threads that are done
Expand All @@ -118,24 +174,34 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
# Use the least loaded devices
if devices is not None:
use_devices = set()
for _ in range(task.ppn):
debug_log(f"[SCHED DEBUG] Assigning GPU devices for task {task_idx+1}")
for device_idx in range(task.ppn):
device = min(sched.LOAD.items(), key=lambda x: x[1])[0]
sched.LOAD[device] += task.load / task.ppn
use_devices.add(device)
debug_log(f"[SCHED DEBUG] Assigned device {device} (load now: {sched.LOAD[device]:.3f})")

nAvailable -= task.ppn

debug_log(f"[SCHED DEBUG] Starting thread for task {task_idx+1}, devices: {use_devices}")
thread = WorkerThread(target=task.func, args=tuple(task.args) + (use_devices,))
thread.start()

threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
debug_log(f"[SCHED DEBUG] Thread started for task {task_idx+1}, {len(threads)} active threads")

debug_log(f"[SCHED DEBUG] All tasks queued, waiting for completion...")

# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
while len(threads) != 0:
debug_log(f"[SCHED DEBUG] Waiting for {len(threads)} threads to complete...")

# Keep track of threads that are done
join_first_dead_thread(progress, complete_tracker)

# Do not overwhelm this core with this loop
time.sleep(0.05)

debug_log(f"[SCHED DEBUG] Scheduler completed successfully!")

sched.LOAD = {}
Loading