Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions internal/cmn/templatefuncs/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (C) 2026 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later

package templatefuncs

import (
"fmt"
"reflect"
"strings"
"text/template"

sprig "github.com/go-task/slim-sprig/v3"
)

// FuncMap returns Dagu's hermetic template function map.
//
// The map is built from slim-sprig's hermetic text functions, removes
// functions that should not be available in DAG templates, and applies
// Dagu-specific pipeline-friendly overrides.
func FuncMap() template.FuncMap {
// Start from the hermetic (no env/network/random) slim-sprig set.
m := sprig.HermeticTxtFuncMap()

// Defense-in-depth: remove any functions that should never be available in
// DAG templates. Some of these are not currently present in the hermetic
// set; keep the blocklist here so future slim-sprig changes cannot expose
// them accidentally.
for _, name := range blockedFuncs {
delete(m, name)
}

// Dagu-specific overrides. These preserve pipeline-compatible argument
// order (pipeline value as last arg) and existing behavior. Each override is
// intentional; slim-sprig defines overlapping names with different arg order
// or semantics.
m["split"] = func(sep, s string) []string {
return strings.Split(s, sep)
}
m["join"] = func(sep string, v any) (string, error) {
if v == nil {
return "", nil
}
switch elems := v.(type) {
case []string:
return strings.Join(elems, sep), nil
case []any:
strs := make([]string, len(elems))
for i, e := range elems {
strs[i] = fmt.Sprint(e)
}
return strings.Join(strs, sep), nil
default:
rv := reflect.ValueOf(v)
if rv.IsValid() && (rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array) {
strs := make([]string, rv.Len())
for i := range strs {
strs[i] = fmt.Sprint(rv.Index(i).Interface())
}
return strings.Join(strs, sep), nil
}
return "", fmt.Errorf("join: unsupported type %T", v)
}
}
m["count"] = func(v any) (int, error) {
if v == nil {
return 0, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() { //nolint:exhaustive // unsupported kinds return an error below
case reflect.Slice, reflect.Map, reflect.Array:
return rv.Len(), nil
case reflect.String:
return rv.Len(), nil
default:
return 0, fmt.Errorf("count: unsupported type %T", v)
}
}
m["add"] = func(b, a int) int {
return a + b
}
m["empty"] = func(v any) bool {
if v == nil {
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() { //nolint:exhaustive // non-empty scalar kinds are handled by IsZero below
case reflect.String:
return rv.Len() == 0
case reflect.Slice, reflect.Map, reflect.Array:
return rv.Len() == 0
default:
return rv.IsZero()
}
}
m["upper"] = func(s string) string {
return strings.ToUpper(s)
}
m["lower"] = func(s string) string {
return strings.ToLower(s)
}
m["trim"] = func(s string) string {
return strings.TrimSpace(s)
}
m["default"] = func(def, val any) any {
if val == nil {
return def
}
rv := reflect.ValueOf(val)
switch rv.Kind() { //nolint:exhaustive // scalar zero values are handled by IsZero below
case reflect.String:
if rv.Len() == 0 {
return def
}
case reflect.Slice, reflect.Map, reflect.Array:
if rv.Len() == 0 {
return def
}
default:
if rv.IsZero() {
return def
}
}
return val
}

return m
}

// blockedFuncs are removed even from the hermetic set as defense-in-depth.
// Some names are not present in slim-sprig v3 today; keep them blocked so
// future or forked slim-sprig versions cannot expose non-hermetic helpers.
var blockedFuncs = []string{
// Environment variable access
"env", "expandenv",
// Network I/O
"getHostByName",
// Non-deterministic time
"now", "date", "dateInZone", "date_in_zone",
"dateModify", "date_modify", "mustDateModify", "must_date_modify",
"ago", "duration", "durationRound",
"unixEpoch", "toDate", "mustToDate",
"htmlDate", "htmlDateInZone",
// Crypto key generation
"genPrivateKey", "derivePassword",
"buildCustomCert", "genCA",
"genSelfSignedCert", "genSignedCert",
// Non-deterministic random
"randBytes", "randString", "randNumeric",
"randAlphaNum", "randAlpha", "randAscii", "randInt",
"uuidv4",
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// BlockedFuncNames returns the names removed from the hermetic slim-sprig
// function map.
func BlockedFuncNames() []string {
return append([]string(nil), blockedFuncs...)
}
37 changes: 37 additions & 0 deletions internal/cmn/templatefuncs/functions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (C) 2026 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later

package templatefuncs

import (
"bytes"
"testing"
"text/template"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestFuncMapJoinRejectsUnsupportedInput(t *testing.T) {
t.Parallel()

tmpl, err := template.New("test").Funcs(FuncMap()).Parse(`{{ . | join "," }}`)
require.NoError(t, err)

var out bytes.Buffer
err = tmpl.Execute(&out, map[string]string{"a": "b"})
require.Error(t, err)
assert.Contains(t, err.Error(), "join: unsupported type map[string]string")
assert.Empty(t, out.String())
}

func TestFuncMapCountNilIsZero(t *testing.T) {
t.Parallel()

tmpl, err := template.New("test").Funcs(FuncMap()).Parse(`{{ count . }}`)
require.NoError(t, err)

var out bytes.Buffer
require.NoError(t, tmpl.Execute(&out, nil))
assert.Equal(t, "0", out.String())
}
20 changes: 11 additions & 9 deletions internal/core/spec/step_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"
gotemplate "text/template"

"github.com/dagucloud/dagu/internal/cmn/templatefuncs"
"github.com/dagucloud/dagu/internal/core"
"github.com/dagucloud/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
Expand Down Expand Up @@ -606,17 +607,18 @@ func renderCustomStepTemplateValue(stepTypeName string, value any, data map[stri
}

func renderCustomStepTemplateString(stepTypeName string, text string, data map[string]any) (string, error) {
funcs := templatefuncs.FuncMap()
funcs["json"] = func(v any) (string, error) {
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(raw), nil
}

tmpl, err := gotemplate.New(stepTypeName).
Option("missingkey=error").
Funcs(gotemplate.FuncMap{
"json": func(v any) (string, error) {
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(raw), nil
},
}).
Funcs(funcs).
Parse(text)
if err != nil {
return "", fmt.Errorf("failed to parse template string: %w", err)
Expand Down
131 changes: 131 additions & 0 deletions internal/core/spec/step_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,137 @@ steps:
assert.Contains(t, err.Error(), `fields "with" and "config" cannot be used together`)
}

func TestCustomStepTypes_TemplateSupportsHermeticFunctions(t *testing.T) {
t.Parallel()

dag, err := LoadYAML(context.Background(), []byte(`
name: custom-step-template-functions
step_types:
format_message:
type: command
input_schema:
type: object
additionalProperties: false
required: [message]
properties:
message:
type: string
fallback:
type: string
default: ""
template:
exec:
command: /bin/echo
args:
- '{{ .input.message | trim | upper | replace "HELLO" "HI" }}'
- '{{ list "b" "a" "b" | uniq | sortAlpha | join "," }}'
- '{{ .input.fallback | default "fallback" }}'
steps:
- type: format_message
config:
message: " hello "
`))
require.NoError(t, err)
require.Len(t, dag.Steps, 1)

step := dag.Steps[0]
require.Len(t, step.Commands, 1)
assert.Equal(t, []string{"HI", "a,b", "fallback"}, step.Commands[0].Args)
assert.Equal(t, "format_message", step.ExecutorConfig.Metadata["custom_type"])
}

func TestCustomStepTypes_TemplateKeepsJSONHelper(t *testing.T) {
t.Parallel()

dag, err := LoadYAML(context.Background(), []byte(`
name: custom-step-json-helper
step_types:
emit:
type: command
input_schema:
type: object
additionalProperties: false
required: [message]
properties:
message:
type: string
template:
exec:
command: /bin/echo
args:
- '{{ json .input.message }}'
steps:
- type: emit
config:
message: 'hello "quoted" world'
`))
require.NoError(t, err)
require.Len(t, dag.Steps, 1)

step := dag.Steps[0]
require.Len(t, step.Commands, 1)
assert.Equal(t, []string{`"hello \"quoted\" world"`}, step.Commands[0].Args)
}

func TestCustomStepTypes_TemplateRejectsBlockedFunctions(t *testing.T) {
t.Parallel()

_, err := LoadYAML(context.Background(), []byte(`
name: custom-step-template-blocked-functions
step_types:
stamp:
type: command
input_schema:
type: object
additionalProperties: false
properties: {}
template:
exec:
command: /bin/echo
args:
- '{{ now }}'
steps:
- type: stamp
`))
require.Error(t, err)
assert.Contains(t, err.Error(), `function "now" not defined`)
}

func TestCustomStepTypes_HarnessCommandCanUseTypedInput(t *testing.T) {
t.Parallel()

dag, err := LoadYAML(context.Background(), []byte(`
name: custom-step-harness-typed-input
step_types:
codex_task:
type: harness
input_schema:
type: object
additionalProperties: false
required: [prompt]
properties:
prompt:
type: string
template:
command:
$input: prompt
config:
provider: codex
steps:
- type: codex_task
config:
prompt: 'Review "quoted" text'
`))
require.NoError(t, err)
require.Len(t, dag.Steps, 1)

step := dag.Steps[0]
assert.Equal(t, "harness", step.ExecutorConfig.Type)
require.Len(t, step.Commands, 1)
assert.Equal(t, `Review "quoted" text`, step.Commands[0].CmdWithArgs)
assert.Equal(t, "codex_task", step.ExecutorConfig.Metadata["custom_type"])
}

func TestCustomStepTypes_RuntimeVariableInputsDeferSchemaValidation(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion internal/intg/queue/proc_liveness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ max_active_runs: 1
steps:
- name: echo
command: echo hello
`, WithProcConfig(50*time.Millisecond, 50*time.Millisecond, 100*time.Millisecond), WithZombieConfig(50*time.Millisecond, 1)).
`, WithProcConfig(queueTestProcHeartbeatInterval, queueTestProcHeartbeatInterval, queueTestProcStaleThreshold), WithZombieConfig(50*time.Millisecond, 3)).
Enqueue(1)
defer f.Stop()

Expand Down
Loading