diff --git a/Makefile b/Makefile index 5a22bbc..027b78e 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .SILENT: -.PHONY: lint test converage up +.PHONY: lint test race converage up lint: go tool -modfile=go.tool.mod golangci-lint run ./... @@ -7,6 +7,9 @@ lint: test: go test ./... -coverprofile cover.out +race: + go test ./... -race + coverage: go tool cover -html cover.out diff --git a/README.md b/README.md index edc6495..0ae5a15 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Running using Docker: $ docker compose up -d ``` -This will start the tgfeed server on port 8080 (can be changed via HTTP_SERVER_PORT environment variable) and a Redis instance for caching. +This will start the tgfeed server on port 8080 (can be changed via HTTP_SERVER_PORT environment variable). ## API Endpoints @@ -65,4 +65,4 @@ Replace `channelname` with the username of the Telegram channel you want to foll ## Docker Compose -The service is preconfigured with Redis for caching. You can customize the configuration through environment variables in the `compose.yaml` file. +The service can be run using Docker Compose. Customize the configuration through environment variables in the `compose.yaml` file. You can uncomment some config values there if you want to keep cache in Redis. Otherwise it will be kept in RAM (by default). \ No newline at end of file diff --git a/cmd/tgfeed/main.go b/cmd/tgfeed/main.go index c485d9b..29e6a63 100644 --- a/cmd/tgfeed/main.go +++ b/cmd/tgfeed/main.go @@ -44,10 +44,6 @@ func main() { redisHost := os.Getenv("REDIS_HOST") - if redisHost == "" { - redisHost = "redis" - } - // Configure IP filtering allowedIPsStr := os.Getenv("ALLOWED_IPS") trustProxy := os.Getenv("REVERSE_PROXY") == "true" || os.Getenv("REVERSE_PROXY") == "1" @@ -66,26 +62,31 @@ func main() { logger.Info("IP filtering enabled", "allowed_ips", allowedIPsStr, "trust_proxy", trustProxy) } - // Initialize Redis cache - redisClient, err := cache.NewRedisClient(ctx, fmt.Sprintf("%s:6379", redisHost)) + var c cache.Cache - if err != nil { - logger.Error("Failed to connect to Redis", "error", err) - os.Exit(1) + if redisHost == "" { + c = cache.NewMemoryClient() + } else { + redisClient, err := cache.NewRedisClient(ctx, fmt.Sprintf("%s:6379", redisHost)) + + if err != nil { + logger.Error("Failed to connect to Redis", "error", err) + os.Exit(1) + } + + c = redisClient } - defer redisClient.Close() + defer c.Close() scraper := feed.NewDefaultScraper() generator := feed.NewGenerator() // Initialize and run the HTTP server - server := rest.NewServer(redisClient, scraper, generator, ipFilter, port) + server := rest.NewServer(c, scraper, generator, ipFilter, port) if err := server.Run(ctx); err != nil { logger.Error("Server error", "error", err) os.Exit(1) } - - logger.Info("Server exited gracefully") } diff --git a/compose.yaml b/compose.yaml index 2e64d00..bf49d4b 100644 --- a/compose.yaml +++ b/compose.yaml @@ -5,7 +5,8 @@ services: environment: - TZ=Europe/Moscow - HTTP_SERVER_PORT=8080 - - REDIS_HOST=redis + # Uncomment the variable if you want to keep cache in Redis + # - REDIS_HOST=redis # You can specify a custom HTML message for cases when the scraper # could not obtain the post content from t.me. # Use {postDeepLink} and {postURL} as placeholders for post links. @@ -23,26 +24,27 @@ services: # - REVERSE_PROXY=true ports: - 8080:8080 - depends_on: - - redis + # Uncomment depends_on if you want to keep cache in Redis + # depends_on: + # - redis restart: unless-stopped +# Uncomment everything below if you want to keep cache in Redis +# redis: +# container_name: redis +# image: redis:alpine +# environment: +# - TZ=Europe/Moscow +# ports: +# - 6379:6379 +# volumes: +# - redis-data:/data +# healthcheck: +# test: ["CMD", "redis-cli", "ping"] +# interval: 30s +# timeout: 10s +# retries: 5 +# start_period: 5s +# restart: unless-stopped - redis: - container_name: redis - image: redis:alpine - environment: - - TZ=Europe/Moscow - ports: - - 6379:6379 - volumes: - - redis-data:/data - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 30s - timeout: 10s - retries: 5 - start_period: 5s - restart: unless-stopped - -volumes: - redis-data: +# volumes: +# redis-data: diff --git a/internal/cache/cache.go b/internal/cache/cache.go index a8bb11b..c451ce8 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -2,9 +2,13 @@ package cache import ( "context" + "errors" "time" ) +// ErrCacheMiss is returned when a key is not found in the cache +var ErrCacheMiss = errors.New("cache miss") + // Cache defines the interface for caching data type Cache interface { // Get retrieves a value from the cache diff --git a/internal/cache/memory.go b/internal/cache/memory.go new file mode 100644 index 0000000..c01ce5c --- /dev/null +++ b/internal/cache/memory.go @@ -0,0 +1,76 @@ +package cache + +import ( + "context" + "maps" + "sync" + "time" +) + +// MemoryCache implements the Cache interface using RAM +type MemoryCache struct { + cache map[string][]byte + mu sync.RWMutex +} + +// NewMemoryClient creates a new cache client +func NewMemoryClient() *MemoryCache { + cache := make(map[string][]byte, 100) + + return &MemoryCache{cache: cache} +} + +// Get retrieves a value from memory +func (c *MemoryCache) Get(_ context.Context, key string) ([]byte, error) { + c.mu.RLock() + val, exists := c.cache[key] + c.mu.RUnlock() + + if !exists { + return nil, ErrCacheMiss + } + + return val, nil +} + +// Set stores a value in memory with the specified TTL +// If ttl is 0, the value will not be cached +func (c *MemoryCache) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { + if ttl == 0 { + return nil // Skip caching if TTL is 0 + } + + c.mu.Lock() + c.cache[key] = value + c.mu.Unlock() + + go func() { + time.Sleep(ttl) + c.mu.Lock() + delete(c.cache, key) + c.mu.Unlock() + }() + + return nil +} + +// Close releases the memory +func (c *MemoryCache) Close() error { + c.mu.Lock() + clear(c.cache) + c.mu.Unlock() + + return nil +} + +// snapshot returns a copy of the cache for testing purposes +func (c *MemoryCache) snapshot() map[string][]byte { + c.mu.RLock() + defer c.mu.RUnlock() + + result := make(map[string][]byte, len(c.cache)) + + maps.Copy(result, c.cache) + + return result +} diff --git a/internal/cache/memory_test.go b/internal/cache/memory_test.go new file mode 100644 index 0000000..022b74e --- /dev/null +++ b/internal/cache/memory_test.go @@ -0,0 +1,202 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryCacheWrite(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + prepare func() *MemoryCache + want map[string][]byte + }{ + { + name: "should write a new entry in empty cache", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: map[string][]byte{"test": {42}}, + }, + { + name: "should write a new entry in non-empty cache", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test1", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + if err := c.Set(ctx, "test2", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: map[string][]byte{"test1": {42}, "test2": {42}}, + }, + { + name: "should skip caching if TTL is 0", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test", []byte{42}, 0); err != nil { + t.Error(err) + } + + return c + }, + want: map[string][]byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := tt.prepare() + snapshot := c.snapshot() + + assert.Equal(t, len(tt.want), len(snapshot)) + assert.Equal(t, tt.want, snapshot) + }) + } +} + +func TestMemoryCacheRead(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + key string + prepare func() *MemoryCache + want []byte + err error + }{ + { + name: "should read from cache", + key: "test", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: []byte{42}, + err: nil, + }, + { + name: "should respond with ErrCacheMiss when no such entry exists", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test1", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + if err := c.Set(ctx, "test2", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: nil, + err: ErrCacheMiss, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := tt.prepare() + v, err := c.Get(ctx, tt.key) + + assert.Equal(t, tt.want, v) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestMemoryCacheClose(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + prepare func() *MemoryCache + want map[string][]byte + }{ + { + name: "should clear cache on close", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test", []byte{42}, 100*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: map[string][]byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := tt.prepare() + err := c.Close() + require.NoError(t, err) + snapshot := c.snapshot() + + assert.Equal(t, len(tt.want), len(snapshot)) + assert.Equal(t, tt.want, snapshot) + }) + } +} + +func TestMemoryCacheTTL(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + prepare func() *MemoryCache + want map[string][]byte + }{ + { + name: "should remove a cache entry after TTL", + prepare: func() *MemoryCache { + c := NewMemoryClient() + + if err := c.Set(ctx, "test", []byte{42}, 5*time.Millisecond); err != nil { + t.Error(err) + } + + return c + }, + want: map[string][]byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := tt.prepare() + + time.Sleep(10 * time.Millisecond) + + snapshot := c.snapshot() + assert.Equal(t, len(tt.want), len(snapshot)) + assert.Equal(t, tt.want, snapshot) + }) + } +} diff --git a/internal/cache/redis.go b/internal/cache/redis.go index d5f26cb..9d52e09 100644 --- a/internal/cache/redis.go +++ b/internal/cache/redis.go @@ -2,15 +2,11 @@ package cache import ( "context" - "errors" "time" "github.com/redis/go-redis/v9" ) -// ErrCacheMiss is returned when a key is not found in the cache -var ErrCacheMiss = errors.New("cache miss") - // RedisCache implements the Cache interface using Redis type RedisCache struct { client *redis.Client