diff --git a/internal/application/identra/config.go b/internal/application/identra/config.go index 6760e53..5c3f963 100644 --- a/internal/application/identra/config.go +++ b/internal/application/identra/config.go @@ -24,6 +24,22 @@ type Config struct { GORMClient *gorm.Config MongoClient *mongo.Config RedisClient *redis.Config + + // LoginMaxAttempts is the maximum number of failed login attempts (password + // or email-code) allowed within LoginLockoutDuration before the account is + // temporarily locked. 0 means use DefaultLoginMaxAttempts. + LoginMaxAttempts int + // LoginLockoutDuration is the sliding window during which failed login + // attempts are counted. 0 means use DefaultLoginLockoutDuration. + LoginLockoutDuration time.Duration + + // SendCodeMaxAttempts is the maximum number of email verification codes + // that can be requested per email address within SendCodeWindow. 0 means + // use DefaultSendCodeMaxAttempts. + SendCodeMaxAttempts int + // SendCodeWindow is the sliding window for the send-code rate limit. 0 + // means use DefaultSendCodeWindow. + SendCodeWindow time.Duration } const ( @@ -31,6 +47,20 @@ const ( DefaultAccessTokenExpiration = 15 * time.Minute // Short-lived access token DefaultRefreshTokenExpiration = 7 * 24 * time.Hour // 7 days refresh token DefaultTokenIssuer = "identra" + + // DefaultLoginMaxAttempts is the default maximum number of failed login + // attempts before a temporary lockout is applied. + DefaultLoginMaxAttempts = 5 + // DefaultLoginLockoutDuration is the default window over which failed login + // attempts are counted. + DefaultLoginLockoutDuration = 15 * time.Minute + + // DefaultSendCodeMaxAttempts is the default maximum number of email + // verification codes that can be sent per address within DefaultSendCodeWindow. + DefaultSendCodeMaxAttempts = 5 + // DefaultSendCodeWindow is the default rate-limit window for sending email + // verification codes. + DefaultSendCodeWindow = 1 * time.Hour ) type MongoConfig struct { diff --git a/internal/application/identra/service.go b/internal/application/identra/service.go index 2ff6a24..6d6ad20 100644 --- a/internal/application/identra/service.go +++ b/internal/application/identra/service.go @@ -19,6 +19,7 @@ import ( identra_v1_pb "github.com/poly-workshop/identra/gen/go/identra/v1" "github.com/poly-workshop/identra/internal/domain" "github.com/poly-workshop/identra/internal/infrastructure/cache" + "github.com/poly-workshop/identra/internal/infrastructure/mail" "github.com/poly-workshop/identra/internal/infrastructure/oauth" "github.com/poly-workshop/identra/internal/infrastructure/persistence" "github.com/poly-workshop/identra/internal/infrastructure/security" @@ -48,12 +49,19 @@ type Service struct { tokenCfg security.TokenConfig githubOAuthConfig *oauth2.Config oauthFetchEmailIfMissing bool - mailer *smtp.Mailer + mailer mail.Sender + + // loginRateLimiter counts failed login attempts per email address and + // blocks further attempts after the configured threshold. + loginRateLimiter cache.RateLimiter + // sendCodeRateLimiter limits how many email verification codes can be sent + // to a single address within the configured window. + sendCodeRateLimiter cache.RateLimiter } func NewService(ctx context.Context, cfg Config) (*Service, error) { mailerCfg := cfg.SmtpMailer - var mailer *smtp.Mailer + var mailer mail.Sender if strings.TrimSpace(mailerCfg.Host) != "" { if err := validateMailerConfig(mailerCfg); err != nil { @@ -109,6 +117,44 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { return nil, fmt.Errorf("failed to initialize email code store: %w", storeErr) } + loginMaxAttempts := cfg.LoginMaxAttempts + if loginMaxAttempts <= 0 { + loginMaxAttempts = DefaultLoginMaxAttempts + } + loginLockoutDuration := cfg.LoginLockoutDuration + if loginLockoutDuration <= 0 { + loginLockoutDuration = DefaultLoginLockoutDuration + } + + loginLimiter, loginLimiterErr := cache.NewRedisRateLimiter( + redis.NewRDB(*cfg.RedisClient), + "identra:rl:login:", + loginMaxAttempts, + loginLockoutDuration, + ) + if loginLimiterErr != nil { + return nil, fmt.Errorf("failed to initialize login rate limiter: %w", loginLimiterErr) + } + + sendCodeMaxAttempts := cfg.SendCodeMaxAttempts + if sendCodeMaxAttempts <= 0 { + sendCodeMaxAttempts = DefaultSendCodeMaxAttempts + } + sendCodeWindow := cfg.SendCodeWindow + if sendCodeWindow <= 0 { + sendCodeWindow = DefaultSendCodeWindow + } + + sendCodeLimiter, sendCodeLimiterErr := cache.NewRedisRateLimiter( + redis.NewRDB(*cfg.RedisClient), + "identra:rl:send_code:", + sendCodeMaxAttempts, + sendCodeWindow, + ) + if sendCodeLimiterErr != nil { + return nil, fmt.Errorf("failed to initialize send-code rate limiter: %w", sendCodeLimiterErr) + } + return &Service{ userStore: userStore, keyManager: km, @@ -119,6 +165,8 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { oauthFetchEmailIfMissing: cfg.OAuthFetchEmailIfMissing, mailer: mailer, userStoreCleanup: cleanup, + loginRateLimiter: loginLimiter, + sendCodeRateLimiter: sendCodeLimiter, }, nil } @@ -398,6 +446,21 @@ func (s *Service) SendLoginEmailCode( return nil, status.Error(codes.InvalidArgument, "email is required") } + if s.sendCodeRateLimiter != nil { + allowed, rlErr := s.sendCodeRateLimiter.IsAllowed(ctx, email) + if rlErr != nil { + slog.ErrorContext(ctx, "send-code rate limiter error", "error", rlErr) + // fail open — a limiter error must not prevent legitimate users + } else if !allowed { + return nil, status.Error(codes.ResourceExhausted, "too many verification code requests, please try again later") + } + if rlErr == nil { + if recordErr := s.sendCodeRateLimiter.Record(ctx, email); recordErr != nil { + slog.ErrorContext(ctx, "failed to record send-code attempt", "error", recordErr) + } + } + } + code, err := generateEmailCode() if err != nil { slog.ErrorContext(ctx, "failed to generate email code", "error", err) @@ -501,12 +564,27 @@ func (s *Service) LoginByEmailCode( return nil, status.Error(codes.InvalidArgument, "email and code are required") } + if s.loginRateLimiter != nil { + allowed, rlErr := s.loginRateLimiter.IsAllowed(ctx, email) + if rlErr != nil { + slog.ErrorContext(ctx, "login rate limiter error", "error", rlErr) + // fail open + } else if !allowed { + return nil, status.Error(codes.ResourceExhausted, "too many failed attempts, please try again later") + } + } + ok, err := s.emailCodeStore.Consume(ctx, email, code) if err != nil { slog.ErrorContext(ctx, "failed to validate verification code", "error", err) return nil, status.Error(codes.Internal, "failed to validate code") } if !ok { + if s.loginRateLimiter != nil { + if recordErr := s.loginRateLimiter.Record(ctx, email); recordErr != nil { + slog.ErrorContext(ctx, "failed to record login failure", "error", recordErr) + } + } return nil, status.Error(codes.Unauthenticated, "invalid or expired code") } @@ -522,6 +600,12 @@ func (s *Service) LoginByEmailCode( return nil, status.Error(codes.Internal, "failed to fetch user") } + if s.loginRateLimiter != nil { + if resetErr := s.loginRateLimiter.Reset(ctx, email); resetErr != nil { + slog.ErrorContext(ctx, "failed to reset login rate limit", "error", resetErr) + } + } + s.recordLogin(ctx, usr) tokenPair, err := security.NewTokenPair(usr.ID, s.tokenCfg) if err != nil { @@ -586,6 +670,16 @@ func (s *Service) LoginByPassword( return nil, status.Error(codes.InvalidArgument, "email and password are required") } + if s.loginRateLimiter != nil { + allowed, rlErr := s.loginRateLimiter.IsAllowed(ctx, email) + if rlErr != nil { + slog.ErrorContext(ctx, "login rate limiter error", "error", rlErr) + // fail open + } else if !allowed { + return nil, status.Error(codes.ResourceExhausted, "too many failed attempts, please try again later") + } + } + usr, err := s.userStore.GetByEmail(ctx, email) switch { case err == nil: @@ -606,9 +700,20 @@ func (s *Service) LoginByPassword( return nil, status.Error(codes.Internal, "failed to verify password") } if !valid { + if s.loginRateLimiter != nil { + if recordErr := s.loginRateLimiter.Record(ctx, email); recordErr != nil { + slog.ErrorContext(ctx, "failed to record login failure", "error", recordErr) + } + } return nil, status.Error(codes.Unauthenticated, "invalid credentials") } + if s.loginRateLimiter != nil { + if resetErr := s.loginRateLimiter.Reset(ctx, email); resetErr != nil { + slog.ErrorContext(ctx, "failed to reset login rate limit", "error", resetErr) + } + } + s.recordLogin(ctx, usr) tokenPair, err := security.NewTokenPair(usr.ID, s.tokenCfg) if err != nil { diff --git a/internal/application/identra/service_rate_limit_test.go b/internal/application/identra/service_rate_limit_test.go new file mode 100644 index 0000000..2dda39d --- /dev/null +++ b/internal/application/identra/service_rate_limit_test.go @@ -0,0 +1,272 @@ +package identra + +import ( + "context" + "testing" + + identra_v1_pb "github.com/poly-workshop/identra/gen/go/identra/v1" + "github.com/poly-workshop/identra/internal/domain" + "github.com/poly-workshop/identra/internal/infrastructure/cache" + "github.com/poly-workshop/identra/internal/infrastructure/notification/smtp" + "github.com/poly-workshop/identra/internal/infrastructure/security" + "google.golang.org/grpc/codes" +) + +// mockEmailCodeStore is a simple in-memory email code store for testing. +type mockEmailCodeStore struct { + codes map[string]string +} + +func newMockEmailCodeStore() *mockEmailCodeStore { + return &mockEmailCodeStore{codes: make(map[string]string)} +} + +func (m *mockEmailCodeStore) Set(_ context.Context, email, code string) error { + m.codes[email] = code + return nil +} + +func (m *mockEmailCodeStore) Consume(_ context.Context, email, code string) (bool, error) { + v, ok := m.codes[email] + if !ok { + return false, nil + } + if v != code { + return false, nil + } + delete(m.codes, email) + return true, nil +} + +// fakeMailer is a no-op mail sender for tests. +type fakeMailer struct{} + +func (f *fakeMailer) SendEmail(_ smtp.Message) error { return nil } + +// mockRateLimiter is a controllable RateLimiter for unit tests. +type mockRateLimiter struct { + allowed bool + recorded int + resets int +} + +func newMockRateLimiter(allowed bool) *mockRateLimiter { + return &mockRateLimiter{allowed: allowed} +} + +func (m *mockRateLimiter) IsAllowed(_ context.Context, _ string) (bool, error) { + return m.allowed, nil +} + +func (m *mockRateLimiter) Record(_ context.Context, _ string) error { + m.recorded++ + return nil +} + +func (m *mockRateLimiter) Reset(_ context.Context, _ string) error { + m.resets++ + return nil +} + +// Verify mockRateLimiter satisfies the cache.RateLimiter interface at compile time. +var _ cache.RateLimiter = (*mockRateLimiter)(nil) + +// ---- LoginByPassword rate-limit tests ---- + +func TestLoginByPassword_RateLimit_BlocksWhenLimitExceeded(t *testing.T) { + hash, err := security.HashPassword("correct") + if err != nil { + t.Fatalf("hash: %v", err) + } + store := newMockUserStore() + _ = store.Create(context.Background(), &domain.UserModel{ + ID: "uid1", + Email: "user@example.com", + HashedPassword: &hash, + }) + + svc := &Service{ + userStore: store, + loginRateLimiter: newMockRateLimiter(false), // limiter says: blocked + } + + _, loginErr := svc.LoginByPassword(context.Background(), &identra_v1_pb.LoginByPasswordRequest{ + Email: "user@example.com", + Password: "correct", + }) + requireCode(t, loginErr, codes.ResourceExhausted) +} + +func TestLoginByPassword_RateLimit_RecordsOnFailure(t *testing.T) { + hash, err := security.HashPassword("correct") + if err != nil { + t.Fatalf("hash: %v", err) + } + store := newMockUserStore() + _ = store.Create(context.Background(), &domain.UserModel{ + ID: "uid1", + Email: "user@example.com", + HashedPassword: &hash, + }) + + limiter := newMockRateLimiter(true) + svc := &Service{ + userStore: store, + loginRateLimiter: limiter, + } + + _, loginErr := svc.LoginByPassword(context.Background(), &identra_v1_pb.LoginByPasswordRequest{ + Email: "user@example.com", + Password: "wrong-password", + }) + requireCode(t, loginErr, codes.Unauthenticated) + + if limiter.recorded != 1 { + t.Errorf("expected 1 recorded failure, got %d", limiter.recorded) + } + if limiter.resets != 0 { + t.Errorf("expected no resets on failure, got %d", limiter.resets) + } +} + +func TestLoginByPassword_RateLimit_ResetsOnSuccess(t *testing.T) { + hash, err := security.HashPassword("correct") + if err != nil { + t.Fatalf("hash: %v", err) + } + store := newMockUserStore() + _ = store.Create(context.Background(), &domain.UserModel{ + ID: "uid1", + Email: "user@example.com", + HashedPassword: &hash, + }) + + limiter := newMockRateLimiter(true) + svc := &Service{ + userStore: store, + tokenCfg: newTestTokenConfig(t), + loginRateLimiter: limiter, + } + + _, loginErr := svc.LoginByPassword(context.Background(), &identra_v1_pb.LoginByPasswordRequest{ + Email: "user@example.com", + Password: "correct", + }) + if loginErr != nil { + t.Fatalf("expected success, got %v", loginErr) + } + + if limiter.recorded != 0 { + t.Errorf("expected no recorded failures on success, got %d", limiter.recorded) + } + if limiter.resets != 1 { + t.Errorf("expected 1 reset on success, got %d", limiter.resets) + } +} + +// ---- LoginByEmailCode rate-limit tests ---- + +func TestLoginByEmailCode_RateLimit_BlocksWhenLimitExceeded(t *testing.T) { + emailStore := newMockEmailCodeStore() + _ = emailStore.Set(context.Background(), "user@example.com", "123456") + + svc := &Service{ + userStore: newMockUserStore(), + emailCodeStore: emailStore, + loginRateLimiter: newMockRateLimiter(false), + } + + _, err := svc.LoginByEmailCode(context.Background(), &identra_v1_pb.LoginByEmailCodeRequest{ + Email: "user@example.com", + Code: "123456", + }) + requireCode(t, err, codes.ResourceExhausted) +} + +func TestLoginByEmailCode_RateLimit_RecordsOnFailure(t *testing.T) { + emailStore := newMockEmailCodeStore() + _ = emailStore.Set(context.Background(), "user@example.com", "123456") + + limiter := newMockRateLimiter(true) + svc := &Service{ + userStore: newMockUserStore(), + emailCodeStore: emailStore, + loginRateLimiter: limiter, + } + + _, err := svc.LoginByEmailCode(context.Background(), &identra_v1_pb.LoginByEmailCodeRequest{ + Email: "user@example.com", + Code: "wrong", + }) + requireCode(t, err, codes.Unauthenticated) + + if limiter.recorded != 1 { + t.Errorf("expected 1 recorded failure, got %d", limiter.recorded) + } + if limiter.resets != 0 { + t.Errorf("expected no resets on failure, got %d", limiter.resets) + } +} + +func TestLoginByEmailCode_RateLimit_ResetsOnSuccess(t *testing.T) { + emailStore := newMockEmailCodeStore() + _ = emailStore.Set(context.Background(), "user@example.com", "123456") + + limiter := newMockRateLimiter(true) + svc := &Service{ + userStore: newMockUserStore(), + emailCodeStore: emailStore, + tokenCfg: newTestTokenConfig(t), + loginRateLimiter: limiter, + } + + _, err := svc.LoginByEmailCode(context.Background(), &identra_v1_pb.LoginByEmailCodeRequest{ + Email: "user@example.com", + Code: "123456", + }) + if err != nil { + t.Fatalf("expected success, got %v", err) + } + + if limiter.recorded != 0 { + t.Errorf("expected no recorded failures on success, got %d", limiter.recorded) + } + if limiter.resets != 1 { + t.Errorf("expected 1 reset on success, got %d", limiter.resets) + } +} + +// ---- SendLoginEmailCode rate-limit tests ---- + +func TestSendLoginEmailCode_RateLimit_BlocksWhenLimitExceeded(t *testing.T) { + svc := &Service{ + mailer: &fakeMailer{}, + emailCodeStore: newMockEmailCodeStore(), + sendCodeRateLimiter: newMockRateLimiter(false), + } + + _, err := svc.SendLoginEmailCode(context.Background(), &identra_v1_pb.SendLoginEmailCodeRequest{ + Email: "user@example.com", + }) + requireCode(t, err, codes.ResourceExhausted) +} + +func TestSendLoginEmailCode_RateLimit_RecordsOnAllowed(t *testing.T) { + limiter := newMockRateLimiter(true) + svc := &Service{ + mailer: &fakeMailer{}, + emailCodeStore: newMockEmailCodeStore(), + sendCodeRateLimiter: limiter, + } + + _, err := svc.SendLoginEmailCode(context.Background(), &identra_v1_pb.SendLoginEmailCodeRequest{ + Email: "user@example.com", + }) + if err != nil { + t.Fatalf("expected success, got %v", err) + } + + if limiter.recorded != 1 { + t.Errorf("expected 1 recorded send, got %d", limiter.recorded) + } +} diff --git a/internal/infrastructure/cache/redis_rate_limiter.go b/internal/infrastructure/cache/redis_rate_limiter.go new file mode 100644 index 0000000..41f6c8a --- /dev/null +++ b/internal/infrastructure/cache/redis_rate_limiter.go @@ -0,0 +1,97 @@ +package cache + +import ( + "context" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +// RateLimiter provides rate limiting and brute-force protection based on a +// per-key attempt counter stored in Redis. +type RateLimiter interface { + // IsAllowed returns true when the number of recorded attempts for key is + // strictly below the configured maximum. It does NOT modify the counter. + IsAllowed(ctx context.Context, key string) (bool, error) + + // Record increments the attempt counter for key. On the very first + // increment the key's TTL is set to the configured window duration. + Record(ctx context.Context, key string) error + + // Reset deletes the attempt counter for key (call after a successful + // action to give the user a fresh window). + Reset(ctx context.Context, key string) error +} + +// NewRedisRateLimiter creates a Redis-backed RateLimiter. +// +// - prefix key prefix (e.g. "identra:rl:login:") +// - maxAttempts maximum number of attempts allowed within window before +// IsAllowed returns false +// - window duration after which the counter automatically expires +func NewRedisRateLimiter(rdb redis.UniversalClient, prefix string, maxAttempts int, window time.Duration) (RateLimiter, error) { + if rdb == nil { + return nil, errors.New("redis client is required for rate limiter") + } + if maxAttempts <= 0 { + return nil, errors.New("maxAttempts must be positive") + } + if window <= 0 { + return nil, errors.New("window must be positive") + } + return &redisRateLimiter{ + rdb: rdb, + prefix: prefix, + maxAttempts: int64(maxAttempts), + window: window, + }, nil +} + +type redisRateLimiter struct { + rdb redis.UniversalClient + prefix string + maxAttempts int64 + window time.Duration +} + +func (r *redisRateLimiter) fullKey(key string) string { + return r.prefix + key +} + +// isAllowedScript checks whether the current counter is below maxAttempts. +// KEYS[1]: counter key +// ARGV[1]: maxAttempts +var isAllowedScript = redis.NewScript(` +local v = redis.call("GET", KEYS[1]) +if not v then return 1 end +if tonumber(v) < tonumber(ARGV[1]) then return 1 end +return 0 +`) + +func (r *redisRateLimiter) IsAllowed(ctx context.Context, key string) (bool, error) { + res, err := isAllowedScript.Run(ctx, r.rdb, []string{r.fullKey(key)}, r.maxAttempts).Int64() + if err != nil { + return false, err + } + return res == 1, nil +} + +// recordScript increments the counter and sets the TTL on the first increment. +// KEYS[1]: counter key +// ARGV[1]: window in seconds +var recordScript = redis.NewScript(` +local n = redis.call("INCR", KEYS[1]) +if n == 1 then + redis.call("EXPIRE", KEYS[1], tonumber(ARGV[1])) +end +return n +`) + +func (r *redisRateLimiter) Record(ctx context.Context, key string) error { + return recordScript.Run(ctx, r.rdb, []string{r.fullKey(key)}, int(r.window.Seconds())).Err() +} + +func (r *redisRateLimiter) Reset(ctx context.Context, key string) error { + return r.rdb.Del(ctx, r.fullKey(key)).Err() +}