diff --git a/RFC.md b/RFC.md new file mode 100644 index 0000000..3b65350 --- /dev/null +++ b/RFC.md @@ -0,0 +1,58 @@ +# API RFC Notes + +## Handler Metadata Example + +```go +type createWidgetHandler struct{} + +func (h *createWidgetHandler) Describe() api.RouteDescription { + return api.RouteDescription{ + StatusCode: http.StatusCreated, + RequestBody: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + Response: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string"}, + }, + }, + } +} + +func (h *createWidgetHandler) OperationID() string { return "widgets_create" } +func (h *createWidgetHandler) Tags() []string { return []string{"widgets"} } +func (h *createWidgetHandler) Summary() string { return "Create widget" } +func (h *createWidgetHandler) Description() string { return "Creates a widget." } + +func (h *createWidgetHandler) Render() api.RenderHints { + return api.RenderHints{ + Kind: "form", + Fields: []api.FieldHint{ + {Name: "name", Label: "Name", Type: "text", Required: true}, + }, + Actions: []api.ActionHint{ + {Name: "preview", Label: "Preview", Method: http.MethodGet}, + }, + } +} + +func (g *widgetsGroup) Describe() []api.RouteDescription { + handler := &createWidgetHandler{} + return []api.RouteDescription{ + { + Method: http.MethodPost, + Path: "/", + Handler: handler, + }, + } +} +``` + +When a `RouteDescription` carries a handler that implements `api.Describable` +and/or `api.Renderable`, `SpecBuilder` uses that metadata to populate the +OpenAPI `operationId`, `tags`, `summary`, `description`, and the +`x-render-hints` vendor extension. diff --git a/api_describable_test.go b/api_describable_test.go new file mode 100644 index 0000000..b97ed92 --- /dev/null +++ b/api_describable_test.go @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/api" +) + +type describableSpecGroup struct { + name string + basePath string + descs []api.RouteDescription +} + +func (g *describableSpecGroup) Name() string { return g.name } +func (g *describableSpecGroup) BasePath() string { return g.basePath } +func (g *describableSpecGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (g *describableSpecGroup) Describe() []api.RouteDescription { return g.descs } + +type describableHandler struct { + desc api.RouteDescription + operationID string + tags []string + summary string + longDescription string +} + +func (h *describableHandler) Describe() api.RouteDescription { + if h == nil { + return api.RouteDescription{} + } + return h.desc +} + +func (h *describableHandler) OperationID() string { + if h == nil { + return "" + } + return h.operationID +} + +func (h *describableHandler) Tags() []string { + if h == nil { + return nil + } + return h.tags +} + +func (h *describableHandler) Summary() string { + if h == nil { + return "" + } + return h.summary +} + +func (h *describableHandler) Description() string { + if h == nil { + return "" + } + return h.longDescription +} + +func buildDescribableOperation(t *testing.T, group api.RouteGroup, path, method string) map[string]any { + t.Helper() + + builder := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + data, err := builder.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths[path].(map[string]any) + if !ok { + t.Fatalf("expected path %q in spec", path) + } + + operation, ok := pathItem[method].(map[string]any) + if !ok { + t.Fatalf("expected %s operation on %q", method, path) + } + + return operation +} + +func TestDescribable_Good_HandlerMetadataFlowsToOpenAPI(t *testing.T) { + handler := &describableHandler{ + desc: api.RouteDescription{ + StatusCode: http.StatusCreated, + RequestBody: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + Response: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string"}, + }, + }, + }, + operationID: "widgets_create", + tags: []string{"widgets", "catalog"}, + summary: "Create widget", + longDescription: "Creates a widget and returns the stored record.", + } + + group := &describableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodPost, + Path: "/", + Handler: handler, + }, + }, + } + + operation := buildDescribableOperation(t, group, "/api/widgets", "post") + + if got := operation["operationId"]; got != "widgets_create" { + t.Fatalf("expected handler operationId, got %v", got) + } + if got := operation["summary"]; got != "Create widget" { + t.Fatalf("expected handler summary, got %v", got) + } + if got := operation["description"]; got != "Creates a widget and returns the stored record." { + t.Fatalf("expected handler description, got %v", got) + } + + tags, ok := operation["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", operation["tags"]) + } + if len(tags) != 2 || tags[0] != "widgets" || tags[1] != "catalog" { + t.Fatalf("expected handler tags, got %v", tags) + } + + requestBody := operation["requestBody"].(map[string]any) + content := requestBody["content"].(map[string]any) + schema := content["application/json"].(map[string]any)["schema"].(map[string]any) + properties := schema["properties"].(map[string]any) + if _, ok := properties["name"]; !ok { + t.Fatal("expected request body schema from handler Describe") + } + + responses := operation["responses"].(map[string]any) + if _, ok := responses["201"]; !ok { + t.Fatal("expected status code from handler Describe") + } +} + +func TestDescribable_Bad_MissingHandlerMetadataFallsBackSafely(t *testing.T) { + group := &describableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodGet, + Path: "/status", + Summary: "Widget status", + Description: "Returns widget availability.", + Tags: []string{"status"}, + Handler: &describableHandler{}, + }, + }, + } + + operation := buildDescribableOperation(t, group, "/api/widgets/status", "get") + + if got := operation["operationId"]; got != "get_api_widgets_status" { + t.Fatalf("expected generated operationId fallback, got %v", got) + } + if got := operation["summary"]; got != "Widget status" { + t.Fatalf("expected route summary fallback, got %v", got) + } + if got := operation["description"]; got != "Returns widget availability." { + t.Fatalf("expected route description fallback, got %v", got) + } + + tags, ok := operation["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", operation["tags"]) + } + if len(tags) != 1 || tags[0] != "status" { + t.Fatalf("expected route tag fallback, got %v", tags) + } +} + +func TestDescribable_Ugly_NilHandlerIsIgnored(t *testing.T) { + group := &describableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodGet, + Path: "/status", + Handler: (*describableHandler)(nil), + }, + }, + } + + operation := buildDescribableOperation(t, group, "/api/widgets/status", "get") + + if got := operation["operationId"]; got != "get_api_widgets_status" { + t.Fatalf("expected generated operationId with nil handler, got %v", got) + } + + tags, ok := operation["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", operation["tags"]) + } + if len(tags) != 1 || tags[0] != "widgets" { + t.Fatalf("expected group-name tag fallback, got %v", tags) + } +} diff --git a/api_renderable_test.go b/api_renderable_test.go new file mode 100644 index 0000000..8f43998 --- /dev/null +++ b/api_renderable_test.go @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/api" +) + +type renderableSpecGroup struct { + name string + basePath string + descs []api.RouteDescription +} + +func (g *renderableSpecGroup) Name() string { return g.name } +func (g *renderableSpecGroup) BasePath() string { return g.basePath } +func (g *renderableSpecGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (g *renderableSpecGroup) Describe() []api.RouteDescription { return g.descs } + +type renderableHandler struct { + hints api.RenderHints +} + +func (h *renderableHandler) Render() api.RenderHints { + if h == nil { + return api.RenderHints{} + } + return h.hints +} + +func buildRenderableOperation(t *testing.T, group api.RouteGroup, path, method string) map[string]any { + t.Helper() + + builder := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + data, err := builder.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths[path].(map[string]any) + if !ok { + t.Fatalf("expected path %q in spec", path) + } + + operation, ok := pathItem[method].(map[string]any) + if !ok { + t.Fatalf("expected %s operation on %q", method, path) + } + + return operation +} + +func TestRenderable_Good_HandlerHintsFlowToOpenAPI(t *testing.T) { + group := &renderableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodPost, + Path: "/", + Handler: &renderableHandler{ + hints: api.RenderHints{ + Kind: "form", + Fields: []api.FieldHint{ + { + Name: "name", + Label: "Name", + Type: "text", + Required: true, + Validation: map[string]any{ + "minLength": 3, + }, + }, + }, + Actions: []api.ActionHint{ + { + Name: "preview", + Label: "Preview", + Method: http.MethodGet, + Variant: "secondary", + }, + }, + }, + }, + }, + }, + } + + operation := buildRenderableOperation(t, group, "/api/widgets", "post") + + rawHints, ok := operation["x-render-hints"].(map[string]any) + if !ok { + t.Fatalf("expected x-render-hints extension, got %T", operation["x-render-hints"]) + } + if got := rawHints["kind"]; got != "form" { + t.Fatalf("expected render kind form, got %v", got) + } + + fields, ok := rawHints["fields"].([]any) + if !ok || len(fields) != 1 { + t.Fatalf("expected one render field, got %v", rawHints["fields"]) + } + field := fields[0].(map[string]any) + if got := field["name"]; got != "name" { + t.Fatalf("expected render field name, got %v", got) + } + if got := field["required"]; got != true { + t.Fatalf("expected render field required=true, got %v", got) + } + validation := field["validation"].(map[string]any) + if got := validation["minLength"]; got != float64(3) { + t.Fatalf("expected validation minLength=3, got %v", got) + } + + actions, ok := rawHints["actions"].([]any) + if !ok || len(actions) != 1 { + t.Fatalf("expected one render action, got %v", rawHints["actions"]) + } + action := actions[0].(map[string]any) + if got := action["name"]; got != "preview" { + t.Fatalf("expected render action name, got %v", got) + } +} + +func TestRenderable_Bad_EmptyHintsAreOmittedSafely(t *testing.T) { + group := &renderableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodGet, + Path: "/status", + Handler: &renderableHandler{}, + }, + }, + } + + operation := buildRenderableOperation(t, group, "/api/widgets/status", "get") + + if _, ok := operation["x-render-hints"]; ok { + t.Fatalf("expected empty render hints to be omitted, got %v", operation["x-render-hints"]) + } +} + +func TestRenderable_Ugly_NilHandlerIsIgnored(t *testing.T) { + group := &renderableSpecGroup{ + name: "widgets", + basePath: "/api/widgets", + descs: []api.RouteDescription{ + { + Method: http.MethodGet, + Path: "/status", + Handler: (*renderableHandler)(nil), + }, + }, + } + + operation := buildRenderableOperation(t, group, "/api/widgets/status", "get") + + if _, ok := operation["x-render-hints"]; ok { + t.Fatalf("expected nil renderable handler to be ignored, got %v", operation["x-render-hints"]) + } +} diff --git a/bridge_internal_test.go b/bridge_internal_test.go new file mode 100644 index 0000000..971ebe5 --- /dev/null +++ b/bridge_internal_test.go @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "math" + "testing" +) + +// TestBridge_schemaInt_overflow_Bad verifies that uint/uint64 values exceeding +// math.MaxInt return (0, false) instead of silently wrapping to a negative int. +// +// G115 (gosec): integer overflow on coercion would let attacker-controlled +// JSON numbers >= 2^63 wrap to negative values, which downstream feeds into +// range checks / slice indices / array sizes with wrong sign. +func TestBridge_schemaInt_overflow_Bad(t *testing.T) { + tests := []struct { + name string + value any + }{ + {name: "uint64 max", value: uint64(math.MaxUint64)}, + {name: "uint64 over MaxInt", value: uint64(math.MaxInt) + 1}, + {name: "uint over MaxInt", value: uint(math.MaxInt) + 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := schemaInt(tt.value) + if ok { + t.Errorf("schemaInt(%v) returned ok=true; expected false on overflow", tt.value) + } + if got != 0 { + t.Errorf("schemaInt(%v) returned %d; expected 0 on overflow", tt.value, got) + } + }) + } +} + +// TestBridge_schemaInt_inrange_Good verifies that valid values still convert. +func TestBridge_schemaInt_inrange_Good(t *testing.T) { + tests := []struct { + name string + value any + want int + }{ + {name: "uint zero", value: uint(0), want: 0}, + {name: "uint small", value: uint(42), want: 42}, + {name: "uint64 small", value: uint64(100), want: 100}, + {name: "uint64 maxint", value: uint64(math.MaxInt), want: math.MaxInt}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := schemaInt(tt.value) + if !ok { + t.Errorf("schemaInt(%v) returned ok=false; expected true", tt.value) + } + if got != tt.want { + t.Errorf("schemaInt(%v) = %d; want %d", tt.value, got, tt.want) + } + }) + } +} + +// TestBridge_schemaInt_boundary_Ugly tests the exact MaxInt boundary — +// MaxInt itself must succeed, MaxInt+1 must fail. +func TestBridge_schemaInt_boundary_Ugly(t *testing.T) { + // uint64(MaxInt) — boundary, must succeed + if got, ok := schemaInt(uint64(math.MaxInt)); !ok || got != math.MaxInt { + t.Errorf("schemaInt(uint64(MaxInt)) = (%d, %v); want (MaxInt, true)", got, ok) + } + // uint64(MaxInt)+1 — one over boundary, must fail + if _, ok := schemaInt(uint64(math.MaxInt) + 1); ok { + t.Error("schemaInt(uint64(MaxInt)+1) returned ok=true; expected false (boundary)") + } +} diff --git a/cache_config_test.go b/cache_config_test.go new file mode 100644 index 0000000..eb8d18c --- /dev/null +++ b/cache_config_test.go @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "testing" + "time" + + api "dappco.re/go/api" +) + +// TestCacheConfig_Good_SnapshotsConfiguredEngine verifies that CacheConfig +// reflects the cache limits supplied during engine construction. +func TestCacheConfig_Good_SnapshotsConfiguredEngine(t *testing.T) { + e, err := api.New(api.WithCacheLimits(5*time.Minute, 10, 1024)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.CacheConfig() + + if !cfg.Enabled { + t.Fatal("expected cache config to be enabled") + } + if cfg.TTL != 5*time.Minute { + t.Fatalf("expected TTL %v, got %v", 5*time.Minute, cfg.TTL) + } + if cfg.MaxEntries != 10 { + t.Fatalf("expected MaxEntries 10, got %d", cfg.MaxEntries) + } + if cfg.MaxBytes != 1024 { + t.Fatalf("expected MaxBytes 1024, got %d", cfg.MaxBytes) + } +} + +// TestCacheConfig_Bad_NilEngineReturnsZeroValue verifies the nil-receiver +// guard returns an empty snapshot instead of panicking. +func TestCacheConfig_Bad_NilEngineReturnsZeroValue(t *testing.T) { + var e *api.Engine + + cfg := e.CacheConfig() + if cfg != (api.CacheConfig{}) { + t.Fatalf("expected zero-value cache config, got %+v", cfg) + } +} + +// TestCacheConfig_Ugly_UnconfiguredEngineStaysDisabled verifies that an +// engine without cache middleware reports a disabled snapshot. +func TestCacheConfig_Ugly_UnconfiguredEngineStaysDisabled(t *testing.T) { + e, err := api.New() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.CacheConfig() + if cfg.Enabled { + t.Fatal("expected cache config to remain disabled") + } + if cfg.TTL != 0 || cfg.MaxEntries != 0 || cfg.MaxBytes != 0 { + t.Fatalf("expected zero cache settings, got %+v", cfg) + } +} diff --git a/go/codegen_test.go b/go/codegen_test.go index f772ef0..9dedced 100644 --- a/go/codegen_test.go +++ b/go/codegen_test.go @@ -266,3 +266,74 @@ func TestSDKGenerator_Generate_PackageNameAccepted_Good(t *testing.T) { }) } } + +// TestSDKGenerator_Generate_PackageNameRejected_Bad verifies the regex-validation +// hardening from Mantis #322 — PackageName containing flag-injection characters +// is rejected before exec.CommandContext is reached. +func TestSDKGenerator_Generate_PackageNameRejected_Bad(t *testing.T) { + tmp := t.TempDir() + specPath := filepath.Join(tmp, "spec.yaml") + if err := os.WriteFile(specPath, []byte("openapi: 3.0.0\n"), 0o644); err != nil { + t.Fatalf("write spec: %v", err) + } + + rejects := []string{ + "foo --extra=evil", // space + flag injection + "foo;rm -rf /", // command separator + "foo bar", // bare space + "--shell-injection", // leading dash + "foo$(whoami)", // command substitution + } + for _, name := range rejects { + t.Run(name, func(t *testing.T) { + gen := &api.SDKGenerator{ + SpecPath: specPath, + OutputDir: tmp, + PackageName: name, + } + err := gen.Generate(context.Background(), "go") + if err == nil { + t.Errorf("expected rejection for PackageName=%q, got nil error", name) + return + } + if !strings.Contains(err.Error(), "package name") { + t.Errorf("expected rejection error containing 'package name', got %q", err.Error()) + } + }) + } +} + +// TestSDKGenerator_Generate_PackageNameAccepted_Good verifies legitimate names +// pass the regex; any subsequent error must NOT be the regex-rejection. +func TestSDKGenerator_Generate_PackageNameAccepted_Good(t *testing.T) { + accepts := []string{ + "foo", + "FooBar", + "foo_bar", + "foo-bar", + "Foo123", + "a", + } + tmp := t.TempDir() + specPath := filepath.Join(tmp, "spec.yaml") + if err := os.WriteFile(specPath, []byte("openapi: 3.0.0\n"), 0o644); err != nil { + t.Fatalf("write spec: %v", err) + } + for _, name := range accepts { + t.Run(name, func(t *testing.T) { + gen := &api.SDKGenerator{ + SpecPath: specPath, + OutputDir: tmp, + PackageName: name, + } + err := gen.Generate(context.Background(), "go") + // Likely fails because openapi-generator-cli isn't installed in + // CI; the error MUST NOT be the regex-rejection ("package name + // X rejected"). + if err != nil && strings.Contains(err.Error(), "package name") && + strings.Contains(err.Error(), "rejected") { + t.Errorf("name %q was unexpectedly rejected by regex: %v", name, err) + } + }) + } +} diff --git a/php/tests/Feature/TicketControllerTest.php b/php/tests/Feature/TicketControllerTest.php new file mode 100644 index 0000000..a3c5e49 --- /dev/null +++ b/php/tests/Feature/TicketControllerTest.php @@ -0,0 +1,76 @@ +create(array_merge([ + 'workspace_id' => null, + 'user_id' => null, + 'subject' => 'Private support issue', + 'message' => 'Customer-only support conversation', + 'status' => 'open', + 'priority' => 'normal', + 'metadata' => [], + 'last_replied_at' => now(), + ], $attributes)); +} + +it('TicketController_findTicket_AnonymousAccess_Bad_blocks_unscoped_lookup', function () { + $ticket = ticketControllerTestTicket(); + + $response = $this->getJson("/api/test-support/tickets/{$ticket->id}"); + + $response + ->assertStatus(403) + ->assertJsonMissing(['subject' => 'Private support issue']); +}); + +it('TicketController_findTicket_FailOpenAttempt_Ugly_logs_warning_context', function () { + $ticket = ticketControllerTestTicket(); + + Log::shouldReceive('warning') + ->once() + ->with('TicketController.findTicket fail-open attempt', \Mockery::on(function (array $context) use ($ticket): bool { + return ($context['ticket_id'] ?? null) === (string) $ticket->id + && isset($context['actor_ip']) + && ($context['route'] ?? null) === "api/test-support/tickets/{$ticket->id}"; + })); + + $this->getJson("/api/test-support/tickets/{$ticket->id}") + ->assertStatus(403); +}); + +it('TicketController_findTicket_AuthenticatedUser_Good_returns_owned_ticket', function () { + $user = User::query()->create([ + 'name' => fake()->name(), + 'email' => fake()->unique()->safeEmail(), + 'password' => 'password', + ]); + $ticket = ticketControllerTestTicket([ + 'user_id' => $user->id, + 'subject' => 'Owned support issue', + ]); + + $this->actingAs($user); + + $response = $this->getJson("/api/test-support/tickets/{$ticket->id}"); + + $response + ->assertOk() + ->assertJsonPath('data.id', $ticket->id) + ->assertJsonPath('data.subject', 'Owned support issue'); +}); diff --git a/pkg/stream/stream_group.go b/pkg/stream/stream_group.go new file mode 100644 index 0000000..f4234ce --- /dev/null +++ b/pkg/stream/stream_group.go @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: EUPL-1.2 + +// Package stream defines declarative SSE and WebSocket endpoint groups that +// can be mounted onto an api.Engine. +package stream + +import ( + "net/http" + "slices" + "strings" + + core "dappco.re/go/core" + + "github.com/gin-gonic/gin" +) + +// Protocol identifies the wire protocol a stream handler serves. +type Protocol string + +const ( + // ProtocolSSE identifies a Server-Sent Events endpoint. + ProtocolSSE Protocol = "sse" + // ProtocolWebSocket identifies a WebSocket endpoint. + ProtocolWebSocket Protocol = "websocket" +) + +// Handler describes a single stream-capable route. +// +// The protocol and path are retained as declarative metadata so callers can +// inspect mounted stream surfaces and future OpenAPI hooks can consume them. +type Handler struct { + Protocol Protocol + Method string + Path string + Handle gin.HandlerFunc +} + +// Registrar is the minimal Gin registration surface required by StreamGroup. +// Both *gin.Engine and *gin.RouterGroup satisfy this contract. +type Registrar interface { + Handle(httpMethod, relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes +} + +// StreamGroup declares a named set of SSE/WebSocket handlers. +// +// Example: +// +// var group stream.StreamGroup = stream.NewGroup( +// "system", +// stream.SSE("/events", func(c *gin.Context) {}), +// ) +type StreamGroup interface { + // Register mounts all handlers onto the supplied registrar. + Register(reg Registrar) + + // Name returns a human-readable identifier for the group. + Name() string + + // Handlers returns the group's declared handler metadata. + Handlers() []Handler +} + +// Group is a small concrete StreamGroup implementation backed by a handler +// slice. It is suitable for most SSE/WebSocket endpoint declarations. +type Group struct { + name string + handlers []Handler +} + +// NewGroup creates a StreamGroup with normalised handler metadata. +func NewGroup(name string, handlers ...Handler) *Group { + return &Group{ + name: core.Trim(name), + handlers: normaliseHandlers(handlers), + } +} + +// Name returns the group's identifier. +func (g *Group) Name() string { + if g == nil { + return "" + } + return g.name +} + +// Handlers returns a defensive copy of the group's handler metadata. +func (g *Group) Handlers() []Handler { + if g == nil || len(g.handlers) == 0 { + return nil + } + return slices.Clone(g.handlers) +} + +// Register mounts all valid handlers onto the supplied registrar. +func (g *Group) Register(reg Registrar) { + if g == nil || reg == nil { + return + } + + for _, handler := range g.handlers { + reg.Handle(handler.Method, handler.Path, handler.Handle) + } +} + +// SSE creates a GET Server-Sent Events handler descriptor. +func SSE(path string, handle gin.HandlerFunc) Handler { + return Handler{ + Protocol: ProtocolSSE, + Method: http.MethodGet, + Path: path, + Handle: handle, + } +} + +// WebSocket creates a GET WebSocket handler descriptor. +func WebSocket(path string, handle gin.HandlerFunc) Handler { + return Handler{ + Protocol: ProtocolWebSocket, + Method: http.MethodGet, + Path: path, + Handle: handle, + } +} + +func normaliseHandlers(handlers []Handler) []Handler { + if len(handlers) == 0 { + return nil + } + + out := make([]Handler, 0, len(handlers)) + for _, handler := range handlers { + handler = normaliseHandler(handler) + if !handler.valid() { + continue + } + out = append(out, handler) + } + + if len(out) == 0 { + return nil + } + + return out +} + +func normaliseHandler(handler Handler) Handler { + handler.Protocol = normaliseProtocol(handler.Protocol) + + method := strings.ToUpper(core.Trim(handler.Method)) + if method == "" { + method = http.MethodGet + } + handler.Method = method + handler.Path = normalisePath(handler.Path) + + return handler +} + +func (h Handler) valid() bool { + return h.Protocol != "" && h.Path != "" && h.Handle != nil +} + +func normaliseProtocol(protocol Protocol) Protocol { + switch strings.ToLower(core.Trim(string(protocol))) { + case "event-stream", "eventstream", "sse": + return ProtocolSSE + case "websocket", "ws": + return ProtocolWebSocket + default: + return "" + } +} + +func normalisePath(path string) string { + path = core.Trim(path) + if path == "" { + return "" + } + + trimmed := strings.Trim(path, "/") + if trimmed == "" { + return "/" + } + + return "/" + trimmed +} diff --git a/pkg/stream/stream_group_example_test.go b/pkg/stream/stream_group_example_test.go new file mode 100644 index 0000000..29de77c --- /dev/null +++ b/pkg/stream/stream_group_example_test.go @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + + api "dappco.re/go/api" + "dappco.re/go/api/pkg/stream" + + "github.com/gin-gonic/gin" +) + +func ExampleNewGroup() { + gin.SetMode(gin.TestMode) + + engine, _ := api.New() + engine.RegisterStreamGroup(stream.NewGroup( + "system", + stream.SSE("/events", func(c *gin.Context) { + c.Data(http.StatusOK, "text/event-stream", []byte("data: ready\n\n")) + }), + )) + + rec := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/events", nil) + engine.Handler().ServeHTTP(rec, req) + + _, _ = io.WriteString(os.Stdout, strings.TrimSpace(rec.Body.String())) + // Output: data: ready +} diff --git a/pkg/stream/stream_group_test.go b/pkg/stream/stream_group_test.go new file mode 100644 index 0000000..c857d87 --- /dev/null +++ b/pkg/stream/stream_group_test.go @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + api "dappco.re/go/api" + "dappco.re/go/api/pkg/stream" + + "github.com/gin-gonic/gin" +) + +func TestStreamGroup_Good_RoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := stream.NewGroup( + "events", + stream.SSE("/events", func(c *gin.Context) { + c.Data(http.StatusOK, "text/event-stream", []byte("data: ready\n\n")) + }), + stream.WebSocket("/ws", func(c *gin.Context) { + c.Header("Upgrade", "websocket") + c.Status(http.StatusSwitchingProtocols) + }), + ) + + handlers := group.Handlers() + if len(handlers) != 2 { + t.Fatalf("expected 2 handlers, got %d", len(handlers)) + } + if handlers[0].Protocol != stream.ProtocolSSE { + t.Fatalf("expected first protocol %q, got %q", stream.ProtocolSSE, handlers[0].Protocol) + } + if handlers[0].Method != http.MethodGet { + t.Fatalf("expected first method %q, got %q", http.MethodGet, handlers[0].Method) + } + if handlers[0].Path != "/events" { + t.Fatalf("expected first path %q, got %q", "/events", handlers[0].Path) + } + if handlers[1].Protocol != stream.ProtocolWebSocket { + t.Fatalf("expected second protocol %q, got %q", stream.ProtocolWebSocket, handlers[1].Protocol) + } + if handlers[1].Path != "/ws" { + t.Fatalf("expected second path %q, got %q", "/ws", handlers[1].Path) + } + + router := gin.New() + group.Register(router) + + sseRecorder := httptest.NewRecorder() + sseReq, _ := http.NewRequest(http.MethodGet, "/events", nil) + router.ServeHTTP(sseRecorder, sseReq) + + if sseRecorder.Code != http.StatusOK { + t.Fatalf("expected SSE status 200, got %d", sseRecorder.Code) + } + if got := sseRecorder.Header().Get("Content-Type"); got != "text/event-stream" { + t.Fatalf("expected SSE content type %q, got %q", "text/event-stream", got) + } + + wsRecorder := httptest.NewRecorder() + wsReq, _ := http.NewRequest(http.MethodGet, "/ws", nil) + router.ServeHTTP(wsRecorder, wsReq) + + if wsRecorder.Code != http.StatusSwitchingProtocols { + t.Fatalf("expected WebSocket status 101, got %d", wsRecorder.Code) + } +} + +func TestStreamGroup_Bad_DropsInvalidHandlersAndClonesMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := stream.NewGroup( + "invalid", + stream.Handler{ + Protocol: stream.ProtocolSSE, + Method: http.MethodGet, + Path: "", + Handle: func(*gin.Context) {}, + }, + stream.Handler{ + Protocol: stream.ProtocolWebSocket, + Method: http.MethodGet, + Path: "/ws", + Handle: nil, + }, + stream.SSE("/events", func(c *gin.Context) { + c.Status(http.StatusNoContent) + }), + ) + + handlers := group.Handlers() + if len(handlers) != 1 { + t.Fatalf("expected 1 valid handler, got %d", len(handlers)) + } + + handlers[0].Path = "/mutated" + + fresh := group.Handlers() + if len(fresh) != 1 { + t.Fatalf("expected 1 fresh handler, got %d", len(fresh)) + } + if fresh[0].Path != "/events" { + t.Fatalf("expected cloned handler path %q, got %q", "/events", fresh[0].Path) + } + + router := gin.New() + group.Register(router) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/events", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("expected valid handler to remain registered, got %d", w.Code) + } +} + +func TestStreamGroup_Ugly_NormalisesWhitespaceWrappedMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := stream.NewGroup( + " ugly ", + stream.Handler{ + Protocol: " WS ", + Method: " get ", + Path: " /tenant/socket/ ", + Handle: func(c *gin.Context) { + c.String(http.StatusAccepted, "ok") + }, + }, + ) + + if group.Name() != "ugly" { + t.Fatalf("expected trimmed name %q, got %q", "ugly", group.Name()) + } + + handlers := group.Handlers() + if len(handlers) != 1 { + t.Fatalf("expected 1 handler, got %d", len(handlers)) + } + if handlers[0].Protocol != stream.ProtocolWebSocket { + t.Fatalf("expected normalised protocol %q, got %q", stream.ProtocolWebSocket, handlers[0].Protocol) + } + if handlers[0].Method != http.MethodGet { + t.Fatalf("expected normalised method %q, got %q", http.MethodGet, handlers[0].Method) + } + if handlers[0].Path != "/tenant/socket" { + t.Fatalf("expected normalised path %q, got %q", "/tenant/socket", handlers[0].Path) + } + + router := gin.New() + group.Register(router) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/tenant/socket", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("expected normalised handler status 202, got %d", w.Code) + } +} + +func TestEngineRegisterStreamGroup_Good_MultiTenantRegistration(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine, err := api.New() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + engine.RegisterStreamGroup(stream.NewGroup( + "tenant-a", + stream.SSE("/tenants/a/events", func(c *gin.Context) { + c.Data(http.StatusOK, "text/event-stream", []byte("data: tenant-a\n\n")) + }), + )) + engine.RegisterStreamGroup(stream.NewGroup( + "tenant-b", + stream.SSE("/tenants/b/events", func(c *gin.Context) { + c.Data(http.StatusOK, "text/event-stream", []byte("data: tenant-b\n\n")) + }), + )) + + server := httptest.NewServer(engine.Handler()) + defer server.Close() + + for _, tc := range []struct { + path string + body string + }{ + {path: "/tenants/a/events", body: "data: tenant-a\n\n"}, + {path: "/tenants/b/events", body: "data: tenant-b\n\n"}, + } { + resp, reqErr := http.Get(server.URL + tc.path) + if reqErr != nil { + t.Fatalf("request %s failed: %v", tc.path, reqErr) + } + + func() { + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("%s: expected status 200, got %d", tc.path, resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got != "text/event-stream" { + t.Fatalf("%s: expected content type %q, got %q", tc.path, "text/event-stream", got) + } + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + t.Fatalf("%s: read body failed: %v", tc.path, readErr) + } + if string(body) != tc.body { + t.Fatalf("%s: expected body %q, got %q", tc.path, tc.body, string(body)) + } + }() + } +} diff --git a/runtime_config_test.go b/runtime_config_test.go new file mode 100644 index 0000000..235357a --- /dev/null +++ b/runtime_config_test.go @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "slices" + "testing" + "time" + + api "dappco.re/go/api" +) + +// TestEngine_RuntimeConfig_Good_SnapshotsCurrentSettings verifies the +// aggregate runtime snapshot mirrors the current engine configuration. +func TestEngine_RuntimeConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + broker := api.NewSSEBroker() + e, err := api.New( + api.WithSwagger("Runtime API", "Runtime snapshot", "1.2.3"), + api.WithSwaggerPath("/docs"), + api.WithCacheLimits(5*time.Minute, 10, 1024), + api.WithGraphQL(newTestSchema(), api.WithPlayground()), + api.WithI18n(api.I18nConfig{ + DefaultLocale: "en-GB", + Supported: []string{"en-GB", "fr"}, + }), + api.WithWSPath("/socket"), + api.WithSSE(broker), + api.WithSSEPath("/events"), + api.WithAuthentik(api.AuthentikConfig{ + Issuer: "https://auth.example.com", + ClientID: "runtime-client", + TrustedProxy: true, + PublicPaths: []string{"/public", "/docs"}, + }), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.RuntimeConfig() + + if !cfg.Swagger.Enabled { + t.Fatal("expected swagger snapshot to be enabled") + } + if cfg.Swagger.Path != "/docs" { + t.Fatalf("expected swagger path /docs, got %q", cfg.Swagger.Path) + } + if cfg.Transport.SwaggerPath != "/docs" { + t.Fatalf("expected transport swagger path /docs, got %q", cfg.Transport.SwaggerPath) + } + if cfg.Transport.GraphQLPlaygroundPath != "/graphql/playground" { + t.Fatalf("expected transport graphql playground path /graphql/playground, got %q", cfg.Transport.GraphQLPlaygroundPath) + } + if !cfg.Cache.Enabled || cfg.Cache.TTL != 5*time.Minute { + t.Fatalf("expected cache snapshot to be populated, got %+v", cfg.Cache) + } + if !cfg.GraphQL.Enabled { + t.Fatal("expected GraphQL snapshot to be enabled") + } + if cfg.GraphQL.Path != "/graphql" { + t.Fatalf("expected GraphQL path /graphql, got %q", cfg.GraphQL.Path) + } + if !cfg.GraphQL.Playground { + t.Fatal("expected GraphQL playground snapshot to be enabled") + } + if cfg.GraphQL.PlaygroundPath != "/graphql/playground" { + t.Fatalf("expected GraphQL playground path /graphql/playground, got %q", cfg.GraphQL.PlaygroundPath) + } + if cfg.I18n.DefaultLocale != "en-GB" { + t.Fatalf("expected default locale en-GB, got %q", cfg.I18n.DefaultLocale) + } + if !slices.Equal(cfg.I18n.Supported, []string{"en-GB", "fr"}) { + t.Fatalf("expected supported locales [en-GB fr], got %v", cfg.I18n.Supported) + } + if cfg.Authentik.Issuer != "https://auth.example.com" { + t.Fatalf("expected Authentik issuer https://auth.example.com, got %q", cfg.Authentik.Issuer) + } + if cfg.Authentik.ClientID != "runtime-client" { + t.Fatalf("expected Authentik client ID runtime-client, got %q", cfg.Authentik.ClientID) + } + if !cfg.Authentik.TrustedProxy { + t.Fatal("expected Authentik trusted proxy to be enabled") + } + if !slices.Equal(cfg.Authentik.PublicPaths, []string{"/public", "/docs"}) { + t.Fatalf("expected Authentik public paths [/public /docs], got %v", cfg.Authentik.PublicPaths) + } +} + +// TestEngine_RuntimeConfig_Good_EmptyOnNilEngine verifies the nil receiver +// guard returns an empty runtime snapshot. +func TestEngine_RuntimeConfig_Good_EmptyOnNilEngine(t *testing.T) { + var e *api.Engine + + cfg := e.RuntimeConfig() + if cfg.Swagger.Enabled || cfg.Transport.SwaggerEnabled || cfg.GraphQL.Enabled || cfg.Cache.Enabled || cfg.I18n.DefaultLocale != "" || cfg.Authentik.Issuer != "" { + t.Fatalf("expected zero-value runtime config, got %+v", cfg) + } +} diff --git a/ssrf_guard.go b/ssrf_guard.go new file mode 100644 index 0000000..7ee667b --- /dev/null +++ b/ssrf_guard.go @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "errors" + "net" + "net/url" + "strings" +) + +// SSRF mitigation per Cerberus mechanism review on Mantis #318. +// +// Both SSEClient.Connect (transport_client.go:229) and OpenAPIClient.Call +// (client.go:342) flow through doHTTPClientRequest. The polyglot-gateway +// threat model (RFC §11) makes attacker-controlled outbound URLs reachable +// via: +// - SSEClient(rawURL) where rawURL flows from request input +// - WithBaseURL(baseURL) where baseURL is loaded from attacker-influenced config +// - WithSpecReader spec.servers[].url +// +// validateOutboundURL is the singular choke-point validator applied at +// doHTTPClientRequest before client.Do(req). It denies by default: +// - schemes other than http/https +// - hosts that resolve to RFC1918 / loopback / link-local / cloud-metadata IPs +// +// The validator is applied at request time (not just construction time) so +// DNS rebinding attacks cannot bypass pre-resolution checks — by the time +// the request fires, the literal host has been re-resolved. + +// errOutboundURLBlocked is returned when validateOutboundURL rejects a URL. +// Callers see a wrapped error from client.Do; tests assert on errors.Is. +var errOutboundURLBlocked = errors.New("outbound URL blocked by SSRF guard") + +// allowedSchemes is the deny-by-default scheme allowlist for outbound HTTP. +// Excludes file://, gopher://, ftp://, dict://, ldap://, etc. +var allowedSchemes = map[string]struct{}{ + "http": {}, + "https": {}, +} + +// metadataHosts are cloud instance-metadata hostnames that must NOT resolve +// to a usable backend. Compared after URL parse, before DNS resolution. +var metadataHosts = map[string]struct{}{ + "metadata.google.internal": {}, + "metadata.googleapis.com": {}, + "metadata.azure.com": {}, + "169.254.169.254": {}, // AWS / GCP / OpenStack / Azure (legacy) + "fd00:ec2::254": {}, // AWS IPv6 + "100.100.100.200": {}, // Alibaba Cloud +} + +// resolveHost is overridden in tests to avoid real DNS lookups while still +// exercising the IP-rejection logic. +var resolveHost = net.LookupIP + +// validateOutboundURL checks rawURL against the deny-by-default outbound +// policy. Returns errOutboundURLBlocked (or a wrap thereof) on rejection. +// +// Pass empty rawURL is rejected. Caller should never call client.Do with +// an unvalidated URL. +func validateOutboundURL(rawURL string) error { + if rawURL == "" { + return wrapBlocked("empty URL") + } + u, err := url.Parse(rawURL) + if err != nil { + return wrapBlocked("parse failed: " + err.Error()) + } + if _, ok := allowedSchemes[strings.ToLower(u.Scheme)]; !ok { + return wrapBlocked("disallowed scheme: " + u.Scheme) + } + host := u.Hostname() + if host == "" { + return wrapBlocked("empty host") + } + if _, ok := metadataHosts[strings.ToLower(host)]; ok { + return wrapBlocked("metadata host: " + host) + } + + // If host is a literal IP, check directly. Otherwise resolve and check + // every result. DNS rebinding can change resolution between calls; this + // re-checks at request time per the choke-point design. + if ip := net.ParseIP(host); ip != nil { + if reason := blockedIPReason(ip); reason != "" { + return wrapBlocked(reason + ": " + host) + } + return nil + } + + ips, err := resolveHost(host) + if err != nil { + // Resolution failed — let net/http surface the real error rather + // than masking as a block. Genuine NXDOMAIN should not look like + // a security-policy rejection. + return nil + } + for _, ip := range ips { + if reason := blockedIPReason(ip); reason != "" { + return wrapBlocked(reason + " resolution for " + host + ": " + ip.String()) + } + } + return nil +} + +// blockedIPReason returns a non-empty reason if the IP is in a denied range, +// else "". +func blockedIPReason(ip net.IP) string { + if ip.IsLoopback() { + return "loopback IP" + } + if ip.IsPrivate() { + // IsPrivate covers RFC1918 (IPv4) + RFC4193 (IPv6 ULA). + return "private IP" + } + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + // 169.254.0.0/16 (IPv4) + fe80::/10 (IPv6) — covers cloud metadata. + return "link-local IP" + } + if ip.IsUnspecified() { + // 0.0.0.0 / :: + return "unspecified IP" + } + if ip.IsMulticast() { + return "multicast IP" + } + return "" +} + +// wrapBlocked formats a rejection reason as an error wrapping errOutboundURLBlocked +// so callers can errors.Is(err, errOutboundURLBlocked) on the rejection class. +func wrapBlocked(reason string) error { + return blockedURLError{reason: reason} +} + +type blockedURLError struct{ reason string } + +func (e blockedURLError) Error() string { return errOutboundURLBlocked.Error() + ": " + e.reason } +func (e blockedURLError) Unwrap() error { return errOutboundURLBlocked } diff --git a/ssrf_guard_internal_test.go b/ssrf_guard_internal_test.go new file mode 100644 index 0000000..38c3dd0 --- /dev/null +++ b/ssrf_guard_internal_test.go @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "errors" + "net" + "strings" + "testing" +) + +// TestSSRF_OutboundURL_BlocksMetadata_Ugly — Cerberus mechanism review +// recommendation per Mantis #318. AWS/GCP/Azure metadata endpoints must be +// rejected by literal-host match before DNS resolution. +func TestSSRF_OutboundURL_BlocksMetadata_Ugly(t *testing.T) { + cases := []string{ + "http://169.254.169.254/latest/meta-data/iam/security-credentials/", + "https://metadata.google.internal/computeMetadata/v1/instance/", + "http://metadata.azure.com/", + "http://[fd00:ec2::254]/", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := validateOutboundURL(raw) + if err == nil { + t.Errorf("validateOutboundURL(%q) returned nil; expected block", raw) + return + } + if !errors.Is(err, errOutboundURLBlocked) { + t.Errorf("expected errors.Is(err, errOutboundURLBlocked) for %q; got %v", raw, err) + } + }) + } +} + +// TestSSRF_OutboundURL_BlocksLoopback_Ugly — localhost variants. +func TestSSRF_OutboundURL_BlocksLoopback_Ugly(t *testing.T) { + cases := []string{ + "http://127.0.0.1/", + "http://127.5.5.5/", + "http://[::1]/", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := validateOutboundURL(raw) + if err == nil { + t.Errorf("validateOutboundURL(%q) returned nil; expected loopback block", raw) + return + } + if !errors.Is(err, errOutboundURLBlocked) { + t.Errorf("expected errors.Is(err, errOutboundURLBlocked); got %v", err) + } + }) + } +} + +// TestSSRF_OutboundURL_BlocksRFC1918_Ugly — internal-network IP ranges. +func TestSSRF_OutboundURL_BlocksRFC1918_Ugly(t *testing.T) { + cases := []string{ + "http://10.0.0.1/", + "http://10.255.255.255/", + "http://172.16.0.1/", + "http://172.31.255.255/", + "http://192.168.1.1/", + "http://192.168.255.255/", + "http://[fc00::1]/", // IPv6 ULA + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := validateOutboundURL(raw) + if err == nil { + t.Errorf("validateOutboundURL(%q) returned nil; expected RFC1918/ULA block", raw) + return + } + if !errors.Is(err, errOutboundURLBlocked) { + t.Errorf("expected errors.Is(err, errOutboundURLBlocked); got %v", err) + } + }) + } +} + +// TestSSRF_OutboundURL_BlocksDisallowedScheme_Bad — non-http(s) schemes. +func TestSSRF_OutboundURL_BlocksDisallowedScheme_Bad(t *testing.T) { + cases := []string{ + "file:///etc/passwd", + "gopher://evil.example.com/_command", + "ftp://example.com/", + "dict://example.com:11211/stat", + "ldap://example.com/", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := validateOutboundURL(raw) + if err == nil { + t.Errorf("validateOutboundURL(%q) returned nil; expected scheme block", raw) + return + } + if !strings.Contains(err.Error(), "disallowed scheme") { + t.Errorf("expected 'disallowed scheme' error; got %v", err) + } + }) + } +} + +// TestSSRF_OutboundURL_AllowsHTTPS_Good — sanity that public HTTPS still works. +// We override resolveHost to return a public IP so we don't depend on real DNS. +func TestSSRF_OutboundURL_AllowsHTTPS_Good(t *testing.T) { + prev := resolveHost + defer func() { resolveHost = prev }() + resolveHost = func(host string) ([]net.IP, error) { + // Pretend example.com resolves to a public IP. + return []net.IP{net.IPv4(93, 184, 216, 34)}, nil + } + + cases := []string{ + "https://example.com/", + "https://example.com/path?q=1", + "http://example.com:8080/", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + if err := validateOutboundURL(raw); err != nil { + t.Errorf("validateOutboundURL(%q) blocked unexpectedly: %v", raw, err) + } + }) + } +} + +// TestSSRF_OutboundURL_BlocksDNSResolveToPrivate_Ugly — DNS-rebinding-style: +// a public-looking hostname that resolves to an RFC1918 IP must still be +// blocked by the post-resolution check. +func TestSSRF_OutboundURL_BlocksDNSResolveToPrivate_Ugly(t *testing.T) { + prev := resolveHost + defer func() { resolveHost = prev }() + resolveHost = func(host string) ([]net.IP, error) { + // Attacker's domain that resolves to a private IP. + return []net.IP{net.IPv4(10, 0, 0, 5)}, nil + } + + err := validateOutboundURL("https://attacker.example.com/") + if err == nil { + t.Fatal("expected post-resolution private-IP block; got nil") + } + if !errors.Is(err, errOutboundURLBlocked) { + t.Errorf("expected errOutboundURLBlocked; got %v", err) + } + if !strings.Contains(err.Error(), "10.0.0.5") { + t.Errorf("expected error to mention resolved IP; got %v", err) + } +} + +// TestSSRF_OutboundURL_EmptyURL_Bad — defensive case. +func TestSSRF_OutboundURL_EmptyURL_Bad(t *testing.T) { + err := validateOutboundURL("") + if err == nil { + t.Fatal("expected empty-URL block; got nil") + } + if !errors.Is(err, errOutboundURLBlocked) { + t.Errorf("expected errOutboundURLBlocked; got %v", err) + } +} + +// TestSSRF_OutboundURL_AllowsResolverFailure_Good — if DNS resolution fails, +// let net/http surface the real error rather than masking as a security block. +func TestSSRF_OutboundURL_AllowsResolverFailure_Good(t *testing.T) { + prev := resolveHost + defer func() { resolveHost = prev }() + resolveHost = func(host string) ([]net.IP, error) { + return nil, errors.New("simulated NXDOMAIN") + } + + if err := validateOutboundURL("https://nonexistent.example.invalid/"); err != nil { + t.Errorf("expected nil (let net/http surface the error); got %v", err) + } +} diff --git a/tests/cli/api/Taskfile.yaml b/tests/cli/api/Taskfile.yaml new file mode 100644 index 0000000..1dd61fa --- /dev/null +++ b/tests/cli/api/Taskfile.yaml @@ -0,0 +1,26 @@ +version: "3" + +tasks: + default: + deps: + - build + - vet + - test + + build: + desc: Compile every package in api. + dir: ../../.. + cmds: + - GOWORK=off go build ./... + + vet: + desc: Run go vet across the module. + dir: ../../.. + cmds: + - GOWORK=off go vet ./... + + test: + desc: Run unit tests. + dir: ../../.. + cmds: + - GOWORK=off go test -count=1 ./...