Skip to content
104 changes: 54 additions & 50 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"path/filepath"
"path"
"reflect"
"slices"
"strings"
Expand Down Expand Up @@ -601,87 +602,84 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
return result, nil
}

// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
useDefaultDir := false
if dir == "" {
dir = "./prompts"
useDefaultDir = true
// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where prompts are located.
func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) {
if fsys == nil {
panic(errors.New("no prompt filesystem provided"))
}

path, err := filepath.Abs(dir)
if err != nil {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
if _, err := fs.Stat(fsys, dir); err != nil {
panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err))
}

if _, err := os.Stat(path); os.IsNotExist(err) {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
}

loadPromptDir(r, path, namespace)
}

// loadPromptDir recursively loads prompts and partials from the directory.
func loadPromptDir(r api.Registry, dir string, namespace string) {
entries, err := os.ReadDir(dir)
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to read prompt directory structure: %w", err))
}

for _, entry := range entries {
filename := entry.Name()
path := filepath.Join(dir, filename)
filePath := path.Join(dir, filename)
if entry.IsDir() {
loadPromptDir(r, path, namespace)
LoadPromptDirFromFS(r, fsys, filePath, namespace)
} else if strings.HasSuffix(filename, ".prompt") {
if strings.HasPrefix(filename, "_") {
partialName := strings.TrimSuffix(filename[1:], ".prompt")
source, err := os.ReadFile(path)
source, err := fs.ReadFile(fsys, filePath)
if err != nil {
slog.Error("Failed to read partial file", "error", err)
continue
}
r.RegisterPartial(partialName, string(source))
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path)
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath)
} else {
LoadPrompt(r, dir, filename, namespace)
LoadPromptFromFS(r, fsys, dir, filename, namespace)
}
}
}
}

// LoadPrompt loads a single prompt into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
// LoadPromptFromFS loads a single prompt from a filesystem into the registry.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where the prompt is located.
func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt {
name := strings.TrimSuffix(filename, ".prompt")
name, variant, _ := strings.Cut(name, ".")

sourceFile := filepath.Join(dir, filename)
source, err := os.ReadFile(sourceFile)
sourceFile := path.Join(dir, filename)
source, err := fs.ReadFile(fsys, sourceFile)
if err != nil {
slog.Error("Failed to read prompt file", "file", sourceFile, "error", err)
return nil
}

p, err := LoadPromptFromSource(r, string(source), name, namespace)
if err != nil {
slog.Error("Failed to load prompt", "file", sourceFile, "error", err)
return nil
}

slog.Debug("Registered Dotprompt", "name", p.Name(), "file", sourceFile)
return p
}

// LoadPromptFromSource loads a prompt from raw .prompt file content.
// The source parameter should contain the complete .prompt file text (frontmatter + template).
// The name parameter is the prompt name (may include variant suffix like "myPrompt.variant").
func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Prompt, error) {
name, variant, _ := strings.Cut(name, ".")

dp := r.Dotprompt()

parsedPrompt, err := dp.Parse(string(source))
parsedPrompt, err := dp.Parse(source)
if err != nil {
slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err)
return nil
return nil, fmt.Errorf("failed to parse dotprompt: %w", err)
}

metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata)
metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata)
if err != nil {
slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err)
return nil
return nil, fmt.Errorf("failed to render dotprompt metadata: %w", err)
}

toolRefs := make([]ToolRef, len(metadata.Tools))
Expand Down Expand Up @@ -765,17 +763,15 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
if err != nil {
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
return nil
return nil, fmt.Errorf("failed to convert prompt template to messages: %w", err)
}

var systemText string
var nonSystemMessages []*Message
for _, dpMsg := range dpMessages {
parts, err := convertToPartPointers(dpMsg.Content)
if err != nil {
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
return nil
return nil, fmt.Errorf("failed to convert message parts: %w", err)
}

role := Role(dpMsg.Role)
Expand Down Expand Up @@ -809,9 +805,17 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

prompt := DefinePrompt(r, key, promptOpts...)

slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile)
return prompt, nil
}

return prompt
// LoadPromptDir loads prompts and partials from a directory on the local filesystem.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace)
}

// LoadPrompt loads a single prompt from a directory on the local filesystem into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
return LoadPromptFromFS(r, os.DirFS(dir), ".", filename, namespace)
}

// promptKey generates a unique key for the prompt in the registry.
Expand Down
Loading
Loading