Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ import PrecompileTools: @compile_workload
include("precompile.jl")

function __init__()
# Initialize system UUID
# Clear any precompile-cached UUID for this process: the precompile workload
# runs system_uuid() and the resulting SYSTEM_UUIDS entry gets baked into
# the compiled image. Without clearing it here, get!() would return that
# stale build-time UUID instead of reading the actual runtime UUID file,
# causing mismatches between process 1 and workers.
delete!(SYSTEM_UUIDS, myid())
system_uuid()

@static if !isdefined(Base, :get_extension)
Expand Down
15 changes: 15 additions & 0 deletions src/cancellation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ function _cancel!(state, tid, force, graceful, halt_sch)
@dagdebug tid :cancel "Interrupting running task ($Tf)"
Threads.@spawn Base.throwto(task, InterruptException())
else
# Skip if already cancelled to avoid duplicate results in the scheduler queue
tid in istate.cancelled && continue
@dagdebug tid :cancel "Cancelling running task ($Tf)"
# Tell the processor to just drop this task
task_occupancy = task_spec.est_occupancy
Expand All @@ -156,6 +158,19 @@ function _cancel!(state, tid, force, graceful, halt_sch)
cancel!(istate.cancel_tokens[tid]; graceful)
end
end
# Also cancel tokens for tasks that have been dequeued but not yet
# recorded in istate.tasks (race window between token assignment and
# task registration). Just cancel the token so the task sees it when
# it starts; DoTaskSpec will handle posting the result normally.
if !force
for (tid, token) in istate.cancel_tokens
_tid !== nothing && tid != _tid && continue
haskey(istate.tasks, tid) && continue # already handled above
tid in istate.cancelled && continue
@dagdebug tid :cancel "Cancelling pre-running task token"
cancel!(token; graceful)
end
end
end
if any_cancelled
notify(istate.reschedule)
Expand Down
12 changes: 10 additions & 2 deletions src/dtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,21 @@ function waitany(tasks::Vector{DTask})
return
end
cond = Threads.Condition()
done = Ref(false)
for task in tasks
Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin
wait(task)
@lock cond notify(cond)
@lock cond begin
done[] = true
notify(cond)
end
end)
end
@lock cond wait(cond)
@lock cond begin
while !done[]
wait(cond)
end
end
return
end
function waitall(tasks::Vector{DTask})
Expand Down
3 changes: 3 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@
@assert isempty(Sch.WORKER_MONITOR_CHANS)
@assert isempty(Sch.WORKER_MONITOR_TASKS)
ID_COUNTER[] = 1
# Clear the precompile-time UUID cache so it is not baked into the compiled
# image; __init__ re-populates it from the shared UUID file at load time.
delete!(SYSTEM_UUIDS, myid())
end
4 changes: 4 additions & 0 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,10 @@ function (dts::DoTaskSpec)()

# Ensure that any spawned tasks get cleaned up
Dagger.cancel!(dts.cancel_token)

# Reset TLS so that reusable tasks don't inherit stale Dagger context.
Dagger.DTASK_TLS[] = nothing
Dagger.DTASK_CANCEL_TOKEN[] = nothing
end
if was_cancelled
# A result was already posted to the return queue
Expand Down
49 changes: 39 additions & 10 deletions src/sch/eager.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
const EAGER_INIT = Threads.Atomic{Bool}(false)
const EAGER_READY = Base.Event()
# Condition variable used to synchronize EAGER_STATE changes.
# Waiters must hold this lock, check EAGER_STATE[], and wait in a loop.
const EAGER_STATE_LOCK = Threads.Condition()
const EAGER_ID_MAP = LockedObject(Dict{UInt64,Int}())
const EAGER_CONTEXT = Ref{Union{Context,Nothing}}(nothing)
const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing)
Expand All @@ -16,12 +18,21 @@ function init_eager()
throw(ConcurrencyViolationError("init_eager can only be called on worker 1"))
end
if Threads.atomic_xchg!(EAGER_INIT, true)
wait(EAGER_READY)
# Secondary path: another caller is initializing or the scheduler is already running.
# Wait (under the condition lock) for EAGER_STATE to become non-nothing (ready) or
# for EAGER_INIT to become false (scheduler exited without becoming ready).
@lock EAGER_STATE_LOCK begin
while EAGER_STATE[] === nothing && EAGER_INIT[]
wait(EAGER_STATE_LOCK)
end
end
if EAGER_STATE[] === nothing
throw(ConcurrencyViolationError("Eager scheduler failed to start"))
end
return
end

