Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewAgentLoop(b *chat.Hub, provider providers.LLMProvider, model string, max
}
reg.Register(fsTool)

reg.Register(tools.NewExecTool(60))
reg.Register(tools.NewExecToolWithWorkspace(60, workspace))
reg.Register(tools.NewWebTool())
reg.Register(tools.NewSpawnTool())
if scheduler != nil {
Expand Down
33 changes: 33 additions & 0 deletions internal/agent/loop_exec_workspace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package agent

import (
"context"
"path/filepath"
"testing"

"github.com/local/picobot/internal/chat"
"github.com/local/picobot/internal/providers"
)

func TestNewAgentLoop_ConfiguresExecToolToWorkspace(t *testing.T) {
ws := t.TempDir()
hub := chat.NewHub(1)
provider := providers.NewStubProvider()

ag := NewAgentLoop(hub, provider, provider.GetDefaultModel(), 3, ws, nil)

execTool := ag.tools.Get("exec")
if execTool == nil {
t.Fatalf("exec tool not registered")
}

out, err := execTool.Execute(context.Background(), map[string]interface{}{
"cmd": []interface{}{"pwd"},
})
if err != nil {
t.Fatalf("exec failed: %v", err)
}
if filepath.Clean(out) != filepath.Clean(ws) {
t.Fatalf("exec working dir mismatch: got %q want %q", out, ws)
}
}
45 changes: 45 additions & 0 deletions internal/agent/tools/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
Expand All @@ -14,6 +15,7 @@ import (
// For safety:
// - prefer array form: {"cmd": ["ls", "-la"]}
// - string form (shell) is disallowed to avoid shell injection
// - safe allowlist enabled by default (opt out with PICOBOT_EXEC_ALLOW_UNSAFE=1)
// - blacklist dangerous program names (rm, sudo, dd, mkfs, shutdown, reboot)
// - arguments containing absolute paths, ~ or .. are rejected
// - optional allowedDir enforces a working directory
Expand Down Expand Up @@ -63,13 +65,39 @@ var dangerous = map[string]struct{}{
"reboot": {},
}

// Default safe allowlist. Set PICOBOT_EXEC_ALLOW_UNSAFE=1 to bypass this list.
// Shell-capable binaries (e.g. git/find/rg) are intentionally excluded.
var safeExecAllowlist = map[string]struct{}{
"cat": {},
"date": {},
"echo": {},
"grep": {},
"head": {},
"ls": {},
"pwd": {},
"sleep": {},
"stat": {},
"tail": {},
"true": {},
"false": {},
"uname": {},
"wc": {},
"whoami": {},
}

func isDangerousProg(prog string) bool {
base := filepath.Base(prog)
base = strings.ToLower(base)
_, ok := dangerous[base]
return ok
}

func isSafeAllowedProg(prog string) bool {
base := strings.ToLower(filepath.Base(prog))
_, ok := safeExecAllowlist[base]
return ok
}

func hasUnsafeArg(s string) bool {
if strings.HasPrefix(s, "/") || strings.HasPrefix(s, "~") || strings.Contains(s, "..") {
return true
Expand Down Expand Up @@ -106,6 +134,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
}

prog := argv[0]
if !isEnvTrue("PICOBOT_EXEC_ALLOW_UNSAFE") {
if strings.Contains(prog, "/") || strings.Contains(prog, "\\") {
return "", fmt.Errorf("exec: program path %q is disallowed; use a program name from safe allowlist", prog)
}
if !isSafeAllowedProg(prog) {
return "", fmt.Errorf("exec: program '%s' is not in safe allowlist; set PICOBOT_EXEC_ALLOW_UNSAFE=1 to override", prog)
}
}
if isDangerousProg(prog) {
return "", fmt.Errorf("exec: program '%s' is disallowed", prog)
}
Expand Down Expand Up @@ -135,3 +171,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
out = strings.TrimRight(out, "\n")
return out, nil
}

func isEnvTrue(key string) bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv(key))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
61 changes: 61 additions & 0 deletions internal/agent/tools/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)

Expand Down Expand Up @@ -63,3 +64,63 @@ func TestExecTimeout(t *testing.T) {
t.Fatalf("expected timeout error")
}
}

func TestExecRejectsProgramPathByDefault(t *testing.T) {
e := NewExecTool(2)
_, err := e.Execute(context.Background(), map[string]interface{}{"cmd": []interface{}{"./script.sh"}})
if err == nil {
t.Fatalf("expected program path rejection")
}
if !strings.Contains(err.Error(), "program path") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestExecRejectsNonAllowlistedProgramByDefault(t *testing.T) {
e := NewExecTool(2)
_, err := e.Execute(context.Background(), map[string]interface{}{"cmd": []interface{}{"sh", "-c", "echo hi"}})
if err == nil {
t.Fatalf("expected non-allowlisted program rejection")
}
if !strings.Contains(err.Error(), "safe allowlist") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestExecUnsafeOverrideAllowsNonAllowlistedProgram(t *testing.T) {
t.Setenv("PICOBOT_EXEC_ALLOW_UNSAFE", "1")
e := NewExecTool(2)
out, err := e.Execute(context.Background(), map[string]interface{}{"cmd": []interface{}{"sh", "-c", "echo hi"}})
if err != nil {
t.Fatalf("expected command to pass with unsafe override: %v", err)
}
if out != "hi" {
t.Fatalf("unexpected output: %q", out)
}
}

func TestExecRejectsGitAliasBypassByDefault(t *testing.T) {
e := NewExecTool(2)
_, err := e.Execute(context.Background(), map[string]interface{}{
"cmd": []interface{}{"git", "-c", "alias.pwn=!echo bypassed", "pwn"},
})
if err == nil {
t.Fatalf("expected git alias bypass payload to be rejected")
}
if !strings.Contains(err.Error(), "safe allowlist") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestExecRejectsFindExecBypassByDefault(t *testing.T) {
e := NewExecTool(2)
_, err := e.Execute(context.Background(), map[string]interface{}{
"cmd": []interface{}{"find", ".", "-maxdepth", "0", "-exec", "sh", "-c", "echo via_find", ";"},
})
if err == nil {
t.Fatalf("expected find -exec payload to be rejected")
}
if !strings.Contains(err.Error(), "safe allowlist") {
t.Fatalf("unexpected error: %v", err)
}
}