diff --git a/lua/claudecode/config.lua b/lua/claudecode/config.lua index a4fd436..427e8f4 100644 --- a/lua/claudecode/config.lua +++ b/lua/claudecode/config.lua @@ -10,6 +10,7 @@ local M = {} M.defaults = { port_range = { min = 10000, max = 65535 }, auto_start = true, + bind_address = "127.0.0.1", terminal_cmd = nil, env = {}, -- Custom environment variables for Claude terminal log_level = "info", diff --git a/lua/claudecode/server/tcp.lua b/lua/claudecode/server/tcp.lua index 4aac69e..62e9ef0 100644 --- a/lua/claudecode/server/tcp.lua +++ b/lua/claudecode/server/tcp.lua @@ -77,7 +77,7 @@ function M.create_server(config, callbacks, auth_token) on_error = callbacks.on_error or function() end, } - local bind_success, bind_err = tcp_server:bind("127.0.0.1", port) + local bind_success, bind_err = tcp_server:bind(config.bind_address, port) if not bind_success then tcp_server:close() return nil, "Failed to bind to port " .. port .. ": " .. (bind_err or "unknown error") diff --git a/lua/claudecode/types.lua b/lua/claudecode/types.lua index 2acc365..b2f990d 100644 --- a/lua/claudecode/types.lua +++ b/lua/claudecode/types.lua @@ -106,6 +106,7 @@ ---@field auto_start boolean ---@field terminal_cmd string|nil ---@field env table +---@field bind_address string ---@field log_level ClaudeCodeLogLevel ---@field track_selection boolean ---@field focus_after_send boolean diff --git a/tests/unit/config_spec.lua b/tests/unit/config_spec.lua index dafc925..0bec839 100644 --- a/tests/unit/config_spec.lua +++ b/tests/unit/config_spec.lua @@ -299,5 +299,28 @@ describe("Configuration", function() expect(tostring(err)).to_match("must be a string or function") end) + it("should default bind_address to 127.0.0.1 for security", function() + -- Critical security requirement: must default to localhost-only to prevent + -- external connections to the WebSocket server + expect(config.defaults.bind_address).to_be("127.0.0.1") + end) + + it("should preserve 127.0.0.1 bind_address when no user config is provided", function() + local final_config = config.apply(nil) + expect(final_config.bind_address).to_be("127.0.0.1") + end) + + it("should preserve 127.0.0.1 bind_address when user config does not specify it", function() + local user_config = { log_level = "debug" } + local final_config = config.apply(user_config) + expect(final_config.bind_address).to_be("127.0.0.1") + end) + + it("should allow custom bind_address to be configured", function() + local user_config = { bind_address = "0.0.0.0" } + local final_config = config.apply(user_config) + expect(final_config.bind_address).to_be("0.0.0.0") + end) + teardown() end) diff --git a/tests/unit/server/tcp_spec.lua b/tests/unit/server/tcp_spec.lua index 83a96e3..6b907f3 100644 --- a/tests/unit/server/tcp_spec.lua +++ b/tests/unit/server/tcp_spec.lua @@ -98,6 +98,65 @@ describe("TCP server disconnect handling", function() expect(server.clients[client.id]).to_be_nil() end) + it("should bind the server to config.bind_address", function() + local bind_calls = {} + local original_new_tcp = vim.loop.new_tcp + vim.loop.new_tcp = function() + local handle = original_new_tcp() + local original_bind = handle.bind + handle.bind = function(self, host, port_arg) + table.insert(bind_calls, host) + return original_bind(self, host, port_arg) + end + return handle + end + + local callbacks = { + on_message = function() end, + on_connect = function() end, + on_disconnect = function() end, + on_error = function() end, + } + + local config = { port_range = { min = 10000, max = 10000 }, bind_address = "127.0.0.1" } + local server, err = tcp.create_server(config, callbacks, nil) + vim.loop.new_tcp = original_new_tcp + + assert.is_nil(err) + assert.is_table(server) + -- The last bind call is from create_server (after find_available_port's test bind) + assert.are.equal("127.0.0.1", bind_calls[#bind_calls]) + end) + + it("should use a custom bind_address from config when specified", function() + local bind_calls = {} + local original_new_tcp = vim.loop.new_tcp + vim.loop.new_tcp = function() + local handle = original_new_tcp() + local original_bind = handle.bind + handle.bind = function(self, host, port_arg) + table.insert(bind_calls, host) + return original_bind(self, host, port_arg) + end + return handle + end + + local callbacks = { + on_message = function() end, + on_connect = function() end, + on_disconnect = function() end, + on_error = function() end, + } + + local config = { port_range = { min = 10000, max = 10000 }, bind_address = "0.0.0.0" } + local server, err = tcp.create_server(config, callbacks, nil) + vim.loop.new_tcp = original_new_tcp + + assert.is_nil(err) + assert.is_table(server) + assert.are.equal("0.0.0.0", bind_calls[#bind_calls]) + end) + it("should only call on_disconnect once if multiple disconnect paths fire", function() client_manager.process_data = function(cl, data, on_message, on_close, on_error, auth_token) on_close(cl, 1000, "bye")