# Primary path: we won the CAS, so we're responsible for starting the scheduler.
ctx = eager_context()
# N.B. We use @async here to prevent the scheduler task from running on a
# different thread than the one that is likely submitting work, as otherwise
Expand All @@ -46,26 +57,44 @@ function init_eager()
seek(iob.io, 0)
write(stderr, iob)
finally
# N.B. Sequence order matters to ensure that observers can see that we failed to start
EAGER_STATE[] = nothing
notify(EAGER_READY)
reset(EAGER_READY)
# Clear EAGER_INIT and EAGER_STATE together under the condition lock.
# Doing both atomically under the lock prevents a race where a new
# scheduler sets EAGER_STATE between our atomic_xchg! and our lock
# acquisition: the new scheduler also needs the lock to set EAGER_STATE,
# so it is forced to wait until after our cleanup, guaranteeing that
# our EAGER_STATE=nothing write cannot overwrite the new state.
@lock EAGER_STATE_LOCK begin
Threads.atomic_xchg!(EAGER_INIT, false)
EAGER_STATE[] = nothing
notify(EAGER_STATE_LOCK; all=true)
end
lock(EAGER_ID_MAP) do id_map
empty!(id_map)
end
Threads.atomic_xchg!(EAGER_INIT, false)
end)
wait(EAGER_READY)

# Wait for eager_thunk to set EAGER_STATE[].
# Loop to handle spurious wakeups and wakeups from old-scheduler cleanup.
@lock EAGER_STATE_LOCK begin
while EAGER_STATE[] === nothing && EAGER_INIT[]
wait(EAGER_STATE_LOCK)
end
end
if EAGER_STATE[] === nothing
throw(ConcurrencyViolationError("Eager scheduler failed to start"))
end
end

function eager_thunk()
exec!(Dagger.sch_handle()) do ctx, state, task, tid, _
EAGER_STATE[] = state
# Set EAGER_STATE and notify all waiters under the condition lock so that
# init_eager's primary wait loop sees the new state atomically.
@lock EAGER_STATE_LOCK begin
EAGER_STATE[] = state
notify(EAGER_STATE_LOCK; all=true)
end
return
end
notify(EAGER_READY)
wait(Dagger.Sch.EAGER_STATE[].halt)
end

Expand Down
6 changes: 4 additions & 2 deletions src/stream-transfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::Stream
try
put!(fetcher.chan, value)
catch err
if err isa InvalidStateException && !isopen(fetcher.chan)
unwrapped = Sch.unwrap_nested_exception(err)
if unwrapped isa InvalidStateException && !isopen(fetcher.chan)
@dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid"
throw(InterruptException())
end
Expand All @@ -35,7 +36,8 @@ function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::Stream
value = try
take!(fetcher.chan)
catch err
if err isa InvalidStateException && !isopen(fetcher.chan)
unwrapped = Sch.unwrap_nested_exception(err)
if unwrapped isa InvalidStateException && !isopen(fetcher.chan)
@dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid"
throw(InterruptException())
end
Expand Down
49 changes: 41 additions & 8 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,26 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B}
initialize_output_stream!(store, output_uid)
end
buffer = store.output_buffers[output_uid]
skip = false
while isfull(buffer)
if !isopen(store)
@dagdebug thunk_id :stream "closed!"
throw(InvalidStateException("Stream is closed", :closed))
end
# Buffer may have been removed by remove_waiters! while we waited
if !haskey(store.output_buffers, output_uid) || !isopen(buffer)
@dagdebug thunk_id :stream "output buffer removed, skipping"
skip = true
break
end
@dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting"
wait(store.lock)
if !isfull(buffer)
@dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing"
end
task_may_cancel!()
end
put!(buffer, value)
skip || put!(buffer, value)
end
notify(store.lock)
end
Expand Down Expand Up @@ -136,9 +143,15 @@ end
function remove_waiters!(store::StreamStore, waiters::Vector{UInt})
@lock store.lock begin
for w in waiters
delete!(store.output_buffers, w)
# Close and remove the output buffer so the output thread can exit
if haskey(store.output_buffers, w)
close(store.output_buffers[w])
delete!(store.output_buffers, w)
end
delete!(store.output_streams, w)
delete!(store.output_fetchers, w)
idx = findfirst(wo->wo==w, store.waiters)
deleteat!(store.waiters, idx)
idx !== nothing && deleteat!(store.waiters, idx)
delete!(store.input_streams, w)
end
notify(store.lock)
Expand Down Expand Up @@ -197,6 +210,9 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
rethrow()
end
finally
# Signal stream! that no more values will arrive, so it can exit
# gracefully instead of blocking forever on take!(buffer)
close(buffer)
@dagdebug STREAM_THUNK_ID[] :stream "input stream closed"
end
end)
Expand Down Expand Up @@ -233,12 +249,15 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt
rethrow()
end
finally
# Close the channel so the downstream input pull thread can detect
# that this upstream source is exhausted and exit gracefully
close(output_fetcher.chan)
@dagdebug thunk_id :stream "output stream closed"
end
end)
end

Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value)
Base.put!(stream::Stream, value) = put!(stream.store, value)

