diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dcf973..9b34b9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.16.0] - 2026-04-06 + +### Added + +- **router** — Named routes via `RouteEntry.Name(name)` and `Router.URL(name, params...)` for reverse URL generation with `{param}` and `{param...}` placeholder substitution +- **router** — Route introspection: `Router.Routes()` returns a snapshot of all registered `RouteInfo` entries; `Router.Walk(fn)` iterates with early-exit support +- **router** — `RouteInfo.HandlerName` captures the original handler's function name via `runtime.FuncForPC` for debugging and documentation +- **router** — Parameter constraints: `ValidateParams(handler, constraints...)` wraps a handler with path-parameter validation; built-in constraint constructors `Int`, `UUID`, `Regex`, `OneOf` +- **router** — `With(middleware...)` returns a child group for per-route inline middleware without affecting sibling routes +- **router** — `Route(prefix, fn, middleware...)` for inline sub-routing — creates a child group, calls `fn` to register routes, and returns the group for further use +- **router** — `Mount(prefix, handler)` attaches an `http.Handler` (or `*Router`) at a prefix with `http.StripPrefix`; sub-router routes and named routes are merged into the parent's route table +- **router** — `Static(prefix, dir)` serves files from a filesystem directory; `File(pattern, filePath)` serves a single file for GET requests +- **router** — `WithNotFound(handler)` and `WithMethodNotAllowed(handler)` options for custom 404/405 handlers, taking precedence over the `ErrorHandler` +- **router** — `WithStripSlash()` silently removes trailing slashes before routing; `WithRedirectSlash()` sends 301 redirects (mutually exclusive, panics if both set) + +### Changed + +- **router** — All method helpers (`Get`, `Post`, `Put`, `Patch`, `Delete`, `Head`, `Options` and their `Func` variants) and `Handle`/`HandleFunc` now return `*RouteEntry` for optional `.Name()` chaining +- **router** — `register()` accepts an additional `origFn` parameter to capture the original handler name before middleware wrapping + ## [0.15.0] - 2026-04-04 ### Changed diff --git a/README.md b/README.md index a7f99b3..cf871d7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ A production-ready Go toolkit for building REST APIs. Zero mandatory dependencie - **`response`** — Consistent JSON envelope, fluent builder, pagination helpers, SSE streaming, XML, JSONP, and more - **`middleware`** — Request ID, logging, panic recovery, CORS, rate limiting, auth, security headers, timeout - **`httpclient`** — HTTP client with retries, exponential backoff, circuit breaker, and `HTTPClient` interface for mocking -- **`router`** — Route grouping with `.Get()`/`.Post()` method helpers, prefix groups, and per-group middleware on top of `http.ServeMux` +- **`router`** — Route grouping with method helpers, named routes, URL generation, parameter constraints, sub-router mounting, static file serving, and trailing-slash handling on top of `http.ServeMux` - **`server`** — Graceful shutdown wrapper with signal handling, lifecycle hooks, and TLS support - **`health`** — Health check endpoint builder with dependency checks, timeouts, and liveness/readiness probes - **`config`** — Load configuration from env vars, `.env` files, and JSON files into typed structs with validation @@ -376,6 +376,62 @@ admin.Delete("/users/{id}", deleteUser) // Handle/HandleFunc for http.Handler (e.g. file servers) api.Handle("GET /docs", http.FileServer(http.Dir("./docs"))) +// --- Named routes & URL generation --- +r.Get("/users/{id}", getUser).Name("get-user") +r.Get("/files/{path...}", serveFile).Name("files") + +url := r.URL("get-user", "id", "42") // "/users/42" +url = r.URL("files", "path", "docs/readme") // "/files/docs/readme" + +// --- Inline sub-routing with Route() --- +r.Route("/users", func(sub *router.Group) { + sub.Get("/", listUsers) + sub.Get("/{id}", getUser) + sub.Post("/", createUser) +}) + +// --- Per-route middleware with With() --- +r.With(authMiddleware).Get("/admin", adminHandler) + +// --- Mount sub-routers --- +adminRouter := router.New() +adminRouter.Get("/stats", statsHandler) +r.Mount("/admin", adminRouter) // routes & named routes are merged + +// --- Static files --- +r.Static("/assets", "./public") // serve directory +r.File("/favicon.ico", "./favicon.ico") // serve single file + +// --- Parameter constraints --- +r.Get("/users/{id}", router.ValidateParams(getUser, + router.Int("id"), +)) +r.Get("/items/{slug}", router.ValidateParams(getItem, + router.Regex("slug", `^[a-z0-9-]+$`), +)) +r.Get("/status/{s}", router.ValidateParams(getStatus, + router.OneOf("s", "active", "inactive", "pending"), +)) + +// --- Trailing slash handling --- +r = router.New(router.WithStripSlash()) // "/users/" → "/users" (silent) +r = router.New(router.WithRedirectSlash()) // "/users/" → 301 → "/users" + +// --- Custom 404/405 handlers --- +r = router.New( + router.WithNotFound(custom404Handler), + router.WithMethodNotAllowed(custom405Handler), +) + +// --- Route introspection --- +for _, ri := range r.Routes() { + fmt.Printf("%s %s → %s\n", ri.Method, ri.Pattern, ri.HandlerName) +} +r.Walk(func(ri router.RouteInfo) error { + // ... + return nil +}) + // Use with server package srv := server.New(r, server.WithAddr(":8080")) srv.Start() diff --git a/router/constraint.go b/router/constraint.go new file mode 100644 index 0000000..d75c026 --- /dev/null +++ b/router/constraint.go @@ -0,0 +1,82 @@ +package router + +import ( + "fmt" + "net/http" + "regexp" + "strconv" + "strings" + + "github.com/KARTIKrocks/apikit/errors" +) + +// ParamConstraint defines a validation rule for a single path parameter. +type ParamConstraint struct { + Name string // path parameter name (must match {name} in the pattern) + Validate func(string) bool // returns true if the value is valid + ErrMessage string // message for the BadRequest error on failure +} + +// ValidateParams wraps a HandlerFunc with path parameter validation. +// Constraints are checked in order before the handler is called. +// On the first failure, it returns errors.BadRequest with the constraint's ErrMessage. +func ValidateParams(fn HandlerFunc, constraints ...ParamConstraint) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { + for _, c := range constraints { + val := r.PathValue(c.Name) + if !c.Validate(val) { + return errors.BadRequest(c.ErrMessage) + } + } + return fn(w, r) + } +} + +// Int returns a constraint that requires the parameter to be a valid integer. +func Int(name string) ParamConstraint { + return ParamConstraint{ + Name: name, + Validate: func(s string) bool { + _, err := strconv.Atoi(s) + return err == nil + }, + ErrMessage: fmt.Sprintf("parameter %q must be an integer", name), + } +} + +// UUID returns a constraint that requires the parameter to be a valid UUID. +func UUID(name string) ParamConstraint { + re := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + return ParamConstraint{ + Name: name, + Validate: re.MatchString, + ErrMessage: fmt.Sprintf("parameter %q must be a valid UUID", name), + } +} + +// Regex returns a constraint that requires the parameter to match the given regular expression. +// It panics if the pattern is not a valid regular expression. +func Regex(name string, pattern string) ParamConstraint { + re := regexp.MustCompile(pattern) + return ParamConstraint{ + Name: name, + Validate: re.MatchString, + ErrMessage: fmt.Sprintf("parameter %q has invalid format", name), + } +} + +// OneOf returns a constraint that requires the parameter to be one of the allowed values. +func OneOf(name string, values ...string) ParamConstraint { + allowed := make(map[string]struct{}, len(values)) + for _, v := range values { + allowed[v] = struct{}{} + } + return ParamConstraint{ + Name: name, + Validate: func(s string) bool { + _, ok := allowed[s] + return ok + }, + ErrMessage: fmt.Sprintf("parameter %q must be one of: %s", name, strings.Join(values, ", ")), + } +} diff --git a/router/group.go b/router/group.go index d342c3c..3ac7526 100644 --- a/router/group.go +++ b/router/group.go @@ -1,6 +1,7 @@ package router import ( + "fmt" "net/http" "strings" @@ -16,96 +17,105 @@ type Group struct { } // Get registers an error-returning handler for GET requests. -func (g *Group) Get(pattern string, fn HandlerFunc) { - g.register("GET", pattern, g.router.wrapError(fn)) +func (g *Group) Get(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("GET", pattern, g.router.wrapError(fn), fn) } // GetFunc registers a standard http.HandlerFunc for GET requests. -func (g *Group) GetFunc(pattern string, fn http.HandlerFunc) { - g.register("GET", pattern, fn) +func (g *Group) GetFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("GET", pattern, fn, fn) } // Post registers an error-returning handler for POST requests. -func (g *Group) Post(pattern string, fn HandlerFunc) { - g.register("POST", pattern, g.router.wrapError(fn)) +func (g *Group) Post(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("POST", pattern, g.router.wrapError(fn), fn) } // PostFunc registers a standard http.HandlerFunc for POST requests. -func (g *Group) PostFunc(pattern string, fn http.HandlerFunc) { - g.register("POST", pattern, fn) +func (g *Group) PostFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("POST", pattern, fn, fn) } // Put registers an error-returning handler for PUT requests. -func (g *Group) Put(pattern string, fn HandlerFunc) { - g.register("PUT", pattern, g.router.wrapError(fn)) +func (g *Group) Put(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("PUT", pattern, g.router.wrapError(fn), fn) } // PutFunc registers a standard http.HandlerFunc for PUT requests. -func (g *Group) PutFunc(pattern string, fn http.HandlerFunc) { - g.register("PUT", pattern, fn) +func (g *Group) PutFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("PUT", pattern, fn, fn) } // Patch registers an error-returning handler for PATCH requests. -func (g *Group) Patch(pattern string, fn HandlerFunc) { - g.register("PATCH", pattern, g.router.wrapError(fn)) +func (g *Group) Patch(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("PATCH", pattern, g.router.wrapError(fn), fn) } // PatchFunc registers a standard http.HandlerFunc for PATCH requests. -func (g *Group) PatchFunc(pattern string, fn http.HandlerFunc) { - g.register("PATCH", pattern, fn) +func (g *Group) PatchFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("PATCH", pattern, fn, fn) } // Delete registers an error-returning handler for DELETE requests. -func (g *Group) Delete(pattern string, fn HandlerFunc) { - g.register("DELETE", pattern, g.router.wrapError(fn)) +func (g *Group) Delete(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("DELETE", pattern, g.router.wrapError(fn), fn) } // DeleteFunc registers a standard http.HandlerFunc for DELETE requests. -func (g *Group) DeleteFunc(pattern string, fn http.HandlerFunc) { - g.register("DELETE", pattern, fn) +func (g *Group) DeleteFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("DELETE", pattern, fn, fn) } // Head registers an error-returning handler for HEAD requests. -func (g *Group) Head(pattern string, fn HandlerFunc) { - g.register("HEAD", pattern, g.router.wrapError(fn)) +func (g *Group) Head(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("HEAD", pattern, g.router.wrapError(fn), fn) } // HeadFunc registers a standard http.HandlerFunc for HEAD requests. -func (g *Group) HeadFunc(pattern string, fn http.HandlerFunc) { - g.register("HEAD", pattern, fn) +func (g *Group) HeadFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("HEAD", pattern, fn, fn) } // Options registers an error-returning handler for OPTIONS requests. -func (g *Group) Options(pattern string, fn HandlerFunc) { - g.register("OPTIONS", pattern, g.router.wrapError(fn)) +func (g *Group) Options(pattern string, fn HandlerFunc) *RouteEntry { + return g.register("OPTIONS", pattern, g.router.wrapError(fn), fn) } // OptionsFunc registers a standard http.HandlerFunc for OPTIONS requests. -func (g *Group) OptionsFunc(pattern string, fn http.HandlerFunc) { - g.register("OPTIONS", pattern, fn) +func (g *Group) OptionsFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.register("OPTIONS", pattern, fn, fn) } // Handle registers an http.Handler for the given pattern. // The pattern may include a method prefix (e.g. "GET /path"). -func (g *Group) Handle(pattern string, handler http.Handler) { +func (g *Group) Handle(pattern string, handler http.Handler) *RouteEntry { method, path := splitPattern(pattern) - fullPath := joinPath(g.fullPrefix(), path) + prefix, chain := g.resolve() + fullPath := joinPath(prefix, path) fullPattern := fullPath if method != "" { fullPattern = method + " " + fullPath } - chain := g.collectMiddleware() + origHandler := handler if len(chain) > 0 { handler = middleware.Chain(chain...)(handler) } g.router.mux.Handle(fullPattern, markMatched(handler)) + + idx := len(g.router.routes) + g.router.routes = append(g.router.routes, RouteInfo{ + Method: method, + Pattern: fullPath, + HandlerName: handlerName(origHandler), + }) + return &RouteEntry{router: g.router, index: idx} } // HandleFunc registers an http.HandlerFunc for the given pattern. // The pattern may include a method prefix (e.g. "GET /path"). -func (g *Group) HandleFunc(pattern string, fn http.HandlerFunc) { - g.Handle(pattern, fn) +func (g *Group) HandleFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return g.Handle(pattern, fn) } // Use appends middleware to this group. Middleware added via Use only @@ -124,52 +134,169 @@ func (g *Group) Group(prefix string, mw ...middleware.Middleware) *Group { } } -// register builds the full pattern and registers the handler on the mux. -func (g *Group) register(method, pattern string, handler http.Handler) { - fullPath := joinPath(g.fullPrefix(), pattern) - fullPattern := method + " " + fullPath +// Route creates a child group with the given prefix and optional middleware, +// then calls fn to register routes on it. This is syntactic sugar for inline sub-routing: +// +// r.Route("/users", func(sub *Group) { +// sub.Get("/", listUsers) +// sub.Get("/{id}", getUser) +// }) +func (g *Group) Route(prefix string, fn func(*Group), mw ...middleware.Middleware) *Group { + sub := g.Group(prefix, mw...) + fn(sub) + return sub +} - chain := g.collectMiddleware() +// Static serves files from the given filesystem directory under the URL prefix. +// Group middleware is applied to all requests. +// +// r.Static("/assets", "./public") +func (g *Group) Static(prefix, dir string) { + groupPrefix, chain := g.resolve() + fullPrefix := joinPath(groupPrefix, prefix) + fs := http.StripPrefix(fullPrefix, http.FileServer(http.Dir(dir))) if len(chain) > 0 { - handler = middleware.Chain(chain...)(handler) + fs = middleware.Chain(chain...)(fs) } - g.router.mux.Handle(fullPattern, markMatched(handler)) + + g.router.mux.Handle(fullPrefix+"/", markMatched(fs)) + + g.router.routes = append(g.router.routes, RouteInfo{ + Method: "GET", + Pattern: fullPrefix + "/{file...}", + }) } -// collectMiddleware walks the parent chain from root to current group -// and returns the accumulated middleware slice in order. -func (g *Group) collectMiddleware() []middleware.Middleware { - // Build parent chain (current → root). - var groups []*Group - for cur := g; cur != nil; cur = cur.parent { - groups = append(groups, cur) +// File registers a handler that serves a single file for GET requests. +// +// r.File("/favicon.ico", "./public/favicon.ico") +func (g *Group) File(pattern, filePath string) { + g.register("GET", pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, filePath) + }), nil) +} + +// Mount attaches an http.Handler at the given prefix, stripping the prefix +// from the request path before passing it to the handler. Group middleware is applied. +// If the handler is a *Router, its routes are merged into the parent's route table for introspection. +// +// admin := router.New() +// admin.Get("/stats", statsHandler) +// r.Mount("/admin", admin) +func (g *Group) Mount(prefix string, handler http.Handler) { + groupPrefix, chain := g.resolve() + fullPrefix := joinPath(groupPrefix, prefix) + + // Order: markMatched → middleware → StripPrefix → handler + // This ensures parent middleware sees the full path (consistent with Static/regular routes) + // and markMatched always runs regardless of StripPrefix outcome. + h := http.StripPrefix(fullPrefix, handler) + if len(chain) > 0 { + h = middleware.Chain(chain...)(h) + } + + g.router.mux.Handle(fullPrefix+"/", markMatched(h)) + g.router.mux.Handle(fullPrefix, markMatched(h)) + + // Merge sub-router routes for introspection; otherwise record a single mount entry. + if sub, ok := handler.(*Router); ok { + for _, ri := range sub.routes { + idx := len(g.router.routes) + g.router.routes = append(g.router.routes, RouteInfo{ + Method: ri.Method, + Pattern: joinPath(fullPrefix, ri.Pattern), + Name: ri.Name, + HandlerName: ri.HandlerName, + }) + if ri.Name != "" { + if _, exists := g.router.namedRoutes[ri.Name]; exists { + panic(fmt.Sprintf("router: duplicate route name %q (from mounted sub-router)", ri.Name)) + } + g.router.namedRoutes[ri.Name] = idx + } + } + } else { + g.router.routes = append(g.router.routes, RouteInfo{ + Pattern: fullPrefix + "/", + HandlerName: handlerName(handler), + }) + } +} + +// With returns a child group that shares this group's prefix but adds +// the given middleware. It is intended for per-route middleware: +// +// r.With(authMW).Get("/admin", adminHandler) +func (g *Group) With(mw ...middleware.Middleware) *Group { + return &Group{ + prefix: "", + middlewares: mw, + router: g.router, + parent: g, } +} - // Reverse to get root → current order. - var mws []middleware.Middleware - for i := len(groups) - 1; i >= 0; i-- { - mws = append(mws, groups[i].middlewares...) +// register builds the full pattern, registers the handler on the mux, and records the route. +func (g *Group) register(method, pattern string, handler http.Handler, origFn any) *RouteEntry { + prefix, chain := g.resolve() + fullPath := joinPath(prefix, pattern) + fullPattern := method + " " + fullPath + if len(chain) > 0 { + handler = middleware.Chain(chain...)(handler) } - return mws + g.router.mux.Handle(fullPattern, markMatched(handler)) + + idx := len(g.router.routes) + g.router.routes = append(g.router.routes, RouteInfo{ + Method: method, + Pattern: fullPath, + HandlerName: handlerName(origFn), + }) + return &RouteEntry{router: g.router, index: idx} } -// fullPrefix returns the concatenated prefix from root to this group. -// It normalizes joins to prevent double slashes (e.g. "/api/" + "/users" → "/api/users"). -func (g *Group) fullPrefix() string { - var groups []*Group +// resolve walks the parent chain once and returns both the full prefix +// and the accumulated middleware slice (root → current order). +// It pre-sizes slices and tracks the last written byte to avoid +// intermediate string allocations. +func (g *Group) resolve() (prefix string, mws []middleware.Middleware) { + // Count depth for pre-sizing. + depth := 0 + for cur := g; cur != nil; cur = cur.parent { + depth++ + } + + // Collect groups in root → current order. + groups := make([]*Group, depth) + i := depth - 1 for cur := g; cur != nil; cur = cur.parent { - groups = append(groups, cur) + groups[i] = cur + i-- } + // Build prefix, tracking last byte to avoid b.String() in the loop. var b strings.Builder - for i := len(groups) - 1; i >= 0; i-- { - prefix := groups[i].prefix - if b.Len() > 0 && strings.HasSuffix(b.String(), "/") && strings.HasPrefix(prefix, "/") { - prefix = prefix[1:] // strip leading slash to avoid double slash + var lastByte byte + for _, grp := range groups { + p := grp.prefix + if len(p) == 0 { + continue } - b.WriteString(prefix) + if b.Len() > 0 && lastByte == '/' && p[0] == '/' { + p = p[1:] + } + if len(p) > 0 { + b.WriteString(p) + lastByte = p[len(p)-1] + } + } + prefix = b.String() + + // Collect middleware in root → current order. + for _, grp := range groups { + mws = append(mws, grp.middlewares...) } - return b.String() + return } // joinPath concatenates a prefix and path, normalizing double slashes at the join point. @@ -185,7 +312,7 @@ func joinPath(prefix, path string) string { // splitPattern separates a Go 1.22 pattern into method and path parts. // "GET /users" → ("GET", "/users"), "/users" → ("", "/users") -// It trims extra whitespace between method and path and validates the method is uppercase. +// It trims extra whitespace between method and path. func splitPattern(pattern string) (method, path string) { method, path, found := strings.Cut(pattern, " ") if !found { diff --git a/router/route.go b/router/route.go new file mode 100644 index 0000000..115e3f2 --- /dev/null +++ b/router/route.go @@ -0,0 +1,50 @@ +package router + +import ( + "fmt" + "reflect" + "runtime" +) + +// RouteInfo holds metadata about a registered route. +type RouteInfo struct { + Method string // HTTP method ("GET", "POST", etc.). Empty for method-agnostic Handle() routes. + Pattern string // Full path pattern as registered, e.g. "/api/v1/users/{id}". + Name string // Optional name set via RouteEntry.Name(), used for URL generation. + HandlerName string // Runtime function name of the original handler. +} + +// RouteEntry is returned by route registration methods to allow optional chaining. +// Callers that ignore the return value get the same behavior as before. +type RouteEntry struct { + router *Router + index int +} + +// Name assigns a name to this route for URL generation. +// It panics if the name is already taken (fail-fast, like http.ServeMux on duplicate patterns). +func (re *RouteEntry) Name(name string) *RouteEntry { + if _, exists := re.router.namedRoutes[name]; exists { + panic(fmt.Sprintf("router: duplicate route name %q", name)) + } + re.router.routes[re.index].Name = name + re.router.namedRoutes[name] = re.index + return re +} + +// handlerName extracts a human-readable function name from a handler using runtime reflection. +func handlerName(fn any) string { + if fn == nil { + return "" + } + v := reflect.ValueOf(fn) + if v.Kind() != reflect.Func { + return "" + } + pc := v.Pointer() + f := runtime.FuncForPC(pc) + if f == nil { + return "" + } + return f.Name() +} diff --git a/router/route_test.go b/router/route_test.go new file mode 100644 index 0000000..a8d7a4b --- /dev/null +++ b/router/route_test.go @@ -0,0 +1,516 @@ +package router + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// ─── Route Introspection ──────────────────────────────────────────────────── + +func TestRoutesEmpty(t *testing.T) { + r := New() + if got := len(r.Routes()); got != 0 { + t.Errorf("expected 0 routes, got %d", got) + } +} + +func TestRoutesBasic(t *testing.T) { + r := New() + r.Get("/users", noopHandler) + r.Post("/users", noopHandler) + r.Put("/users/{id}", noopHandler) + + routes := r.Routes() + if len(routes) != 3 { + t.Fatalf("expected 3 routes, got %d", len(routes)) + } + + want := []struct{ method, pattern string }{ + {"GET", "/users"}, + {"POST", "/users"}, + {"PUT", "/users/{id}"}, + } + for i, w := range want { + if routes[i].Method != w.method { + t.Errorf("route[%d] method: expected %q, got %q", i, w.method, routes[i].Method) + } + if routes[i].Pattern != w.pattern { + t.Errorf("route[%d] pattern: expected %q, got %q", i, w.pattern, routes[i].Pattern) + } + } +} + +func TestRoutesWithGroups(t *testing.T) { + r := New() + api := r.Group("/api") + v1 := api.Group("/v1") + v1.Get("/items", noopHandler) + + routes := r.Routes() + if len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + if routes[0].Pattern != "/api/v1/items" { + t.Errorf("expected /api/v1/items, got %q", routes[0].Pattern) + } +} + +func TestRoutesCopy(t *testing.T) { + r := New() + r.Get("/a", noopHandler) + + routes := r.Routes() + routes[0].Method = "MODIFIED" + + // Original should be unchanged. + if r.Routes()[0].Method != "GET" { + t.Error("modifying returned slice should not affect router") + } +} + +func TestWalk(t *testing.T) { + r := New() + r.Get("/a", noopHandler) + r.Post("/b", noopHandler) + r.Delete("/c", noopHandler) + + var collected []string + err := r.Walk(func(ri RouteInfo) error { + collected = append(collected, ri.Method+" "+ri.Pattern) + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) != 3 { + t.Fatalf("expected 3 entries, got %d", len(collected)) + } + if collected[0] != "GET /a" { + t.Errorf("first: expected %q, got %q", "GET /a", collected[0]) + } +} + +func TestWalkEarlyExit(t *testing.T) { + r := New() + r.Get("/a", noopHandler) + r.Get("/b", noopHandler) + r.Get("/c", noopHandler) + + stopErr := fmt.Errorf("stop") + count := 0 + err := r.Walk(func(ri RouteInfo) error { + count++ + if count == 2 { + return stopErr + } + return nil + }) + if err != stopErr { + t.Errorf("expected stop error, got %v", err) + } + if count != 2 { + t.Errorf("expected 2 iterations, got %d", count) + } +} + +func TestRouteInfoHandlerName(t *testing.T) { + r := New() + r.Get("/test", namedTestHandler) + + routes := r.Routes() + if !strings.Contains(routes[0].HandlerName, "namedTestHandler") { + t.Errorf("expected handler name to contain 'namedTestHandler', got %q", routes[0].HandlerName) + } +} + +func TestHandleHandlerNameWithMiddleware(t *testing.T) { + r := New() + api := r.Group("/api", headerMiddleware("X-Test", "yes")) + api.Handle("GET /data", http.HandlerFunc(namedStdlibHandler)) + + routes := r.Routes() + if strings.Contains(routes[0].HandlerName, "Chain") { + t.Errorf("handler name should be the original, not middleware wrapper, got %q", routes[0].HandlerName) + } +} + +func TestProbeWriterUnwrap(t *testing.T) { + r := New() + r.Get("/flush", func(w http.ResponseWriter, req *http.Request) error { + rc := http.NewResponseController(w) + if err := rc.Flush(); err != nil { + return fmt.Errorf("flush failed: %w", err) + } + fmt.Fprint(w, "flushed") + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/flush", nil) + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if !rec.Flushed { + t.Error("expected response to be flushed") + } +} + +func TestRoutesHandle(t *testing.T) { + r := New() + r.Handle("/catch-all", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {})) + + routes := r.Routes() + if len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + if routes[0].Method != "" { + t.Errorf("expected empty method for method-agnostic Handle, got %q", routes[0].Method) + } +} + +func TestRoutesHandleWithMethod(t *testing.T) { + r := New() + r.Handle("GET /explicit", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {})) + + routes := r.Routes() + if routes[0].Method != "GET" { + t.Errorf("expected method GET, got %q", routes[0].Method) + } + if routes[0].Pattern != "/explicit" { + t.Errorf("expected pattern /explicit, got %q", routes[0].Pattern) + } +} + +func TestRoutesFuncHelpers(t *testing.T) { + r := New() + r.GetFunc("/a", func(w http.ResponseWriter, req *http.Request) {}) + r.PostFunc("/b", func(w http.ResponseWriter, req *http.Request) {}) + + routes := r.Routes() + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + if routes[0].Method != "GET" || routes[1].Method != "POST" { + t.Error("unexpected methods for Func helpers") + } +} + +// ─── Named Routes ─────────────────────────────────────────────────────────── + +func TestNamedRouteBasic(t *testing.T) { + r := New() + r.Get("/users", noopHandler).Name("list-users") + + routes := r.Routes() + if routes[0].Name != "list-users" { + t.Errorf("expected name 'list-users', got %q", routes[0].Name) + } +} + +func TestNamedRouteDuplicatePanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic on duplicate route name") + } + }() + + r := New() + r.Get("/a", noopHandler).Name("dup") + r.Get("/b", noopHandler).Name("dup") +} + +func TestNamedRouteOnGroup(t *testing.T) { + r := New() + api := r.Group("/api") + api.Get("/users/{id}", noopHandler).Name("get-user") + + routes := r.Routes() + if routes[0].Name != "get-user" { + t.Errorf("expected name 'get-user', got %q", routes[0].Name) + } +} + +func TestNamedRouteChaining(t *testing.T) { + r := New() + entry := r.Get("/test", noopHandler).Name("test-route") + if entry == nil { + t.Error("Name() should return non-nil RouteEntry") + } +} + +// ─── URL Generation ───────────────────────────────────────────────────────── + +func TestURLBasic(t *testing.T) { + r := New() + r.Get("/users/{id}", noopHandler).Name("get-user") + + got := r.URL("get-user", "id", "42") + if got != "/users/42" { + t.Errorf("expected /users/42, got %q", got) + } +} + +func TestURLMultipleParams(t *testing.T) { + r := New() + r.Get("/users/{userID}/posts/{postID}", noopHandler).Name("get-post") + + got := r.URL("get-post", "userID", "1", "postID", "99") + if got != "/users/1/posts/99" { + t.Errorf("expected /users/1/posts/99, got %q", got) + } +} + +func TestURLNoParams(t *testing.T) { + r := New() + r.Get("/health", noopHandler).Name("health") + + got := r.URL("health") + if got != "/health" { + t.Errorf("expected /health, got %q", got) + } +} + +func TestURLCatchAll(t *testing.T) { + r := New() + r.Get("/files/{path...}", noopHandler).Name("files") + + got := r.URL("files", "path", "docs/readme.md") + if got != "/files/docs/readme.md" { + t.Errorf("expected /files/docs/readme.md, got %q", got) + } +} + +func TestURLWithGroup(t *testing.T) { + r := New() + api := r.Group("/api/v1") + api.Get("/users/{id}", noopHandler).Name("api-user") + + got := r.URL("api-user", "id", "5") + if got != "/api/v1/users/5" { + t.Errorf("expected /api/v1/users/5, got %q", got) + } +} + +func TestURLUnknownNamePanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic for unknown route name") + } + }() + + r := New() + r.URL("nonexistent") +} + +func TestURLOddParamsPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic for odd param count") + } + }() + + r := New() + r.Get("/test/{id}", noopHandler).Name("test") + r.URL("test", "id") +} + +func TestURLMissingParamPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic for missing param") + } + }() + + r := New() + r.Get("/users/{id}", noopHandler).Name("user") + r.URL("user", "wrong", "42") +} + +func TestURLExtraParamsPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic for extra params") + } + }() + + r := New() + r.Get("/users/{id}", noopHandler).Name("user") + r.URL("user", "id", "42", "extra", "value") +} + +// ─── Parameter Constraints ────────────────────────────────────────────────── + +func TestConstraintInt(t *testing.T) { + c := Int("id") + if !c.Validate("42") { + t.Error("expected 42 to be valid int") + } + if !c.Validate("-1") { + t.Error("expected -1 to be valid int") + } + if c.Validate("abc") { + t.Error("expected 'abc' to be invalid int") + } + if c.Validate("") { + t.Error("expected empty string to be invalid int") + } +} + +func TestConstraintUUID(t *testing.T) { + c := UUID("id") + if !c.Validate("550e8400-e29b-41d4-a716-446655440000") { + t.Error("expected valid UUID to pass") + } + if c.Validate("not-a-uuid") { + t.Error("expected invalid UUID to fail") + } + if c.Validate("") { + t.Error("expected empty string to fail") + } +} + +func TestConstraintRegex(t *testing.T) { + c := Regex("slug", `^[a-z0-9-]+$`) + if !c.Validate("hello-world") { + t.Error("expected 'hello-world' to match") + } + if c.Validate("Hello World!") { + t.Error("expected 'Hello World!' not to match") + } +} + +func TestConstraintRegexInvalidPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic on invalid regex") + } + }() + Regex("id", "[invalid") +} + +func TestConstraintOneOf(t *testing.T) { + c := OneOf("status", "active", "inactive", "pending") + if !c.Validate("active") { + t.Error("expected 'active' to be valid") + } + if c.Validate("deleted") { + t.Error("expected 'deleted' to be invalid") + } +} + +func TestConstraintMultiple(t *testing.T) { + r := New() + r.Get("/users/{id}/status/{status}", ValidateParams( + func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "ok") + return nil + }, + Int("id"), + OneOf("status", "active", "inactive"), + )) + + // Valid request. + rec := doRequest(r, "GET", "/users/42/status/active") + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + + // Invalid id. + rec = doRequest(r, "GET", "/users/abc/status/active") + if rec.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid id, got %d", rec.Code) + } + + // Invalid status. + rec = doRequest(r, "GET", "/users/42/status/deleted") + if rec.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid status, got %d", rec.Code) + } +} + +func TestConstraintErrorFormat(t *testing.T) { + r := New() + r.Get("/items/{id}", ValidateParams(noopHandler, Int("id"))) + + rec := doRequest(r, "GET", "/items/abc") + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } + + var env errorEnvelope + if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { + t.Fatalf("failed to decode: %v", err) + } + if env.Success { + t.Error("expected success=false") + } + if env.Error.Code != "BAD_REQUEST" { + t.Errorf("expected code BAD_REQUEST, got %s", env.Error.Code) + } + if !strings.Contains(env.Error.Message, "integer") { + t.Errorf("expected error message to mention 'integer', got %q", env.Error.Message) + } +} + +func TestConstraintPassThrough(t *testing.T) { + called := false + r := New() + r.Get("/items/{id}", ValidateParams( + func(w http.ResponseWriter, req *http.Request) error { + called = true + fmt.Fprint(w, req.PathValue("id")) + return nil + }, + Int("id"), + )) + + rec := doRequest(r, "GET", "/items/99") + if !called { + t.Error("expected handler to be called") + } + if rec.Body.String() != "99" { + t.Errorf("expected body '99', got %q", rec.Body.String()) + } +} + +func TestConstraintIntegration(t *testing.T) { + r := New() + api := r.Group("/api") + api.Get("/users/{id}", ValidateParams( + func(w http.ResponseWriter, req *http.Request) error { + w.WriteHeader(http.StatusOK) + return nil + }, + Int("id"), + )) + + // Valid. + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/users/123", nil) + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + + // Invalid. + rec = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/users/abc", nil) + r.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rec.Code) + } +} + +// ─── Test Helpers ─────────────────────────────────────────────────────────── + +func noopHandler(_ http.ResponseWriter, _ *http.Request) error { return nil } + +func namedTestHandler(_ http.ResponseWriter, _ *http.Request) error { return nil } + +func namedStdlibHandler(_ http.ResponseWriter, _ *http.Request) {} diff --git a/router/router.go b/router/router.go index 7867756..7a77b80 100644 --- a/router/router.go +++ b/router/router.go @@ -21,12 +21,19 @@ import ( "encoding/json" stderrors "errors" "net/http" + "strconv" + "strings" + "sync" "time" "github.com/KARTIKrocks/apikit/errors" "github.com/KARTIKrocks/apikit/middleware" ) +var probeWriterPool = sync.Pool{ + New: func() any { return &probeWriter{} }, +} + // HandlerFunc is an HTTP handler that returns an error. // Errors are handled by the router's ErrorHandler. type HandlerFunc func(http.ResponseWriter, *http.Request) error @@ -40,9 +47,15 @@ type Option func(*Router) // Router is a thin wrapper around http.ServeMux that provides // method helpers, route grouping, and per-group middleware. type Router struct { - mux *http.ServeMux - group Group - errorHandler ErrorHandler + mux *http.ServeMux + group Group + errorHandler ErrorHandler + notFoundHandler http.Handler + methodNotAllowedHandler http.Handler + stripSlash bool + redirectSlash bool + routes []RouteInfo + namedRoutes map[string]int // name → index into routes } // New creates a new Router with the given options. @@ -50,6 +63,7 @@ func New(opts ...Option) *Router { r := &Router{ mux: http.NewServeMux(), errorHandler: DefaultErrorHandler, + namedRoutes: make(map[string]int), } r.group = Group{ router: r, @@ -57,6 +71,9 @@ func New(opts ...Option) *Router { for _, opt := range opts { opt(r) } + if r.stripSlash && r.redirectSlash { + panic("router: WithStripSlash and WithRedirectSlash are mutually exclusive") + } return r } @@ -67,29 +84,98 @@ func WithErrorHandler(fn ErrorHandler) Option { } } +// WithNotFound sets a custom handler for 404 Not Found responses. +// When set, this handler is called instead of the ErrorHandler for unmatched routes. +func WithNotFound(handler http.Handler) Option { + return func(r *Router) { + r.notFoundHandler = handler + } +} + +// WithMethodNotAllowed sets a custom handler for 405 Method Not Allowed responses. +// When set, this handler is called instead of the ErrorHandler for disallowed methods. +func WithMethodNotAllowed(handler http.Handler) Option { + return func(r *Router) { + r.methodNotAllowedHandler = handler + } +} + +// WithStripSlash silently removes trailing slashes from request paths before routing. +// "/users/" becomes "/users". The root path "/" is never modified. +func WithStripSlash() Option { + return func(r *Router) { + r.stripSlash = true + } +} + +// WithRedirectSlash sends a 301 Moved Permanently redirect for requests with a trailing slash. +// "/users/" redirects to "/users". The root path "/" is never redirected. +// Mutually exclusive with WithStripSlash. +func WithRedirectSlash() Option { + return func(r *Router) { + r.redirectSlash = true + } +} + // ServeHTTP implements http.Handler. // It intercepts 404 and 405 responses from the underlying ServeMux // and routes them through the router's ErrorHandler for consistent error format. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // Use a probe writer to detect 404/405 from ServeMux before writing to the real response. - pw := &probeWriter{ResponseWriter: w} + // Handle trailing slashes before routing. + path := req.URL.Path + if path != "/" && strings.HasSuffix(path, "/") { + trimmed := strings.TrimRight(path, "/") + if trimmed == "" { + trimmed = "/" + } + if r.redirectSlash { + target := trimmed + if req.URL.RawQuery != "" { + target += "?" + req.URL.RawQuery + } + http.Redirect(w, req, target, http.StatusMovedPermanently) + return + } + if r.stripSlash { + req.URL.Path = trimmed + } + } + + // Use a pooled probe writer to detect 404/405 from ServeMux before writing to the real response. + pw := probeWriterPool.Get().(*probeWriter) + pw.ResponseWriter = w + pw.code = 0 + pw.intercepted = false + pw.wroteBody = false + pw.matched = false + r.mux.ServeHTTP(pw, req) if !pw.intercepted { + probeWriterPool.Put(pw) return } - // ServeMux returned 404 or 405 — handle through our error handler. + // ServeMux returned 404 or 405 — use dedicated handler if set, otherwise ErrorHandler. switch pw.code { case http.StatusNotFound: - r.errorHandler(w, req, errors.NotFound("")) + if r.notFoundHandler != nil { + r.notFoundHandler.ServeHTTP(w, req) + } else { + r.errorHandler(w, req, errors.NotFound("")) + } case http.StatusMethodNotAllowed: - r.errorHandler(w, req, &errors.Error{ - StatusCode: http.StatusMethodNotAllowed, - Code: errors.CodeMethodNotAllowed, - Message: "Method not allowed", - }) + if r.methodNotAllowedHandler != nil { + r.methodNotAllowedHandler.ServeHTTP(w, req) + } else { + r.errorHandler(w, req, &errors.Error{ + StatusCode: http.StatusMethodNotAllowed, + Code: errors.CodeMethodNotAllowed, + Message: "Method not allowed", + }) + } } + probeWriterPool.Put(pw) } // probeWriter intercepts WriteHeader calls to detect 404/405 from ServeMux. @@ -126,6 +212,10 @@ func (pw *probeWriter) Write(b []byte) (int, error) { return pw.ResponseWriter.Write(b) } +// Unwrap returns the underlying ResponseWriter so that http.ResponseController +// can reach interfaces like http.Flusher and http.Hijacker through the wrapper. +func (pw *probeWriter) Unwrap() http.ResponseWriter { return pw.ResponseWriter } + // markMatched wraps a handler to set the matched flag on the probeWriter. // This lets probeWriter distinguish ServeMux's default 404/405 from // intentional 404/405 responses returned by user handlers. @@ -139,56 +229,98 @@ func markMatched(h http.Handler) http.Handler { } // Get registers an error-returning handler for GET requests. -func (r *Router) Get(pattern string, fn HandlerFunc) { r.group.Get(pattern, fn) } +func (r *Router) Get(pattern string, fn HandlerFunc) *RouteEntry { return r.group.Get(pattern, fn) } // GetFunc registers a standard http.HandlerFunc for GET requests. -func (r *Router) GetFunc(pattern string, fn http.HandlerFunc) { r.group.GetFunc(pattern, fn) } +func (r *Router) GetFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.GetFunc(pattern, fn) +} // Post registers an error-returning handler for POST requests. -func (r *Router) Post(pattern string, fn HandlerFunc) { r.group.Post(pattern, fn) } +func (r *Router) Post(pattern string, fn HandlerFunc) *RouteEntry { return r.group.Post(pattern, fn) } // PostFunc registers a standard http.HandlerFunc for POST requests. -func (r *Router) PostFunc(pattern string, fn http.HandlerFunc) { r.group.PostFunc(pattern, fn) } +func (r *Router) PostFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.PostFunc(pattern, fn) +} // Put registers an error-returning handler for PUT requests. -func (r *Router) Put(pattern string, fn HandlerFunc) { r.group.Put(pattern, fn) } +func (r *Router) Put(pattern string, fn HandlerFunc) *RouteEntry { return r.group.Put(pattern, fn) } // PutFunc registers a standard http.HandlerFunc for PUT requests. -func (r *Router) PutFunc(pattern string, fn http.HandlerFunc) { r.group.PutFunc(pattern, fn) } +func (r *Router) PutFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.PutFunc(pattern, fn) +} // Patch registers an error-returning handler for PATCH requests. -func (r *Router) Patch(pattern string, fn HandlerFunc) { r.group.Patch(pattern, fn) } +func (r *Router) Patch(pattern string, fn HandlerFunc) *RouteEntry { + return r.group.Patch(pattern, fn) +} // PatchFunc registers a standard http.HandlerFunc for PATCH requests. -func (r *Router) PatchFunc(pattern string, fn http.HandlerFunc) { r.group.PatchFunc(pattern, fn) } +func (r *Router) PatchFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.PatchFunc(pattern, fn) +} // Delete registers an error-returning handler for DELETE requests. -func (r *Router) Delete(pattern string, fn HandlerFunc) { r.group.Delete(pattern, fn) } +func (r *Router) Delete(pattern string, fn HandlerFunc) *RouteEntry { + return r.group.Delete(pattern, fn) +} // DeleteFunc registers a standard http.HandlerFunc for DELETE requests. -func (r *Router) DeleteFunc(pattern string, fn http.HandlerFunc) { r.group.DeleteFunc(pattern, fn) } +func (r *Router) DeleteFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.DeleteFunc(pattern, fn) +} // Head registers an error-returning handler for HEAD requests. -func (r *Router) Head(pattern string, fn HandlerFunc) { r.group.Head(pattern, fn) } +func (r *Router) Head(pattern string, fn HandlerFunc) *RouteEntry { return r.group.Head(pattern, fn) } // HeadFunc registers a standard http.HandlerFunc for HEAD requests. -func (r *Router) HeadFunc(pattern string, fn http.HandlerFunc) { r.group.HeadFunc(pattern, fn) } +func (r *Router) HeadFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.HeadFunc(pattern, fn) +} // Options registers an error-returning handler for OPTIONS requests. -func (r *Router) Options(pattern string, fn HandlerFunc) { r.group.Options(pattern, fn) } +func (r *Router) Options(pattern string, fn HandlerFunc) *RouteEntry { + return r.group.Options(pattern, fn) +} // OptionsFunc registers a standard http.HandlerFunc for OPTIONS requests. -func (r *Router) OptionsFunc(pattern string, fn http.HandlerFunc) { r.group.OptionsFunc(pattern, fn) } +func (r *Router) OptionsFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.OptionsFunc(pattern, fn) +} // Handle registers an http.Handler for the given pattern. -func (r *Router) Handle(pattern string, handler http.Handler) { r.group.Handle(pattern, handler) } +func (r *Router) Handle(pattern string, handler http.Handler) *RouteEntry { + return r.group.Handle(pattern, handler) +} // HandleFunc registers an http.HandlerFunc for the given pattern. -func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) { r.group.HandleFunc(pattern, fn) } +func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) *RouteEntry { + return r.group.HandleFunc(pattern, fn) +} // Use adds middleware to the root group. func (r *Router) Use(mw ...middleware.Middleware) { r.group.Use(mw...) } +// With returns a group that shares the root prefix but adds +// the given middleware for the next registered route(s). +func (r *Router) With(mw ...middleware.Middleware) *Group { return r.group.With(mw...) } + +// Route creates a sub-group with the given prefix and calls fn to register routes on it. +func (r *Router) Route(prefix string, fn func(*Group), mw ...middleware.Middleware) *Group { + return r.group.Route(prefix, fn, mw...) +} + +// Mount attaches an http.Handler at the given prefix. +func (r *Router) Mount(prefix string, handler http.Handler) { r.group.Mount(prefix, handler) } + +// Static serves files from the given directory under the URL prefix. +func (r *Router) Static(prefix, dir string) { r.group.Static(prefix, dir) } + +// File registers a handler that serves a single file for GET requests. +func (r *Router) File(pattern, filePath string) { r.group.File(pattern, filePath) } + // Group creates a new route group with the given prefix and optional middleware. func (r *Router) Group(prefix string, mw ...middleware.Middleware) *Group { return r.group.Group(prefix, mw...) @@ -233,11 +365,7 @@ func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) { fields = apiErr.Fields } - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - - _ = json.NewEncoder(w).Encode(errorEnvelope{ + body, err := json.Marshal(errorEnvelope{ Success: false, Error: &errorBody{ Code: errCode, @@ -246,4 +374,15 @@ func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) { }, Timestamp: time.Now().Unix(), }) + if err != nil { + http.Error(w, `{"success":false,"error":{"code":"INTERNAL_ERROR","message":"An internal error occurred"}}`, http.StatusInternalServerError) + return + } + body = append(body, '\n') + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(code) + _, _ = w.Write(body) } diff --git a/router/router_test.go b/router/router_test.go index a07f602..c84b272 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/KARTIKrocks/apikit/errors" @@ -34,7 +36,7 @@ func doRequest(handler http.Handler, method, path string) *httptest.ResponseReco func TestMethodHelpers(t *testing.T) { methods := []struct { - register func(r *Router, pattern string, fn HandlerFunc) + register func(r *Router, pattern string, fn HandlerFunc) *RouteEntry method string }{ {(*Router).Get, "GET"}, @@ -791,3 +793,651 @@ func TestGroupHeadAndOptions(t *testing.T) { t.Error("Group OPTIONS: expected middleware applied") } } + +// ─── With() — Per-Route Inline Middleware ────────────────────────────────── + +func TestWithBasic(t *testing.T) { + r := New() + r.With(headerMiddleware("X-Auth", "yes")).Get("/admin", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "admin") + return nil + }) + + rec := doRequest(r, "GET", "/admin") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Header().Get("X-Auth") != "yes" { + t.Error("expected X-Auth=yes from With middleware") + } + if rec.Body.String() != "admin" { + t.Errorf("expected body 'admin', got %q", rec.Body.String()) + } +} + +func TestWithDoesNotAffectSiblingRoutes(t *testing.T) { + r := New() + r.With(headerMiddleware("X-Auth", "yes")).Get("/admin", func(w http.ResponseWriter, req *http.Request) error { + return nil + }) + r.Get("/public", func(w http.ResponseWriter, req *http.Request) error { + return nil + }) + + rec := doRequest(r, "GET", "/public") + if rec.Header().Get("X-Auth") != "" { + t.Error("With middleware should not apply to sibling routes") + } +} + +func TestWithOnGroup(t *testing.T) { + r := New() + api := r.Group("/api", headerMiddleware("X-Group", "api")) + api.With(headerMiddleware("X-Auth", "yes")).Get("/secret", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "secret") + return nil + }) + + rec := doRequest(r, "GET", "/api/secret") + if rec.Header().Get("X-Group") != "api" { + t.Error("expected group middleware X-Group=api") + } + if rec.Header().Get("X-Auth") != "yes" { + t.Error("expected With middleware X-Auth=yes") + } +} + +func TestWithPreservesMiddlewareOrdering(t *testing.T) { + r := New() + r.Use(headerMiddleware("X-Order", "root")) + r.With(headerMiddleware("X-Order", "with")).Get("/test", func(w http.ResponseWriter, req *http.Request) error { + return nil + }) + + rec := doRequest(r, "GET", "/test") + values := rec.Header().Values("X-Order") + if len(values) != 2 { + t.Fatalf("expected 2 X-Order values, got %d: %v", len(values), values) + } + if values[0] != "root" || values[1] != "with" { + t.Errorf("expected [root, with], got %v", values) + } +} + +func TestWithChained(t *testing.T) { + r := New() + r.With(headerMiddleware("X-A", "a")).With(headerMiddleware("X-B", "b")).Get("/test", func(w http.ResponseWriter, req *http.Request) error { + return nil + }) + + rec := doRequest(r, "GET", "/test") + if rec.Header().Get("X-A") != "a" { + t.Error("expected X-A=a") + } + if rec.Header().Get("X-B") != "b" { + t.Error("expected X-B=b") + } +} + +func TestWithNamedRoute(t *testing.T) { + r := New() + r.With(headerMiddleware("X-Auth", "yes")).Get("/users/{id}", func(w http.ResponseWriter, req *http.Request) error { + return nil + }).Name("get-user") + + got := r.URL("get-user", "id", "42") + if got != "/users/42" { + t.Errorf("expected /users/42, got %q", got) + } +} + +// ─── Route() — Inline Sub-Routing ───────────────────────────────────────── + +func TestRouteBasic(t *testing.T) { + r := New() + r.Route("/users", func(sub *Group) { + sub.Get("/", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "list") + return nil + }) + sub.Get("/{id}", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprintf(w, "get:%s", req.PathValue("id")) + return nil + }) + sub.Post("/", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "create") + return nil + }) + }) + + tests := []struct { + method, path, body string + }{ + {"GET", "/users/", "list"}, + {"GET", "/users/42", "get:42"}, + {"POST", "/users/", "create"}, + } + for _, tt := range tests { + rec := doRequest(r, tt.method, tt.path) + if rec.Code != http.StatusOK { + t.Errorf("%s %s: expected 200, got %d", tt.method, tt.path, rec.Code) + } + if rec.Body.String() != tt.body { + t.Errorf("%s %s: expected body %q, got %q", tt.method, tt.path, tt.body, rec.Body.String()) + } + } +} + +func TestRouteNested(t *testing.T) { + r := New() + r.Route("/api", func(api *Group) { + api.Route("/v1", func(v1 *Group) { + v1.Get("/items", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "items-v1") + return nil + }) + }) + }) + + rec := doRequest(r, "GET", "/api/v1/items") + if rec.Body.String() != "items-v1" { + t.Errorf("expected body 'items-v1', got %q", rec.Body.String()) + } +} + +func TestRouteWithMiddleware(t *testing.T) { + r := New() + r.Route("/admin", func(sub *Group) { + sub.Get("/dashboard", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "dashboard") + return nil + }) + }, headerMiddleware("X-Admin", "yes")) + + rec := doRequest(r, "GET", "/admin/dashboard") + if rec.Header().Get("X-Admin") != "yes" { + t.Error("expected middleware to be applied via Route") + } +} + +func TestRouteReturnsGroup(t *testing.T) { + r := New() + g := r.Route("/api", func(sub *Group) { + sub.Get("/data", func(w http.ResponseWriter, req *http.Request) error { + return nil + }) + }) + + // Should be able to add more routes to the returned group. + g.Get("/extra", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "extra") + return nil + }) + + rec := doRequest(r, "GET", "/api/extra") + if rec.Body.String() != "extra" { + t.Errorf("expected body 'extra', got %q", rec.Body.String()) + } +} + +func TestRouteWithNamedRoutes(t *testing.T) { + r := New() + r.Route("/users", func(sub *Group) { + sub.Get("/{id}", func(w http.ResponseWriter, req *http.Request) error { + return nil + }).Name("get-user") + }) + + got := r.URL("get-user", "id", "5") + if got != "/users/5" { + t.Errorf("expected /users/5, got %q", got) + } +} + +// ─── Custom NotFound / MethodNotAllowed Handlers ────────────────────────── + +func TestCustomNotFoundHandler(t *testing.T) { + r := New(WithNotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "custom 404") + }))) + r.Get("/exists", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + rec := doRequest(r, "GET", "/nope") + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d", rec.Code) + } + if rec.Body.String() != "custom 404" { + t.Errorf("expected 'custom 404', got %q", rec.Body.String()) + } +} + +func TestCustomMethodNotAllowedHandler(t *testing.T) { + r := New(WithMethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + fmt.Fprint(w, "custom 405") + }))) + r.Get("/users", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + rec := doRequest(r, "POST", "/users") + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected 405, got %d", rec.Code) + } + if rec.Body.String() != "custom 405" { + t.Errorf("expected 'custom 405', got %q", rec.Body.String()) + } +} + +func TestNotFoundFallsBackToErrorHandler(t *testing.T) { + // No WithNotFound set — should use DefaultErrorHandler (existing behavior). + r := New() + r.Get("/exists", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + rec := doRequest(r, "GET", "/nope") + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d", rec.Code) + } + var env errorEnvelope + json.NewDecoder(rec.Body).Decode(&env) + if env.Error.Code != "NOT_FOUND" { + t.Errorf("expected NOT_FOUND, got %s", env.Error.Code) + } +} + +func TestCustomNotFoundTakesPrecedenceOverErrorHandler(t *testing.T) { + var errorHandlerCalled bool + r := New( + WithErrorHandler(func(w http.ResponseWriter, req *http.Request, err error) { + errorHandlerCalled = true + DefaultErrorHandler(w, req, err) + }), + WithNotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "custom wins") + })), + ) + r.Get("/exists", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + rec := doRequest(r, "GET", "/nope") + if errorHandlerCalled { + t.Error("ErrorHandler should not be called when NotFoundHandler is set") + } + if rec.Body.String() != "custom wins" { + t.Errorf("expected 'custom wins', got %q", rec.Body.String()) + } +} + +// ─── Trailing Slash Handling ────────────────────────────────────────────── + +func TestStripSlash(t *testing.T) { + r := New(WithStripSlash()) + r.Get("/users", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "users") + return nil + }) + + // /users/ should match /users after stripping. + rec := doRequest(r, "GET", "/users/") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "users" { + t.Errorf("expected body 'users', got %q", rec.Body.String()) + } +} + +func TestStripSlashRootUnaffected(t *testing.T) { + r := New(WithStripSlash()) + r.Get("/", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "root") + return nil + }) + + rec := doRequest(r, "GET", "/") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "root" { + t.Errorf("expected body 'root', got %q", rec.Body.String()) + } +} + +func TestRedirectSlash(t *testing.T) { + r := New(WithRedirectSlash()) + r.Get("/users", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "users") + return nil + }) + + rec := doRequest(r, "GET", "/users/") + if rec.Code != http.StatusMovedPermanently { + t.Fatalf("expected 301, got %d", rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/users" { + t.Errorf("expected redirect to /users, got %q", loc) + } +} + +func TestRedirectSlashPreservesQuery(t *testing.T) { + r := New(WithRedirectSlash()) + r.Get("/users", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + req := httptest.NewRequest("GET", "/users/?page=2&limit=10", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusMovedPermanently { + t.Fatalf("expected 301, got %d", rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/users?page=2&limit=10" { + t.Errorf("expected redirect to /users?page=2&limit=10, got %q", loc) + } +} + +func TestRedirectSlashRootUnaffected(t *testing.T) { + r := New(WithRedirectSlash()) + r.Get("/", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "root") + return nil + }) + + rec := doRequest(r, "GET", "/") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} + +func TestStripSlashAndRedirectSlashPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Error("expected panic when both WithStripSlash and WithRedirectSlash are set") + } + }() + New(WithStripSlash(), WithRedirectSlash()) +} + +func TestStripSlashDoubleSlashPath(t *testing.T) { + r := New(WithStripSlash()) + r.Get("/", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "root") + return nil + }) + + // "//" should be trimmed to "/" not "". + rec := doRequest(r, "GET", "//") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "root" { + t.Errorf("expected body 'root', got %q", rec.Body.String()) + } +} + +func TestRedirectSlashDoubleSlashPath(t *testing.T) { + r := New(WithRedirectSlash()) + r.Get("/", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + rec := doRequest(r, "GET", "//") + if rec.Code != http.StatusMovedPermanently { + t.Fatalf("expected 301, got %d", rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Errorf("expected redirect to /, got %q", loc) + } +} + +// ─── Mount() — Mount Sub-Routers ───────────────────────────────────────── + +func TestMountHTTPHandler(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /hello", func(w http.ResponseWriter, req *http.Request) { + fmt.Fprint(w, "hello from sub") + }) + + r := New() + r.Mount("/sub", mux) + + rec := doRequest(r, "GET", "/sub/hello") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "hello from sub" { + t.Errorf("expected 'hello from sub', got %q", rec.Body.String()) + } +} + +func TestMountRouter(t *testing.T) { + admin := New() + admin.Get("/stats", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "stats") + return nil + }) + admin.Get("/users", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "admin-users") + return nil + }) + + r := New() + r.Mount("/admin", admin) + + tests := []struct { + path, body string + }{ + {"/admin/stats", "stats"}, + {"/admin/users", "admin-users"}, + } + for _, tt := range tests { + rec := doRequest(r, "GET", tt.path) + if rec.Code != http.StatusOK { + t.Errorf("%s: expected 200, got %d", tt.path, rec.Code) + } + if rec.Body.String() != tt.body { + t.Errorf("%s: expected body %q, got %q", tt.path, tt.body, rec.Body.String()) + } + } +} + +func TestMountWithGroupMiddleware(t *testing.T) { + sub := New() + sub.Get("/data", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "data") + return nil + }) + + r := New() + api := r.Group("/api", headerMiddleware("X-MW", "applied")) + api.Mount("/sub", sub) + + rec := doRequest(r, "GET", "/api/sub/data") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Header().Get("X-MW") != "applied" { + t.Error("expected group middleware to be applied to mounted handler") + } +} + +func TestMountRouteIntrospection(t *testing.T) { + sub := New() + sub.Get("/stats", func(w http.ResponseWriter, req *http.Request) error { return nil }) + sub.Post("/stats", func(w http.ResponseWriter, req *http.Request) error { return nil }) + + r := New() + r.Mount("/admin", sub) + + routes := r.Routes() + if len(routes) != 2 { + t.Fatalf("expected 2 routes from mounted router, got %d", len(routes)) + } + if routes[0].Pattern != "/admin/stats" { + t.Errorf("expected /admin/stats, got %q", routes[0].Pattern) + } + if routes[0].Method != "GET" { + t.Errorf("expected GET, got %q", routes[0].Method) + } + if routes[1].Method != "POST" { + t.Errorf("expected POST, got %q", routes[1].Method) + } +} + +func TestMountInGroup(t *testing.T) { + sub := New() + sub.Get("/items", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "items") + return nil + }) + + r := New() + api := r.Group("/api/v1") + api.Mount("/catalog", sub) + + rec := doRequest(r, "GET", "/api/v1/catalog/items") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "items" { + t.Errorf("expected 'items', got %q", rec.Body.String()) + } +} + +func TestMountNamedRoutesPropagated(t *testing.T) { + sub := New() + sub.Get("/users/{id}", func(w http.ResponseWriter, req *http.Request) error { + return nil + }).Name("get-user") + + r := New() + r.Mount("/api", sub) + + got := r.URL("get-user", "id", "42") + if got != "/api/users/42" { + t.Errorf("expected /api/users/42, got %q", got) + } +} + +func TestMountMiddlewareSeesFullPath(t *testing.T) { + sub := New() + sub.Get("/data", func(w http.ResponseWriter, req *http.Request) error { + fmt.Fprint(w, "data") + return nil + }) + + var capturedPath string + pathCapture := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + next.ServeHTTP(w, r) + }) + } + + r := New() + api := r.Group("/api", pathCapture) + api.Mount("/sub", sub) + + rec := doRequest(r, "GET", "/api/sub/data") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if capturedPath != "/api/sub/data" { + t.Errorf("middleware should see full path /api/sub/data, got %q", capturedPath) + } +} + +// ─── Static() / File() — Static File Serving ───────────────────────────── + +func TestStaticServesFiles(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "style.css"), []byte("body{}"), 0644) + + r := New() + r.Static("/assets", dir) + + rec := doRequest(r, "GET", "/assets/style.css") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "body{}" { + t.Errorf("expected 'body{}', got %q", rec.Body.String()) + } +} + +func TestStaticMissingFile(t *testing.T) { + dir := t.TempDir() + + r := New() + r.Static("/assets", dir) + + rec := doRequest(r, "GET", "/assets/missing.js") + if rec.Code != http.StatusNotFound { + t.Errorf("expected 404 for missing file, got %d", rec.Code) + } +} + +func TestStaticWithGroupMiddleware(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "app.js"), []byte("alert(1)"), 0644) + + r := New() + api := r.Group("/cdn", headerMiddleware("X-CDN", "yes")) + api.Static("/assets", dir) + + rec := doRequest(r, "GET", "/cdn/assets/app.js") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Header().Get("X-CDN") != "yes" { + t.Error("expected group middleware to be applied to static handler") + } +} + +func TestStaticIntrospection(t *testing.T) { + dir := t.TempDir() + r := New() + r.Static("/assets", dir) + + routes := r.Routes() + if len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + if routes[0].Method != "GET" { + t.Errorf("expected GET, got %q", routes[0].Method) + } + if routes[0].Pattern != "/assets/{file...}" { + t.Errorf("expected /assets/{file...}, got %q", routes[0].Pattern) + } +} + +func TestFileSingle(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "favicon.ico") + os.WriteFile(path, []byte("icon-data"), 0644) + + r := New() + r.File("/favicon.ico", path) + + rec := doRequest(r, "GET", "/favicon.ico") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "icon-data" { + t.Errorf("expected 'icon-data', got %q", rec.Body.String()) + } +} + +func TestFileInGroup(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "robots.txt") + os.WriteFile(path, []byte("User-agent: *"), 0644) + + r := New() + api := r.Group("/public") + api.File("/robots.txt", path) + + rec := doRequest(r, "GET", "/public/robots.txt") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "User-agent: *" { + t.Errorf("expected 'User-agent: *', got %q", rec.Body.String()) + } +} diff --git a/router/url.go b/router/url.go new file mode 100644 index 0000000..e57066d --- /dev/null +++ b/router/url.go @@ -0,0 +1,63 @@ +package router + +import ( + "fmt" + "strings" +) + +// URL builds a URL path for the named route, substituting path parameters. +// Parameters are provided as key-value pairs: r.URL("get-user", "id", "42") returns "/users/42". +// +// It panics if the route name is not found, if the number of params is odd, +// if a {placeholder} in the pattern has no matching key in params, +// or if extra params are provided that don't match any placeholder. +func (r *Router) URL(name string, params ...string) string { + idx, ok := r.namedRoutes[name] + if !ok { + panic(fmt.Sprintf("router: unknown route name %q", name)) + } + if len(params)%2 != 0 { + panic("router: URL params must be key-value pairs") + } + + pattern := r.routes[idx].Pattern + + // Build param map. + m := make(map[string]string, len(params)/2) + for i := 0; i < len(params); i += 2 { + m[params[i]] = params[i+1] + } + + // Replace {name} and {name...} placeholders, tracking which params are used. + used := 0 + var b strings.Builder + b.Grow(len(pattern)) + for i := 0; i < len(pattern); { + if pattern[i] == '{' { + end := strings.IndexByte(pattern[i:], '}') + if end == -1 { + b.WriteByte(pattern[i]) + i++ + continue + } + placeholder := pattern[i+1 : i+end] + key := strings.TrimSuffix(placeholder, "...") + val, found := m[key] + if !found { + panic(fmt.Sprintf("router: missing param %q for route %q", key, name)) + } + used++ + b.WriteString(val) + i += end + 1 + } else { + b.WriteByte(pattern[i]) + i++ + } + } + + if used != len(m) { + panic(fmt.Sprintf("router: %d extra param(s) provided for route %q", len(m)-used, name)) + } + + return b.String() +} diff --git a/router/walk.go b/router/walk.go new file mode 100644 index 0000000..79d55a3 --- /dev/null +++ b/router/walk.go @@ -0,0 +1,18 @@ +package router + +// Walk iterates over all registered routes in registration order, calling fn for each. +// If fn returns a non-nil error, Walk stops and returns that error. +func (r *Router) Walk(fn func(RouteInfo) error) error { + for _, ri := range r.routes { + if err := fn(ri); err != nil { + return err + } + } + return nil +} + +// Routes returns a copy of all registered routes. +// The returned slice is a snapshot; modifying it has no effect on the router. +func (r *Router) Routes() []RouteInfo { + return append([]RouteInfo(nil), r.routes...) +}