Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 60 additions & 15 deletions internal/cache/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,58 @@ import (
"time"
)

type cacheItem struct {
value []byte
expiresAt time.Time
}

// MemoryCache implements the Cache interface using RAM
type MemoryCache struct {
cache map[string][]byte
cache map[string]cacheItem
mu sync.RWMutex
stop chan struct{}
}

// NewMemoryClient creates a new cache client
func NewMemoryClient() *MemoryCache {
cache := make(map[string][]byte, 100)
mc := &MemoryCache{
cache: make(map[string]cacheItem, 100),
stop: make(chan struct{}),
}

go func() {
for {
select {
case <-mc.stop:
return
case <-time.After(1 * time.Minute):
mc.cleanup(time.Now())
}
}
}()

return &MemoryCache{cache: cache}
return mc
}

// Get retrieves a value from memory
func (c *MemoryCache) Get(_ context.Context, key string) ([]byte, error) {
c.mu.RLock()
val, exists := c.cache[key]
item, exists := c.cache[key]
c.mu.RUnlock()

if !exists {
return nil, ErrCacheMiss
}

return val, nil
if item.expiresAt.Before(time.Now()) {
c.mu.Lock()
delete(c.cache, key)
c.mu.Unlock()

return nil, ErrCacheMiss
}
Comment thread
nDmitry marked this conversation as resolved.

return item.value, nil
}

// Set stores a value in memory with the specified TTL
Expand All @@ -41,34 +69,51 @@ func (c *MemoryCache) Set(_ context.Context, key string, value []byte, ttl time.
}

c.mu.Lock()
c.cache[key] = value
c.cache[key] = cacheItem{value: value, expiresAt: time.Now().Add(ttl)}
c.mu.Unlock()

go func() {
time.Sleep(ttl)
c.mu.Lock()
delete(c.cache, key)
c.mu.Unlock()
}()

return nil
}

// Close releases the memory
func (c *MemoryCache) Close() error {
close(c.stop)

c.mu.Lock()
clear(c.cache)
c.mu.Unlock()

return nil
}
Comment thread
nDmitry marked this conversation as resolved.
Comment thread
nDmitry marked this conversation as resolved.

