diff --git a/.golangci.yml b/.golangci.yml index afd5bae..3fd4f23 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,28 +1,34 @@ version: "2" run: timeout: 5m - tests: false + #tests: false linters: default: all disable: - depguard - - gochecknoglobals - noinlineerr - nlreturn + - paralleltest + - testpackage - wsl - wsl_v5 settings: + errcheck: + exclude-functions: + - io.WriteString + - (net/http.ResponseWriter).Write exhaustruct: + allow-empty: true exclude: - net/http.Server + goconst: + ignore-string-values: + - "default router" varnamelen: ignore-decls: - - w http.ResponseWriter - r *http.Request - errcheck: - exclude-functions: - - io.WriteString - - (net/http.ResponseWriter).Write + - w http.ResponseWriter + - w *httptest.ResponseRecorder exclusions: generated: lax warn-unused: true @@ -30,7 +36,7 @@ linters: # - comments # - common-false-positives - legacy - # - std-error-handling + - std-error-handling formatters: enable: - gofumpt diff --git a/example/main.go b/example/main.go index 75ea5c0..1ff6c1e 100644 --- a/example/main.go +++ b/example/main.go @@ -4,7 +4,6 @@ package main import ( "io" "log" - "log/slog" "net/http" "net/http/pprof" @@ -14,7 +13,7 @@ import ( func main() { log.SetFlags(log.Lshortfile) - router := mux.NewRouter(slog.Default(), mux.Logger) + router := mux.NewRouter(mux.Logger) router.NotFound(notFound).NotAllowed(notAllowed) // router.Use(mux.Logger) router.Get("/{$}", page) diff --git a/middleware.go b/middleware.go index 88387d1..70b7b2a 100644 --- a/middleware.go +++ b/middleware.go @@ -4,12 +4,9 @@ import ( "fmt" "log/slog" "net/http" - "strings" "time" ) -var logger *slog.Logger - type statusRecorder struct { http.ResponseWriter @@ -28,20 +25,21 @@ func Logger(next http.Handler) http.Handler { now := time.Now() rec := statusRecorder{w, http.StatusOK} next.ServeHTTP(&rec, r) - remote := strings.Split(r.RemoteAddr, ":")[0] + // remote := strings.Split(r.RemoteAddr, ":")[0] + remote := r.RemoteAddr if r.Header.Get("X-Forwarded-For") != "" { remote = r.Header.Get("X-Forwarded-For") } details := fmt.Sprintf( - "%s %s %s %s %d %s %s", + "%s %s%s %d %s %s %s", r.Method, r.Host, r.URL.Path, - remote, rec.status, + remote, time.Since(now).String(), r.UserAgent(), ) - logger.Info(details) + slog.Info(details) }) } diff --git a/router.go b/router.go index 6be2e9d..886a858 100644 --- a/router.go +++ b/router.go @@ -17,7 +17,6 @@ type Middleware func(http.Handler) http.Handler // Router provides a chain of middlewares and routes. type Router struct { *http.ServeMux - *slog.Logger chain http.Handler methods []string @@ -25,12 +24,11 @@ type Router struct { notAllowed func(http.ResponseWriter, string, int) } -// DefaultRouter creates a new Router using the default ServeMux. -func DefaultRouter() *Router { +// defaultRouter creates a new Router using the default ServeMux. +func defaultRouter() *Router { mux := http.NewServeMux() router := &Router{ ServeMux: mux, - Logger: slog.New(slog.DiscardHandler), chain: mux, methods: []string{}, notFound: http.NotFound, @@ -55,7 +53,6 @@ func DefaultRouter() *Router { if len(allowed) != 0 { w.Header().Set("Allow", strings.Join(allowed, ", ")) router.notAllowed(w, "Method Not Allowed", http.StatusMethodNotAllowed) - // http.Error(w, "Custom Method Not Allowed", http.StatusMethodNotAllowed) return } // http.Error(w, "Custom Not Found", http.StatusNotFound) @@ -65,12 +62,8 @@ func DefaultRouter() *Router { } // NewRouter creates a new Router with the given middleware applied. -func NewRouter(l *slog.Logger, middleware ...Middleware) *Router { - router := DefaultRouter() - if l != nil { - router.Logger = l - logger = l - } +func NewRouter(middleware ...Middleware) *Router { + router := defaultRouter() router.Use(middleware...) return router } @@ -95,10 +88,9 @@ func (router *Router) Group(prefix string, middlewares ...Middleware) *Router { } } - subRouter := DefaultRouter() + subRouter := defaultRouter() subRouter.notFound = router.notFound subRouter.notAllowed = router.notAllowed - subRouter.Logger = router.Logger subRouter.Use(middlewares...) router.Handle(prefix+"/", http.StripPrefix(prefix, subRouter)) return subRouter @@ -203,9 +195,9 @@ func (router *Router) Run(addr string) { ReadHeaderTimeout: time.Second, Handler: router, } - router.Info("Starting server:", "Address", addr) + slog.Info("Starting server:", "Address", addr) if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - router.Error("Router.Run: failed to start server: ", "error", err) + slog.Error("Router.Run: failed to start server: ", "error", err) } } diff --git a/router_bench_test.go b/router_bench_test.go index c0a2a62..8baaad9 100644 --- a/router_bench_test.go +++ b/router_bench_test.go @@ -2,13 +2,12 @@ package mux import ( "io" - "log/slog" "net/http" "net/http/httptest" "testing" ) -func dummyHandler(w http.ResponseWriter, r *http.Request) { +func dummyHandler(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "ok") } @@ -23,10 +22,10 @@ func makeMiddleware(tag string) Middleware { } func BenchmarkRouterBase(b *testing.B) { - router := DefaultRouter() + router := defaultRouter() router.HandleFunc("/bench", dummyHandler) - req := httptest.NewRequest("GET", "/bench", nil) + req := httptest.NewRequest(http.MethodGet, "/bench", nil) w := httptest.NewRecorder() for b.Loop() { @@ -35,14 +34,14 @@ func BenchmarkRouterBase(b *testing.B) { } func Benchmark5Routes(b *testing.B) { - router := DefaultRouter() + router := defaultRouter() router.HandleFunc("/bench1", dummyHandler) router.HandleFunc("/bench2", dummyHandler) router.HandleFunc("/bench3", dummyHandler) router.HandleFunc("/bench4", dummyHandler) router.HandleFunc("/bench5", dummyHandler) - req := httptest.NewRequest("GET", "/bench2", nil) + req := httptest.NewRequest(http.MethodGet, "/bench2", nil) w := httptest.NewRecorder() for b.Loop() { @@ -51,10 +50,10 @@ func Benchmark5Routes(b *testing.B) { } func BenchmarkRouterWithOneMiddleware(b *testing.B) { - router := NewRouter(nil, makeMiddleware("m1")) + router := NewRouter(makeMiddleware("m1")) router.HandleFunc("/bench", dummyHandler) - req := httptest.NewRequest("GET", "/bench", nil) + req := httptest.NewRequest(http.MethodGet, "/bench", nil) w := httptest.NewRecorder() for b.Loop() { @@ -63,7 +62,7 @@ func BenchmarkRouterWithOneMiddleware(b *testing.B) { } func BenchmarkRouterWithFiveMiddlewares(b *testing.B) { - router := DefaultRouter() + router := defaultRouter() router.Use( makeMiddleware("m1"), makeMiddleware("m2"), @@ -73,21 +72,21 @@ func BenchmarkRouterWithFiveMiddlewares(b *testing.B) { ) router.HandleFunc("/bench", dummyHandler) - req := httptest.NewRequest("GET", "/bench", nil) + req := httptest.NewRequest(http.MethodGet, "/bench", nil) w := httptest.NewRecorder() - for i := 0; i < b.N; i++ { + for b.Loop() { router.ServeHTTP(w, req) } } func BenchmarkRouterGroup(b *testing.B) { - mainRouter := DefaultRouter() + mainRouter := defaultRouter() group := mainRouter.Group("/api", makeMiddleware("group")) group.HandleFunc("/bench", dummyHandler) - req := httptest.NewRequest("GET", "/api/bench", nil) + req := httptest.NewRequest(http.MethodGet, "/api/bench", nil) w := httptest.NewRecorder() for b.Loop() { @@ -96,7 +95,7 @@ func BenchmarkRouterGroup(b *testing.B) { } func BenchmarkLogginMiddleWare(b *testing.B) { - router := NewRouter(slog.New(slog.DiscardHandler), Logger) + router := NewRouter(Logger) router.HandleFunc("/bench", dummyHandler) req := httptest.NewRequest(http.MethodGet, "/bench", nil) w := httptest.NewRecorder() diff --git a/router_test.go b/router_test.go index d770185..78195c2 100644 --- a/router_test.go +++ b/router_test.go @@ -13,12 +13,12 @@ import ( ) func TestDefaultRouter(t *testing.T) { - router := DefaultRouter() - router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -31,12 +31,13 @@ func TestDefaultRouter(t *testing.T) { } func TestLoggerMiddleware(t *testing.T) { - expectedLog := "DELETE example.com / 192.0.2.1 204" - expectedForwardedLog := "DELETE example.com / 192.168.0.1 204" + expectedLog := "DELETE example.com/ 204 192.0.2.1" + expectedForwardedLog := "DELETE example.com/ 204 192.168.0.1" buf := new(bytes.Buffer) logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{})) - router := NewRouter(logger, Logger) - router.Delete("/", func(w http.ResponseWriter, r *http.Request) { + slog.SetDefault(logger) + router := NewRouter(Logger) + router.Delete("/", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNoContent) }) req := httptest.NewRequest(http.MethodDelete, "/", nil) @@ -62,12 +63,12 @@ func TestMiddlewareExecution(t *testing.T) { next.ServeHTTP(w, r) }) } - router := DefaultRouter() + router := NewRouter() router.Use(middleware) - router.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + router.HandleFunc("/ping", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "pong") }) - req := httptest.NewRequest("GET", "/ping", nil) + req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.Header.Set("User-Agent", "Go-Test") w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -82,7 +83,7 @@ func TestMiddlewareExecution(t *testing.T) { } func TestGroupRouting(t *testing.T) { - router := DefaultRouter() + router := NewRouter() subRouter := router.Group("/api", func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -95,7 +96,7 @@ func TestGroupRouting(t *testing.T) { io.WriteString(w, r.Header.Get("X-Group")) }) - req := httptest.NewRequest("GET", "/api/hello", nil) + req := httptest.NewRequest(http.MethodGet, "/api/hello", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -114,35 +115,35 @@ func TestNilMiddlewarePanic(t *testing.T) { } }() - router := DefaultRouter() + router := NewRouter() router.Group("/should-panic", nil) } func TestChainedMiddlewareOrder(t *testing.T) { var trace []string - m1 := func(next http.Handler) http.Handler { + mid1 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { trace = append(trace, "m1") next.ServeHTTP(w, r) }) } - m2 := func(next http.Handler) http.Handler { + mid2 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { trace = append(trace, "m2") next.ServeHTTP(w, r) }) } - router := DefaultRouter() - router.Use(m1, m2) + router := NewRouter() + router.Use(mid1, mid2) - router.HandleFunc("/chain", func(w http.ResponseWriter, r *http.Request) { + router.HandleFunc("/chain", func(w http.ResponseWriter, _ *http.Request) { trace = append(trace, "handler") io.WriteString(w, "done") }) - req := httptest.NewRequest("GET", "/chain", nil) + req := httptest.NewRequest(http.MethodGet, "/chain", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -156,11 +157,11 @@ func TestChainedMiddlewareOrder(t *testing.T) { } func TestRouter_All(t *testing.T) { - router := DefaultRouter() - router.All("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.All("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) resp := w.Result() @@ -168,7 +169,7 @@ func TestRouter_All(t *testing.T) { if string(body) != "default router" { t.Errorf("Expected 'default router', got '%s'", string(body)) } - req = httptest.NewRequest("POST", "/test", nil) + req = httptest.NewRequest(http.MethodPost, "/test", nil) w = httptest.NewRecorder() router.ServeHTTP(w, req) resp = w.Result() @@ -179,8 +180,8 @@ func TestRouter_All(t *testing.T) { } func TestRouter_Get(t *testing.T) { - router := DefaultRouter() - router.Get("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.Get("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -194,8 +195,8 @@ func TestRouter_Get(t *testing.T) { } func TestRouter_Post(t *testing.T) { - router := DefaultRouter() - router.Post("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.Post("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) @@ -209,8 +210,8 @@ func TestRouter_Post(t *testing.T) { } func TestRouter_Delete(t *testing.T) { - router := DefaultRouter() - router.Delete("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.Delete("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) req := httptest.NewRequest(http.MethodDelete, "/test", nil) @@ -224,8 +225,8 @@ func TestRouter_Delete(t *testing.T) { } func TestRouter_Put(t *testing.T) { - router := DefaultRouter() - router.Put("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.Put("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) req := httptest.NewRequest(http.MethodPut, "/test", nil) @@ -239,8 +240,8 @@ func TestRouter_Put(t *testing.T) { } func TestRouter_Patch(t *testing.T) { - router := DefaultRouter() - router.Patch("/test", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.Patch("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "default router") }) req := httptest.NewRequest(http.MethodPatch, "/test", nil) @@ -256,8 +257,9 @@ func TestRouter_Patch(t *testing.T) { func TestRouterRun(t *testing.T) { buf := new(bytes.Buffer) logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{})) - router := NewRouter(logger) - router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + slog.SetDefault(logger) + router := NewRouter() + router.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "test response") }) @@ -278,7 +280,7 @@ func TestRouterRun(t *testing.T) { // start server again, should err with address already in use router.Run(":8000") - line, err = buf.ReadString(0x0a) + _, err = buf.ReadString(0x0a) if err != nil { t.Error("read log buffer", err) } @@ -292,12 +294,12 @@ func TestRouterRun(t *testing.T) { } func TestStaticFiles(t *testing.T) { - router := DefaultRouter() + router := NewRouter() router.Static("/files", "example/static") router.ServeFile("/hello", "example/static/hello.txt") // get dir - req := httptest.NewRequest("GET", "/files/", nil) + req := httptest.NewRequest(http.MethodGet, "/files/", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) body, err := io.ReadAll(w.Result().Body) @@ -309,7 +311,7 @@ func TestStaticFiles(t *testing.T) { } // get file from dir - req = httptest.NewRequest("GET", "/files/hello.txt", nil) + req = httptest.NewRequest(http.MethodGet, "/files/hello.txt", nil) w = httptest.NewRecorder() router.ServeHTTP(w, req) body, err = io.ReadAll(w.Result().Body) @@ -321,7 +323,7 @@ func TestStaticFiles(t *testing.T) { } // get file directly - req = httptest.NewRequest("GET", "/hello", nil) + req = httptest.NewRequest(http.MethodGet, "/hello", nil) w = httptest.NewRecorder() router.ServeHTTP(w, req) body, err = io.ReadAll(w.Result().Body) @@ -337,60 +339,60 @@ func TestStaticFiles(t *testing.T) { var content embed.FS func TestStaticFilesFS(t *testing.T) { - router := DefaultRouter() + router := NewRouter() router.StaticFS("/files", content) router.ServeFileFS("/hello/", "example/static/hello.txt", content) // get dir t.Run("getDir", func(t *testing.T) { - req := httptest.NewRequest("GET", "/files/", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - body, err := io.ReadAll(w.Result().Body) - if err != nil { - t.Error("error reading body", err) - } - if !strings.Contains(string(body), ">example/") { - t.Error("wrong response", string(body)) - } -}) + req := httptest.NewRequest(http.MethodGet, "/files/", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + body, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Error("error reading body", err) + } + if !strings.Contains(string(body), ">example/") { + t.Error("wrong response", string(body)) + } + }) // get file from dir t.Run("getFile", func(t *testing.T) { - req := httptest.NewRequest("GET", "/files/example/static/hello.txt", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - body, err := io.ReadAll(w.Result().Body) - if err != nil { - t.Error("error reading body", err) - } - if !strings.Contains(string(body), "hello world") { - t.Error("wrong response", string(body)) - } -}) + req := httptest.NewRequest(http.MethodGet, "/files/example/static/hello.txt", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + body, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Error("error reading body", err) + } + if !strings.Contains(string(body), "hello world") { + t.Error("wrong response", string(body)) + } + }) // get file directly t.Run("getFileDirect", func(t *testing.T) { - req := httptest.NewRequest("GET", "/hello/", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - body, err := io.ReadAll(w.Result().Body) - if err != nil { - t.Error("error reading body", err) - } - if !strings.Contains(string(body), "hello world") { - t.Error("wrong response", string(body)) - t.Log(w.Result().Header) - } -}) + req := httptest.NewRequest(http.MethodGet, "/hello/", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + body, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Error("error reading body", err) + } + if !strings.Contains(string(body), "hello world") { + t.Error("wrong response", string(body)) + t.Log(w.Result().Header) + } + }) } func TestErrorHandling(t *testing.T) { - router := NewRouter(slog.Default()).NotAllowed( + router := NewRouter().NotAllowed( func(w http.ResponseWriter, _ string, _ int) { http.Error(w, "Custom Method Not Allowed", http.StatusMethodNotAllowed) }).NotFound( - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) io.WriteString(w, "Custom Not Found") }) @@ -426,8 +428,8 @@ func TestErrorHandling(t *testing.T) { } func TestCustomMethod(t *testing.T) { - router := DefaultRouter() - router.CustomMethod("UPDATE", "/{$}", func(w http.ResponseWriter, r *http.Request) { + router := NewRouter() + router.CustomMethod("UPDATE", "/{$}", func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "custom method handler") }) req := httptest.NewRequest("UPDATE", "/", nil)