function Base.isopen(stream::Stream, id::UInt)::Bool
return MemPool.access_ref(stream.store_ref.handle, id) do store, id
Expand Down Expand Up @@ -520,6 +539,9 @@ function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...)
# FIXME: Remove when scheduler is distributed
uid = UInt(thunk_id)

# Save original args before initialize_input_stream! rebinds them to StreamingValues,
# so that remove_waiters! can find the upstream Stream objects in the finally block
original_args = args
try
# TODO: This kwarg song-and-dance is required to ensure that we don't
# allocate boxes within `stream!`, when possible
Expand All @@ -532,7 +554,7 @@ function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...)
if !sf.stream.store.migrating
# Remove ourself as a waiter for upstream Streams
streams = Set{Stream}()
for (idx, arg) in enumerate(args)
for (idx, arg) in enumerate(original_args)
if arg isa Stream
push!(streams, arg)
end
Expand All @@ -544,7 +566,7 @@ function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...)
end
for stream in streams
@dagdebug thunk_id :stream "dropping waiter"
remove_waiters!(stream, uid)
remove_waiters!(stream, UInt[uid])
@dagdebug thunk_id :stream "dropped waiter"
end

Expand Down Expand Up @@ -576,8 +598,19 @@ function stream!(sf::StreamingFunction, uid,
end

# Get values from Stream args/kwargs
stream_args = _stream_take_values!(args)
stream_kwarg_values = _stream_take_values!(kwarg_values)
# An InvalidStateException here means an input stream was closed because
# the upstream task finished; exit gracefully in that case
local stream_args, stream_kwarg_values
try
stream_args = _stream_take_values!(args)
stream_kwarg_values = _stream_take_values!(kwarg_values)
catch err
if err isa InvalidStateException
@dagdebug STREAM_THUNK_ID[] :stream "input stream closed, exiting"
return
end
rethrow()
end
stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values)

if length(stream_args) > 0 || length(stream_kwarg_values) > 0
Expand Down
56 changes: 28 additions & 28 deletions src/utils/dagdebug.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
function istask end
function task_id end

const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope,
:take, :execute, :move, :processor, :finish,
:cancel, :stream]
# Use a Set for O(1) membership checks (vs O(n) for Vector).
const DAGDEBUG_CATEGORIES = Set{Symbol}([:global, :submit, :schedule, :scope,
:take, :execute, :move, :processor, :finish,
:cancel, :stream])

# Out-of-line emission keeps call-site IR minimal: just one `in` check + one
# function call per @dagdebug site, regardless of how complex the message is.
@noinline function _dagdebug_emit(thunk, cat_sym::Symbol, msg::String)
id = -1
if thunk isa Integer
id = Int(thunk)
elseif istask(thunk)
id = task_id(thunk)
elseif thunk === nothing
id = 0
else
@warn "Unsupported thunk argument to @dagdebug: $(typeof(thunk))"
id = -1
end
if id > 0
@debug "[$id] ($cat_sym) $msg" _module=Dagger
elseif id == 0
@debug "($cat_sym) $msg" _module=Dagger
end
end

macro dagdebug(thunk, category, msg, args...)
cat_sym = category.value
@gensym id
debug_ex_id = :(@debug "[$($id)] ($($(repr(cat_sym)))) $($msg)" _module=Dagger _file=$(string(__source__.file)) _line=$(__source__.line))
append!(debug_ex_id.args, args)
debug_ex_noid = :(@debug "($($(repr(cat_sym)))) $($msg)" _module=Dagger _file=$(string(__source__.file)) _line=$(__source__.line))
append!(debug_ex_noid.args, args)
esc(quote
let $id = -1
if $thunk isa Integer
$id = Int($thunk)
elseif $istask($thunk)
$id = $task_id($thunk)
elseif $thunk === nothing
$id = 0
else
@warn "Unsupported thunk argument to @dagdebug: $(typeof($thunk))"
$id = -1
end
if $id > 0
if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES || :all in $DAGDEBUG_CATEGORIES
$debug_ex_id
end
elseif $id == 0
if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES || :all in $DAGDEBUG_CATEGORIES
$debug_ex_noid
end
end
if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES || :all in $DAGDEBUG_CATEGORIES
$_dagdebug_emit($thunk, $(QuoteNode(cat_sym)), string($msg))
end
end)
end
Expand Down
Loading
Loading