func (c *MemoryCache) cleanup(now time.Time) {
c.mu.RLock()

var expiredKeys []string

for key, item := range c.cache {
if item.expiresAt.Before(now) || item.expiresAt.Equal(now) {
expiredKeys = append(expiredKeys, key)
}
}

c.mu.RUnlock()

if len(expiredKeys) > 0 {
c.mu.Lock()
for _, key := range expiredKeys {
delete(c.cache, key)
}
c.mu.Unlock()
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// snapshot returns a copy of the cache for testing purposes
func (c *MemoryCache) snapshot() map[string][]byte {
func (c *MemoryCache) snapshot() map[string]cacheItem {
c.mu.RLock()
defer c.mu.RUnlock()

result := make(map[string][]byte, len(c.cache))
result := make(map[string]cacheItem, len(c.cache))

maps.Copy(result, c.cache)

Expand Down
152 changes: 85 additions & 67 deletions internal/cache/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cache
import (
"context"
"testing"
"testing/synctest"
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"time"

"github.com/stretchr/testify/assert"
Expand All @@ -13,62 +14,39 @@ func TestMemoryCacheWrite(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
prepare func() *MemoryCache
want map[string][]byte
name string
value []byte
duration time.Duration
wantLen int
}{
{
name: "should write a new entry in empty cache",
prepare: func() *MemoryCache {
c := NewMemoryClient()

if err := c.Set(ctx, "test", []byte{42}, 100*time.Millisecond); err != nil {
t.Error(err)
}

return c
},
want: map[string][]byte{"test": {42}},
name: "should write a new entry in empty cache",
value: []byte{42},
duration: 100 * time.Millisecond,
wantLen: 1,
},
{
name: "should write a new entry in non-empty cache",
prepare: func() *MemoryCache {
c := NewMemoryClient()

if err := c.Set(ctx, "test1", []byte{42}, 100*time.Millisecond); err != nil {
t.Error(err)
}

if err := c.Set(ctx, "test2", []byte{42}, 100*time.Millisecond); err != nil {
t.Error(err)
}

return c
},
want: map[string][]byte{"test1": {42}, "test2": {42}},
name: "should skip caching if TTL is 0",
value: []byte{42},
duration: 0,
wantLen: 0,
},
{
name: "should skip caching if TTL is 0",
prepare: func() *MemoryCache {
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
c := NewMemoryClient()
defer c.Close()

if err := c.Set(ctx, "test", []byte{42}, 0); err != nil {
if err := c.Set(ctx, "test", tt.value, tt.duration); err != nil {
t.Error(err)
}

return c
},
want: map[string][]byte{},
},
}
snapshot := c.snapshot()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.prepare()
snapshot := c.snapshot()

assert.Equal(t, len(tt.want), len(snapshot))
assert.Equal(t, tt.want, snapshot)
assert.Len(t, snapshot, tt.wantLen)
})
})
}
}
Expand Down Expand Up @@ -100,6 +78,7 @@ func TestMemoryCacheRead(t *testing.T) {
},
{
name: "should respond with ErrCacheMiss when no such entry exists",
key: "test3",
prepare: func() *MemoryCache {
c := NewMemoryClient()

Expand All @@ -120,11 +99,15 @@ func TestMemoryCacheRead(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.prepare()
v, err := c.Get(ctx, tt.key)
synctest.Test(t, func(t *testing.T) {
c := tt.prepare()
defer c.Close()

v, err := c.Get(ctx, tt.key)

assert.Equal(t, tt.want, v)
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.want, v)
assert.Equal(t, tt.err, err)
})
})
}
}
Expand All @@ -135,7 +118,7 @@ func TestMemoryCacheClose(t *testing.T) {
tests := []struct {
name string
prepare func() *MemoryCache
want map[string][]byte
want map[string]cacheItem
}{
{
name: "should clear cache on close",
Expand All @@ -148,19 +131,21 @@ func TestMemoryCacheClose(t *testing.T) {

return c
},
want: map[string][]byte{},
want: map[string]cacheItem{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.prepare()
err := c.Close()
require.NoError(t, err)
snapshot := c.snapshot()

assert.Equal(t, len(tt.want), len(snapshot))
assert.Equal(t, tt.want, snapshot)
synctest.Test(t, func(t *testing.T) {
c := tt.prepare()
err := c.Close()
require.NoError(t, err)
snapshot := c.snapshot()

assert.Equal(t, len(tt.want), len(snapshot))
assert.Equal(t, tt.want, snapshot)
})
})
}
}
Expand All @@ -171,32 +156,65 @@ func TestMemoryCacheTTL(t *testing.T) {
tests := []struct {
name string
prepare func() *MemoryCache
want map[string][]byte
want map[string]cacheItem
}{
{
name: "should remove a cache entry after TTL",
name: "should remove a cache entry after TTL via cleanup",
prepare: func() *MemoryCache {
c := NewMemoryClient()

if err := c.Set(ctx, "test", []byte{42}, 5*time.Millisecond); err != nil {
if err := c.Set(ctx, "test", []byte{42}, 30*time.Second); err != nil {
t.Error(err)
}

return c
},
want: map[string][]byte{},
want: map[string]cacheItem{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.prepare()
synctest.Test(t, func(t *testing.T) {
c := tt.prepare()
defer c.Close()

time.Sleep(10 * time.Millisecond)
time.Sleep(2 * time.Minute)
synctest.Wait()

snapshot := c.snapshot()
assert.Equal(t, len(tt.want), len(snapshot))
assert.Equal(t, tt.want, snapshot)
snapshot := c.snapshot()
assert.Equal(t, len(tt.want), len(snapshot))
assert.Equal(t, tt.want, snapshot)
})
})
}

t.Run("should remove expired entry on Get before cleanup runs", func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
c := NewMemoryClient()
defer c.Close()

// Set with TTL less than cleanup interval (1 minute)
err := c.Set(ctx, "test", []byte{42}, 30*time.Second)
require.NoError(t, err)

// Verify entry exists
val, err := c.Get(ctx, "test")
require.NoError(t, err)
assert.Equal(t, []byte{42}, val)

// Wait for TTL to expire (but less than cleanup interval)
time.Sleep(31 * time.Second)

// Get should detect expiration and remove entry
val, err = c.Get(ctx, "test")
assert.Equal(t, ErrCacheMiss, err)
assert.Nil(t, val)

// Verify entry was removed from cache
snapshot := c.snapshot()
assert.Empty(t, snapshot)
assert.Equal(t, map[string]cacheItem{}, snapshot)
})
})
}