From 0be466ab556270934635686ff396bcb5f749343e Mon Sep 17 00:00:00 2001 From: Dmitry Nikitenko Date: Tue, 13 Jan 2026 15:22:59 +0600 Subject: [PATCH] Refactor MemoryCache to use ExpiresAt (less goroutines) --- internal/cache/memory.go | 75 +++++++++++++---- internal/cache/memory_test.go | 152 +++++++++++++++++++--------------- 2 files changed, 145 insertions(+), 82 deletions(-) diff --git a/internal/cache/memory.go b/internal/cache/memory.go index c01ce5c..fe6c189 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -7,30 +7,58 @@ import ( "time" ) +type cacheItem struct { + value []byte + expiresAt time.Time +} + // MemoryCache implements the Cache interface using RAM type MemoryCache struct { - cache map[string][]byte + cache map[string]cacheItem mu sync.RWMutex + stop chan struct{} } // NewMemoryClient creates a new cache client func NewMemoryClient() *MemoryCache { - cache := make(map[string][]byte, 100) + mc := &MemoryCache{ + cache: make(map[string]cacheItem, 100), + stop: make(chan struct{}), + } + + go func() { + for { + select { + case <-mc.stop: + return + case <-time.After(1 * time.Minute): + mc.cleanup(time.Now()) + } + } + }() - return &MemoryCache{cache: cache} + return mc } // Get retrieves a value from memory func (c *MemoryCache) Get(_ context.Context, key string) ([]byte, error) { c.mu.RLock() - val, exists := c.cache[key] + item, exists := c.cache[key] c.mu.RUnlock() if !exists { return nil, ErrCacheMiss } - return val, nil + if item.expiresAt.Before(time.Now()) { + c.mu.Lock() + delete(c.cache, key) + c.mu.Unlock() + + return nil, ErrCacheMiss + } + + return item.value, nil } // Set stores a value in memory with the specified TTL @@ -41,21 +69,16 @@ func (c *MemoryCache) Set(_ context.Context, key string, value []byte, ttl time. } c.mu.Lock() - c.cache[key] = value + c.cache[key] = cacheItem{value: value, expiresAt: time.Now().Add(ttl)} c.mu.Unlock() - go func() { - time.Sleep(ttl) - c.mu.Lock() - delete(c.cache, key) - c.mu.Unlock() - }() - return nil } // Close releases the memory func (c *MemoryCache) Close() error { + close(c.stop) + c.mu.Lock() clear(c.cache) c.mu.Unlock() @@ -63,12 +86,34 @@ func (c *MemoryCache) Close() error { return nil } +func (c *MemoryCache) cleanup(now time.Time) { + c.mu.RLock() + + var expiredKeys []string + + for key, item := range c.cache { + if item.expiresAt.Before(now) || item.expiresAt.Equal(now) { + expiredKeys = append(expiredKeys, key) + } + } + + c.mu.RUnlock() + + if len(expiredKeys) > 0 { + c.mu.Lock() + for _, key := range expiredKeys { + delete(c.cache, key) + } + c.mu.Unlock() + } +} + // snapshot returns a copy of the cache for testing purposes -func (c *MemoryCache) snapshot() map[string][]byte { +func (c *MemoryCache) snapshot() map[string]cacheItem { c.mu.RLock() defer c.mu.RUnlock() - result := make(map[string][]byte, len(c.cache)) + result := make(map[string]cacheItem, len(c.cache)) maps.Copy(result, c.cache) diff --git a/internal/cache/memory_test.go b/internal/cache/memory_test.go index 022b74e..c7d941c 100644 --- a/internal/cache/memory_test.go +++ b/internal/cache/memory_test.go @@ -3,6 +3,7 @@ package cache import ( "context" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -13,62 +14,39 @@ func TestMemoryCacheWrite(t *testing.T) { ctx := context.Background() tests := []struct { - name string - prepare func() *MemoryCache - want map[string][]byte + name string + value []byte + duration time.Duration + wantLen int }{ { - name: "should write a new entry in empty cache", - prepare: func() *MemoryCache { - c := NewMemoryClient() - - if err := c.Set(ctx, "test", []byte{42}, 100*time.Millisecond); err != nil { - t.Error(err) - } - - return c - }, - want: map[string][]byte{"test": {42}}, + name: "should write a new entry in empty cache", + value: []byte{42}, + duration: 100 * time.Millisecond, + wantLen: 1, }, { - name: "should write a new entry in non-empty cache", - prepare: func() *MemoryCache { - c := NewMemoryClient() - - if err := c.Set(ctx, "test1", []byte{42}, 100*time.Millisecond); err != nil { - t.Error(err) - } - - if err := c.Set(ctx, "test2", []byte{42}, 100*time.Millisecond); err != nil { - t.Error(err) - } - - return c - }, - want: map[string][]byte{"test1": {42}, "test2": {42}}, + name: "should skip caching if TTL is 0", + value: []byte{42}, + duration: 0, + wantLen: 0, }, - { - name: "should skip caching if TTL is 0", - prepare: func() *MemoryCache { + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { c := NewMemoryClient() + defer c.Close() - if err := c.Set(ctx, "test", []byte{42}, 0); err != nil { + if err := c.Set(ctx, "test", tt.value, tt.duration); err != nil { t.Error(err) } - return c - }, - want: map[string][]byte{}, - }, - } + snapshot := c.snapshot() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := tt.prepare() - snapshot := c.snapshot() - - assert.Equal(t, len(tt.want), len(snapshot)) - assert.Equal(t, tt.want, snapshot) + assert.Len(t, snapshot, tt.wantLen) + }) }) } } @@ -100,6 +78,7 @@ func TestMemoryCacheRead(t *testing.T) { }, { name: "should respond with ErrCacheMiss when no such entry exists", + key: "test3", prepare: func() *MemoryCache { c := NewMemoryClient() @@ -120,11 +99,15 @@ func TestMemoryCacheRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := tt.prepare() - v, err := c.Get(ctx, tt.key) + synctest.Test(t, func(t *testing.T) { + c := tt.prepare() + defer c.Close() + + v, err := c.Get(ctx, tt.key) - assert.Equal(t, tt.want, v) - assert.Equal(t, tt.err, err) + assert.Equal(t, tt.want, v) + assert.Equal(t, tt.err, err) + }) }) } } @@ -135,7 +118,7 @@ func TestMemoryCacheClose(t *testing.T) { tests := []struct { name string prepare func() *MemoryCache - want map[string][]byte + want map[string]cacheItem }{ { name: "should clear cache on close", @@ -148,19 +131,21 @@ func TestMemoryCacheClose(t *testing.T) { return c }, - want: map[string][]byte{}, + want: map[string]cacheItem{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := tt.prepare() - err := c.Close() - require.NoError(t, err) - snapshot := c.snapshot() - - assert.Equal(t, len(tt.want), len(snapshot)) - assert.Equal(t, tt.want, snapshot) + synctest.Test(t, func(t *testing.T) { + c := tt.prepare() + err := c.Close() + require.NoError(t, err) + snapshot := c.snapshot() + + assert.Equal(t, len(tt.want), len(snapshot)) + assert.Equal(t, tt.want, snapshot) + }) }) } } @@ -171,32 +156,65 @@ func TestMemoryCacheTTL(t *testing.T) { tests := []struct { name string prepare func() *MemoryCache - want map[string][]byte + want map[string]cacheItem }{ { - name: "should remove a cache entry after TTL", + name: "should remove a cache entry after TTL via cleanup", prepare: func() *MemoryCache { c := NewMemoryClient() - if err := c.Set(ctx, "test", []byte{42}, 5*time.Millisecond); err != nil { + if err := c.Set(ctx, "test", []byte{42}, 30*time.Second); err != nil { t.Error(err) } return c }, - want: map[string][]byte{}, + want: map[string]cacheItem{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := tt.prepare() + synctest.Test(t, func(t *testing.T) { + c := tt.prepare() + defer c.Close() - time.Sleep(10 * time.Millisecond) + time.Sleep(2 * time.Minute) + synctest.Wait() - snapshot := c.snapshot() - assert.Equal(t, len(tt.want), len(snapshot)) - assert.Equal(t, tt.want, snapshot) + snapshot := c.snapshot() + assert.Equal(t, len(tt.want), len(snapshot)) + assert.Equal(t, tt.want, snapshot) + }) }) } + + t.Run("should remove expired entry on Get before cleanup runs", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + c := NewMemoryClient() + defer c.Close() + + // Set with TTL less than cleanup interval (1 minute) + err := c.Set(ctx, "test", []byte{42}, 30*time.Second) + require.NoError(t, err) + + // Verify entry exists + val, err := c.Get(ctx, "test") + require.NoError(t, err) + assert.Equal(t, []byte{42}, val) + + // Wait for TTL to expire (but less than cleanup interval) + time.Sleep(31 * time.Second) + + // Get should detect expiration and remove entry + val, err = c.Get(ctx, "test") + assert.Equal(t, ErrCacheMiss, err) + assert.Nil(t, val) + + // Verify entry was removed from cache + snapshot := c.snapshot() + assert.Empty(t, snapshot) + assert.Equal(t, map[string]cacheItem{}, snapshot) + }) + }) }