diff --git a/Project.toml b/Project.toml index ccb3250..f68ad16 100644 --- a/Project.toml +++ b/Project.toml @@ -11,13 +11,19 @@ ModelContextProtocol = "c58f755f-f2a7-4f48-bf29-4e9659b78499" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +[sources] +# DEV-ONLY pin (do not merge to main): send_progress and RequestContext-passing +# tool handlers are not yet in a released ModelContextProtocol.jl. When upstream +# releases them, drop this [sources] block and bump the [compat] entry instead. +ModelContextProtocol = {url = "https://github.com/samtalki/ModelContextProtocol.jl", rev = "feat/progress-notifications"} + [compat] Aqua = "0.8" Dates = "1" JSON3 = "1" JuliaSyntaxHighlighting = "1" Malt = "1.4" -ModelContextProtocol = "0.4" +ModelContextProtocol = "0.4, 0.5" Pkg = "1" Revise = "3" Test = "1" diff --git a/src/packages.jl b/src/packages.jl index ad97424..3dfbe82 100644 --- a/src/packages.jl +++ b/src/packages.jl @@ -54,11 +54,19 @@ function activate_project_on_worker!(path::String; session_name::Union{String,No end """ - run_pkg_action_on_worker(action::String, pkg_list::Vector{String}; session_name::Union{String,Nothing}=nothing) + run_pkg_action_on_worker(action::String, pkg_list::Vector{String}; + session_name=nothing, progress_cb=nothing) Run a Pkg action on the worker process. + +When `progress_cb` is supplied, the action runs through `_run_with_heartbeat`, which +emits a heartbeat every couple of seconds while the action is in flight (the callback +receives `(n, message)` and is used to send MCP `notifications/progress`). Without it +the action runs as a single blocking call, exactly as before. """ -function run_pkg_action_on_worker(action::String, pkg_list::Vector{String}; session_name::Union{String,Nothing}=nothing) +function run_pkg_action_on_worker(action::String, pkg_list::Vector{String}; + session_name::Union{String,Nothing}=nothing, + progress_cb=nothing) session = resolve_session(session_name) worker = ensure_worker!(session) @@ -110,7 +118,10 @@ function run_pkg_action_on_worker(action::String, pkg_list::Vector{String}; sess end try - return _remote_eval_fetch(worker, pkg_expr) + if progress_cb === nothing + return _remote_eval_fetch(worker, pkg_expr) + end + return _run_with_heartbeat(session, worker, pkg_expr, "Pkg.$action", progress_cb) catch e _handle_worker_crash!(session, e) return (error = "Pkg.$action failed — $(_crash_message(e))", diff --git a/src/tools.jl b/src/tools.jl index 5493e4d..57e753f 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -376,7 +376,7 @@ Examples: required = false ) ], - handler = params -> begin + handler = (params, ctx) -> begin try action_lower = _validate_action(params, ["add", "rm", "status", "update", "instantiate", "resolve", "test", "develop", "free"]) @@ -401,7 +401,12 @@ Examples: end session_name = get(params, "session", nothing) - result = run_pkg_action_on_worker(action_lower, pkg_list; session_name=session_name) + # Emit MCP progress for the long-running actions. send_progress is a + # no-op when the client supplied no progressToken, so this is safe to + # always wire up. The status line is also teed to the log viewer. + progress_cb = (n, msg) -> ModelContextProtocol.send_progress(ctx, n; message=msg) + result = run_pkg_action_on_worker(action_lower, pkg_list; + session_name=session_name, progress_cb=progress_cb) if result.error !== nothing return TextContent(text = "Error during Pkg.$action_lower:\n$(result.error)") diff --git a/src/worker.jl b/src/worker.jl index e7be7e6..28d5b8d 100644 --- a/src/worker.jl +++ b/src/worker.jl @@ -163,6 +163,67 @@ function _start_output_drain!(session::SessionState, w::Malt.Worker) return nothing end +""" + _tee_status!(session::SessionState, msg::String) + +Route a single status line the same safe ways as drained worker output: the MCP +server's stderr, the session's `recent_output` ring, and the live log viewer when +attached. Never the stdout transport. Used by `_run_with_heartbeat` so long-op +progress is visible in the log viewer and the `session/log` resource even when the +client does not render `notifications/progress`. +""" +function _tee_status!(session::SessionState, msg::String) + tagged = "[worker:$(session.name):progress] " * msg + try; println(stderr, tagged); catch; end + try + push!(session.recent_output, msg) + length(session.recent_output) > MAX_RECENT_OUTPUT && popfirst!(session.recent_output) + catch; end + try + if LOG_VIEWER.log_io !== nothing + println(LOG_VIEWER.log_io, tagged); flush(LOG_VIEWER.log_io) + end + catch; end + return nothing +end + +""" + _run_with_heartbeat(session, worker, expr, label, progress_cb; interval=2.0) -> Any + +Evaluate `expr` on `worker` while emitting a heartbeat every `interval` seconds for +as long as it keeps running. Each heartbeat tees a status line through +`_tee_status!` and calls `progress_cb(n, message)` (used to emit an MCP +`notifications/progress`). Returns the worker result; a worker failure is unwrapped +(`_unwrap`) and rethrown so the caller can classify it. + +The eval runs on a child task while the heartbeat runs on the calling task, so only +the calling task ever writes the transport. The heartbeat stops before this returns, +so the response write that follows cannot interleave with a notification write. +""" +function _run_with_heartbeat(session::SessionState, worker::Malt.Worker, expr, + label::AbstractString, progress_cb; interval::Real=2.0) + task = @async _remote_eval_fetch(worker, expr) + n = 0 + t0 = time() + while timedwait(() -> istaskdone(task), interval; pollint=0.1) === :timed_out + n += 1 + msg = "$label: still running ($(round(Int, time() - t0))s)" + _tee_status!(session, msg) + try; progress_cb(n, msg); catch; end + end + result = try + fetch(task) + catch e + throw(_unwrap(e)) + end + if n > 0 + msg = "$label: done ($(round(Int, time() - t0))s)" + _tee_status!(session, msg) + try; progress_cb(n + 1, msg); catch; end + end + return result +end + """ ensure_worker!(session::SessionState; _retry_without_revise::Bool=false) -> Malt.Worker diff --git a/test/test_competitive_features.jl b/test/test_competitive_features.jl index f3f26b6..c0f1bf9 100644 --- a/test/test_competitive_features.jl +++ b/test/test_competitive_features.jl @@ -267,3 +267,43 @@ end @test s.socket_path === nothing end end + +@testset "Progress heartbeat for long ops" begin + MCP = AgentREPL.ModelContextProtocol + + @testset "Heartbeat fires and tees to the log ring" begin + s = AgentREPL.create_session!("progress-hb") + try + w = AgentREPL.ensure_worker!(s) + ticks = Tuple{Int,String}[] + result = AgentREPL._run_with_heartbeat(s, w, :(sleep(0.7); 7), "TestOp", + (n, msg) -> push!(ticks, (n, msg)); interval=0.2) + @test result == 7 + @test length(ticks) >= 2 + @test occursin("still running", ticks[1][2]) + @test occursin("done", ticks[end][2]) + @test any(l -> occursin("TestOp", l), s.recent_output) # teed to the ring + finally + try; AgentREPL.destroy_session!("progress-hb"); catch; end + end + end + + @testset "Heartbeat emits MCP notifications/progress" begin + s = AgentREPL.create_session!("progress-wire") + try + w = AgentREPL.ensure_worker!(s) + server = MCP.mcp_server(name="hb", version="0.0.0") + buf = IOBuffer() + server.transport = MCP.StdioTransport(output=buf) + ctx = MCP.RequestContext(server=server, progress_token="hb-1") + AgentREPL._run_with_heartbeat(s, w, :(sleep(0.7); nothing), "WireOp", + (n, msg) -> MCP.send_progress(ctx, n; message=msg); interval=0.2) + wire = String(take!(buf)) + @test occursin("notifications/progress", wire) + @test occursin("hb-1", wire) + @test occursin("WireOp", wire) + finally + try; AgentREPL.destroy_session!("progress-wire"); catch; end + end + end +end