Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/tgfeed/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func main() {
generator := feed.NewGenerator()

// Initialize and run the HTTP server
server := rest.NewServer(c, scraper, generator, ipFilter, port)
server := rest.NewServer(c, scraper, generator, ipFilter, port, trustProxy)

if err := server.Run(ctx); err != nil {
logger.Error("Server error", "error", err)
Expand Down
77 changes: 11 additions & 66 deletions internal/api/rest/ipfilter.go → internal/api/rest/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ type Firewall struct {
// The allowedIPsStr parameter accepts a comma-separated list of IP addresses
// and/or CIDR ranges (e.g., "10.0.0.0/24,192.168.1.1,2001:db8::/32").
// If allowedIPsStr is empty, all IP addresses are allowed by default.
// When trustProxy is true, the firewall will check X-Real-IP and X-Forwarded-For
// headers to determine the client's IP address, which is necessary when the
// application runs behind a reverse proxy.
// Returns an error if any IP address or CIDR notation is invalid.
func NewFirewall(allowedIPsStr string, trustProxy bool) (*Firewall, error) {
if allowedIPsStr == "" {
Expand All @@ -45,8 +42,7 @@ func NewFirewall(allowedIPsStr string, trustProxy bool) (*Firewall, error) {
}

// IsAllowed checks if the request originates from an allowed IP address.
// When trustProxy is enabled, it first checks X-Real-IP and X-Forwarded-For headers
// before falling back to RemoteAddr. If no IP restrictions are configured (empty allowlist),
// If no IP restrictions are configured (empty allowlist),
// all requests are allowed. Returns false if the IP cannot be extracted or is not in the allowlist.
func (f *Firewall) IsAllowed(r *http.Request) bool {
if len(f.allowedNets) == 0 {
Expand All @@ -59,58 +55,24 @@ func (f *Firewall) IsAllowed(r *http.Request) bool {
return false
}

return isIPAllowed(clientIP, f.allowedNets)
return f.isIPAllowed(clientIP)
}

// extractClientIP extracts the client IP address from the request
func extractClientIP(r *http.Request, trustProxy bool) (string, error) {
if trustProxy {
if clientIP := tryExtractFromProxyHeaders(r); clientIP != "" {
return clientIP, nil
}
}

return extractFromRemoteAddr(r.RemoteAddr)
}
// isIPAllowed checks if an IP address is in the allowed networks
func (f *Firewall) isIPAllowed(ipStr string) bool {
ip := net.ParseIP(ipStr)

// tryExtractFromProxyHeaders attempts to extract IP from proxy headers
func tryExtractFromProxyHeaders(r *http.Request) string {
if xRealIP := r.Header.Get("X-Real-IP"); xRealIP != "" {
if ip := net.ParseIP(xRealIP); ip != nil {
return ip.String()
}
if ip == nil {
return false
}

if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")

if len(ips) > 0 {
clientIP := strings.TrimSpace(ips[0])

if ip := net.ParseIP(clientIP); ip != nil {
return ip.String()
}
for _, ipNet := range f.allowedNets {
if ipNet.Contains(ip) {
return true
}
}

return ""
}

// extractFromRemoteAddr extracts IP from RemoteAddr
func extractFromRemoteAddr(remoteAddr string) (string, error) {
host, _, err := net.SplitHostPort(remoteAddr)

if err != nil {
return "", fmt.Errorf("invalid remote address: %w", err)
}

ip := net.ParseIP(host)

if ip == nil {
return "", fmt.Errorf("invalid IP address: %s", host)
}

return ip.String(), nil
return false
}

// parseAllowedIPs parses a comma-separated list of IP addresses and CIDR ranges
Expand Down Expand Up @@ -169,20 +131,3 @@ func parseIPOrCIDR(part string) (*net.IPNet, error) {

return ipNet, nil
}

// isIPAllowed checks if an IP address is in the allowed networks
func isIPAllowed(ipStr string, allowedNets []*net.IPNet) bool {
ip := net.ParseIP(ipStr)

if ip == nil {
return false
}

for _, ipNet := range allowedNets {
if ipNet.Contains(ip) {
return true
}
}

return false
}
File renamed without changes.
67 changes: 67 additions & 0 deletions internal/api/rest/ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package rest

import (
"fmt"
"net"
"net/http"
"strings"
)

// extractClientIP extracts the client IP address from the request.
// When trustProxy is true, the firewall will check X-Real-IP and
// X-Forwarded-For headers to determine the client's IP address,
// which is necessary when the application runs behind a reverse proxy.
func extractClientIP(r *http.Request, trustProxy bool) (string, error) {
if trustProxy {
if clientIP := tryExtractFromProxyHeaders(r); clientIP != "" {
return clientIP, nil
}
}

return extractFromRemoteAddr(r.RemoteAddr)
}

// mustExtractClientIP behaves exactly like extractClientIP except it
// doesn't return an error, ignoring it instead.
func mustExtractClientIP(r *http.Request, trustProxy bool) string {
ip, _ := extractClientIP(r, trustProxy)

return ip
}

// tryExtractFromProxyHeaders attempts to extract IP from proxy headers
func tryExtractFromProxyHeaders(r *http.Request) string {
if xRealIP := r.Header.Get("X-Real-IP"); xRealIP != "" {
if ip := net.ParseIP(xRealIP); ip != nil {
return ip.String()
}
}

if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
clientIP := strings.TrimSpace(ips[0])

if ip := net.ParseIP(clientIP); ip != nil {
return ip.String()
}
}

return ""
}

// extractFromRemoteAddr extracts IP from RemoteAddr
func extractFromRemoteAddr(remoteAddr string) (string, error) {
host, _, err := net.SplitHostPort(remoteAddr)

if err != nil {
return "", fmt.Errorf("invalid remote address: %w", err)
}

ip := net.ParseIP(host)

if ip == nil {
return "", fmt.Errorf("invalid IP address: %s", host)
}

return ip.String(), nil
}
8 changes: 4 additions & 4 deletions internal/api/rest/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// Logger wraps an http.Handler with request/response logging
func Logger(next http.Handler) http.Handler {
func Logger(next http.Handler, trustProxy bool) http.Handler {
logger := app.Logger()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -25,7 +25,7 @@ func Logger(next http.Handler) http.Handler {
"method", r.Method,
"path", r.URL.Path,
"query", r.URL.RawQuery,
"remote_addr", r.RemoteAddr,
"remote_addr", mustExtractClientIP(r, trustProxy),
"user_agent", r.UserAgent(),
)

Expand Down Expand Up @@ -73,7 +73,7 @@ func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter {
// When a filter is provided, each request is validated using filter.IsAllowed.
// Denied requests receive a 403 Forbidden response with a JSON error message.
// The middleware logs warnings for denied requests including the remote address and path.
func IPFilterMiddleware(filter IPFilter) func(http.Handler) http.Handler {
func IPFilterMiddleware(filter IPFilter, trustProxy bool) func(http.Handler) http.Handler {
logger := app.Logger()

return func(next http.Handler) http.Handler {
Expand All @@ -88,7 +88,7 @@ func IPFilterMiddleware(filter IPFilter) func(http.Handler) http.Handler {
}

logger.Warn("IP not allowed",
"remote_addr", r.RemoteAddr,
"remote_addr", mustExtractClientIP(r, trustProxy),
"path", r.URL.Path,
)

Expand Down
38 changes: 20 additions & 18 deletions internal/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,34 @@ import (

// Server represents the REST API server
type Server struct {
mux *http.ServeMux
server *http.Server
logger *slog.Logger
cache cache.Cache
scraper Scraper
generator Generator
ipFilter IPFilter
port string
mux *http.ServeMux
server *http.Server
logger *slog.Logger
cache cache.Cache
scraper Scraper
generator Generator
ipFilter IPFilter
port string
trustProxy bool
}

// NewServer creates a new REST API server with the specified dependencies.
// The ipFilter parameter controls IP-based access restrictions; pass nil to disable filtering.
// The port parameter specifies the TCP port to listen on (e.g., "8080").
// The server is pre-configured with secure timeout values to mitigate common attacks.
func NewServer(c cache.Cache, s Scraper, g Generator, ipFilter IPFilter, port string) *Server {
func NewServer(c cache.Cache, s Scraper, g Generator, ipFilter IPFilter, port string, trustProxy bool) *Server {
mux := http.NewServeMux()
logger := app.Logger()

server := &Server{
mux: mux,
logger: logger,
cache: c,
scraper: s,
generator: g,
ipFilter: ipFilter,
port: port,
mux: mux,
logger: logger,
cache: c,
scraper: s,
generator: g,
ipFilter: ipFilter,
port: port,
trustProxy: trustProxy,
server: &http.Server{
Addr: ":" + port,
Handler: nil, // Will be set in Run
Expand All @@ -65,8 +67,8 @@ func (s *Server) registerHandlers() {
func (s *Server) Run(ctx context.Context) error {
// Apply middleware chain
handler := http.Handler(s.mux)
handler = IPFilterMiddleware(s.ipFilter)(handler)
handler = Logger(handler)
handler = IPFilterMiddleware(s.ipFilter, s.trustProxy)(handler)
handler = Logger(handler, s.trustProxy)

// Set the handler with middleware
s.server.Handler = handler
Expand Down