diff --git a/.gitignore b/.gitignore index 7e648c3b..e130cc1d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ dist/ .idea/ .vscode/ workspace/ +.vscode/ +.cursor/ __pycache__/ java-sdk/onfs/feature-store-client/target/ java-sdk/onfs/feature-store-flink-connector-sdk-flink1x/target/ @@ -15,6 +17,15 @@ java-sdk/onfs/feature-store-core/target/ .internal-configs/ .dev-toggle-state .go.mod.appended + + +flashring/performance_results.csv +flashring/mem.prof +flashring/flashring +flashring/flashringtest + + + horizon/pem/*.pem horizon/pem/*.key horizon/configs/*.pem diff --git a/flashring/README.md b/flashring/README.md new file mode 100644 index 00000000..f006c3f6 --- /dev/null +++ b/flashring/README.md @@ -0,0 +1,461 @@ +# High-Performance Append-Only File Writing Benchmarks + +This package provides comprehensive benchmarks for append-only file writing in Go, focusing on maximum throughput and optimal page-aligned buffering strategies. + +## Features + +- **Page-Aligned Buffering**: Custom buffer implementation that flushes only when page boundaries are reached +- **Multiple Buffer Sizes**: Tests with 4KB, 8KB, 16KB, and 64KB buffers aligned to system page sizes +- **Memory-Mapped I/O**: Uses mmap for ultra-fast sequential writes +- **Direct Write Comparison**: Benchmarks unbuffered writes for baseline comparison +- **Concurrent Write Testing**: Thread-safe concurrent write benchmarks +- **Multiple Record Sizes**: Tests with small (128B), medium (1KB), and large (8KB) records + +## Quick Start + +### Run Visual Benchmarks +```bash +go run main.go +``` + +This will run comprehensive benchmarks showing: +- Throughput in MB/s +- Records per second +- Duration comparisons +- Performance recommendations + +## Test Results & Analysis + +### Hardware Configuration +- **CPU**: AMD Ryzen 7 9800X3D 8-Core Processor +- **OS**: Linux (kernel 6.11.0-26-generic) +- **Go Version**: 1.22.12 +- **Architecture**: amd64 +- **Storage**: SSD with ext4 filesystem + +### Visual Benchmark Results + +``` +=== Append-Only File Writing Benchmarks === + +=== Small Records (128B x 100K) === +Method : Duration | MB/s | Records/s | Total MB +-------------------------------------------------------------------------------- +Direct Write : 50.8ms | 240.07 | 1,966,655 | 12.21 +Buffered (4K) : 9.6ms | 1,266.93 | 10,378,707 | 12.21 +Buffered (8K) : 9.1ms | 1,337.27 | 10,954,887 | 12.21 +Buffered (16K) : 9.2ms | 1,327.55 | 10,875,326 | 12.21 +Buffered (64K) : 8.6ms | 1,415.92 | 11,599,245 | 12.21 +Page-Aligned (4K) : 10.5ms | 1,165.22 | 9,545,493 | 12.21 +Page-Aligned (8K) : 9.8ms | 1,244.86 | 10,197,862 | 12.21 +Page-Aligned (16K) : 10.4ms | 1,176.88 | 9,641,008 | 12.21 +Page-Aligned (64K) : 9.5ms | 1,281.76 | 10,500,163 | 12.21 +Memory Mapped : 10.4ms | 1,168.32 | 9,570,867 | 12.21 + +=== Medium Records (1KB x 50K) === +Method : Duration | MB/s | Records/s | Total MB +-------------------------------------------------------------------------------- +Direct Write : 43.1ms | 1,134.06 | 1,161,276 | 48.83 +Buffered (4K) : 24.1ms | 2,025.50 | 2,074,108 | 48.83 +Buffered (8K) : 21.1ms | 2,308.94 | 2,364,359 | 48.83 +Buffered (16K) : 19.8ms | 2,464.45 | 2,523,597 | 48.83 +Buffered (64K) : 19.9ms | 2,458.15 | 2,517,143 | 48.83 +Page-Aligned (4K) : 24.8ms | 1,970.50 | 2,017,793 | 48.83 +Page-Aligned (8K) : 21.6ms | 2,262.77 | 2,317,076 | 48.83 +Page-Aligned (16K) : 21.1ms | 2,311.49 | 2,366,963 | 48.83 +Page-Aligned (64K) : 19.5ms | 2,499.25 | 2,559,228 | 48.83 +Memory Mapped : 23.8ms | 2,054.37 | 2,103,677 | 48.83 + +=== Large Records (8KB x 10K) === +Method : Duration | MB/s | Records/s | Total MB +-------------------------------------------------------------------------------- +Direct Write : 31.3ms | 2,496.41 | 319,540 | 78.12 +Buffered (4K) : 31.9ms | 2,450.08 | 313,610 | 78.12 +Buffered (8K) : 32.8ms | 2,384.48 | 305,213 | 78.12 +Buffered (16K) : 30.6ms | 2,551.66 | 326,613 | 78.12 +Buffered (64K) : 29.0ms | 2,693.30 | 344,743 | 78.12 +Page-Aligned (4K) : 31.6ms | 2,473.40 | 316,595 | 78.12 +Page-Aligned (8K) : 31.8ms | 2,457.32 | 314,537 | 78.12 +Page-Aligned (16K) : 30.3ms | 2,576.79 | 329,829 | 78.12 +Page-Aligned (64K) : 29.4ms | 2,655.21 | 339,867 | 78.12 +Memory Mapped : 35.4ms | 2,207.78 | 282,596 | 78.12 +``` + +### Go Benchmark Results + +``` +goos: linux +goarch: amd64 +pkg: github.com/Meesho/BharatMLStack/ssd-cache +cpu: AMD Ryzen 7 9800X3D 8-Core Processor + +BenchmarkDirectWrite-8 2359388 513.5 ns/op 1994.02 MB/s 0 B/op 0 allocs/op +BenchmarkPageAligned4K-8 4910527 238.6 ns/op 4290.94 MB/s 0 B/op 0 allocs/op +BenchmarkPageAligned16K-8 6308680 188.0 ns/op 5446.73 MB/s 0 B/op 0 allocs/op +BenchmarkPageAligned64K-8 6850387 176.4 ns/op 5803.96 MB/s 0 B/op 0 allocs/op +BenchmarkMemoryMapped-8 4761464 246.8 ns/op 4148.75 MB/s 0 B/op 0 allocs/op + +BenchmarkSmallRecords/DirectWrite-8 3071392 387.8 ns/op 330.08 MB/s 0 B/op 0 allocs/op +BenchmarkSmallRecords/PageAligned16K-8 36121743 32.68 ns/op 3916.19 MB/s 0 B/op 0 allocs/op +BenchmarkMediumRecords/DirectWrite-8 2346501 516.5 ns/op 1982.42 MB/s 0 B/op 0 allocs/op +BenchmarkMediumRecords/PageAligned16K-8 6304753 188.8 ns/op 5422.59 MB/s 0 B/op 0 allocs/op +BenchmarkLargeRecords/DirectWrite-8 710790 1514 ns/op 5409.65 MB/s 0 B/op 0 allocs/op +BenchmarkLargeRecords/PageAligned16K-8 757474 1431 ns/op 5723.57 MB/s 0 B/op 0 allocs/op +BenchmarkConcurrentWrites-8 5787453 204.3 ns/op 5012.58 MB/s 0 B/op 0 allocs/op +``` + +### Performance Analysis + +#### Key Findings + +1. **Page-Aligned Buffers Dominate**: The page-aligned 64KB buffer achieved the highest throughput at **5,803.96 MB/s** +2. **Buffer Size Sweet Spot**: 16KB-64KB buffers provide optimal performance across all record sizes +3. **Zero Memory Allocations**: All implementations achieve zero heap allocations per operation +4. **Consistent Performance**: Page-aligned buffers maintain high performance across different record sizes + +#### Record Size Impact + +| Record Size | Best Method | Peak Throughput | Performance Gain vs Direct | +|-------------|-------------|-----------------|----------------------------| +| Small (128B) | Buffered 64K | 1,415.92 MB/s | **5.9x faster** | +| Medium (1KB) | Page-Aligned 64K | 2,499.25 MB/s | **2.2x faster** | +| Large (8KB) | Buffered 64K | 2,693.30 MB/s | **1.08x faster** | + +#### Latency Analysis (from Go benchmarks) + +- **Direct Write**: 513.5 ns/op (baseline) +- **Page-Aligned 16K**: 188.0 ns/op (**2.7x faster**) +- **Page-Aligned 64K**: 176.4 ns/op (**2.9x faster**) +- **Small Records**: 32.68 ns/op (**15.7x faster** with page alignment) + +#### Scalability Characteristics + +1. **Small Records**: Page-aligned buffers show dramatic improvement (5-15x) +2. **Medium Records**: Consistent 2-3x improvement across all buffered methods +3. **Large Records**: Diminishing returns as record size approaches buffer size +4. **Concurrent Writes**: Thread-safe implementation maintains high throughput (5,012 MB/s) + +#### Technical Insights + +**Why Page-Aligned Buffers Win:** +- **Reduced System Calls**: Buffer aggregation minimizes expensive kernel transitions +- **Cache Line Efficiency**: Page-aligned memory access patterns optimize CPU cache usage +- **Filesystem Optimization**: Writes aligned to filesystem block boundaries reduce overhead +- **Memory Management**: Eliminates heap allocations through pre-allocated buffers + +**Buffer Size Analysis:** +- **4KB**: Matches most filesystem page sizes, good baseline performance +- **16KB**: Sweet spot for balanced throughput and memory usage +- **64KB**: Maximum throughput but higher memory consumption +- **Beyond 64KB**: Diminishing returns due to cache pressure + +**Record Size Effects:** +- **Small Records (128B)**: Massive gains from batching (up to 15x improvement) +- **Medium Records (1KB)**: Strong benefits from reduced syscall overhead +- **Large Records (8KB)**: Minimal gains as records approach buffer size + +#### Production Recommendations + +**For High-Throughput Applications:** +```go +// Optimal configuration for maximum throughput +writer := NewPageAlignedBuffer("data.log", PageSize64K) +defer writer.Close() + +// Batch small records for maximum efficiency +batch := make([]byte, 0, 8192) +for record := range records { + batch = append(batch, record...) + if len(batch) >= 8192 { + writer.Write(batch) + batch = batch[:0] + } +} +``` + +**For Low-Latency Applications:** +```go +// Balance between throughput and latency +writer := NewPageAlignedBuffer("events.log", PageSize16K) +defer writer.Close() + +// Periodic flushes for guaranteed durability +ticker := time.NewTicker(100 * time.Millisecond) +go func() { + for range ticker.C { + writer.Sync() + } +}() +``` + +**Memory vs Performance Trade-offs:** + +| Buffer Size | Memory Usage | Throughput | Best For | +|-------------|--------------|------------|----------| +| 4KB | 4KB per writer | Good | Memory-constrained | +| 16KB | 16KB per writer | **Optimal** | **General purpose** | +| 64KB | 64KB per writer | Maximum | Bulk ingestion | + +## FUSE Filesystem Analysis + +### Can FUSE Improve Performance? + +**Short Answer: Usually No** - FUSE typically **reduces** performance for append-only workloads due to context switching overhead. + +### FUSE Performance Impact + +| Aspect | Impact | Reason | +|--------|--------|--------| +| **Context Switches** | -50-200μs per operation | Kernel ↔ Userspace transitions | +| **Data Copying** | -10-50μs per MB | Additional memory copies | +| **System Call Overhead** | -1-5μs per call | Extra syscalls in pipeline | +| **Overall Performance** | **3-5x slower** | Cumulative overhead | + +### When FUSE Might Help + +FUSE becomes beneficial when you need: + +1. **Custom Compression** (compression ratio > 3:1) +```go +// FUSE with transparent compression +compressed := compress(data) // Saves 3x storage I/O +backingFile.Write(compressed) // Compensates for FUSE overhead +``` + +2. **Specialized Storage Formats** +```go +// Convert row-based to columnar storage +columns := convertToColumns(records) +writeColumnarData(columns) // Optimized for analytics +``` + +3. **Network Storage Optimization** +```go +// Batch operations for network efficiency +batch := accumulate(data) +sendBatchAsync(compress(batch)) // Reduces network round-trips +``` + +4. **Multi-tier Storage Management** +```go +// Intelligent data placement +if isHotData(data) { + writeSSD(data) +} else { + writeToCloud(compress(data)) +} +``` + +### Performance Comparison + +Based on our benchmarks: + +| Method | Throughput | Best Use Case | +|--------|------------|---------------| +| **Direct Write** | 1,134 MB/s | Simple baseline | +| **Page-Aligned 16K** | **2,311 MB/s** | **Recommended** | +| **Memory Mapped** | 2,054 MB/s | Large sequential | +| **FUSE Basic** | ~400 MB/s | ❌ Not recommended | +| **FUSE + Compression** | ~800 MB/s | High compression ratios only | + +### Recommendation + +**For pure append-only performance**: Use **PageAlignedBuffer** - it's 2-3x faster than direct writes and 5-6x faster than FUSE. + +**Consider FUSE only when**: +- You need data transformation (compression, encryption, format conversion) +- Working with network storage where batching helps +- Building storage abstraction layers + +See `FUSE_ANALYSIS.md` for detailed technical analysis. + +### Run Go Benchmarks +```bash +# Run all benchmarks +go test -bench=. + +# Run specific benchmark +go test -bench=BenchmarkPageAligned16K + +# Run with memory profiling +go test -bench=. -memprofile=mem.prof + +# Run with CPU profiling +go test -bench=. -cpuprofile=cpu.prof + +# Detailed benchmark with allocations +go test -bench=. -benchmem +``` + +## Architecture Components + +### 1. PageAlignedBuffer +Custom buffered writer that: +- Maintains internal buffer aligned to page boundaries +- Flushes only when buffer reaches capacity or explicitly requested +- Thread-safe with mutex protection +- Optimized for sequential append operations + +```go +writer, err := NewPageAlignedBuffer("file.log", PageSize16K) +defer writer.Close() + +// Writes are buffered until page boundary +writer.Write(data) +writer.Sync() // Flush and fsync to disk +``` + +### 2. Memory-Mapped Writer +Uses `mmap()` system call for: +- Zero-copy writes directly to memory +- Kernel-managed page cache optimization +- Efficient for large sequential writes + +```go +writer, err := NewMemoryMappedWriter("file.log", totalSize) +defer writer.Close() + +writer.Write(data) // Writes directly to mapped memory +writer.Sync() // Sync to disk with msync() +``` + +### 3. Direct Writer +Baseline implementation for comparison: +- No buffering - each write goes directly to kernel +- Useful for understanding buffering benefits +- Higher syscall overhead but guaranteed write ordering + +## Performance Optimization Strategies + +### Buffer Size Selection +- **4KB-8KB**: Best for low-latency applications requiring frequent flushes +- **16KB-32KB**: Optimal for most high-throughput workloads +- **64KB+**: Best for bulk data ingestion with less frequent syncing + +### Write Pattern Optimization +1. **Batch Small Writes**: Accumulate small records before writing +2. **Align to Page Boundaries**: Use page-sized buffers (4KB multiples) +3. **Minimize Sync Calls**: Only sync when durability is required +4. **Pre-allocate Files**: Use `fallocate()` to pre-allocate disk space + +### System-Level Optimizations +```bash +# Disable file access time updates +mount -o noatime,nodiratime /dev/sda1 /data + +# Increase write buffer sizes +echo 'vm.dirty_ratio = 40' >> /etc/sysctl.conf +echo 'vm.dirty_background_ratio = 10' >> /etc/sysctl.conf + +# Use deadline I/O scheduler for sequential writes +echo deadline > /sys/block/sda/queue/scheduler +``` + +## Benchmark Results Analysis + +### Expected Performance Characteristics + +| Method | Throughput | Latency | CPU Usage | Use Case | +|--------|------------|---------|-----------|----------| +| Direct Write | Low | High | Low | Strict ordering | +| Buffered 4K | Medium | Medium | Medium | Balanced | +| Page-Aligned 16K | High | Low | Medium | High throughput | +| Memory Mapped | Highest | Lowest | Highest | Bulk ingestion | + +### Platform-Specific Considerations + +**SSD Storage:** +- Page-aligned buffers show 3-5x improvement over direct writes +- Memory mapping excels for large sequential writes +- 16KB-32KB buffers provide optimal throughput + +**HDD Storage:** +- Larger buffers (64KB+) reduce seek overhead +- Sequential write patterns are crucial +- Pre-allocation reduces fragmentation + +**Network Storage (NFS/CIFS):** +- Larger buffers reduce network round-trips +- Memory mapping may not provide benefits +- Consider async write modes + +## Advanced Usage + +### Custom Record Format +```go +type LogRecord struct { + Timestamp int64 + Level uint8 + Message []byte +} + +func (r *LogRecord) Marshal() []byte { + // Custom serialization optimized for append-only writes +} +``` + +### Batch Writing +```go +writer := NewPageAlignedBuffer("batch.log", PageSize16K) +defer writer.Close() + +// Accumulate records until page boundary +var batch []byte +for record := range records { + batch = append(batch, record.Marshal()...) + if len(batch) >= PageSize4K { + writer.Write(batch) + batch = batch[:0] // Reset slice + } +} +``` + +### Error Recovery +```go +if err := writer.Write(data); err != nil { + // Log error but continue - append-only design allows recovery + log.Printf("Write failed: %v", err) + + // Attempt to sync partial data + if syncErr := writer.Sync(); syncErr != nil { + log.Printf("Sync failed: %v", syncErr) + } +} +``` + +## Monitoring and Metrics + +### Key Performance Indicators +- **Write Throughput**: MB/s sustained write rate +- **Write Latency**: p99 latency for individual writes +- **Buffer Efficiency**: Ratio of buffered to direct writes +- **Disk Utilization**: IOPs and queue depth +- **Memory Usage**: Buffer memory and page cache + +### Profiling Integration +```bash +# CPU profiling +go test -bench=BenchmarkPageAligned16K -cpuprofile=cpu.prof +go tool pprof cpu.prof + +# Memory profiling +go test -bench=BenchmarkMemoryMapped -memprofile=mem.prof +go tool pprof mem.prof + +# Trace analysis +go test -bench=. -trace=trace.out +go tool trace trace.out +``` + +## Contributing + +When adding new benchmarks: +1. Follow the naming convention `Benchmark` +2. Use `b.SetBytes()` to report throughput +3. Reset timers appropriately with `b.ResetTimer()` +4. Clean up test files with `defer os.Remove()` +5. Test on multiple platforms (Linux, macOS, Windows) + +## License + +This benchmark suite is part of the BharatMLStack project and follows the same licensing terms. \ No newline at end of file diff --git a/flashring/cmd/flashringtest/main.go b/flashring/cmd/flashringtest/main.go new file mode 100644 index 00000000..39379189 --- /dev/null +++ b/flashring/cmd/flashringtest/main.go @@ -0,0 +1,277 @@ +package main + +import ( + "flag" + "fmt" + "math/bits" + "math/rand" + "net/http" + "os" + "runtime" + "runtime/pprof" + "sync/atomic" + "time" + + _ "net/http/pprof" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +// normalDistInt returns an integer in [0, max) following a normal distribution +// centered at max/2 with standard deviation = max/8 +func normalDistInt(max int) int { + if max <= 0 { + return 0 + } + mean := float64(max) / 2.0 + stdDev := float64(max) / 8.0 + for { + val := rand.NormFloat64()*stdDev + mean + if val >= 0 && val < float64(max) { + return int(val) + } + } +} + +// normalDistIntPartitioned returns an integer following a normal distribution +// constrained to a specific worker's partition of the total key space. +func normalDistIntPartitioned(workerID, numWorkers, totalKeys int) int { + if totalKeys <= 0 || numWorkers <= 0 { + return 0 + } + partitionSize := totalKeys / numWorkers + partitionStart := workerID * partitionSize + partitionEnd := partitionStart + partitionSize + if workerID == numWorkers-1 { + partitionEnd = totalKeys + } + mean := float64(totalKeys) / 2.0 + stdDev := float64(totalKeys) / 8.0 + for { + val := rand.NormFloat64()*stdDev + mean + if val >= float64(partitionStart) && val < float64(partitionEnd) { + return int(val) + } + } +} + +// ---- Shared metrics & profiling infrastructure ---- + +const histBuckets = 32 + +type opMetrics struct { + count atomic.Int64 + totalNs atomic.Int64 + minNs atomic.Int64 + maxNs atomic.Int64 + hist [histBuckets]atomic.Int64 +} + +func (m *opMetrics) record(d time.Duration) { + ns := d.Nanoseconds() + if ns <= 0 { + ns = 1 + } + m.count.Add(1) + m.totalNs.Add(ns) + + bucket := bits.Len64(uint64(ns)) - 1 + if bucket >= histBuckets { + bucket = histBuckets - 1 + } + m.hist[bucket].Add(1) + + for { + cur := m.minNs.Load() + if cur != 0 && cur <= ns { + break + } + if m.minNs.CompareAndSwap(cur, ns) { + break + } + } + for { + cur := m.maxNs.Load() + if cur >= ns { + break + } + if m.maxNs.CompareAndSwap(cur, ns) { + break + } + } +} + +func (m *opMetrics) percentile(p float64) time.Duration { + total := m.count.Load() + if total == 0 { + return 0 + } + threshold := int64(float64(total)*p/100.0 + 0.5) + var cumulative int64 + for i := 0; i < histBuckets; i++ { + cumulative += m.hist[i].Load() + if cumulative >= threshold { + return time.Duration(int64(1) << i) + } + } + return time.Duration(m.maxNs.Load()) +} + +func (m *opMetrics) snapshot() (count int64, avg, min, max, p50, p99 time.Duration) { + count = m.count.Load() + if count == 0 { + return + } + avg = time.Duration(m.totalNs.Load() / count) + min = time.Duration(m.minNs.Load()) + max = time.Duration(m.maxNs.Load()) + p50 = m.percentile(50) + p99 = m.percentile(99) + return +} + +type loadMetrics struct { + getMetrics opMetrics + putMetrics opMetrics + prepopulatePutMetrics opMetrics + getHits atomic.Int64 + getMisses atomic.Int64 + getExpired atomic.Int64 +} + +func printOpLine(name string, m *opMetrics) { + count, avg, min, max, p50, p99 := m.snapshot() + fmt.Printf("%-5s count=%-12d\n", name, count) + if count > 0 { + fmt.Printf(" avg=%-14s min=%-14s max=%-14s p50=%-14s p99=%-14s\n", avg, min, max, p50, p99) + } +} + +func (lm *loadMetrics) printStats(label string) { + gc, _, _, _, _, _ := lm.getMetrics.snapshot() + fmt.Printf("\n===== %s =====\n", label) + fmt.Printf("GET count=%-12d hits=%-12d misses=%-12d expired=%-12d\n", + gc, lm.getHits.Load(), lm.getMisses.Load(), lm.getExpired.Load()) + if gc > 0 { + printOpLine("GET", &lm.getMetrics) + } + printOpLine("PUT", &lm.putMetrics) + printOpLine("PREPOP", &lm.prepopulatePutMetrics) + fmt.Println() +} + +// commonFlags holds the shared flags across all plans. +type commonFlags struct { + mountPoint string + numShards int + keysPerShard int + memtableMB int + fileSizeMultiplier float64 + readWorkers int + writeWorkers int + sampleSecs int + iterations int64 + logStats bool + memProfile string + cpuProfile string +} + +func (f *commonFlags) register(fs *flag.FlagSet, defaults commonFlags) { + fs.StringVar(&f.mountPoint, "mount", defaults.mountPoint, "data directory for shard files") + fs.IntVar(&f.numShards, "shards", defaults.numShards, "number of shards") + fs.IntVar(&f.keysPerShard, "keys-per-shard", defaults.keysPerShard, "keys per shard") + fs.IntVar(&f.memtableMB, "memtable-mb", defaults.memtableMB, "memtable size in MiB") + fs.Float64Var(&f.fileSizeMultiplier, "file-size-multiplier", defaults.fileSizeMultiplier, "file size in GiB per shard") + fs.IntVar(&f.readWorkers, "readers", defaults.readWorkers, "number of read workers") + fs.IntVar(&f.writeWorkers, "writers", defaults.writeWorkers, "number of write workers") + fs.IntVar(&f.sampleSecs, "sample-secs", defaults.sampleSecs, "predictor sampling window in seconds") + fs.Int64Var(&f.iterations, "iterations", defaults.iterations, "number of iterations") + fs.BoolVar(&f.logStats, "log-stats", defaults.logStats, "periodically log cache stats") + fs.StringVar(&f.memProfile, "memprofile", defaults.memProfile, "write memory profile to this file") + fs.StringVar(&f.cpuProfile, "cpuprofile", defaults.cpuProfile, "write cpu profile to this file") +} + +func (f *commonFlags) memtableSizeBytes() int32 { + return int32(f.memtableMB) * 1024 * 1024 +} + +func (f *commonFlags) fileSizeBytes() int64 { + return int64(f.fileSizeMultiplier * 1024 * 1024 * 1024) +} + +// setupProfiling starts pprof, CPU profiling and returns a teardown function +// that writes the memory profile. +func setupProfiling(flags commonFlags) func() { + zerolog.SetGlobalLevel(zerolog.InfoLevel) + + go func() { + log.Info().Msg("Starting pprof server on :8080") + if err := http.ListenAndServe(":8080", nil); err != nil { + log.Error().Err(err).Msg("pprof server failed") + } + }() + + if flags.cpuProfile != "" { + f, err := os.Create(flags.cpuProfile) + if err != nil { + log.Fatal().Err(err).Msg("could not create CPU profile") + } + if err := pprof.StartCPUProfile(f); err != nil { + f.Close() + log.Fatal().Err(err).Msg("could not start CPU profile") + } + } + + return func() { + pprof.StopCPUProfile() + + if flags.memProfile != "" { + runtime.GC() + f, err := os.Create(flags.memProfile) + if err != nil { + log.Fatal().Err(err).Msg("could not create memory profile") + } + defer f.Close() + if err := pprof.WriteHeapProfile(f); err != nil { + log.Fatal().Err(err).Msg("could not write memory profile") + } + log.Info().Msgf("Memory profile written to %s", flags.memProfile) + } + + var m runtime.MemStats + runtime.ReadMemStats(&m) + log.Info(). + Str("alloc", fmt.Sprintf("%.2f MB", float64(m.Alloc)/1024/1024)). + Str("total_alloc", fmt.Sprintf("%.2f MB", float64(m.TotalAlloc)/1024/1024)). + Str("sys", fmt.Sprintf("%.2f MB", float64(m.Sys)/1024/1024)). + Uint32("num_gc", m.NumGC). + Msg("Memory statistics") + } +} + +// ---- Plan registry ---- + +type plan func() + +var plans = map[string]plan{ + "freecache": planFreecache, + "readthrough": planReadthroughGaussian, + "random": planRandomGaussian, + "readthrough-batched": planReadthroughGaussianBatched, + "badger": planBadger, +} + +func main() { + name := os.Getenv("PLAN") + p, ok := plans[name] + if !ok { + fmt.Fprintf(os.Stderr, "unknown plan %q, available: ", name) + for k := range plans { + fmt.Fprintf(os.Stderr, "%s ", k) + } + fmt.Fprintln(os.Stderr) + os.Exit(1) + } + p() +} diff --git a/flashring/cmd/flashringtest/mem.prof b/flashring/cmd/flashringtest/mem.prof new file mode 100644 index 00000000..f11189a6 Binary files /dev/null and b/flashring/cmd/flashringtest/mem.prof differ diff --git a/flashring/cmd/flashringtest/plan_badger.go b/flashring/cmd/flashringtest/plan_badger.go new file mode 100644 index 00000000..2c988e39 --- /dev/null +++ b/flashring/cmd/flashringtest/plan_badger.go @@ -0,0 +1,110 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "strings" + "sync" + "time" + + cachepkg "github.com/Meesho/BharatMLStack/flashring/pkg/cache" + "github.com/rs/zerolog/log" +) + +func planBadger() { + var flags commonFlags + flags.register(flag.CommandLine, commonFlags{ + mountPoint: "/mnt/disks/nvme/badger", + numShards: 1, + keysPerShard: 20_000_000, + memtableMB: 16, + fileSizeMultiplier: 1, + readWorkers: 4, + writeWorkers: 4, + sampleSecs: 30, + iterations: 100_000_000, + logStats: true, + memProfile: "mem.prof", + }) + flag.Parse() + teardown := setupProfiling(flags) + defer teardown() + + cache, err := cachepkg.NewBadger(cachepkg.Config{}, flags.mountPoint) + if err != nil { + panic(err) + } + defer cache.Close() + + const multiplier = 300 + totalKeys := flags.keysPerShard * flags.numShards + str1kb := "%d" + strings.Repeat("a", 1024) + + missedKeyChanList := make([]chan int, flags.writeWorkers) + for i := range missedKeyChanList { + missedKeyChanList[i] = make(chan int) + } + + fmt.Println("----------------------------------------------prepopulating keys") + for k := 0; k < totalKeys; k++ { + if rand.Intn(100) < 30 { + continue + } + key := fmt.Sprintf("key%d", k) + val := []byte(fmt.Sprintf(str1kb, k)) + if err := cache.Put(key, val, time.Hour); err != nil { + panic(err) + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------prepopulated %d keys\n", k) + } + } + + var wg, writeWg sync.WaitGroup + + if flags.writeWorkers > 0 { + fmt.Println("----------------------------------------------starting write workers") + writeWg.Add(flags.writeWorkers) + for w := 0; w < flags.writeWorkers; w++ { + go func(workerID int) { + defer writeWg.Done() + for mk := range missedKeyChanList[workerID] { + key := fmt.Sprintf("key%d", mk) + val := []byte(fmt.Sprintf(str1kb, mk)) + if err := cache.Put(key, val, time.Hour); err != nil { + panic(err) + } + } + }(w) + } + } + + if flags.readWorkers > 0 { + fmt.Println("----------------------------------------------reading keys") + wg.Add(flags.readWorkers) + for r := 0; r < flags.readWorkers; r++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistInt(totalKeys) + key := fmt.Sprintf("key%d", randomval) + _, found, expired := cache.Get(key) + + if !found { + missedKeyChanList[randomval%flags.writeWorkers] <- randomval + } + if expired { + panic("key expired") + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------read %d keys %d readerid\n", k, workerID) + } + } + }(r) + } + } + + wg.Wait() + log.Info().Msg("done") +} diff --git a/flashring/cmd/flashringtest/plan_freecache.go b/flashring/cmd/flashringtest/plan_freecache.go new file mode 100644 index 00000000..8e417abb --- /dev/null +++ b/flashring/cmd/flashringtest/plan_freecache.go @@ -0,0 +1,110 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "strings" + "sync" + "time" + + cachepkg "github.com/Meesho/BharatMLStack/flashring/pkg/cache" + "github.com/rs/zerolog/log" +) + +func planFreecache() { + var flags commonFlags + flags.register(flag.CommandLine, commonFlags{ + mountPoint: "/mnt/disks/nvme/", + numShards: 1, + keysPerShard: 20_000_000, + memtableMB: 16, + fileSizeMultiplier: 4, + readWorkers: 4, + writeWorkers: 4, + sampleSecs: 30, + iterations: 100_000_000, + logStats: true, + memProfile: "mem.prof", + }) + flag.Parse() + teardown := setupProfiling(flags) + defer teardown() + + cache, err := cachepkg.NewFreecache(int(flags.fileSizeBytes())) + if err != nil { + panic(err) + } + defer cache.Close() + + const multiplier = 300 + totalKeys := flags.keysPerShard * flags.numShards + str1kb := "%d" + strings.Repeat("a", 1024) + + missedKeyChanList := make([]chan int, flags.writeWorkers) + for i := range missedKeyChanList { + missedKeyChanList[i] = make(chan int) + } + + fmt.Println("----------------------------------------------prepopulating keys") + for k := 0; k < totalKeys; k++ { + if rand.Intn(100) < 30 { + continue + } + key := fmt.Sprintf("key%d", k) + val := []byte(fmt.Sprintf(str1kb, k)) + if err := cache.Put(key, val, time.Hour); err != nil { + panic(err) + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------prepopulated %d keys\n", k) + } + } + + var wg, writeWg sync.WaitGroup + + if flags.writeWorkers > 0 { + fmt.Println("----------------------------------------------starting write workers") + writeWg.Add(flags.writeWorkers) + for w := 0; w < flags.writeWorkers; w++ { + go func(workerID int) { + defer writeWg.Done() + for mk := range missedKeyChanList[workerID] { + key := fmt.Sprintf("key%d", mk) + val := []byte(fmt.Sprintf(str1kb, mk)) + if err := cache.Put(key, val, time.Hour); err != nil { + panic(err) + } + } + }(w) + } + } + + if flags.readWorkers > 0 { + fmt.Println("----------------------------------------------reading keys") + wg.Add(flags.readWorkers) + for r := 0; r < flags.readWorkers; r++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistInt(totalKeys) + key := fmt.Sprintf("key%d", randomval) + _, found, expired := cache.Get(key) + + if !found { + missedKeyChanList[randomval%flags.writeWorkers] <- randomval + } + if expired { + panic("key expired") + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------read %d keys %d readerid\n", k, workerID) + } + } + }(r) + } + } + + wg.Wait() + log.Info().Msg("done") +} diff --git a/flashring/cmd/flashringtest/plan_random_gausian.go b/flashring/cmd/flashringtest/plan_random_gausian.go new file mode 100644 index 00000000..88ddf26e --- /dev/null +++ b/flashring/cmd/flashringtest/plan_random_gausian.go @@ -0,0 +1,112 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + cachepkg "github.com/Meesho/BharatMLStack/flashring/pkg/cache" + "github.com/rs/zerolog/log" +) + +func planRandomGaussian() { + var flags commonFlags + flags.register(flag.CommandLine, commonFlags{ + mountPoint: "/mnt/disks/nvme/", + numShards: 1, + keysPerShard: 20_000_000, + memtableMB: 16, + fileSizeMultiplier: 40, + readWorkers: 1, + writeWorkers: 1, + sampleSecs: 30, + iterations: 100_000_000, + logStats: true, + memProfile: "mem.prof", + }) + flag.Parse() + teardown := setupProfiling(flags) + defer teardown() + + files, err := os.ReadDir(flags.mountPoint) + if err != nil { + panic(err) + } + for _, file := range files { + os.Remove(filepath.Join(flags.mountPoint, file.Name())) + } + + cfg := cachepkg.Config{ + NumShards: flags.numShards, + KeysPerShard: flags.keysPerShard, + FileSize: flags.fileSizeBytes(), + MemtableSize: flags.memtableSizeBytes(), + ReWriteScoreThreshold: 0.8, + GridSearchEpsilon: 0.0001, + SampleDuration: time.Duration(flags.sampleSecs) * time.Second, + } + + pc, err := cachepkg.NewWrapCache(cfg, flags.mountPoint) + if err != nil { + panic(err) + } + defer pc.Close() + + const multiplier = 300 + totalKeys := flags.keysPerShard * flags.numShards + str1kb := "%d" + strings.Repeat("a", 1024) + + var wg sync.WaitGroup + + if flags.writeWorkers > 0 { + fmt.Println("----------------------------------------------writing keys") + wg.Add(flags.writeWorkers) + for w := 0; w < flags.writeWorkers; w++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistInt(totalKeys) + key := fmt.Sprintf("key%d", randomval) + val := []byte(fmt.Sprintf(str1kb, randomval)) + if err := pc.Put(key, val, 60*time.Minute); err != nil { + panic(err) + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------wrote %d keys %d writerid\n", k, workerID) + } + } + }(w) + } + } + + if flags.readWorkers > 0 { + fmt.Println("----------------------------------------------reading keys") + wg.Add(flags.readWorkers) + for r := 0; r < flags.readWorkers; r++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistInt(totalKeys) + key := fmt.Sprintf("key%d", randomval) + val, found, expired := pc.Get(key) + if expired { + panic("key expired") + } + if found && string(val) != fmt.Sprintf(str1kb, randomval) { + panic("value mismatch") + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------read %d keys %d readerid\n", k, workerID) + } + } + }(r) + } + } + + wg.Wait() + log.Info().Msg("done") +} diff --git a/flashring/cmd/flashringtest/plan_readthrough_gausian.go b/flashring/cmd/flashringtest/plan_readthrough_gausian.go new file mode 100644 index 00000000..3e71edb9 --- /dev/null +++ b/flashring/cmd/flashringtest/plan_readthrough_gausian.go @@ -0,0 +1,160 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + "time" + + cachepkg "github.com/Meesho/BharatMLStack/flashring/pkg/cache" + "github.com/rs/zerolog/log" +) + +func planReadthroughGaussian() { + var flags commonFlags + flags.register(flag.CommandLine, commonFlags{ + mountPoint: "/mnt/disks/nvme/", + numShards: 50, + keysPerShard: 6_00_000, + memtableMB: 2, + fileSizeMultiplier: 0.25, + readWorkers: 16, + writeWorkers: 16, + sampleSecs: 30, + iterations: 100_000_000, + logStats: true, + memProfile: "mem.prof", + }) + flag.Parse() + teardown := setupProfiling(flags) + defer teardown() + + files, err := os.ReadDir(flags.mountPoint) + if err != nil { + panic(err) + } + for _, file := range files { + os.Remove(filepath.Join(flags.mountPoint, file.Name())) + } + + cfg := cachepkg.Config{ + NumShards: flags.numShards, + KeysPerShard: flags.keysPerShard, + FileSize: flags.fileSizeBytes(), + MemtableSize: flags.memtableSizeBytes(), + ReWriteScoreThreshold: 0.8, + GridSearchEpsilon: 0.0001, + SampleDuration: time.Duration(flags.sampleSecs) * time.Second, + } + + pc, err := cachepkg.NewWrapCache(cfg, flags.mountPoint) + if err != nil { + panic(err) + } + defer pc.Close() + + const multiplier = 300 + totalKeys := 10_000_000 + str1kb := "%d" + strings.Repeat("a", 1024) + + var metrics loadMetrics + stopReporter := make(chan struct{}) + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + metrics.printStats("PERIODIC") + case <-stopReporter: + return + } + } + }() + + missedKeyChanList := make([]chan int, flags.writeWorkers) + for i := range missedKeyChanList { + missedKeyChanList[i] = make(chan int) + } + + fmt.Println("----------------------------------------------prepopulating keys") + for k := 0; k < totalKeys; k++ { + if rand.Intn(100) < 30 { + continue + } + key := fmt.Sprintf("key%d", k) + val := []byte(fmt.Sprintf(str1kb, k)) + start := time.Now() + if err := pc.Put(key, val, 60*time.Minute); err != nil { + log.Error().Err(err).Msgf("error putting key %s", key) + } + metrics.prepopulatePutMetrics.record(time.Since(start)) + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------prepopulated %d keys\n", k) + } + } + + var wg, writeWg sync.WaitGroup + + if flags.writeWorkers > 0 { + fmt.Println("----------------------------------------------starting write workers") + writeWg.Add(flags.writeWorkers) + for w := 0; w < flags.writeWorkers; w++ { + go func(workerID int) { + defer writeWg.Done() + for mk := range missedKeyChanList[workerID] { + key := fmt.Sprintf("key%d", mk) + val := []byte(fmt.Sprintf(str1kb, mk)) + start := time.Now() + if err := pc.Put(key, val, 60*time.Minute); err != nil { + log.Error().Err(err).Msgf("error putting key %s", key) + } + metrics.putMetrics.record(time.Since(start)) + } + }(w) + } + } + + if flags.readWorkers > 0 { + fmt.Println("----------------------------------------------reading keys") + wg.Add(flags.readWorkers) + for r := 0; r < flags.readWorkers; r++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistIntPartitioned(workerID, flags.readWorkers, totalKeys) + key := fmt.Sprintf("key%d", randomval) + start := time.Now() + val, found, expired := pc.Get(key) + metrics.getMetrics.record(time.Since(start)) + + if !found { + metrics.getMisses.Add(1) + missedKeyChanList[randomval%flags.writeWorkers] <- randomval + } else { + metrics.getHits.Add(1) + } + if expired { + metrics.getExpired.Add(1) + log.Error().Msgf("key %s expired", key) + } + if found && string(val) != fmt.Sprintf(str1kb, randomval) { + panic("value mismatch") + } + if k%50000 == 0 { + fmt.Printf("----------------------------------------------read %d keys %d readerid\n", k, workerID) + } + } + }(r) + } + } + + wg.Wait() + close(stopReporter) + metrics.printStats("FINAL") + log.Info().Msg("done") +} diff --git a/flashring/cmd/flashringtest/plan_readthrough_gausian_batched.go b/flashring/cmd/flashringtest/plan_readthrough_gausian_batched.go new file mode 100644 index 00000000..786eef36 --- /dev/null +++ b/flashring/cmd/flashringtest/plan_readthrough_gausian_batched.go @@ -0,0 +1,133 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + "time" + + cachepkg "github.com/Meesho/BharatMLStack/flashring/pkg/cache" + "github.com/rs/zerolog/log" +) + +func planReadthroughGaussianBatched() { + var flags commonFlags + flags.register(flag.CommandLine, commonFlags{ + mountPoint: "/mnt/disks/nvme/", + numShards: 200, + keysPerShard: 10_00_00, + memtableMB: 16, + fileSizeMultiplier: 10, + readWorkers: 8, + writeWorkers: 8, + sampleSecs: 30, + iterations: 100_000_000, + logStats: true, + memProfile: "mem.prof", + }) + flag.Parse() + teardown := setupProfiling(flags) + defer teardown() + + files, err := os.ReadDir(flags.mountPoint) + if err != nil { + panic(err) + } + for _, file := range files { + os.Remove(filepath.Join(flags.mountPoint, file.Name())) + } + + cfg := cachepkg.Config{ + NumShards: flags.numShards, + KeysPerShard: flags.keysPerShard, + FileSize: flags.fileSizeBytes(), + MemtableSize: flags.memtableSizeBytes(), + ReWriteScoreThreshold: 0.8, + GridSearchEpsilon: 0.0001, + SampleDuration: time.Duration(flags.sampleSecs) * time.Second, + } + + pc, err := cachepkg.NewWrapCache(cfg, flags.mountPoint) + if err != nil { + panic(err) + } + defer pc.Close() + + const multiplier = 300 + totalKeys := flags.keysPerShard * flags.numShards + str1kb := "%d" + strings.Repeat("a", 1024) + + missedKeyChanList := make([]chan int, flags.writeWorkers) + for i := range missedKeyChanList { + missedKeyChanList[i] = make(chan int) + } + + fmt.Println("----------------------------------------------prepopulating keys") + for k := 0; k < totalKeys; k++ { + if rand.Intn(100) < 30 { + continue + } + key := fmt.Sprintf("key%d", k) + val := []byte(fmt.Sprintf(str1kb, k)) + if err := pc.Put(key, val, 60*time.Minute); err != nil { + panic(err) + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------prepopulated %d keys\n", k) + } + } + + var wg, writeWg sync.WaitGroup + + if flags.writeWorkers > 0 { + fmt.Println("----------------------------------------------starting write workers") + writeWg.Add(flags.writeWorkers) + for w := 0; w < flags.writeWorkers; w++ { + go func(workerID int) { + defer writeWg.Done() + for mk := range missedKeyChanList[workerID] { + key := fmt.Sprintf("key%d", mk) + val := []byte(fmt.Sprintf(str1kb, mk)) + if err := pc.Put(key, val, 60*time.Minute); err != nil { + panic(err) + } + } + }(w) + } + } + + if flags.readWorkers > 0 { + fmt.Println("----------------------------------------------reading keys") + wg.Add(flags.readWorkers) + for r := 0; r < flags.readWorkers; r++ { + go func(workerID int) { + defer wg.Done() + for k := 0; k < totalKeys*multiplier; k++ { + randomval := normalDistIntPartitioned(workerID, flags.readWorkers, totalKeys) + key := fmt.Sprintf("key%d", randomval) + val, found, expired := pc.Get(key) + + if !found { + missedKeyChanList[randomval%flags.writeWorkers] <- randomval + } + if expired { + panic("key expired") + } + if found && string(val) != fmt.Sprintf(str1kb, randomval) { + panic("value mismatch") + } + if k%5_000_000 == 0 { + fmt.Printf("----------------------------------------------read %d keys %d readerid\n", k, workerID) + } + } + }(r) + } + } + + wg.Wait() + log.Info().Msg("done") +} diff --git a/flashring/go.mod b/flashring/go.mod new file mode 100644 index 00000000..206adab3 --- /dev/null +++ b/flashring/go.mod @@ -0,0 +1,49 @@ +module github.com/Meesho/BharatMLStack/flashring + +go 1.24.0 + +toolchain go1.24.9 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 + github.com/coocood/freecache v1.2.4 + github.com/rs/zerolog v1.34.0 + github.com/zeebo/xxh3 v1.0.2 + golang.org/x/sys v0.38.0 +) + +require ( + github.com/Microsoft/go-winio v0.5.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/text v0.28.0 // indirect +) + +require ( + github.com/DataDog/datadog-go/v5 v5.8.2 + github.com/dgraph-io/badger/v4 v4.9.0 + github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/flatbuffers v25.2.10+incompatible // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/spf13/viper v1.21.0 + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/net v0.43.0 // indirect + google.golang.org/protobuf v1.36.7 // indirect +) diff --git a/flashring/go.sum b/flashring/go.sum new file mode 100644 index 00000000..5d69f8d2 --- /dev/null +++ b/flashring/go.sum @@ -0,0 +1,144 @@ +github.com/DataDog/datadog-go/v5 v5.8.2 h1:9IEfH1Mw9AjWwhAMqCAkhbxjuJeMxm2ARX2VdgL+ols= +github.com/DataDog/datadog-go/v5 v5.8.2/go.mod h1:K9kcYBlxkcPP8tvvjZZKs/m1edNAUFzBbdpTUKfCsuw= +github.com/Microsoft/go-winio v0.5.0 h1:Elr9Wn+sGKPlkaBvwu4mTrxtmOp3F3yV9qhaHbXGjwU= +github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/coocood/freecache v1.2.4 h1:UdR6Yz/X1HW4fZOuH0Z94KwG851GWOSknua5VUbb/5M= +github.com/coocood/freecache v1.2.4/go.mod h1:RBUWa/Cy+OHdfTGFEhEuE1pMCMX51Ncizj7rthiQ3vk= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger/v4 v4.9.0 h1:tpqWb0NewSrCYqTvywbcXOhQdWcqephkVkbBmaaqHzc= +github.com/dgraph-io/badger/v4 v4.9.0/go.mod h1:5/MEx97uzdPUHR4KtkNt8asfI2T4JiEiQlV7kWUo8c0= +github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= +github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/flashring/internal/allocators/allocators.go b/flashring/internal/allocators/allocators.go new file mode 100644 index 00000000..48d75b82 --- /dev/null +++ b/flashring/internal/allocators/allocators.go @@ -0,0 +1,11 @@ +package allocators + +type SizeClass struct { + Size int + MinCount int +} + +type Meta struct { + Size int + Name string +} diff --git a/flashring/internal/allocators/slab_aligned_page_allocator.go b/flashring/internal/allocators/slab_aligned_page_allocator.go new file mode 100644 index 00000000..4d1c1ae4 --- /dev/null +++ b/flashring/internal/allocators/slab_aligned_page_allocator.go @@ -0,0 +1,71 @@ +package allocators + +import ( + "errors" + "fmt" + "sort" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/internal/pools" + "github.com/rs/zerolog/log" +) + +var ( + ErrSizeNotAligned = errors.New("size not aligned") +) + +type SlabAlignedPageAllocatorConfig struct { + SizeClasses []SizeClass +} + +type SlabAlignedPageAllocator struct { + config SlabAlignedPageAllocatorConfig + pools []*pools.LeakyPool[*fs.AlignedPage] + sizes []int +} + +func NewSlabAlignedPageAllocator(config SlabAlignedPageAllocatorConfig) (*SlabAlignedPageAllocator, error) { + sort.Slice(config.SizeClasses, func(i, j int) bool { + return config.SizeClasses[i].Size < config.SizeClasses[j].Size + }) + + poolList := make([]*pools.LeakyPool[*fs.AlignedPage], len(config.SizeClasses)) + sizes := make([]int, len(config.SizeClasses)) + + for i, sc := range config.SizeClasses { + if sc.Size%fs.BLOCK_SIZE != 0 { + return nil, ErrSizeNotAligned + } + sizes[i] = sc.Size + size := sc.Size + poolList[i] = pools.NewLeakyPool(pools.LeakyPoolConfig[*fs.AlignedPage]{ + Capacity: sc.MinCount, + Meta: Meta{Size: sc.Size, Name: fmt.Sprintf("SlabAlignedPagePool-%dBytes", sc.Size)}, + CreateFunc: func() *fs.AlignedPage { return fs.NewAlignedPage(size) }, + }) + poolList[i].RegisterPreDrefHook(func(p *fs.AlignedPage) { + fs.Unmap(p) + }) + log.Debug().Msgf("SlabAlignedPageAllocator: size class - %d | min count - %d", sc.Size, sc.MinCount) + } + return &SlabAlignedPageAllocator{config: config, pools: poolList, sizes: sizes}, nil +} + +func (a *SlabAlignedPageAllocator) Get(size int) *fs.AlignedPage { + for i, s := range a.sizes { + if size <= s { + return a.pools[i].Get() + } + } + return nil +} + +func (a *SlabAlignedPageAllocator) Put(p *fs.AlignedPage) { + for i, s := range a.sizes { + if len(p.Buf) <= s { + a.pools[i].Put(p) + return + } + } + log.Error().Msgf("SlabAlignedPageAllocator: Size class not found for size %d", len(p.Buf)) +} diff --git a/flashring/internal/allocators/slab_aligned_page_allocator_test.go b/flashring/internal/allocators/slab_aligned_page_allocator_test.go new file mode 100644 index 00000000..55a187c7 --- /dev/null +++ b/flashring/internal/allocators/slab_aligned_page_allocator_test.go @@ -0,0 +1,693 @@ +package allocators + +import ( + "testing" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" +) + +func TestNewSlabAlignedPageAllocator(t *testing.T) { + t.Run("creates allocator with single aligned size class", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, // 4096 is aligned to fs.BLOCK_SIZE + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if allocator == nil { + t.Error("Expected allocator to be non-nil") + } + if allocator.config.SizeClasses[0].Size != config.SizeClasses[0].Size { + t.Errorf("Expected config to match, got %v", allocator.config) + } + if len(allocator.pools) != 1 { + t.Errorf("Expected 1 pool, got %d", len(allocator.pools)) + } + if allocator.pools[0].Meta.(Meta).Size != 4096 { + t.Errorf("Expected pool size 4096, got %d", allocator.pools[0].Meta.(Meta).Size) + } + if allocator.pools[0].Meta.(Meta).Name != "SlabAlignedPagePool-4096Bytes" { + t.Errorf("Expected pool name 'SlabAlignedPagePool-4096Bytes', got %s", allocator.pools[0].Meta.(Meta).Name) + } + }) + + t.Run("creates allocator with multiple aligned size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 8192, MinCount: 5}, // 8192 is aligned to fs.BLOCK_SIZE + {Size: 4096, MinCount: 10}, // 4096 is aligned to fs.BLOCK_SIZE + {Size: 16384, MinCount: 3}, // 16384 is aligned to fs.BLOCK_SIZE + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if allocator == nil { + t.Error("Expected allocator to be non-nil") + } + if len(allocator.pools) != 3 { + t.Errorf("Expected 3 pools, got %d", len(allocator.pools)) + } + + // Should be sorted by size + if allocator.pools[0].Meta.(Meta).Size != 4096 { + t.Errorf("Expected first pool size 4096, got %d", allocator.pools[0].Meta.(Meta).Size) + } + if allocator.pools[1].Meta.(Meta).Size != 8192 { + t.Errorf("Expected second pool size 8192, got %d", allocator.pools[1].Meta.(Meta).Size) + } + if allocator.pools[2].Meta.(Meta).Size != 16384 { + t.Errorf("Expected third pool size 16384, got %d", allocator.pools[2].Meta.(Meta).Size) + } + }) + + t.Run("creates allocator with empty size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{}, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if allocator == nil { + t.Error("Expected allocator to be non-nil") + } + if len(allocator.pools) != 0 { + t.Errorf("Expected 0 pools, got %d", len(allocator.pools)) + } + }) + + t.Run("returns error for non-aligned size class", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4097, MinCount: 10}, // 4097 is not aligned to fs.BLOCK_SIZE (4096) + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if err != ErrSizeNotAligned { + t.Errorf("Expected ErrSizeNotAligned, got %v", err) + } + if allocator != nil { + t.Error("Expected allocator to be nil on error") + } + }) + + t.Run("returns error for mixed aligned and non-aligned size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, // aligned + {Size: 3000, MinCount: 5}, // not aligned + {Size: 8192, MinCount: 3}, // aligned + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if err != ErrSizeNotAligned { + t.Errorf("Expected ErrSizeNotAligned, got %v", err) + } + if allocator != nil { + t.Error("Expected allocator to be nil on error") + } + }) +} + +func TestSlabAlignedPageAllocator_Get(t *testing.T) { + t.Run("returns aligned page for exact size match", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(4096) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + if cap(page.Buf) != 4096 { + t.Errorf("Expected page buffer capacity 4096, got %d", cap(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) + + t.Run("returns aligned page for smaller size", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(2048) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) + + t.Run("returns smallest suitable size class", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 5}, + {Size: 8192, MinCount: 10}, + {Size: 16384, MinCount: 3}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(6000) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 8192 { + t.Errorf("Expected page buffer length 8192, got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) + + t.Run("returns nil for size larger than all size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(8192) + if page != nil { + t.Error("Expected page to be nil for size larger than all size classes") + } + }) + + t.Run("returns nil for empty size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{}, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(4096) + if page != nil { + t.Error("Expected page to be nil for empty size classes") + } + }) + + t.Run("returns page for zero size request", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(0) + if page == nil { + t.Error("Expected page to be non-nil for zero size request") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) + + t.Run("returns page for negative size request", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(-1) + if page == nil { + t.Error("Expected page to be non-nil for negative size request") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) +} + +func TestSlabAlignedPageAllocator_Put(t *testing.T) { + t.Run("puts aligned page back to correct pool", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(4096) + if page == nil { + t.Fatal("Expected page to be non-nil") + } + + // Put should not panic + allocator.Put(page) + }) + + t.Run("puts page to smallest suitable pool", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 5}, + {Size: 8192, MinCount: 10}, + {Size: 16384, MinCount: 3}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Create a page manually (not from pool) + page := fs.NewAlignedPage(6000) + if page == nil { + t.Fatal("Failed to create aligned page") + } + + // Should not panic, even though page wasn't from the pool + allocator.Put(page) + }) + + t.Run("handles page larger than all size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Create a large page + page := fs.NewAlignedPage(8192) + if page == nil { + t.Fatal("Failed to create aligned page") + } + + // Should not panic, but will log error + allocator.Put(page) + + // Clean up manually since it won't be put back in pool + fs.Unmap(page) + }) + + t.Run("handles nil page", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Should not panic, but may cause issues due to nil pointer + // This test mainly ensures the method doesn't crash completely + defer func() { + if r := recover(); r != nil { + // It's expected that this might panic due to nil pointer access + t.Logf("Expected panic occurred: %v", r) + } + }() + + allocator.Put(nil) + }) +} + +func TestSlabAlignedPageAllocator_GetAndPut_Integration(t *testing.T) { + t.Run("get and put multiple times", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 2}, + {Size: 8192, MinCount: 3}, + {Size: 16384, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Get multiple pages + pages := make([]*fs.AlignedPage, 5) + for i := 0; i < 5; i++ { + pages[i] = allocator.Get(3000) // Should get 4096 size + if pages[i] == nil { + t.Errorf("Expected page %d to be non-nil", i) + } + if len(pages[i].Buf) != 4096 { + t.Errorf("Expected page %d buffer length 4096, got %d", i, len(pages[i].Buf)) + } + } + + // Put them back + for _, page := range pages { + if page != nil { + allocator.Put(page) + } + } + + // Get them again + for i := 0; i < 5; i++ { + page := allocator.Get(3000) + if page == nil { + t.Errorf("Expected page %d to be non-nil on second get", i) + } + if page != nil && len(page.Buf) != 4096 { + t.Errorf("Expected page %d buffer length 4096 on second get, got %d", i, len(page.Buf)) + } + } + }) + + t.Run("get and put with different sizes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 2}, + {Size: 8192, MinCount: 3}, + {Size: 16384, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Get pages of different sizes + page4k := allocator.Get(3000) // Should get 4096 + page8k := allocator.Get(6000) // Should get 8192 + page16k := allocator.Get(12000) // Should get 16384 + + if len(page4k.Buf) != 4096 { + t.Errorf("Expected page4k buffer length 4096, got %d", len(page4k.Buf)) + } + if len(page8k.Buf) != 8192 { + t.Errorf("Expected page8k buffer length 8192, got %d", len(page8k.Buf)) + } + if len(page16k.Buf) != 16384 { + t.Errorf("Expected page16k buffer length 16384, got %d", len(page16k.Buf)) + } + + // Put them back + allocator.Put(page4k) + allocator.Put(page8k) + allocator.Put(page16k) + + // Get them again + newPage4k := allocator.Get(3000) + newPage8k := allocator.Get(6000) + newPage16k := allocator.Get(12000) + + if len(newPage4k.Buf) != 4096 { + t.Errorf("Expected newPage4k buffer length 4096, got %d", len(newPage4k.Buf)) + } + if len(newPage8k.Buf) != 8192 { + t.Errorf("Expected newPage8k buffer length 8192, got %d", len(newPage8k.Buf)) + } + if len(newPage16k.Buf) != 16384 { + t.Errorf("Expected newPage16k buffer length 16384, got %d", len(newPage16k.Buf)) + } + }) +} + +func TestSlabAlignedPageAllocator_SizeClassSorting(t *testing.T) { + t.Run("size classes are sorted correctly", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 16384, MinCount: 3}, + {Size: 4096, MinCount: 10}, + {Size: 8192, MinCount: 5}, + {Size: 12288, MinCount: 2}, // 12288 = 3 * 4096, aligned + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Verify pools are sorted by size + if allocator.pools[0].Meta.(Meta).Size != 4096 { + t.Errorf("Expected first pool size 4096, got %d", allocator.pools[0].Meta.(Meta).Size) + } + if allocator.pools[1].Meta.(Meta).Size != 8192 { + t.Errorf("Expected second pool size 8192, got %d", allocator.pools[1].Meta.(Meta).Size) + } + if allocator.pools[2].Meta.(Meta).Size != 12288 { + t.Errorf("Expected third pool size 12288, got %d", allocator.pools[2].Meta.(Meta).Size) + } + if allocator.pools[3].Meta.(Meta).Size != 16384 { + t.Errorf("Expected fourth pool size 16384, got %d", allocator.pools[3].Meta.(Meta).Size) + } + + // Test that Get returns from the correct pool + page := allocator.Get(10000) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 12288 { + t.Errorf("Expected page buffer length 12288 (should use 12288 pool), got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) +} + +func TestSlabAlignedPageAllocator_EdgeCases(t *testing.T) { + t.Run("single size class with exact match", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(4096) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + + allocator.Put(page) + + // Get again after putting back + page2 := allocator.Get(4096) + if page2 == nil { + t.Error("Expected page2 to be non-nil") + } + if len(page2.Buf) != 4096 { + t.Errorf("Expected page2 buffer length 4096, got %d", len(page2.Buf)) + } + }) + + t.Run("duplicate size classes", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 5}, + {Size: 4096, MinCount: 10}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(allocator.pools) != 2 { + t.Errorf("Expected 2 pools, got %d", len(allocator.pools)) + } + + page := allocator.Get(4096) + if page == nil { + t.Error("Expected page to be non-nil") + } + if len(page.Buf) != 4096 { + t.Errorf("Expected page buffer length 4096, got %d", len(page.Buf)) + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) +} + +func TestSlabAlignedPageAllocator_MemoryAlignment(t *testing.T) { + t.Run("pages are properly aligned", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + page := allocator.Get(4096) + if page == nil { + t.Error("Expected page to be non-nil") + } + + // Test that we can write to the page without issues + if len(page.Buf) > 0 { + page.Buf[0] = 0x42 + page.Buf[len(page.Buf)-1] = 0x24 + + if page.Buf[0] != 0x42 { + t.Error("Failed to write to first byte of page") + } + if page.Buf[len(page.Buf)-1] != 0x24 { + t.Error("Failed to write to last byte of page") + } + } + + // Clean up + if page != nil { + fs.Unmap(page) + } + }) +} + +func TestSlabAlignedPageAllocator_PreDrefHook(t *testing.T) { + t.Run("pre deref hook is registered", func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: 4096, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // The PreDrefHook should be registered during construction + // We can't directly test the hook execution without accessing private fields + // But we can verify that pool creation succeeded + if len(allocator.pools) != 1 { + t.Errorf("Expected 1 pool to be created, got %d", len(allocator.pools)) + } + + // Test normal allocation and deallocation + page := allocator.Get(4096) + if page == nil { + t.Error("Expected page to be non-nil") + } + + // Put back should trigger the hook internally when pool is full + allocator.Put(page) + }) +} + +func TestSlabAlignedPageAllocator_AlignmentValidation(t *testing.T) { + t.Run("various alignment checks", func(t *testing.T) { + tests := []struct { + name string + size int + shouldError bool + }{ + {"aligned 4096", 4096, false}, + {"aligned 8192", 8192, false}, + {"aligned 12288", 12288, false}, + {"aligned 16384", 16384, false}, + {"unaligned 4097", 4097, true}, + {"unaligned 4000", 4000, true}, + {"unaligned 5000", 5000, true}, + {"unaligned 1024", 1024, true}, // 1024 < 4096 + {"unaligned 2048", 2048, true}, // 2048 < 4096 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := SlabAlignedPageAllocatorConfig{ + SizeClasses: []SizeClass{ + {Size: tt.size, MinCount: 1}, + }, + } + allocator, err := NewSlabAlignedPageAllocator(config) + + if tt.shouldError { + if err != ErrSizeNotAligned { + t.Errorf("Expected ErrSizeNotAligned for size %d, got %v", tt.size, err) + } + if allocator != nil { + t.Errorf("Expected nil allocator for size %d", tt.size) + } + } else { + if err != nil { + t.Errorf("Expected no error for size %d, got %v", tt.size, err) + } + if allocator == nil { + t.Errorf("Expected non-nil allocator for size %d", tt.size) + } + } + }) + } + }) +} diff --git a/flashring/internal/fs/README.md b/flashring/internal/fs/README.md new file mode 100644 index 00000000..dac08884 --- /dev/null +++ b/flashring/internal/fs/README.md @@ -0,0 +1,144 @@ +# Memtable Performance Benchmark (DirectIO + Go) + +This benchmark evaluates a single-threaded, append-only, `O_DIRECT`-backed memtable implementation in Go. The design mimics ScyllaDB’s core-local memtables and flush logic, emphasizing high throughput and stable latencies. + +## 🔧 Benchmark Configuration + +- **CPU**: AMD Ryzen 7 9800X3D +- **Memtable Write Size**: 16KB per record +- **Concurrency**: Single-threaded (8 goroutines pipelined into one locked OS thread) +- **Flush Trigger**: Memtable capacity exceeded +- **IO Mode**: DirectIO (`O_DIRECT`), Append-only +- **Benchmark Tool**: `go test -bench` + +--- + +## 📊 Performance Overview (NO_DSYNC vs DSYNC) + +| Capacity | RPS (NO_DSYNC) | Latency (ns/op) | RPS (DSYNC) | Latency (ns/op) | +|---------:|---------------:|----------------:|------------:|----------------:| +| 64KB | 785 | 1,273,903 | 482 | 2,073,246 | +| 128KB | 1,568 | 637,656 | 970 | 1,030,739 | +| 256KB | 3,214 | 311,103 | 1,934 | 517,106 | +| 512KB | 6,499 | 153,871 | 3,930 | 254,432 | +| 1MB | 12,769 | 78,317 | 7,659 | 130,561 | +| 2MB | 25,013 | 39,979 | 15,186 | 65,849 | +| 4MB | 46,907 | 21,319 | 24,932 | 40,110 | +| 8MB | 84,494 | 11,835 | 41,206 | 24,268 | +| 16MB | 138,896 | 7,200 | 50,840 | 19,670 | +| 32MB | 170,877 | 5,852 | 66,387 | 15,063 | +| 64MB | 213,214 | 4,690 | 73,646 | 13,579 | +| 128MB | 250,319 | 3,995 | 76,413 | 13,087 | +| 256MB | 88,229 | 11,334 | 76,672 | 13,043 | +| 512MB | 81,517 | 12,267 | 77,174 | 12,958 | +| 1GB | 83,717 | 11,945 | 82,203 | 12,165 | + +--- + +## 📉 Throughput vs Latency (Log Scale) + +![Throughput vs Latency](./profile.png) + +> Left axis: Throughput in MB/s (log scale) +> Right axis: Latency in ns/op (log scale) +> X-axis: Memtable size (KB, log scale) + +--- + +## 🔁 Flush Frequency Trend + +- Smaller memtables trigger frequent flushes, degrading both throughput and latency. +- Flush frequency stabilizes beyond **8–16MB**, where throughput growth starts to plateau. + +--- + +## 🔒 `runtime.LockOSThread()` Impact + +To ensure predictable syscall behavior with `O_DIRECT` (DirectIO) and aligned memory buffers, we benchmarked with and without `runtime.LockOSThread()`. + +| Capacity | RPS (No Lock) | Latency (ns/op) | RPS (LockOSThread) | Latency (ns/op) | +|---------:|--------------:|----------------:|--------------------:|----------------:| +| 128MB | ~220,000 | ~5,500 | **250,319** | **3,995** | +| 256MB | ~85,000 | ~11,000 | **88,229** | **11,334** | +| 1GB | ~81,000 | ~12,000 | **83,717** | **11,945** | + +✅ **Locking OS threads**: +- Reduces context-switching overhead +- Ensures aligned buffers remain valid (important for `O_DIRECT`) +- Prevents `EINVAL` during write() syscalls +- Better latency consistency + +--- + +## 🧠 Final Conclusions + +- **Memtable Size Matters**: Performance improves linearly with size up to 128MB. Beyond that, throughput plateaus. +- **DSYNC vs NO_DSYNC**: DSYNC incurs 1.5–2x higher latency at small sizes but converges at 512MB+. Use DSYNC if durability is essential. +- **DirectIO Requirements**: `runtime.LockOSThread()` is highly recommended for DMA-safe writes, especially in single-threaded core-local memtable designs. +- **Flush Design**: Scylla-like batching improves throughput. Flushes can be run on the same core if they yield properly between IO calls. + +--- + +## Raw Stats + +```bash +Running tool: /usr/local/go/bin/go test -benchmem -run=^$ -bench ^BenchmarkMemtable_Write16KBWorkload$ github.com/Meesho/BharatMLStack/ssd-cache/internal/memtable + +goos: linux +goarch: amd64 +pkg: github.com/Meesho/BharatMLStack/ssd-cache/internal/memtable +cpu: AMD Ryzen 7 9800X3D 8-Core Processor +BenchmarkMemtable_Write16KBWorkload/64KB-NO-DSYNC-8 950 1273903 ns/op 15532032 file_size 237.0 flushes 195.8 flushes/sec 785.0 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/128KB-NO-DSYNC-8 2079 637656 ns/op 33947648 file_size 259.0 flushes 195.4 flushes/sec 1568 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/256KB-NO-DSYNC-8 4028 311103 ns/op 65798144 file_size 251.0 flushes 200.3 flushes/sec 3214 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/512KB-NO-DSYNC-8 8194 153871 ns/op 134217728 file_size 256.0 flushes 203.0 flushes/sec 6499 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/1024KB-NO-DSYNC-8 15468 78317 ns/op 252706816 file_size 241.0 flushes 198.9 flushes/sec 12769 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/2048KB-NO-DSYNC-8 30043 39979 ns/op 490733568 file_size 234.0 flushes 194.8 flushes/sec 25013 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/4096KB-NO-DSYNC-8 56930 21319 ns/op 931135488 file_size 222.0 flushes 182.9 flushes/sec 46907 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/8192KB-NO-DSYNC-8 103630 11835 ns/op 1694498816 file_size 202.0 flushes 164.7 flushes/sec 84494 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/16384KB-NO-DSYNC-8 175530 7200 ns/op 2868903936 file_size 171.0 flushes 135.3 flushes/sec 138896 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/32768KB-NO-DSYNC-8 271888 5852 ns/op 4429185024 file_size 132.0 flushes 82.96 flushes/sec 170877 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/65536KB-NO-DSYNC-8 235149 4690 ns/op 3825205248 file_size 57.00 flushes 51.68 flushes/sec 213214 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/131072KB-NO-DSYNC-8 304314 3995 ns/op 4966055936 file_size 37.00 flushes 30.43 flushes/sec 250319 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/262144KB-NO-DSYNC-8 542956 11334 ns/op 8858370048 file_size 33.00 flushes 5.362 flushes/sec 88229 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/524288KB-NO-DSYNC-8 540237 12267 ns/op 8589934592 file_size 16.00 flushes 2.414 flushes/sec 81517 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/1048576KB-NO-DSYNC-8 555834 11945 ns/op 8589934592 file_size 8.000 flushes 1.205 flushes/sec 83717 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/64KB-DSYNC-8 591 2073246 ns/op 9633792 file_size 147.0 flushes 120.0 flushes/sec 482.3 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/128KB-DSYNC-8 1215 1030739 ns/op 19791872 file_size 151.0 flushes 120.6 flushes/sec 970.2 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/256KB-DSYNC-8 2455 517106 ns/op 40108032 file_size 153.0 flushes 120.5 flushes/sec 1934 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/512KB-DSYNC-8 5034 254432 ns/op 82313216 file_size 157.0 flushes 122.6 flushes/sec 3930 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/1024KB-DSYNC-8 10000 130561 ns/op 163577856 file_size 156.0 flushes 119.5 flushes/sec 7659 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/2048KB-DSYNC-8 18921 65849 ns/op 308281344 file_size 147.0 flushes 118.0 flushes/sec 15186 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/4096KB-DSYNC-8 30013 40110 ns/op 490733568 file_size 117.0 flushes 97.19 flushes/sec 24932 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/8192KB-DSYNC-8 49298 24268 ns/op 805306368 file_size 96.00 flushes 80.24 flushes/sec 41206 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/16384KB-DSYNC-8 66595 19670 ns/op 1090519040 file_size 65.00 flushes 49.62 flushes/sec 50840 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/32768KB-DSYNC-8 91797 15063 ns/op 1476395008 file_size 44.00 flushes 31.82 flushes/sec 66387 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/65536KB-DSYNC-8 97675 13579 ns/op 1543503872 file_size 23.00 flushes 17.34 flushes/sec 73646 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/131072KB-DSYNC-8 92379 13087 ns/op 1476395008 file_size 11.00 flushes 9.099 flushes/sec 76413 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/262144KB-DSYNC-8 561945 13043 ns/op 9126805504 file_size 34.00 flushes 4.639 flushes/sec 76672 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/524288KB-DSYNC-8 562118 12958 ns/op 9126805504 file_size 17.00 flushes 2.334 flushes/sec 77174 records/sec 0 B/op 0 allocs/op +BenchmarkMemtable_Write16KBWorkload/1048576KB-DSYNC-8 559707 12165 ns/op 8589934592 file_size 8.000 flushes 1.175 flushes/sec 82203 records/sec 0 B/op 0 allocs/op +PASS +ok github.com/Meesho/BharatMLStack/ssd-cache/internal/memtable 78.589s +``` + +## 🧪 Design Inspiration + +This experiment was inspired by **ScyllaDB’s core-local architecture**: +- Per-core memtables +- Flush triggered by memory thresholds +- IO parallelism via sharded threads + +This design brings similar performance characteristics to a Go-based system using low-level syscalls and memory alignment. + +--- + +## 📂 Future Work + +- Add WAL benchmarking +- Integrate `io_uring` for flush batching +- Explore compression + zero-copy read path + +--- + +Made with ❤️ by [BharatMLStack](https://github.com/Meesho/BharatMLStack) diff --git a/flashring/internal/fs/aligned_page.go b/flashring/internal/fs/aligned_page.go new file mode 100644 index 00000000..099ccd9d --- /dev/null +++ b/flashring/internal/fs/aligned_page.go @@ -0,0 +1,52 @@ +//go:build linux +// +build linux + +package fs + +import ( + "golang.org/x/sys/unix" +) + +const ( + PROT_READ = unix.PROT_READ + PROT_WRITE = unix.PROT_WRITE + MAP_PRIVATE = unix.MAP_PRIVATE + MAP_ANON = unix.MAP_ANON +) + +// var mmapProf = pprof.NewProfile("mmap") // will show up in /debug/pprof/ + +type AlignedPage struct { + Buf []byte + mmap []byte +} + +func NewAlignedPage(pageSize int) *AlignedPage { + b, err := unix.Mmap(-1, 0, pageSize, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANON) + if err != nil { + panic(err) + } + // if pageSize > 0 { + // mmapProf.Add(&b[0], pageSize) // attribute sz bytes to this callsite + // } + return &AlignedPage{ + Buf: b, + mmap: b, + } +} + +func Unmap(p *AlignedPage) error { + // if len(p.mmap) > 0 { + // mmapProf.Remove(&p.mmap[0]) // release from custom profile + // } + if p.mmap != nil { + err := unix.Munmap(p.mmap) + if err != nil { + return err + } + p.mmap = nil + } + p.Buf = nil + p.mmap = nil + return nil +} diff --git a/flashring/internal/fs/file_bench_test.go b/flashring/internal/fs/file_bench_test.go new file mode 100644 index 00000000..2d3da83a --- /dev/null +++ b/flashring/internal/fs/file_bench_test.go @@ -0,0 +1,161 @@ +package fs + +import ( + "path/filepath" + "testing" +) + +func BenchmarkPwrite(b *testing.B) { + tmpDir := b.TempDir() + filename := filepath.Join(tmpDir, "bench_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024 * 1024, // 1GB + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + b.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Create aligned buffer for DirectIO + data := createAlignedBuffer(4096, 4096) + for i := 0; i < 4096; i++ { + data[i] = byte(i % 256) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := raf.Pwrite(data) + if err != nil { + b.Fatalf("Pwrite failed: %v", err) + } + } +} + +func BenchmarkPread(b *testing.B) { + tmpDir := b.TempDir() + filename := filepath.Join(tmpDir, "bench_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024 * 1024, // 1GB + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + b.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Pre-populate with data using aligned buffer + writeData := createAlignedBuffer(4096, 4096) + for i := 0; i < 4096; i++ { + writeData[i] = byte(i % 256) + } + + for i := 0; i < 200000; i++ { + _, err := raf.Pwrite(writeData) + if err != nil { + b.Fatalf("Pwrite failed: %v", err) + } + } + + readData := createAlignedBuffer(4096, 4096) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + offset := int64((i % 200000) * 4096) + _, err := raf.Pread(offset, readData) + if err != nil { + b.Fatalf("Pread failed: %v", err) + } + } +} + +// Benchmarks +func BenchmarkWrapAppendFile_Pwrite(b *testing.B) { + tmpDir := b.TempDir() + filename := filepath.Join(tmpDir, "bench_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024 * 1024, // 1GB + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + b.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Create aligned buffer for DirectIO + data := createAlignedBuffer(4096, 4096) + for i := 0; i < 4096; i++ { + data[i] = byte(i % 256) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := waf.Pwrite(data) + if err != nil { + b.Fatalf("Pwrite failed: %v", err) + } + } +} + +func BenchmarkWrapAppendFile_Pread(b *testing.B) { + tmpDir := b.TempDir() + filename := filepath.Join(tmpDir, "bench_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024 * 1024, // 1GB + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + b.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Pre-populate with data using aligned buffer + writeData := createAlignedBuffer(4096, 4096) + for i := 0; i < 4096; i++ { + writeData[i] = byte(i % 256) + } + + for i := 0; i < 200000; i++ { + _, err := waf.Pwrite(writeData) + if err != nil { + b.Fatalf("Pwrite failed: %v", err) + } + } + + readData := createAlignedBuffer(4096, 4096) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + offset := int64((i % 200000) * 4096) + _, err := waf.Pread(offset, readData) + if err != nil { + b.Fatalf("Pread failed: %v", err) + } + } +} diff --git a/flashring/internal/fs/fs.go b/flashring/internal/fs/fs.go new file mode 100644 index 00000000..57a785bc --- /dev/null +++ b/flashring/internal/fs/fs.go @@ -0,0 +1,152 @@ +//go:build linux +// +build linux + +package fs + +import ( + "errors" + "fmt" + "os" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/rs/zerolog/log" + "golang.org/x/sys/unix" +) + +const ( + O_DIRECT = 0x4000 + O_WRONLY = syscall.O_WRONLY + O_RDONLY = syscall.O_RDONLY + O_APPEND = syscall.O_APPEND + O_CREAT = syscall.O_CREAT + O_DSYNC = syscall.O_DSYNC + FALLOC_FL_PUNCH_HOLE = unix.FALLOC_FL_PUNCH_HOLE + FALLOC_FL_KEEP_SIZE = unix.FALLOC_FL_KEEP_SIZE + FILE_MODE = 0644 + BLOCK_SIZE = 4096 +) + +var ( + ErrBufNoAlign = errors.New("buffer is not aligned to block size") + ErrFileSizeExceeded = errors.New("file size exceeded. Please punch hole") + ErrFileOffsetOutOfRange = errors.New("file offset is out of range") + ErrOffsetNotAligned = errors.New("offset is not aligned to block size") + ErrReadTimeout = errors.New("read timeout") +) + +type Stat struct { + WriteCount atomic.Int64 + ReadCount atomic.Int64 + PunchHoleCount atomic.Int64 + CurrentLogicalSize int64 +} + +type FileConfig struct { + Filename string + MaxFileSize int64 + FilePunchHoleSize int64 + BlockSize int +} + +type File interface { + Pwrite(buf []byte) (currentPhysicalOffset int64, err error) + Pread(fileOffset int64, buf []byte) (n int32, err error) + TrimHead() (err error) + Close() +} + +type Page interface { + Unmap() error +} + +// openWithDirectIO attempts to open a file with O_DIRECT, falling back to +// regular flags if the filesystem doesn't support it. +func openWithDirectIO(filename string, baseFlags int) (int, bool, error) { + fd, err := syscall.Open(filename, baseFlags|O_DIRECT, FILE_MODE) + if err == nil { + return fd, true, nil + } + log.Warn().Msgf("DIRECT_IO not supported, falling back to regular flags: %v", err) + fd, err = syscall.Open(filename, baseFlags, FILE_MODE) + if err != nil { + return 0, false, err + } + return fd, false, nil +} + +func fdToFile(fd int, filename string) (*os.File, error) { + file := os.NewFile(uintptr(fd), filename) + if file == nil { + return nil, fmt.Errorf("failed to create file from fd") + } + return file, nil +} + +func createAppendOnlyWriteFileDescriptor(filename string) (int, *os.File, bool, error) { + fd, directIO, err := openWithDirectIO(filename, O_WRONLY|O_CREAT|O_DSYNC) + if err != nil { + return 0, nil, false, err + } + file, err := fdToFile(fd, filename) + if err != nil { + return 0, nil, false, err + } + return fd, file, directIO, nil +} + +func createPreAllocatedWriteFileDescriptor(filename string, maxFileSize int64) (int, *os.File, bool, error) { + fd, directIO, err := openWithDirectIO(filename, O_WRONLY|O_CREAT|O_DSYNC) + if err != nil { + return 0, nil, false, err + } + + if err = unix.Fallocate(fd, 0, 0, maxFileSize); err != nil { + log.Error().Err(err).Msg("Failed to fallocate file") + syscall.Close(fd) + return 0, nil, false, err + } + + file, err := fdToFile(fd, filename) + if err != nil { + return 0, nil, false, err + } + return fd, file, directIO, nil +} + +func createReadFileDescriptor(filename string) (int, *os.File, bool, error) { + flags := O_DIRECT | O_RDONLY + fd, err := syscall.Open(filename, flags, 0) + if err != nil { + return 0, nil, false, err + } + file, err := fdToFile(fd, filename) + if err != nil { + return 0, nil, false, err + } + return fd, file, true, nil +} + +func isAlignedBuffer(buf []byte, alignment int) bool { + pt := uintptr(alignment) + if len(buf) == 0 { + return false + } + addr := uintptr(unsafe.Pointer(&buf[0])) + return addr%pt == 0 +} + +func isAlignedOffset(offset int64, alignment int) bool { + return offset%int64(alignment) == 0 +} + +// AlignRange computes the block-aligned start offset and total aligned size +// for a read spanning [offset, offset+length). Useful for O_DIRECT reads +// where both offset and buffer size must be block-aligned. +func AlignRange(offset int64, length int, blockSize int64) (alignedStart, alignedSize int64) { + alignedStart = (offset / blockSize) * blockSize + end := offset + int64(length) + alignedEnd := ((end + blockSize - 1) / blockSize) * blockSize + return alignedStart, alignedEnd - alignedStart +} diff --git a/flashring/internal/fs/profile.png b/flashring/internal/fs/profile.png new file mode 100644 index 00000000..ee759234 Binary files /dev/null and b/flashring/internal/fs/profile.png differ diff --git a/flashring/internal/fs/rolling_appendonly_file.go b/flashring/internal/fs/rolling_appendonly_file.go new file mode 100644 index 00000000..01df89ca --- /dev/null +++ b/flashring/internal/fs/rolling_appendonly_file.go @@ -0,0 +1,119 @@ +//go:build linux +// +build linux + +package fs + +import ( + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +type RollingAppendFile struct { + WriteDirectIO bool + ReadDirectIO bool + blockSize int + WriteFd int // write file descriptor + ReadFd int // read file descriptor + MaxFileSize int64 // max file size in bytes + FilePunchHoleSize int64 // file punch hole size in bytes + LogicalStartOffset int64 // logical start offset in bytes + CurrentLogicalOffset int64 // file current size in bytes + CurrentPhysicalOffset int64 // file current physical offset in bytes + WriteFile *os.File // write file + ReadFile *os.File // read file + Stat *Stat // file statistics +} + +func NewRollingAppendFile(config FileConfig) (*RollingAppendFile, error) { + filename := config.Filename + maxFileSize := config.MaxFileSize + filePunchHoleSize := config.FilePunchHoleSize + + writeFd, writeFile, wDirectIO, err := createAppendOnlyWriteFileDescriptor(filename) + if err != nil { + return nil, err + } + readFd, readFile, rDirectIO, err := createReadFileDescriptor(filename) + if err != nil { + return nil, err + } + blockSize := config.BlockSize + if blockSize == 0 { + blockSize = BLOCK_SIZE + } + return &RollingAppendFile{ + WriteDirectIO: wDirectIO, + ReadDirectIO: rDirectIO, + blockSize: blockSize, + WriteFd: writeFd, + ReadFd: readFd, + WriteFile: writeFile, + ReadFile: readFile, + MaxFileSize: maxFileSize, + FilePunchHoleSize: filePunchHoleSize, + LogicalStartOffset: 0, + CurrentLogicalOffset: 0, + CurrentPhysicalOffset: 0, + Stat: &Stat{}, + }, nil +} + +func (r *RollingAppendFile) Pwrite(buf []byte) (currentPhysicalOffset int64, err error) { + if r.CurrentLogicalOffset+int64(len(buf)) > r.MaxFileSize { + return 0, ErrFileSizeExceeded + } + if r.WriteDirectIO { + if !isAlignedBuffer(buf, r.blockSize) { + return 0, ErrBufNoAlign + } + } + n, err := syscall.Pwrite(r.WriteFd, buf, r.CurrentPhysicalOffset) + if err != nil { + return 0, err + } + r.CurrentPhysicalOffset += int64(n) + r.Stat.WriteCount.Add(1) + return r.CurrentPhysicalOffset, nil +} + +func (r *RollingAppendFile) Pread(fileOffset int64, buf []byte) (n int32, err error) { + if fileOffset < r.LogicalStartOffset || fileOffset+int64(len(buf)) > r.CurrentPhysicalOffset { + return 0, ErrFileOffsetOutOfRange + } + if r.ReadDirectIO { + if !isAlignedOffset(fileOffset, r.blockSize) { + return 0, ErrOffsetNotAligned + } + if !isAlignedBuffer(buf, r.blockSize) { + return 0, ErrBufNoAlign + } + } + syscall.Pread(r.ReadFd, buf, fileOffset) + r.Stat.ReadCount.Add(1) + return int32(len(buf)), nil +} + +func (r *RollingAppendFile) TrimHead() (err error) { + if r.WriteDirectIO { + if !isAlignedOffset(r.LogicalStartOffset, r.blockSize) { + return ErrOffsetNotAligned + } + } + err = unix.Fallocate(r.WriteFd, FALLOC_FL_PUNCH_HOLE|FALLOC_FL_KEEP_SIZE, r.LogicalStartOffset, int64(r.FilePunchHoleSize)) + if err != nil { + return err + } + r.LogicalStartOffset += int64(r.FilePunchHoleSize) + r.CurrentLogicalOffset -= int64(r.FilePunchHoleSize) + r.Stat.PunchHoleCount.Add(1) + return nil +} + +func (r *RollingAppendFile) Close() { + syscall.Close(r.WriteFd) + syscall.Close(r.ReadFd) + os.Remove(r.WriteFile.Name()) + os.Remove(r.ReadFile.Name()) +} diff --git a/flashring/internal/fs/rolling_appendonly_file_test.go b/flashring/internal/fs/rolling_appendonly_file_test.go new file mode 100644 index 00000000..ce858030 --- /dev/null +++ b/flashring/internal/fs/rolling_appendonly_file_test.go @@ -0,0 +1,502 @@ +//go:build linux +// +build linux + +package fs + +import ( + "os" + "path/filepath" + "testing" + "unsafe" +) + +// Helper function to create aligned buffers for DirectIO +func createAlignedBuffer(size, alignment int) []byte { + // Allocate more memory than needed to ensure we can find an aligned address + buf := make([]byte, size+alignment) + + // Find the aligned address + addr := uintptr(unsafe.Pointer(&buf[0])) + alignedAddr := (addr + uintptr(alignment-1)) &^ uintptr(alignment-1) + + // Calculate the offset + offset := alignedAddr - addr + + // Return the aligned slice + return buf[offset : offset+uintptr(size)] +} + +func TestNewRollingAppendFile(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, // 1MB + FilePunchHoleSize: 64 * 1024, // 64KB + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Verify initial state + if raf.MaxFileSize != config.MaxFileSize { + t.Errorf("Expected MaxFileSize %d, got %d", config.MaxFileSize, raf.MaxFileSize) + } + if raf.FilePunchHoleSize != config.FilePunchHoleSize { + t.Errorf("Expected FilePunchHoleSize %d, got %d", config.FilePunchHoleSize, raf.FilePunchHoleSize) + } + if raf.blockSize != config.BlockSize { + t.Errorf("Expected BlockSize %d, got %d", config.BlockSize, raf.blockSize) + } + if raf.CurrentLogicalOffset != 0 { + t.Errorf("Expected CurrentLogicalOffset 0, got %d", raf.CurrentLogicalOffset) + } + if raf.CurrentPhysicalOffset != 0 { + t.Errorf("Expected CurrentPhysicalOffset 0, got %d", raf.CurrentPhysicalOffset) + } +} + +func TestNewRollingAppendFile_DefaultBlockSize(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 0, // Should default to BLOCK_SIZE + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + if raf.blockSize != BLOCK_SIZE { + t.Errorf("Expected default BlockSize %d, got %d", BLOCK_SIZE, raf.blockSize) + } +} + +func TestPwrite_Success(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Create aligned buffer + data := createAlignedBuffer(4096, 4096) + for i := range data { + data[i] = byte(i % 256) + } + + offset, err := raf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + if offset != int64(len(data)) { + t.Errorf("Expected offset %d, got %d", len(data), offset) + } + + if raf.CurrentPhysicalOffset != int64(len(data)) { + t.Errorf("Expected CurrentPhysicalOffset %d, got %d", len(data), raf.CurrentPhysicalOffset) + } + + if raf.Stat.WriteCount.Load() != 1 { + t.Errorf("Expected WriteCount 1, got %d", raf.Stat.WriteCount.Load()) + } +} + +func TestPwrite_FileSizeExceeded(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024, // Small max size + FilePunchHoleSize: 512, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Try to write more than max file size + data := make([]byte, 2048) + + _, err = raf.Pwrite(data) + if err != ErrFileSizeExceeded { + t.Errorf("Expected ErrFileSizeExceeded, got %v", err) + } +} + +func TestPwrite_BufferNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Only test if using DirectIO + if raf.WriteDirectIO { + // Create unaligned buffer + data := make([]byte, 4097) // Not aligned to 4096 + + _, err = raf.Pwrite(data) + if err != ErrBufNoAlign { + t.Errorf("Expected ErrBufNoAlign, got %v", err) + } + } +} + +func TestPread_Success(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Write some data first + writeData := createAlignedBuffer(4096, 4096) + for i := range writeData { + writeData[i] = byte(i % 256) + } + + _, err = raf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Read the data back + readData := createAlignedBuffer(4096, 4096) + n, err := raf.Pread(0, readData) + if err != nil { + t.Fatalf("Pread failed: %v", err) + } + + if n != int32(len(readData)) { + t.Errorf("Expected read length %d, got %d", len(readData), n) + } + + // Verify data matches + for i := range readData { + if readData[i] != writeData[i] { + t.Errorf("Data mismatch at index %d: expected %d, got %d", i, writeData[i], readData[i]) + } + } + + if raf.Stat.ReadCount.Load() != 1 { + t.Errorf("Expected ReadCount 1, got %d", raf.Stat.ReadCount.Load()) + } +} + +func TestPread_FileOffsetOutOfRange(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Try to read without writing anything + readData := createAlignedBuffer(4096, 4096) + _, err = raf.Pread(0, readData) + if err != ErrFileOffsetOutOfRange { + t.Errorf("Expected ErrFileOffsetOutOfRange, got %v", err) + } + + // Write some data + writeData := createAlignedBuffer(4096, 4096) + _, err = raf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Try to read beyond written data + _, err = raf.Pread(4096, readData) + if err != ErrFileOffsetOutOfRange { + t.Errorf("Expected ErrFileOffsetOutOfRange, got %v", err) + } +} + +func TestPread_OffsetNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Only test if using DirectIO + if raf.ReadDirectIO { + // Write some data first + writeData := createAlignedBuffer(8192, 4096) + _, err = raf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Try to read from unaligned offset + readData := createAlignedBuffer(4096, 4096) + _, err = raf.Pread(100, readData) // Not aligned to 4096 + if err != ErrOffsetNotAligned { + t.Errorf("Expected ErrOffsetNotAligned, got %v", err) + } + } +} + +func TestTrimHead_Success(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, // One block + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Write some data first + writeData := createAlignedBuffer(8192, 4096) // 2 blocks + _, err = raf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Trim head + err = raf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + + // Verify state changes + if raf.LogicalStartOffset != int64(config.FilePunchHoleSize) { + t.Errorf("Expected LogicalStartOffset %d, got %d", config.FilePunchHoleSize, raf.LogicalStartOffset) + } + + if raf.Stat.PunchHoleCount.Load() != 1 { + t.Errorf("Expected PunchHoleCount 1, got %d", raf.Stat.PunchHoleCount.Load()) + } +} + +func TestIsAlignedOffset(t *testing.T) { + tests := []struct { + name string + offset int64 + alignment int + expected bool + }{ + {"aligned_0", 0, 4096, true}, + {"aligned_4096", 4096, 4096, true}, + {"aligned_8192", 8192, 4096, true}, + {"unaligned_100", 100, 4096, false}, + {"unaligned_4097", 4097, 4096, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isAlignedOffset(tt.offset, tt.alignment) + if result != tt.expected { + t.Errorf("isAlignedOffset(%d, %d) = %v, expected %v", tt.offset, tt.alignment, result, tt.expected) + } + }) + } +} + +func TestMultipleOperations(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Write multiple blocks + for i := 0; i < 5; i++ { + data := createAlignedBuffer(4096, 4096) + for j := range data { + data[j] = byte((i*256 + j) % 256) + } + + _, err = raf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite %d failed: %v", i, err) + } + } + + // Verify total written + expectedPhysicalOffset := int64(5 * 4096) + if raf.CurrentPhysicalOffset != expectedPhysicalOffset { + t.Errorf("Expected CurrentPhysicalOffset %d, got %d", expectedPhysicalOffset, raf.CurrentPhysicalOffset) + } + + // Read back data from different offsets + for i := 0; i < 5; i++ { + readData := createAlignedBuffer(4096, 4096) + _, err = raf.Pread(int64(i*4096), readData) + if err != nil { + t.Fatalf("Pread %d failed: %v", i, err) + } + + // Verify data integrity + for j := range readData { + expected := byte((i*256 + j) % 256) + if readData[j] != expected { + t.Errorf("Data mismatch at block %d, index %d: expected %d, got %d", i, j, expected, readData[j]) + } + } + } + + // Verify statistics + if raf.Stat.WriteCount.Load() != 5 { + t.Errorf("Expected WriteCount 5, got %d", raf.Stat.WriteCount.Load()) + } + if raf.Stat.ReadCount.Load() != 5 { + t.Errorf("Expected ReadCount 5, got %d", raf.Stat.ReadCount.Load()) + } +} + +func TestStatistics(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_rolling_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + raf, err := NewRollingAppendFile(config) + if err != nil { + t.Fatalf("Failed to create RollingAppendFile: %v", err) + } + defer cleanup(raf) + + // Initial state + if raf.Stat.WriteCount.Load() != 0 { + t.Errorf("Expected initial WriteCount 0, got %d", raf.Stat.WriteCount.Load()) + } + if raf.Stat.ReadCount.Load() != 0 { + t.Errorf("Expected initial ReadCount 0, got %d", raf.Stat.ReadCount.Load()) + } + if raf.Stat.PunchHoleCount.Load() != 0 { + t.Errorf("Expected initial PunchHoleCount 0, got %d", raf.Stat.PunchHoleCount.Load()) + } + + // Perform operations and verify statistics + data := createAlignedBuffer(4096, 4096) + + // Write operation + _, err = raf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + if raf.Stat.WriteCount.Load() != 1 { + t.Errorf("Expected WriteCount 1, got %d", raf.Stat.WriteCount.Load()) + } + + // Read operation + _, err = raf.Pread(0, data) + if err != nil { + t.Fatalf("Pread failed: %v", err) + } + if raf.Stat.ReadCount.Load() != 1 { + t.Errorf("Expected ReadCount 1, got %d", raf.Stat.ReadCount.Load()) + } + + // Trim operation + err = raf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + if raf.Stat.PunchHoleCount.Load() != 1 { + t.Errorf("Expected PunchHoleCount 1, got %d", raf.Stat.PunchHoleCount.Load()) + } +} + +// Helper function to clean up resources +func cleanup(raf *RollingAppendFile) { + if raf.WriteFile != nil { + raf.WriteFile.Close() + } + if raf.ReadFile != nil { + raf.ReadFile.Close() + } + if raf.WriteFile != nil { + os.Remove(raf.WriteFile.Name()) + } +} diff --git a/flashring/internal/fs/wrap_file.go b/flashring/internal/fs/wrap_file.go new file mode 100644 index 00000000..d8ebc1bb --- /dev/null +++ b/flashring/internal/fs/wrap_file.go @@ -0,0 +1,280 @@ +//go:build linux +// +build linux + +package fs + +import ( + "os" + "syscall" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/iouring" + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" + "golang.org/x/sys/unix" +) + +type WrapAppendFile struct { + WriteDirectIO bool + ReadDirectIO bool + wrapped bool + blockSize int + WriteFd int // write file descriptor + ReadFd int // read file descriptor + MaxFileSize int64 // max file size in bytes + FilePunchHoleSize int64 // file punch hole size in bytes + PhysicalStartOffset int64 // physical start offset in bytes + LogicalCurrentOffset int64 // file current size in bytes + PhysicalWriteOffset int64 // file current physical offset in bytes + WriteFile *os.File // write file + ReadFile *os.File // read file + Stat *Stat // file statistics + WriteRing *iouring.IoUringWriter // io_uring writer for batched writes +} + +func NewWrapAppendFile(config FileConfig) (*WrapAppendFile, error) { + filename := config.Filename + maxFileSize := config.MaxFileSize + filePunchHoleSize := config.FilePunchHoleSize + + writeFd, writeFile, wDirectIO, err := createPreAllocatedWriteFileDescriptor(filename, maxFileSize) + if err != nil { + return nil, err + } + readFd, readFile, rDirectIO, err := createReadFileDescriptor(filename) + if err != nil { + return nil, err + } + blockSize := config.BlockSize + if blockSize == 0 { + blockSize = BLOCK_SIZE + } + return &WrapAppendFile{ + WriteDirectIO: wDirectIO, + ReadDirectIO: rDirectIO, + blockSize: blockSize, + WriteFd: writeFd, + ReadFd: readFd, + WriteFile: writeFile, + ReadFile: readFile, + MaxFileSize: maxFileSize, + FilePunchHoleSize: filePunchHoleSize, + PhysicalStartOffset: 0, + LogicalCurrentOffset: 0, + PhysicalWriteOffset: 0, + Stat: &Stat{}, + }, nil +} + +func (r *WrapAppendFile) Pwrite(buf []byte) (currentPhysicalOffset int64, err error) { + if r.WriteDirectIO { + if !isAlignedBuffer(buf, r.blockSize) { + return 0, ErrBufNoAlign + } + } + var startTime = time.Now() + n, err := syscall.Pwrite(r.WriteFd, buf, r.PhysicalWriteOffset) + metrics.Timing(metrics.KEY_PWRITE_LATENCY, time.Since(startTime), []string{}) + if err != nil { + return 0, err + } + + r.PhysicalWriteOffset += int64(n) + if r.PhysicalWriteOffset >= r.MaxFileSize { + r.wrapped = true + r.PhysicalWriteOffset = 0 + } + r.LogicalCurrentOffset += int64(n) + r.Stat.WriteCount.Add(1) + + return r.PhysicalWriteOffset, nil +} + +// PwriteBatch writes a large buffer in chunkSize pieces via io_uring. +// Chunks are submitted in sub-batches that fit within the ring's SQ depth, +// so arbitrarily large buffers work regardless of ring size. +// Returns total bytes written and the final PhysicalWriteOffset. +func (r *WrapAppendFile) PwriteBatch(buf []byte, chunkSize int) (totalWritten int, fileOffset int64, err error) { + if r.WriteDirectIO { + if !isAlignedBuffer(buf, r.blockSize) { + return 0, 0, ErrBufNoAlign + } + } + + // Maximum SQEs per submission -- capped to ring depth. + maxPerBatch := r.WriteRing.MaxBatchSize() + + for written := 0; written < len(buf); { + // Build a sub-batch that fits within the ring + var bufs [][]byte + var offsets []uint64 + + for i := 0; i < maxPerBatch && written < len(buf); i++ { + end := written + chunkSize + if end > len(buf) { + end = len(buf) + } + bufs = append(bufs, buf[written:end]) + offsets = append(offsets, uint64(r.PhysicalWriteOffset)) + + // Advance write offset, handle ring-buffer wrap + r.PhysicalWriteOffset += int64(end - written) + if r.PhysicalWriteOffset >= r.MaxFileSize { + r.wrapped = true + r.PhysicalWriteOffset = 0 + } + written = end + } + + results, serr := r.WriteRing.SubmitWriteBatch(r.WriteFd, bufs, offsets) + if serr != nil { + return totalWritten, r.PhysicalWriteOffset, serr + } + + for _, n := range results { + totalWritten += n + r.LogicalCurrentOffset += int64(n) + r.Stat.WriteCount.Add(1) + } + } + + return totalWritten, r.PhysicalWriteOffset, nil +} + +// AdvanceWriteOffset moves the write pointer forward by n bytes without +// writing any data. Used to skip over uninitialized regions in staggered +// memtables so that the memId*capacity file layout is preserved. +func (r *WrapAppendFile) AdvanceWriteOffset(n int64) { + r.PhysicalWriteOffset += n + if r.PhysicalWriteOffset >= r.MaxFileSize { + r.wrapped = true + r.PhysicalWriteOffset = 0 + } + r.LogicalCurrentOffset += n +} + +func (r *WrapAppendFile) TrimHeadIfNeeded() bool { + if r.wrapped && r.PhysicalWriteOffset == r.PhysicalStartOffset { + return true + } + return false +} + +// isValidReadRegion checks whether [physOffset, physEnd) falls entirely within +// the ring's live data. +// +// When wrapped, the writer and start pointers chase each other around +// [0, MaxFileSize). Three sub-cases: +// +// W == S → ring is full (writer just caught up); entire file is valid. +// W < S → two valid segments: [S, MaxFileSize) and [0, W). +// W > S → contiguous [S, W); the region [W, MaxFileSize) was punched +// by the TrimHead that sent S past W's position. +func (r *WrapAppendFile) isValidReadRegion(physOffset, physEnd int64) bool { + if !r.wrapped { + return physOffset >= r.PhysicalStartOffset && physEnd <= r.PhysicalWriteOffset + } + W, S := r.PhysicalWriteOffset, r.PhysicalStartOffset + if W == S { + return physEnd <= r.MaxFileSize + } + if W < S { + return (physOffset >= S && physEnd <= r.MaxFileSize) || + (physOffset >= 0 && physEnd <= W) + } + return physOffset >= S && physEnd <= W +} + +func (r *WrapAppendFile) Pread(fileOffset int64, buf []byte) (int32, error) { + if r.ReadDirectIO { + if !isAlignedOffset(fileOffset, r.blockSize) { + return 0, ErrOffsetNotAligned + } + if !isAlignedBuffer(buf, r.blockSize) { + return 0, ErrBufNoAlign + } + } + + physOffset := fileOffset + if r.wrapped { + physOffset = fileOffset % r.MaxFileSize + } + physEnd := physOffset + int64(len(buf)) + + if !r.isValidReadRegion(physOffset, physEnd) { + return 0, ErrFileOffsetOutOfRange + } + + var startTime = time.Now() + n, err := syscall.Pread(r.ReadFd, buf, physOffset) + metrics.Timing(metrics.KEY_PREAD_LATENCY, time.Since(startTime), []string{}) + if err != nil { + return 0, err + } + r.Stat.ReadCount.Add(1) + return int32(n), nil +} + +// ValidateReadOffset checks the read window and wraps the offset for ring-buffer +// files. Returns the physical file offset to use, or an error. +// Mirrors the validation logic in Pread so callers that use the batched +// io_uring path get identical safety checks. +func (r *WrapAppendFile) ValidateReadOffset(fileOffset int64, bufLen int) (int64, error) { + if r.ReadDirectIO { + if !isAlignedOffset(fileOffset, r.blockSize) { + return 0, ErrOffsetNotAligned + } + } + + physOffset := fileOffset + if r.wrapped { + physOffset = fileOffset % r.MaxFileSize + } + physEnd := physOffset + int64(bufLen) + + if !r.isValidReadRegion(physOffset, physEnd) { + return 0, ErrFileOffsetOutOfRange + } + + return physOffset, nil +} + +func (r *WrapAppendFile) TrimHead() (err error) { + + var startTime = time.Now() + if r.WriteDirectIO { + if !isAlignedOffset(r.PhysicalStartOffset, r.blockSize) { + return ErrOffsetNotAligned + } + } + err = unix.Fallocate(r.WriteFd, FALLOC_FL_PUNCH_HOLE|FALLOC_FL_KEEP_SIZE, r.PhysicalStartOffset, int64(r.FilePunchHoleSize)) + if err != nil { + return err + } + r.PhysicalStartOffset += int64(r.FilePunchHoleSize) + if r.PhysicalStartOffset >= r.MaxFileSize { + r.PhysicalStartOffset = 0 + } + r.Stat.PunchHoleCount.Add(1) + metrics.Incr(metrics.KEY_PUNCH_HOLE_COUNT, []string{}) + metrics.Timing(metrics.KEY_TRIM_HEAD_LATENCY, time.Since(startTime), []string{}) + return nil +} + +func (r *WrapAppendFile) Close() { + syscall.Close(r.WriteFd) + syscall.Close(r.ReadFd) + os.Remove(r.WriteFile.Name()) + os.Remove(r.ReadFile.Name()) +} + +func preadv2(fd int, buf []byte, off int64, flags int) (int, error) { + if len(buf) == 0 { + return 0, nil + } + n, err := unix.Preadv2(fd, [][]byte{buf}, off, flags) + // Kernel or FS may not support preadv2/flags; fall back + if err == unix.ENOSYS || err == unix.EOPNOTSUPP || err == unix.EINVAL { + return unix.Pread(fd, buf, off) + } + return n, err +} diff --git a/flashring/internal/fs/wrap_file_test.go b/flashring/internal/fs/wrap_file_test.go new file mode 100644 index 00000000..45375ec5 --- /dev/null +++ b/flashring/internal/fs/wrap_file_test.go @@ -0,0 +1,840 @@ +//go:build linux +// +build linux + +package fs + +import ( + "os" + "path/filepath" + "testing" +) + +func TestNewWrapAppendFile(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, // 1MB + FilePunchHoleSize: 64 * 1024, // 64KB + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Verify initial state + if waf.MaxFileSize != config.MaxFileSize { + t.Errorf("Expected MaxFileSize %d, got %d", config.MaxFileSize, waf.MaxFileSize) + } + if waf.FilePunchHoleSize != config.FilePunchHoleSize { + t.Errorf("Expected FilePunchHoleSize %d, got %d", config.FilePunchHoleSize, waf.FilePunchHoleSize) + } + if waf.blockSize != config.BlockSize { + t.Errorf("Expected BlockSize %d, got %d", config.BlockSize, waf.blockSize) + } + if waf.LogicalCurrentOffset != 0 { + t.Errorf("Expected LogicalCurrentOffset 0, got %d", waf.LogicalCurrentOffset) + } + if waf.PhysicalWriteOffset != 0 { + t.Errorf("Expected PhysicalWriteOffset 0, got %d", waf.PhysicalWriteOffset) + } + if waf.PhysicalStartOffset != 0 { + t.Errorf("Expected PhysicalStartOffset 0, got %d", waf.PhysicalStartOffset) + } + if waf.wrapped { + t.Errorf("Expected wrapped to be false initially") + } +} + +func TestNewWrapAppendFile_DefaultBlockSize(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 0, // Should default to BLOCK_SIZE + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + if waf.blockSize != BLOCK_SIZE { + t.Errorf("Expected default BlockSize %d, got %d", BLOCK_SIZE, waf.blockSize) + } +} + +func TestWrapAppendFile_Pwrite_Success(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Create aligned buffer + data := createAlignedBuffer(4096, 4096) + for i := range data { + data[i] = byte(i % 256) + } + + offset, err := waf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + if offset != int64(len(data)) { + t.Errorf("Expected offset %d, got %d", len(data), offset) + } + + if waf.PhysicalWriteOffset != int64(len(data)) { + t.Errorf("Expected PhysicalWriteOffset %d, got %d", len(data), waf.PhysicalWriteOffset) + } + + if waf.LogicalCurrentOffset != int64(len(data)) { + t.Errorf("Expected LogicalCurrentOffset %d, got %d", len(data), waf.LogicalCurrentOffset) + } + + if waf.Stat.WriteCount.Load() != 1 { + t.Errorf("Expected WriteCount 1, got %d", waf.Stat.WriteCount.Load()) + } + + if waf.wrapped { + t.Errorf("Expected wrapped to be false") + } +} + +func TestPwrite_WrapAround(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 8192, // Small max size for easy wrapping + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Write first block + data1 := createAlignedBuffer(4096, 4096) + for i := range data1 { + data1[i] = byte(1) + } + + _, err = waf.Pwrite(data1) + if err != nil { + t.Fatalf("First Pwrite failed: %v", err) + } + + if waf.wrapped { + t.Errorf("Should not be wrapped after first write") + } + + // Write second block - should trigger wrap + data2 := createAlignedBuffer(4096, 4096) + for i := range data2 { + data2[i] = byte(2) + } + + offset, err := waf.Pwrite(data2) + if err != nil { + t.Fatalf("Second Pwrite failed: %v", err) + } + + // After wrapping, should be at PhysicalStartOffset + if !waf.wrapped { + t.Errorf("Should be wrapped after exceeding MaxFileSize") + } + + if waf.PhysicalWriteOffset != waf.PhysicalStartOffset { + t.Errorf("Expected PhysicalWriteOffset %d after wrap, got %d", waf.PhysicalStartOffset, waf.PhysicalWriteOffset) + } + + if offset != waf.PhysicalStartOffset { + t.Errorf("Expected return offset %d after wrap, got %d", waf.PhysicalStartOffset, offset) + } + + if waf.LogicalCurrentOffset != int64(8192) { + t.Errorf("Expected LogicalCurrentOffset %d, got %d", 8192, waf.LogicalCurrentOffset) + } +} + +func TestWrapAppendFile_Pwrite_BufferNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Only test if using DirectIO + if waf.WriteDirectIO { + // Create unaligned buffer + data := make([]byte, 4097) // Not aligned to 4096 + + _, err = waf.Pwrite(data) + if err != ErrBufNoAlign { + t.Errorf("Expected ErrBufNoAlign, got %v", err) + } + } +} + +func TestPread_Success_NoWrap(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Write some data first + writeData := createAlignedBuffer(4096, 4096) + for i := range writeData { + writeData[i] = byte(i % 256) + } + + _, err = waf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Read the data back + readData := createAlignedBuffer(4096, 4096) + n, err := waf.Pread(0, readData) + if err != nil { + t.Fatalf("Pread failed: %v", err) + } + + if n != int32(len(readData)) { + t.Errorf("Expected read length %d, got %d", len(readData), n) + } + + // Verify data matches + for i := range readData { + if readData[i] != writeData[i] { + t.Errorf("Data mismatch at index %d: expected %d, got %d", i, writeData[i], readData[i]) + } + } + + if waf.Stat.ReadCount.Load() != 1 { + t.Errorf("Expected ReadCount 1, got %d", waf.Stat.ReadCount.Load()) + } +} + +func TestPread_Success_WithWrap(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 8192, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Fill the file: data1 at [0,4096), data2 at [4096,8192). + data1 := createAlignedBuffer(4096, 4096) + for i := range data1 { + data1[i] = byte(1) + } + if _, err = waf.Pwrite(data1); err != nil { + t.Fatalf("First Pwrite failed: %v", err) + } + + data2 := createAlignedBuffer(4096, 4096) + for i := range data2 { + data2[i] = byte(2) + } + if _, err = waf.Pwrite(data2); err != nil { + t.Fatalf("Second Pwrite failed: %v", err) + } + + // Second Pwrite reaches MaxFileSize → writer wraps to 0. + if !waf.wrapped { + t.Fatalf("Expected wrapped to be true after filling file") + } + if waf.PhysicalWriteOffset != 0 { + t.Fatalf("Expected PhysicalWriteOffset=0 after wrap, got %d", waf.PhysicalWriteOffset) + } + + // W == S == 0 (full ring): both regions are readable. + readData := createAlignedBuffer(4096, 4096) + if _, err = waf.Pread(0, readData); err != nil { + t.Fatalf("Pread [0,4096) in full ring failed: %v", err) + } + if _, err = waf.Pread(4096, readData); err != nil { + t.Fatalf("Pread [4096,8192) in full ring failed: %v", err) + } + + // TrimHead: punch [0,4096), S advances to 4096. + if err = waf.TrimHead(); err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + if waf.PhysicalStartOffset != 4096 { + t.Fatalf("Expected PhysicalStartOffset=4096 after trim, got %d", waf.PhysicalStartOffset) + } + + // Now W(0) < S(4096). Valid region: [4096, 8192). + // Punched region [0,4096) must be rejected. + if _, err = waf.Pread(0, readData); err != ErrFileOffsetOutOfRange { + t.Errorf("Pread from punched region should fail, got %v", err) + } + n, err := waf.Pread(4096, readData) + if err != nil { + t.Fatalf("Pread from old tail failed: %v", err) + } + if n != 4096 { + t.Errorf("Expected read length 4096, got %d", n) + } + for i := range readData { + if readData[i] != byte(2) { + t.Errorf("Old tail data mismatch at %d: expected 2, got %d", i, readData[i]) + break + } + } + + // Write data3 into the freed region [0,4096). W advances to 4096. + data3 := createAlignedBuffer(4096, 4096) + for i := range data3 { + data3[i] = byte(3) + } + if _, err = waf.Pwrite(data3); err != nil { + t.Fatalf("Third Pwrite failed: %v", err) + } + + // W == S == 4096 (full ring again): both regions readable. + readOld := createAlignedBuffer(4096, 4096) + n, err = waf.Pread(4096, readOld) + if err != nil { + t.Fatalf("Pread old tail [4096,8192) after refill failed: %v", err) + } + if n != 4096 { + t.Errorf("Expected 4096 bytes, got %d", n) + } + for i := range readOld { + if readOld[i] != byte(2) { + t.Errorf("Old tail mismatch at %d: expected 2, got %d", i, readOld[i]) + break + } + } + + readNew := createAlignedBuffer(4096, 4096) + n, err = waf.Pread(0, readNew) + if err != nil { + t.Fatalf("Pread new data [0,4096) after refill failed: %v", err) + } + if n != 4096 { + t.Errorf("Expected 4096 bytes, got %d", n) + } + for i := range readNew { + if readNew[i] != byte(3) { + t.Errorf("New data mismatch at %d: expected 3, got %d", i, readNew[i]) + break + } + } +} + +func TestPread_FileOffsetOutOfRange_NoWrap(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Try to read without writing anything + readData := createAlignedBuffer(4096, 4096) + _, err = waf.Pread(0, readData) + if err != ErrFileOffsetOutOfRange { + t.Errorf("Expected ErrFileOffsetOutOfRange, got %v", err) + } + + // Write some data + writeData := createAlignedBuffer(4096, 4096) + _, err = waf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Try to read beyond written data + _, err = waf.Pread(4096, readData) + if err != ErrFileOffsetOutOfRange { + t.Errorf("Expected ErrFileOffsetOutOfRange, got %v", err) + } +} + +func TestPread_FileOffsetOutOfRange_WithWrap(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 8192, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Write 2 blocks to cause wrapping (MaxFileSize=8192, block=4096) + for i := 0; i < 2; i++ { + data := createAlignedBuffer(4096, 4096) + _, err = waf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite %d failed: %v", i, err) + } + } + + if !waf.wrapped { + t.Errorf("Expected wrapped to be true") + } + + // After 2 writes: PhysicalStartOffset=0, PhysicalWriteOffset=0 (wrapped). + // TrimHead to advance PhysicalStartOffset to 4096, creating a gap at [0, 4096). + err = waf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + + // State: PhysicalStartOffset=4096, PhysicalWriteOffset=0, wrapped=true + // Valid regions: [4096, 8192) and [0, 0) (empty). + // Gap: [0, 4096) — reading from offset 0 should fail. + readData := createAlignedBuffer(4096, 4096) + _, err = waf.Pread(0, readData) + if err != ErrFileOffsetOutOfRange { + t.Errorf("Expected ErrFileOffsetOutOfRange for gap read, got %v", err) + } +} + +func TestWrapAppendFile_Pread_OffsetNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Only test if using DirectIO + if waf.ReadDirectIO { + // Write some data first + writeData := createAlignedBuffer(8192, 4096) + _, err = waf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Try to read from unaligned offset + readData := createAlignedBuffer(4096, 4096) + _, err = waf.Pread(100, readData) // Not aligned to 4096 + if err != ErrOffsetNotAligned { + t.Errorf("Expected ErrOffsetNotAligned, got %v", err) + } + } +} + +func TestPread_BufferNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Only test if using DirectIO + if waf.ReadDirectIO { + // Write some data first + writeData := createAlignedBuffer(4096, 4096) + _, err = waf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + // Try to read with unaligned buffer + readData := make([]byte, 4097) // Not aligned + _, err = waf.Pread(0, readData) + if err != ErrBufNoAlign { + t.Errorf("Expected ErrBufNoAlign, got %v", err) + } + } +} + +func TestWrapAppendFile_TrimHead_Success(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, // One block + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Write some data first + writeData := createAlignedBuffer(8192, 4096) // 2 blocks + _, err = waf.Pwrite(writeData) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + + initialStartOffset := waf.PhysicalStartOffset + + // Trim head + err = waf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + + // Verify state changes + expectedStartOffset := initialStartOffset + int64(config.FilePunchHoleSize) + if waf.PhysicalStartOffset != expectedStartOffset { + t.Errorf("Expected PhysicalStartOffset %d, got %d", expectedStartOffset, waf.PhysicalStartOffset) + } + + if waf.Stat.PunchHoleCount.Load() != 1 { + t.Errorf("Expected PunchHoleCount 1, got %d", waf.Stat.PunchHoleCount.Load()) + } +} + +func TestTrimHead_WrapAround(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 8192, + FilePunchHoleSize: 8192, // Same as max file size + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Set PhysicalStartOffset to near end + waf.PhysicalStartOffset = 4096 + + // Trim head - should wrap around to 0 + err = waf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + + // Should wrap to 0 since 4096 + 8192 >= 8192 + if waf.PhysicalStartOffset != 0 { + t.Errorf("Expected PhysicalStartOffset to wrap to 0, got %d", waf.PhysicalStartOffset) + } +} + +func TestTrimHead_OffsetNotAligned(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Only test if using DirectIO + if waf.WriteDirectIO { + // Set unaligned PhysicalStartOffset + waf.PhysicalStartOffset = 100 + + err = waf.TrimHead() + if err != ErrOffsetNotAligned { + t.Errorf("Expected ErrOffsetNotAligned, got %v", err) + } + } +} + +func TestPwrite_WrapAndContinue(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 8192, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + for i := 0; i < 2; i++ { + data := createAlignedBuffer(4096, 4096) + _, err = waf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite %d failed: %v", i, err) + } + } + + if !waf.wrapped { + t.Errorf("Expected wrapped to be true") + } + + // After wrapping, PhysicalWriteOffset resets to PhysicalStartOffset. + // A subsequent write overwrites at that position (trim is done at + // the shard level via DeleteManager, not by Pwrite itself). + data := createAlignedBuffer(4096, 4096) + for i := range data { + data[i] = 0xAB + } + _, err = waf.Pwrite(data) + if err != nil { + t.Fatalf("Post-wrap Pwrite failed: %v", err) + } + + if waf.Stat.WriteCount.Load() != 3 { + t.Errorf("Expected WriteCount 3, got %d", waf.Stat.WriteCount.Load()) + } + + if waf.LogicalCurrentOffset != int64(3*4096) { + t.Errorf("Expected LogicalCurrentOffset %d, got %d", 3*4096, waf.LogicalCurrentOffset) + } +} + +func TestClose(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 64 * 1024, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + + // Verify file exists + if _, err := os.Stat(filename); os.IsNotExist(err) { + t.Errorf("File should exist before Close") + } + + // Close and verify cleanup + waf.Close() + + // File should be removed + if _, err := os.Stat(filename); !os.IsNotExist(err) { + t.Errorf("File should be removed after Close") + } +} + +func TestWrapAppendFile_MultipleOperations(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 16384, // 4 blocks + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Write multiple blocks to test wrap behavior + for i := 0; i < 6; i++ { // More than max file size / block size + data := createAlignedBuffer(4096, 4096) + for j := range data { + data[j] = byte((i*256 + j) % 256) + } + + _, err = waf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite %d failed: %v", i, err) + } + } + + // Should be wrapped + if !waf.wrapped { + t.Errorf("Expected wrapped to be true after writing 6 blocks") + } + + // Verify logical offset continues to grow + expectedLogicalOffset := int64(6 * 4096) + if waf.LogicalCurrentOffset != expectedLogicalOffset { + t.Errorf("Expected LogicalCurrentOffset %d, got %d", expectedLogicalOffset, waf.LogicalCurrentOffset) + } + + // Verify statistics + if waf.Stat.WriteCount.Load() != 6 { + t.Errorf("Expected WriteCount 6, got %d", waf.Stat.WriteCount.Load()) + } +} + +func TestWrapAppendFile_Statistics(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_wrap_file.dat") + + config := FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, + FilePunchHoleSize: 4096, + BlockSize: 4096, + } + + waf, err := NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create WrapAppendFile: %v", err) + } + defer cleanupWrapFile(waf) + + // Initial state + if waf.Stat.WriteCount.Load() != 0 { + t.Errorf("Expected initial WriteCount 0, got %d", waf.Stat.WriteCount.Load()) + } + if waf.Stat.ReadCount.Load() != 0 { + t.Errorf("Expected initial ReadCount 0, got %d", waf.Stat.ReadCount.Load()) + } + if waf.Stat.PunchHoleCount.Load() != 0 { + t.Errorf("Expected initial PunchHoleCount 0, got %d", waf.Stat.PunchHoleCount.Load()) + } + + // Perform operations and verify statistics + data := createAlignedBuffer(4096, 4096) + + // Write operation + _, err = waf.Pwrite(data) + if err != nil { + t.Fatalf("Pwrite failed: %v", err) + } + if waf.Stat.WriteCount.Load() != 1 { + t.Errorf("Expected WriteCount 1, got %d", waf.Stat.WriteCount.Load()) + } + + // Read operation + _, err = waf.Pread(0, data) + if err != nil { + t.Fatalf("Pread failed: %v", err) + } + if waf.Stat.ReadCount.Load() != 1 { + t.Errorf("Expected ReadCount 1, got %d", waf.Stat.ReadCount.Load()) + } + + // Trim operation + err = waf.TrimHead() + if err != nil { + t.Fatalf("TrimHead failed: %v", err) + } + if waf.Stat.PunchHoleCount.Load() != 1 { + t.Errorf("Expected PunchHoleCount 1, got %d", waf.Stat.PunchHoleCount.Load()) + } +} + +// Helper function to clean up resources for WrapAppendFile +func cleanupWrapFile(waf *WrapAppendFile) { + if waf.WriteFile != nil { + waf.WriteFile.Close() + } + if waf.ReadFile != nil { + waf.ReadFile.Close() + } + if waf.WriteFile != nil { + os.Remove(waf.WriteFile.Name()) + } +} diff --git a/flashring/internal/index/constant.go b/flashring/internal/index/constant.go new file mode 100644 index 00000000..0b87ea43 --- /dev/null +++ b/flashring/internal/index/constant.go @@ -0,0 +1,21 @@ +package index + +const ( + LENGTH_MASK = (1 << 16) - 1 + DELTA_EXPTIME_MASK = (1 << 16) - 1 + LAST_ACCESS_MASK = (1 << 16) - 1 + FREQ_MASK = (1 << 16) - 1 + PREV_MASK = (1 << 32) - 1 + NEXT_MASK = (1 << 32) - 1 + + MEM_ID_MASK = (1 << 32) - 1 + OFFSET_MASK = (1 << 32) - 1 + + LENGTH_SHIFT = 48 + DELTA_EXPTIME_SHIFT = 32 + LAST_ACCESS_SHIFT = 16 + FREQ_SHIFT = 0 + + MEM_ID_SHIFT = 32 + OFFSET_SHIFT = 0 +) diff --git a/flashring/internal/index/delete_manager.go b/flashring/internal/index/delete_manager.go new file mode 100644 index 00000000..ae395647 --- /dev/null +++ b/flashring/internal/index/delete_manager.go @@ -0,0 +1,77 @@ +package index + +import ( + "errors" + "fmt" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/rs/zerolog/log" +) + +type DeleteManager struct { + memtableData map[uint32]int + toBeDeletedMemId uint32 + keyIndex *Index + wrapFile *fs.WrapAppendFile + deleteInProgress bool + deleteAmortizedStep int + deleteCount int +} + +func NewDeleteManager(keyIndex *Index, wrapFile *fs.WrapAppendFile, deleteAmortizedStep int) *DeleteManager { + return &DeleteManager{ + memtableData: make(map[uint32]int), + keyIndex: keyIndex, + wrapFile: wrapFile, + deleteAmortizedStep: deleteAmortizedStep, + } +} + +func (dm *DeleteManager) IncMemtableKeyCount(memId uint32) { + dm.memtableData[memId]++ +} + +func (dm *DeleteManager) ExecuteDeleteIfNeeded() error { + if dm.deleteInProgress { + memtableId, count := dm.keyIndex.Delete(dm.deleteCount) + if count == -1 { + return fmt.Errorf("delete failed") + } + if memtableId != dm.toBeDeletedMemId { + dm.memtableData[dm.toBeDeletedMemId] -= count + log.Debug().Msgf("memtableId: %d, toBeDeletedMemId: %d", memtableId, dm.toBeDeletedMemId) + if dm.memtableData[dm.toBeDeletedMemId] != 0 { + return fmt.Errorf("memtableData[dm.toBeDeletedMemId] != 0") + } + delete(dm.memtableData, dm.toBeDeletedMemId) + dm.toBeDeletedMemId = memtableId + dm.deleteInProgress = false + dm.deleteCount = 0 + return nil + } + dm.memtableData[memtableId] -= count + return nil + } + + trimNeeded := dm.wrapFile.TrimHeadIfNeeded() + nextAddNeedsDelete := dm.keyIndex.GetRB().NextAddNeedsDelete() + + if trimNeeded || nextAddNeedsDelete { + dm.deleteInProgress = true + dm.deleteCount = dm.memtableData[dm.toBeDeletedMemId] / dm.deleteAmortizedStep + if dm.deleteCount == 0 { + dm.deleteCount = dm.memtableData[dm.toBeDeletedMemId] % dm.deleteAmortizedStep + } + memIdAtHead, err := dm.keyIndex.PeekMemIdAtHead() + if err != nil { + return err + } + if memIdAtHead != dm.toBeDeletedMemId { + return fmt.Errorf("memIdAtHead: %d, toBeDeletedMemId: %d", memIdAtHead, dm.toBeDeletedMemId) + } + + dm.wrapFile.TrimHead() + return errors.New("trim needed retry this write") + } + return nil +} diff --git a/flashring/internal/index/encoder.go b/flashring/internal/index/encoder.go new file mode 100644 index 00000000..7638277a --- /dev/null +++ b/flashring/internal/index/encoder.go @@ -0,0 +1,74 @@ +package index + +func encode(key string, length, deltaExptime, lastAccess, freq uint16, memId, offset uint32, entry *Entry) { + d1 := uint64(length&LENGTH_MASK) << LENGTH_SHIFT + d1 |= uint64(deltaExptime&DELTA_EXPTIME_MASK) << DELTA_EXPTIME_SHIFT + d1 |= uint64(lastAccess&LAST_ACCESS_MASK) << LAST_ACCESS_SHIFT + d1 |= uint64(freq&FREQ_MASK) << FREQ_SHIFT + + ByteOrder.PutUint64(entry[:8], d1) + + d2 := uint64(memId&MEM_ID_MASK) << MEM_ID_SHIFT + d2 |= uint64(offset&OFFSET_MASK) << OFFSET_SHIFT + + ByteOrder.PutUint64(entry[8:16], d2) +} + +func encodeHashNextPrev(hhi, hlo uint64, prev, next int32, entry *HashNextPrev) { + entry[0] = hhi + entry[1] = hlo + entry[2] = uint64(uint32(prev))<<32 | uint64(uint32(next)) +} + +func encodeUpdatePrev(prev int32, entry *HashNextPrev) { + next := entry[2] & NEXT_MASK + entry[2] = uint64(uint32(prev))<<32 | next +} + +func encodeUpdateNext(next int32, entry *HashNextPrev) { + prev := (entry[2] >> 32) & PREV_MASK + entry[2] = uint64(uint32(prev))<<32 | uint64(uint32(next)) +} + +func decodeNext(entry *HashNextPrev) int32 { + return int32(uint32(entry[2] & NEXT_MASK)) +} + +func decodePrev(entry *HashNextPrev) int32 { + return int32(uint32(entry[2]>>32) & PREV_MASK) +} + +func decodeHashLo(entry *HashNextPrev) uint64 { + return entry[1] +} + +func decode(entry *Entry) (length, deltaExptime, lastAccess, freq uint16, memId, offset uint32) { + d1 := ByteOrder.Uint64(entry[:8]) + d2 := ByteOrder.Uint64(entry[8:16]) + + length = uint16(d1>>LENGTH_SHIFT) & LENGTH_MASK + deltaExptime = uint16(d1>>DELTA_EXPTIME_SHIFT) & DELTA_EXPTIME_MASK + lastAccess = uint16(d1>>LAST_ACCESS_SHIFT) & LAST_ACCESS_MASK + freq = uint16(d1>>FREQ_SHIFT) & FREQ_MASK + + memId = uint32(d2>>MEM_ID_SHIFT) & MEM_ID_MASK + offset = uint32(d2>>OFFSET_SHIFT) & OFFSET_MASK + + return +} + +func encodeLastAccessNFreq(lastAccess, freq uint16, entry *Entry) { + d1 := ByteOrder.Uint64(entry[:8]) + // Clear lastAccess (bits 31:16) and freq (bits 15:0) before writing new values. + d1 &^= uint64(LAST_ACCESS_MASK)<>MEM_ID_SHIFT) & MEM_ID_MASK + offset = uint32(d2>>OFFSET_SHIFT) & OFFSET_MASK + return +} diff --git a/flashring/internal/index/index.go b/flashring/internal/index/index.go new file mode 100644 index 00000000..26ca1d4b --- /dev/null +++ b/flashring/internal/index/index.go @@ -0,0 +1,178 @@ +package index + +import ( + "errors" + "sync" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/maths" + "github.com/cespare/xxhash/v2" + "github.com/zeebo/xxh3" +) + +var ErrGettingHeadEntry = errors.New("getting head entry failed") + +type Status int + +const ( + StatusOK Status = iota + StatusNotFound + StatusExpired +) + +type Index struct { + mu *sync.RWMutex + rm map[uint64]int + rb *RingBuffer + mc *maths.MorrisLogCounter + startAt int64 + hashBits int +} + +func NewIndex(hashBits int, rbInitial, rbMax, deleteAmortizedStep int, mu *sync.RWMutex) *Index { + return &Index{ + mu: mu, + rm: make(map[uint64]int), + rb: NewRingBuffer(rbInitial, rbMax), + mc: maths.New(), + startAt: time.Now().Unix(), + hashBits: hashBits, + } +} + +func (i *Index) Put(key string, length, ttlInMinutes uint16, memId, offset uint32) { + hhi, hlo := hash128(key) + entry, hashNextPrev, idx, _ := i.rb.GetNextFreeSlot() + lastAccess := i.generateLastAccess() + freq := uint16(1) + expiryAt := (time.Now().Unix() / 60) + int64(ttlInMinutes) + delta := uint16(expiryAt - (i.startAt / 60)) + encode(key, length, delta, lastAccess, freq, memId, offset, entry) + + if headIdx, ok := i.rm[hlo]; !ok { + encodeHashNextPrev(hhi, hlo, -1, -1, hashNextPrev) + i.rm[hlo] = idx + } else { + _, headHashNextPrev, _ := i.rb.Get(int(headIdx)) + encodeUpdatePrev(int32(idx), headHashNextPrev) + encodeHashNextPrev(hhi, hlo, -1, int32(headIdx), hashNextPrev) + i.rm[hlo] = idx + } +} + +func (i *Index) Get(key string) (length, lastAccess, remainingTTL uint16, freq uint64, memId, offset uint32, status Status) { + hhi, hlo := hash128(key) + + i.mu.RLock() + idx, ok := i.rm[hlo] + i.mu.RUnlock() + + if !ok { + return 0, 0, 0, 0, 0, 0, StatusNotFound + } + + for { + entry, hashNextPrev, _ := i.rb.Get(int(idx)) + if isHashMatch(hhi, hlo, hashNextPrev) { + length, deltaExptime, oldLastAccess, freq, memId, offset := decode(entry) + exptime := int(deltaExptime) + int(i.startAt/60) + currentTime := int(time.Now().Unix() / 60) + remainingTTL := exptime - currentTime + if remainingTTL <= 0 { + return 0, 0, 0, 0, 0, 0, StatusExpired + } + newLastAccess := i.generateLastAccess() + recency := newLastAccess - oldLastAccess // minutes since previous access + freq = i.incrFreq(freq, hlo) + encodeLastAccessNFreq(newLastAccess, freq, entry) + return length, recency, uint16(remainingTTL), i.mc.Value(uint16(freq)), memId, offset, StatusOK + } + if hasNext(hashNextPrev) { + idx = int(decodeNext(hashNextPrev)) + } else { + return 0, 0, 0, 0, 0, 0, StatusNotFound + } + } +} + +func (ix *Index) Delete(count int) (uint32, int) { + if count == 0 { + return 0, 0 + } + for i := 0; i < count; i++ { + deleted, deletedHashNextPrev, deletedIdx, next := ix.rb.Delete() + if deleted == nil { + return 0, -1 + } + delMemId, _ := DecodeMemIdOffset(deleted) + deletedHlo := decodeHashLo(deletedHashNextPrev) + mapIdx, ok := ix.rm[deletedHlo] + if ok && mapIdx == deletedIdx { + delete(ix.rm, deletedHlo) + } else if ok && hasPrev(deletedHashNextPrev) { + prevIdx := decodePrev(deletedHashNextPrev) + _, hashNextPrev, _ := ix.rb.Get(int(prevIdx)) + encodeUpdateNext(-1, hashNextPrev) + } + + nextMemId, _ := DecodeMemIdOffset(next) + if nextMemId == delMemId+1 { + return nextMemId, i + 1 + } else if nextMemId == delMemId && i == count-1 { + return delMemId, i + 1 + } else if nextMemId == delMemId { + continue + } else { + return 0, -1 + } + } + return 0, -1 +} + +// DeleteKey removes the key from the index map only. Debug use only. +func (ix *Index) DeleteKey(key string) bool { + _, hlo := hash128(key) + if _, ok := ix.rm[hlo]; !ok { + return false + } + delete(ix.rm, hlo) + return true +} + +func (ki *Index) GetRB() *RingBuffer { + return ki.rb +} + +func (ki *Index) PeekMemIdAtHead() (uint32, error) { + entry, _, ok := ki.rb.Get(ki.rb.head) + if !ok { + return 0, ErrGettingHeadEntry + } + memId, _ := DecodeMemIdOffset(entry) + return memId, nil +} + +func (i *Index) generateLastAccess() uint16 { + return uint16((time.Now().Unix() - i.startAt) / 60) +} + +func (i *Index) incrFreq(freq uint16, hlo uint64) uint16 { + newFreq, _ := i.mc.Inc(uint16(freq), hlo) + return uint16(newFreq) +} + +func hash128(key string) (uint64, uint64) { + return xxhash.Sum64String(key), xxh3.HashString(key) +} + +func isHashMatch(hhi, hlo uint64, entry *HashNextPrev) bool { + return entry[0] == hhi && entry[1] == hlo +} + +func hasNext(entry *HashNextPrev) bool { + return int32(entry[2]&NEXT_MASK) != -1 +} + +func hasPrev(entry *HashNextPrev) bool { + return int32((entry[2]>>32)&PREV_MASK) != -1 +} diff --git a/flashring/internal/index/ringbuffer.go b/flashring/internal/index/ringbuffer.go new file mode 100644 index 00000000..02214403 --- /dev/null +++ b/flashring/internal/index/ringbuffer.go @@ -0,0 +1,71 @@ +package index + +// Entry represents a 16-byte index entry. +type Entry [16]byte + +// HashNextPrev stores the dual hash (for collision detection) and linked-list +// pointers for chaining entries that share the same hash-lo bucket. +type HashNextPrev [3]uint64 + +// RingBuffer is a fixed-size circular queue. It maintains a sliding window +// of the most recent entries and wraps around when full, overwriting the oldest. +type RingBuffer struct { + buf []Entry + hashTable []HashNextPrev + head int + tail int + size int + nextIndex int + capacity int + wrapped bool +} + +func NewRingBuffer(initial, max int) *RingBuffer { + if initial <= 0 || initial > max { + panic("invalid capacity") + } + capacity := max + return &RingBuffer{ + buf: make([]Entry, capacity), + hashTable: make([]HashNextPrev, capacity), + capacity: capacity, + } +} + +func (rb *RingBuffer) NextAddNeedsDelete() bool { + return rb.nextIndex == rb.head && rb.wrapped +} + +func (rb *RingBuffer) GetNextFreeSlot() (*Entry, *HashNextPrev, int, bool) { + idx := rb.nextIndex + rb.nextIndex = (rb.nextIndex + 1) % rb.capacity + shouldDelete := false + if rb.nextIndex == rb.head { + rb.wrapped = true + shouldDelete = true + } + return &rb.buf[idx], &rb.hashTable[idx], idx, shouldDelete +} + +func (rb *RingBuffer) Get(index int) (*Entry, *HashNextPrev, bool) { + if index > rb.capacity { + return nil, nil, false + } + return &rb.buf[index], &rb.hashTable[index], true +} + +func (rb *RingBuffer) Delete() (*Entry, *HashNextPrev, int, *Entry) { + deletedIdx := rb.head + deleted := rb.buf[rb.head] + deletedHashNextPrev := rb.hashTable[rb.head] + rb.head = (rb.head + 1) % rb.capacity + return &deleted, &deletedHashNextPrev, deletedIdx, &rb.buf[rb.head] +} + +func (rb *RingBuffer) TailIndex() int { + return rb.nextIndex +} + +func (rb *RingBuffer) ActiveEntries() int { + return (rb.nextIndex - rb.head + rb.capacity) % rb.capacity +} diff --git a/flashring/internal/index/system.go b/flashring/internal/index/system.go new file mode 100644 index 00000000..0a3868b7 --- /dev/null +++ b/flashring/internal/index/system.go @@ -0,0 +1,54 @@ +package index + +import ( + "encoding/binary" + "unsafe" +) + +var ByteOrder *CustomByteOrder + +func init() { + loadByteOrder() +} + +type CustomByteOrder struct { + binary.ByteOrder +} + +func loadByteOrder() { + buf := [2]byte{} + *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD) + + switch buf { + case [2]byte{0xCD, 0xAB}: + ByteOrder = &CustomByteOrder{binary.LittleEndian} + case [2]byte{0xAB, 0xCD}: + ByteOrder = &CustomByteOrder{binary.BigEndian} + default: + panic("Could not determine endianness.") + } +} + +func (c *CustomByteOrder) PutInt64(b []byte, v int64) { + c.PutUint64(b, uint64(v)) +} + +func (c *CustomByteOrder) Int64(b []byte) int64 { + return int64(c.Uint64(b)) +} + +func (c *CustomByteOrder) PutInt32(b []byte, v int32) { + c.PutUint32(b, uint32(v)) +} + +func (c *CustomByteOrder) Int32(b []byte) int32 { + return int32(c.Uint32(b)) +} + +func (c *CustomByteOrder) PutUint32(b []byte, v uint32) { + c.ByteOrder.PutUint32(b, v) +} + +func (c *CustomByteOrder) Uint32(b []byte) uint32 { + return c.ByteOrder.Uint32(b) +} diff --git a/flashring/internal/iouring/batch_writer.go b/flashring/internal/iouring/batch_writer.go new file mode 100644 index 00000000..fbf1ec56 --- /dev/null +++ b/flashring/internal/iouring/batch_writer.go @@ -0,0 +1,284 @@ +//go:build linux +// +build linux + +package iouring + +import ( + "fmt" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" +) + +// WriteResult holds the outcome of a single io_uring pwrite. +type WriteResult struct { + N int + Err error +} + +// batchWriteRequest is a pwrite submitted to the batch writer. +type batchWriteRequest struct { + fd int + buf []byte + offset uint64 + done chan WriteResult +} + +// BatchIoUringWriter decouples submission from completion for write operations, +// mirroring the BatchIoUringReader pattern. The mutex is held only during SQE +// preparation + io_uring_enter (~1-5μs), not during CQE drain. +// +// submitLoop: reqCh → collect batch → prep SQEs → io_uring_enter → loop +// completeLoop: waitCqe → dispatch result to caller → loop +type BatchIoUringWriter struct { + ring *IoUring + reqCh chan *batchWriteRequest + maxBatch int + closeCh chan struct{} + wg sync.WaitGroup + + inflight []atomic.Pointer[batchWriteRequest] + freeSlots chan uint32 + pending atomic.Int32 +} + +// NewBatchIoUringWriter creates a decoupled batch writer with its own io_uring +// ring and starts the submit + completion goroutines. +func NewBatchIoUringWriter(cfg BatchIoUringConfig) (*BatchIoUringWriter, error) { + if cfg.RingDepth == 0 { + cfg.RingDepth = 256 + } + ringDepth := int(cfg.RingDepth) + + maxInflight := cfg.MaxInflight + if maxInflight <= 0 || maxInflight > ringDepth { + maxInflight = ringDepth + } + if cfg.MaxBatch <= 0 || cfg.MaxBatch > maxInflight { + cfg.MaxBatch = maxInflight + } + if cfg.QueueSize == 0 { + cfg.QueueSize = 1024 + } + + ring, err := NewIoUring(cfg.RingDepth, 0) + if err != nil { + return nil, fmt.Errorf("batch io_uring writer init: %w", err) + } + + freeSlots := make(chan uint32, maxInflight) + for i := 0; i < maxInflight; i++ { + freeSlots <- uint32(i) + } + + b := &BatchIoUringWriter{ + ring: ring, + reqCh: make(chan *batchWriteRequest, cfg.QueueSize), + maxBatch: cfg.MaxBatch, + closeCh: make(chan struct{}), + inflight: make([]atomic.Pointer[batchWriteRequest], ringDepth), + freeSlots: freeSlots, + } + b.wg.Add(2) + go b.submitLoop() + go b.completeLoop() + return b, nil +} + +// MaxBatchSize returns the ring depth, which is the maximum number of SQEs +// that can be in-flight at once. +func (b *BatchIoUringWriter) MaxBatchSize() int { + return int(b.ring.sqEntries) +} + +// SubmitWriteBatch submits N pwrite operations and waits for all completions. +// Thread-safe. Unlike the old IoUring.SubmitWriteBatch, the ring mutex is NOT +// held during CQE drain — other batches can be submitted concurrently. +func (b *BatchIoUringWriter) SubmitWriteBatch(fd int, bufs [][]byte, offsets []uint64) ([]int, error) { + n := len(bufs) + if n == 0 { + return nil, nil + } + + startTime := time.Now() + + // Submit all write requests into the channel. The submitLoop will + // collect them into batches and prep SQEs. + doneChans := make([]chan WriteResult, n) + for i := 0; i < n; i++ { + req := &batchWriteRequest{ + fd: fd, + buf: bufs[i], + offset: offsets[i], + done: make(chan WriteResult, 1), + } + doneChans[i] = req.done + b.reqCh <- req + } + + // Collect all completions. + results := make([]int, n) + for i := 0; i < n; i++ { + res := <-doneChans[i] + if res.Err != nil { + return results, res.Err + } + results[i] = res.N + metrics.Timing(metrics.KEY_PWRITE_LATENCY, time.Since(startTime), []string{}) + } + + return results, nil +} + +// Close shuts down both goroutines and releases the io_uring ring. +func (b *BatchIoUringWriter) Close() { + close(b.closeCh) + + b.ring.mu.Lock() + sqe := b.ring.getSqe() + if sqe != nil { + sqe.Opcode = iouringOpNop + sqe.UserData = sentinelUserData + b.ring.submit(0) + } + b.ring.mu.Unlock() + + b.wg.Wait() + b.ring.Close() +} + +// submitLoop collects write requests and submits them as io_uring SQEs. +// Mutex held only during SQE prep + io_uring_enter. +func (b *BatchIoUringWriter) submitLoop() { + defer b.wg.Done() + + batch := make([]*batchWriteRequest, 0, b.maxBatch) + slots := make([]uint32, 0, b.maxBatch) + + for { + select { + case req := <-b.reqCh: + batch = append(batch, req) + case <-b.closeCh: + return + } + + // Non-blocking drain. + for len(batch) < b.maxBatch { + select { + case req := <-b.reqCh: + batch = append(batch, req) + default: + goto submit + } + } + + submit: + for i, req := range batch { + select { + case slot := <-b.freeSlots: + slots = append(slots, slot) + b.inflight[slot].Store(req) + case <-b.closeCh: + for j := i; j < len(batch); j++ { + batch[j].done <- WriteResult{Err: fmt.Errorf("io_uring writer: shutting down")} + } + return + } + } + + b.ring.mu.Lock() + + prepared := 0 + for i, slot := range slots { + sqe := b.ring.getSqe() + if sqe == nil { + for j := i; j < len(slots); j++ { + req := b.inflight[slots[j]].Swap(nil) + b.freeSlots <- slots[j] + if req != nil { + req.done <- WriteResult{ + Err: fmt.Errorf("io_uring writer: SQ full, batch=%d depth=%d", len(batch), b.ring.sqEntries), + } + } + } + break + } + prepWrite(sqe, batch[i].fd, batch[i].buf, batch[i].offset) + sqe.UserData = uint64(slot) + prepared++ + } + + if prepared > 0 { + b.pending.Add(int32(prepared)) + _, err := b.ring.submit(0) + if err != nil { + b.pending.Add(-int32(prepared)) + for i := 0; i < prepared; i++ { + req := b.inflight[slots[i]].Swap(nil) + b.freeSlots <- slots[i] + if req != nil { + req.done <- WriteResult{Err: fmt.Errorf("io_uring_enter: %w", err)} + } + } + } + } + + b.ring.mu.Unlock() + + batch = batch[:0] + slots = slots[:0] + } +} + +// completeLoop drains CQEs and dispatches results to callers. +func (b *BatchIoUringWriter) completeLoop() { + defer b.wg.Done() + + for { + cqe, err := b.ring.waitCqe() + if err != nil { + select { + case <-b.closeCh: + if b.pending.Load() <= 0 { + return + } + default: + } + continue + } + + userData := cqe.UserData + res := cqe.Res + b.ring.seenCqe() + + if userData == sentinelUserData { + if b.pending.Load() <= 0 { + return + } + continue + } + + slot := uint32(userData) + b.pending.Add(-1) + + req := b.inflight[slot].Swap(nil) + b.freeSlots <- slot + + if req == nil { + continue + } + + if res < 0 { + req.done <- WriteResult{ + Err: fmt.Errorf("io_uring pwrite errno %d (%s), fd=%d off=%d len=%d", + -res, syscall.Errno(-res), req.fd, req.offset, len(req.buf)), + } + } else { + req.done <- WriteResult{N: int(res)} + } + } +} diff --git a/flashring/internal/iouring/iouring.go b/flashring/internal/iouring/iouring.go new file mode 100644 index 00000000..721fc9c2 --- /dev/null +++ b/flashring/internal/iouring/iouring.go @@ -0,0 +1,583 @@ +//go:build linux +// +build linux + +// Package iouring provides a minimal io_uring implementation using raw syscalls. +// No external dependencies beyond golang.org/x/sys/unix are needed. +// Compatible with Go 1.24+ (no go:linkname usage). +package iouring + +import ( + "fmt" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" + "golang.org/x/sys/unix" +) + +// ----------------------------------------------------------------------- +// io_uring syscall numbers (amd64) +// ----------------------------------------------------------------------- + +const ( + sysIOUringSetup = 425 + sysIOUringEnter = 426 + sysIOUringRegister = 427 +) + +// ----------------------------------------------------------------------- +// io_uring constants +// ----------------------------------------------------------------------- + +const ( + // Setup flags + iouringSetupSQPoll = 1 << 1 + + // Enter flags + iouringEnterGetEvents = 1 << 0 + iouringEnterSQWakeup = 1 << 1 + + // SQ flags (read from kernel-shared memory) + iouringSQNeedWakeup = 1 << 0 + + // Opcodes + iouringOpNop = 0 + iouringOpRead = 22 + iouringOpWrite = 23 + + // offsets for mmap + iouringOffSQRing = 0 + iouringOffCQRing = 0x8000000 + iouringOffSQEs = 0x10000000 +) + +// ----------------------------------------------------------------------- +// io_uring kernel structures (must match kernel ABI exactly) +// ----------------------------------------------------------------------- + +// ioUringSqe is the 64-byte submission queue entry. +type ioUringSqe struct { + Opcode uint8 + Flags uint8 + IoPrio uint16 + Fd int32 + Off uint64 // union: off / addr2 + Addr uint64 // union: addr / splice_off_in + Len uint32 + OpFlags uint32 // union: rw_flags, etc. + UserData uint64 + BufIndex uint16 // union: buf_index / buf_group + _ uint16 // personality + _ int32 // splice_fd_in / file_index + _ uint64 // addr3 + _ uint64 // __pad2[0] +} + +// ioUringCqe is the 16-byte completion queue entry. +type ioUringCqe struct { + UserData uint64 + Res int32 + Flags uint32 +} + +// ioUringParams is passed to io_uring_setup. +type ioUringParams struct { + SqEntries uint32 + CqEntries uint32 + Flags uint32 + SqThreadCPU uint32 + SqThreadIdle uint32 + Features uint32 + WqFd uint32 + Resv [3]uint32 + SqOff ioUringSqringOffsets + CqOff ioUringCqringOffsets +} + +type ioUringSqringOffsets struct { + Head uint32 + Tail uint32 + RingMask uint32 + RingEntries uint32 + Flags uint32 + Dropped uint32 + Array uint32 + Resv1 uint32 + Resv2 uint64 +} + +type ioUringCqringOffsets struct { + Head uint32 + Tail uint32 + RingMask uint32 + RingEntries uint32 + Overflow uint32 + Cqes uint32 + Flags uint32 + Resv1 uint32 + Resv2 uint64 +} + +// ----------------------------------------------------------------------- +// IoUring is the main ring handle +// ----------------------------------------------------------------------- + +// IoUring wraps a single io_uring instance with SQ/CQ ring mappings. +type IoUring struct { + fd int + + // SQ ring mapped memory + sqRingPtr []byte + sqMask uint32 + sqEntries uint32 + sqHead *uint32 // kernel-updated + sqTail *uint32 // user-updated + sqFlags *uint32 // kernel-updated (NEED_WAKEUP etc.) + sqArray unsafe.Pointer + sqeTail uint32 // local tracking of next SQE slot + sqeHead uint32 // local tracking of submitted SQEs + sqesMmap []byte + sqesBase unsafe.Pointer // base pointer to SQE array + sqRingSz int + cqRingSz int + sqesSz int + singleMmap bool + + // CQ ring mapped memory + cqRingPtr []byte + cqMask uint32 + cqEntries uint32 + cqHead *uint32 // user-updated + cqTail *uint32 // kernel-updated + cqesBase unsafe.Pointer + + // Setup flags + flags uint32 + + // Mutex for concurrent SQE submission from multiple goroutines + mu sync.Mutex + + // Diagnostic counter -- limits debug output to first N failures + debugCount int +} + +// NewIoUring creates a new io_uring instance with the given queue depth. +// flags can be 0 for normal mode. +func NewIoUring(entries uint32, flags uint32) (*IoUring, error) { + var params ioUringParams + params.Flags = flags + if flags&iouringSetupSQPoll != 0 { + params.SqThreadIdle = 2000 // kernel poll thread sleeps after 2s idle + } + + fd, _, errno := syscall.Syscall(sysIOUringSetup, uintptr(entries), uintptr(unsafe.Pointer(¶ms)), 0) + if errno != 0 { + return nil, fmt.Errorf("io_uring_setup failed: %w", errno) + } + + ring := &IoUring{ + fd: int(fd), + flags: params.Flags, + } + + if err := ring.mapRings(¶ms); err != nil { + syscall.Close(ring.fd) + return nil, err + } + + return ring, nil +} + +func (r *IoUring) mapRings(p *ioUringParams) error { + sqOff := &p.SqOff + cqOff := &p.CqOff + + // Calculate SQ ring size + r.sqRingSz = int(sqOff.Array + p.SqEntries*4) // Array + entries*sizeof(uint32) + + // Calculate CQ ring size + r.cqRingSz = int(cqOff.Cqes + p.CqEntries*uint32(unsafe.Sizeof(ioUringCqe{}))) + + // Check if kernel supports single mmap for both rings + r.singleMmap = (p.Features & 1) != 0 // IORING_FEAT_SINGLE_MMAP = 1 + if r.singleMmap { + if r.cqRingSz > r.sqRingSz { + r.sqRingSz = r.cqRingSz + } + } + + // Map SQ ring + var err error + r.sqRingPtr, err = unix.Mmap(r.fd, iouringOffSQRing, r.sqRingSz, + unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + return fmt.Errorf("mmap SQ ring: %w", err) + } + + // Map CQ ring (same or separate mapping) + if r.singleMmap { + r.cqRingPtr = r.sqRingPtr + } else { + r.cqRingPtr, err = unix.Mmap(r.fd, iouringOffCQRing, r.cqRingSz, + unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + unix.Munmap(r.sqRingPtr) + return fmt.Errorf("mmap CQ ring: %w", err) + } + } + + // Map SQE array + r.sqesSz = int(p.SqEntries) * int(unsafe.Sizeof(ioUringSqe{})) + r.sqesMmap, err = unix.Mmap(r.fd, iouringOffSQEs, r.sqesSz, + unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + unix.Munmap(r.sqRingPtr) + if !r.singleMmap { + unix.Munmap(r.cqRingPtr) + } + return fmt.Errorf("mmap SQEs: %w", err) + } + r.sqesBase = unsafe.Pointer(&r.sqesMmap[0]) + + // Set up SQ ring pointers + sqBase := unsafe.Pointer(&r.sqRingPtr[0]) + r.sqHead = (*uint32)(unsafe.Add(sqBase, sqOff.Head)) + r.sqTail = (*uint32)(unsafe.Add(sqBase, sqOff.Tail)) + r.sqFlags = (*uint32)(unsafe.Add(sqBase, sqOff.Flags)) + r.sqMask = *(*uint32)(unsafe.Add(sqBase, sqOff.RingMask)) + r.sqEntries = *(*uint32)(unsafe.Add(sqBase, sqOff.RingEntries)) + r.sqArray = unsafe.Add(sqBase, sqOff.Array) + + // Set up CQ ring pointers + cqBase := unsafe.Pointer(&r.cqRingPtr[0]) + r.cqHead = (*uint32)(unsafe.Add(cqBase, cqOff.Head)) + r.cqTail = (*uint32)(unsafe.Add(cqBase, cqOff.Tail)) + r.cqMask = *(*uint32)(unsafe.Add(cqBase, cqOff.RingMask)) + r.cqEntries = *(*uint32)(unsafe.Add(cqBase, cqOff.RingEntries)) + r.cqesBase = unsafe.Add(cqBase, cqOff.Cqes) + + return nil +} + +// Close releases all resources associated with the ring. +func (r *IoUring) Close() { + unix.Munmap(r.sqesMmap) + unix.Munmap(r.sqRingPtr) + if !r.singleMmap { + unix.Munmap(r.cqRingPtr) + } + syscall.Close(r.fd) +} + +// ----------------------------------------------------------------------- +// SQE helpers +// ----------------------------------------------------------------------- + +func (r *IoUring) getSqeAt(idx uint32) *ioUringSqe { + return (*ioUringSqe)(unsafe.Add(r.sqesBase, uintptr(idx)*unsafe.Sizeof(ioUringSqe{}))) +} + +func (r *IoUring) getCqeAt(idx uint32) *ioUringCqe { + return (*ioUringCqe)(unsafe.Add(r.cqesBase, uintptr(idx)*unsafe.Sizeof(ioUringCqe{}))) +} + +func (r *IoUring) sqArrayAt(idx uint32) *uint32 { + return (*uint32)(unsafe.Add(r.sqArray, uintptr(idx)*4)) +} + +// getSqe returns the next available SQE, or nil if the SQ is full. +func (r *IoUring) getSqe() *ioUringSqe { + head := atomic.LoadUint32(r.sqHead) + next := r.sqeTail + 1 + if next-head > r.sqEntries { + return nil // SQ full + } + sqe := r.getSqeAt(r.sqeTail & r.sqMask) + r.sqeTail++ + // Zero out the SQE + *sqe = ioUringSqe{} + return sqe +} + +// flushSq flushes locally queued SQEs into the kernel-visible SQ ring. +func (r *IoUring) flushSq() uint32 { + tail := *r.sqTail + toSubmit := r.sqeTail - r.sqeHead + if toSubmit == 0 { + return tail - atomic.LoadUint32(r.sqHead) + } + for ; toSubmit > 0; toSubmit-- { + *r.sqArrayAt(tail & r.sqMask) = r.sqeHead & r.sqMask + tail++ + r.sqeHead++ + } + atomic.StoreUint32(r.sqTail, tail) + return tail - atomic.LoadUint32(r.sqHead) +} + +// ----------------------------------------------------------------------- +// Submission and completion +// ----------------------------------------------------------------------- + +func ioUringEnter(fd int, toSubmit, minComplete, flags uint32) (int, error) { + ret, _, errno := syscall.Syscall6(sysIOUringEnter, + uintptr(fd), uintptr(toSubmit), uintptr(minComplete), uintptr(flags), 0, 0) + if errno != 0 { + return int(ret), errno + } + return int(ret), nil +} + +// submit flushes SQEs and calls io_uring_enter if needed. +// Retries automatically on EINTR (signal interruption). +func (r *IoUring) submit(waitNr uint32) (int, error) { + submitted := r.flushSq() + var flags uint32 = 0 + + // If not using SQPOLL, we always need to enter + if r.flags&iouringSetupSQPoll == 0 { + if waitNr > 0 { + flags |= iouringEnterGetEvents + } + for { + ret, err := ioUringEnter(r.fd, submitted, waitNr, flags) + if err == syscall.EINTR { + continue + } + return ret, err + } + } + + // SQPOLL: only enter if kernel thread needs wakeup + if atomic.LoadUint32(r.sqFlags)&iouringSQNeedWakeup != 0 { + flags |= iouringEnterSQWakeup + } + if waitNr > 0 { + flags |= iouringEnterGetEvents + } + if flags != 0 { + for { + ret, err := ioUringEnter(r.fd, submitted, waitNr, flags) + if err == syscall.EINTR { + continue + } + return ret, err + } + } + return int(submitted), nil +} + +// waitCqe waits for at least one CQE to be available and returns it. +// The caller MUST call SeenCqe after processing. +func (r *IoUring) waitCqe() (*ioUringCqe, error) { + for { + head := atomic.LoadUint32(r.cqHead) + tail := atomic.LoadUint32(r.cqTail) + if head != tail { + cqe := r.getCqeAt(head & r.cqMask) + return cqe, nil + } + // No CQE available, ask the kernel + _, err := ioUringEnter(r.fd, 0, 1, iouringEnterGetEvents) + if err != nil { + if err == syscall.EINTR { + continue // signal interrupted the syscall; retry + } + return nil, err + } + } +} + +// seenCqe advances the CQ head by 1, releasing the CQE slot. +func (r *IoUring) seenCqe() { + atomic.StoreUint32(r.cqHead, atomic.LoadUint32(r.cqHead)+1) +} + +// ----------------------------------------------------------------------- +// PrepRead / PrepWrite helpers +// ----------------------------------------------------------------------- + +func prepRead(sqe *ioUringSqe, fd int, buf []byte, offset uint64) { + if len(buf) == 0 { + sqe.Opcode = iouringOpNop + return + } + sqe.Opcode = iouringOpRead + sqe.Fd = int32(fd) + sqe.Addr = uint64(uintptr(unsafe.Pointer(&buf[0]))) + sqe.Len = uint32(len(buf)) + sqe.Off = offset +} + +func prepWrite(sqe *ioUringSqe, fd int, buf []byte, offset uint64) { + if len(buf) == 0 { + sqe.Opcode = iouringOpNop + return + } + sqe.Opcode = iouringOpWrite + sqe.Fd = int32(fd) + sqe.Addr = uint64(uintptr(unsafe.Pointer(&buf[0]))) + sqe.Len = uint32(len(buf)) + sqe.Off = offset +} + +// ----------------------------------------------------------------------- +// High-level thread-safe API +// ----------------------------------------------------------------------- + +// SubmitRead submits a pread and waits for completion. Thread-safe. +// Returns bytes read or an error. +func (r *IoUring) SubmitRead(fd int, buf []byte, offset uint64) (int, error) { + if len(buf) == 0 { + return 0, nil + } + + r.mu.Lock() + + sqe := r.getSqe() + if sqe == nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring: SQ full, no SQE available") + } + prepRead(sqe, fd, buf, offset) + // Tag the SQE so we can verify the CQE belongs to this request + sqe.UserData = offset + + submitted, err := r.submit(1) + if err != nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring_enter failed: %w", err) + } + + cqe, err := r.waitCqe() + if err != nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring wait cqe: %w", err) + } + + res := cqe.Res + userData := cqe.UserData + cqeFlags := cqe.Flags + r.seenCqe() + r.mu.Unlock() + + if res < 0 { + return 0, fmt.Errorf("io_uring pread errno %d (%s), fd=%d off=%d len=%d submitted=%d ud=%d", + -res, syscall.Errno(-res), fd, offset, len(buf), submitted, userData) + } + + // Diagnostic: if io_uring returned 0 (EOF) or short read, compare with syscall.Pread + if r.debugCount < 20 && int(res) != len(buf) { + r.debugCount++ + pn, perr := syscall.Pread(fd, buf, int64(offset)) + // Also stat the fd to check file size + var stat syscall.Stat_t + fstatErr := syscall.Fstat(fd, &stat) + var fsize int64 + if fstatErr == nil { + fsize = stat.Size + } + fmt.Printf("[io_uring diag] fd=%d off=%d len=%d uring_res=%d uring_ud=%d uring_flags=%d "+ + "submitted=%d pread_n=%d pread_err=%v filesize=%d fstat_err=%v sqeHead=%d sqeTail=%d\n", + fd, offset, len(buf), res, userData, cqeFlags, + submitted, pn, perr, fsize, fstatErr, r.sqeHead, r.sqeTail) + } + + return int(res), nil +} + +// SubmitWriteBatch submits N pwrite operations in a single io_uring_enter call +// and waits for all completions. Thread-safe. +// Returns per-chunk bytes written. On error, partial results may be returned. +func (r *IoUring) SubmitWriteBatch(fd int, bufs [][]byte, offsets []uint64) ([]int, error) { + n := len(bufs) + if n == 0 { + return nil, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Prepare all SQEs + for i := 0; i < n; i++ { + sqe := r.getSqe() + if sqe == nil { + return nil, fmt.Errorf("io_uring: SQ full, need %d slots but ring has %d", n, r.sqEntries) + } + prepWrite(sqe, fd, bufs[i], offsets[i]) + sqe.UserData = uint64(i) + } + + // Submit all at once; kernel waits for all completions + _, err := r.submit(uint32(n)) + if err != nil { + return nil, fmt.Errorf("io_uring_enter: %w", err) + } + + var startTime = time.Now() + + // Drain all CQEs (order may differ from submission) + results := make([]int, n) + for i := 0; i < n; i++ { + cqe, err := r.waitCqe() + if err != nil { + return results, fmt.Errorf("io_uring waitCqe: %w", err) + } + idx := int(cqe.UserData) + res := cqe.Res + r.seenCqe() + + if res < 0 { + return results, fmt.Errorf("io_uring pwrite errno %d (%s), fd=%d off=%d len=%d", + -res, syscall.Errno(-res), fd, offsets[idx], len(bufs[idx])) + } + if idx >= 0 && idx < n { + results[idx] = int(res) + } + + metrics.Timing(metrics.KEY_PWRITE_LATENCY, time.Since(startTime), []string{}) + } + + return results, nil +} + +// SubmitWrite submits a pwrite and waits for completion. Thread-safe. +// Returns bytes written or an error. +func (r *IoUring) SubmitWrite(fd int, buf []byte, offset uint64) (int, error) { + if len(buf) == 0 { + return 0, nil + } + + r.mu.Lock() + + sqe := r.getSqe() + if sqe == nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring: SQ full, no SQE available") + } + prepWrite(sqe, fd, buf, offset) + + _, err := r.submit(1) + if err != nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring_enter failed: %w", err) + } + + cqe, err := r.waitCqe() + if err != nil { + r.mu.Unlock() + return 0, fmt.Errorf("io_uring wait cqe: %w", err) + } + + res := cqe.Res + r.seenCqe() + r.mu.Unlock() + + if res < 0 { + return 0, fmt.Errorf("io_uring pwrite failed: errno %d (%s)", -res, syscall.Errno(-res)) + } + return int(res), nil +} diff --git a/flashring/internal/iouring/iouring_reader.go b/flashring/internal/iouring/iouring_reader.go new file mode 100644 index 00000000..06594cde --- /dev/null +++ b/flashring/internal/iouring/iouring_reader.go @@ -0,0 +1,384 @@ +//go:build linux +// +build linux + +package iouring + +import ( + "fmt" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" +) + +// ReadResult holds the outcome of a single io_uring pread. +type ReadResult struct { + N int + Err error +} + +// batchReadRequest is a pread submitted to the batch reader. +type batchReadRequest struct { + fd int + buf []byte + offset uint64 + done chan ReadResult +} + +var batchReqPool = sync.Pool{ + New: func() interface{} { + return &batchReadRequest{ + done: make(chan ReadResult, 1), + } + }, +} + +// sentinelUserData is stored in the NOP SQE submitted during Close to unblock +// the completion goroutine. +const sentinelUserData = ^uint64(0) + +// BatchIoUringReader collects pread requests and submits them as io_uring +// batches. Submission and completion run in separate goroutines so the submit +// path is never blocked by CQE draining: +// +// submitLoop: reqCh → collect batch → prep SQEs → io_uring_enter → loop +// completeLoop: waitCqe → dispatch result to caller → loop +// +// This eliminates the head-of-batch queueing delay where new requests had to +// wait for the entire previous batch's CQE drain before being submitted. +type BatchIoUringReader struct { + ring *IoUring + reqCh chan *batchReadRequest + maxBatch int + closeCh chan struct{} + wg sync.WaitGroup + + // In-flight tracking: each SQE gets a slot index as its UserData. + // The submit goroutine stores the request; the completion goroutine + // reads it back when the CQE arrives. + inflight []atomic.Pointer[batchReadRequest] + freeSlots chan uint32 // pool of available slot indices + pending atomic.Int32 // number of SQEs currently in-flight +} + +// BatchIoUringConfig configures the batch reader. +type BatchIoUringConfig struct { + RingDepth uint32 // io_uring SQ/CQ size (default 256) + MaxBatch int // max requests per batch (auto-capped to MaxInflight) + MaxInflight int // max SQEs in-flight across all batches; controls NVMe queue depth (default RingDepth) + Window time.Duration // unused in decoupled mode; kept for API compatibility + QueueSize int // channel buffer size (default 1024) + SQPoll bool // use IORING_SETUP_SQPOLL; kernel polls SQ, eliminating submit syscalls under load +} + +// NewBatchIoUringReader creates a batch reader with its own io_uring ring +// and starts the submit + completion goroutines. +func NewBatchIoUringReader(cfg BatchIoUringConfig) (*BatchIoUringReader, error) { + if cfg.RingDepth == 0 { + cfg.RingDepth = 256 + } + ringDepth := int(cfg.RingDepth) + + maxInflight := cfg.MaxInflight + if maxInflight <= 0 || maxInflight > ringDepth { + maxInflight = ringDepth + } + if cfg.MaxBatch <= 0 || cfg.MaxBatch > maxInflight { + cfg.MaxBatch = maxInflight + } + if cfg.QueueSize == 0 { + cfg.QueueSize = 1024 + } + + var flags uint32 + if cfg.SQPoll { + flags = iouringSetupSQPoll + } + + ring, err := NewIoUring(cfg.RingDepth, flags) + if err != nil && cfg.SQPoll { + // SQPOLL may fail without CAP_SYS_NICE on kernels < 5.13; fall back. + ring, err = NewIoUring(cfg.RingDepth, 0) + } + if err != nil { + return nil, fmt.Errorf("batch io_uring init: %w", err) + } + + freeSlots := make(chan uint32, maxInflight) + for i := 0; i < maxInflight; i++ { + freeSlots <- uint32(i) + } + + b := &BatchIoUringReader{ + ring: ring, + reqCh: make(chan *batchReadRequest, cfg.QueueSize), + maxBatch: cfg.MaxBatch, + closeCh: make(chan struct{}), + inflight: make([]atomic.Pointer[batchReadRequest], ringDepth), + freeSlots: freeSlots, + } + b.wg.Add(2) + go b.submitLoop() + go b.completeLoop() + return b, nil +} + +// Submit sends a pread request into the batch channel and blocks until the +// io_uring completion is received. Thread-safe; called from many goroutines. +func (b *BatchIoUringReader) Submit(fd int, buf []byte, offset uint64) (int, error) { + if len(buf) == 0 { + return 0, nil + } + + var startTime = time.Now() + + req := batchReqPool.Get().(*batchReadRequest) + req.fd = fd + req.buf = buf + req.offset = offset + + b.reqCh <- req + + result := <-req.done + n, err := result.N, result.Err + metrics.Timing(metrics.KEY_PREAD_LATENCY, time.Since(startTime), []string{}) + metrics.Incr(metrics.KEY_PREAD_COUNT, []string{}) + // Reset and return to pool + req.fd = 0 + req.buf = nil + req.offset = 0 + batchReqPool.Put(req) + + return n, err +} + +// SubmitAsync enqueues a pread request and returns immediately with a channel +// that will receive the result when the io_uring completion arrives. +// Unlike Submit, it does not block the caller. Thread-safe. +// The caller must read exactly once from the returned channel. +func (b *BatchIoUringReader) SubmitAsync(fd int, buf []byte, offset uint64) <-chan ReadResult { + if len(buf) == 0 { + ch := make(chan ReadResult, 1) + ch <- ReadResult{} + return ch + } + req := &batchReadRequest{ + fd: fd, + buf: buf, + offset: offset, + done: make(chan ReadResult, 1), + } + b.reqCh <- req + return req.done +} + +// Close shuts down both goroutines and releases the io_uring ring. +func (b *BatchIoUringReader) Close() { + close(b.closeCh) + + // Submit a NOP with sentinel UserData to unblock the completion + // goroutine if it is blocked in waitCqe with no pending I/O. + b.ring.mu.Lock() + sqe := b.ring.getSqe() + if sqe != nil { + sqe.Opcode = iouringOpNop + sqe.UserData = sentinelUserData + b.ring.submit(0) + } + b.ring.mu.Unlock() + + b.wg.Wait() + b.ring.Close() +} + +// submitLoop collects requests from reqCh and submits them as io_uring SQEs. +// It never waits for completions — that happens in completeLoop. The ring +// mutex is held only during SQE preparation + io_uring_enter (~1-5μs). +func (b *BatchIoUringReader) submitLoop() { + defer b.wg.Done() + + batch := make([]*batchReadRequest, 0, b.maxBatch) + slots := make([]uint32, 0, b.maxBatch) + + for { + // Block until the first request arrives. + select { + case req := <-b.reqCh: + batch = append(batch, req) + case <-b.closeCh: + return + } + + // Non-blocking drain of whatever else is already queued. + for len(batch) < b.maxBatch { + select { + case req := <-b.reqCh: + batch = append(batch, req) + default: + goto submit + } + } + + submit: + // Acquire a free slot for each request. Under normal load (~30 + // in-flight out of 256 slots) this never blocks. + for i, req := range batch { + select { + case slot := <-b.freeSlots: + slots = append(slots, slot) + b.inflight[slot].Store(req) + case <-b.closeCh: + for j := i; j < len(batch); j++ { + batch[j].done <- ReadResult{Err: fmt.Errorf("io_uring: shutting down")} + } + return + } + } + + metrics.Timing(metrics.KEY_IOURING_SIZE, time.Duration(len(batch))*time.Millisecond, []string{}) + + b.ring.mu.Lock() + + prepared := 0 + for i, slot := range slots { + sqe := b.ring.getSqe() + if sqe == nil { + for j := i; j < len(slots); j++ { + req := b.inflight[slots[j]].Swap(nil) + b.freeSlots <- slots[j] + if req != nil { + req.done <- ReadResult{ + Err: fmt.Errorf("io_uring: SQ full, batch=%d depth=%d", len(batch), b.ring.sqEntries), + } + } + } + break + } + prepRead(sqe, batch[i].fd, batch[i].buf, batch[i].offset) + sqe.UserData = uint64(slot) + prepared++ + } + + if prepared > 0 { + b.pending.Add(int32(prepared)) + _, err := b.ring.submit(0) + if err != nil { + b.pending.Add(-int32(prepared)) + for i := 0; i < prepared; i++ { + req := b.inflight[slots[i]].Swap(nil) + b.freeSlots <- slots[i] + if req != nil { + req.done <- ReadResult{Err: fmt.Errorf("io_uring_enter: %w", err)} + } + } + } + } + + b.ring.mu.Unlock() + + batch = batch[:0] + slots = slots[:0] + } +} + +// completeLoop continuously drains CQEs and dispatches results to callers. +// It runs independently of submitLoop — the ring's SQ and CQ are separate +// data structures, so no mutex is needed for CQ access (single consumer). +func (b *BatchIoUringReader) completeLoop() { + defer b.wg.Done() + + for { + cqe, err := b.ring.waitCqe() + if err != nil { + select { + case <-b.closeCh: + if b.pending.Load() <= 0 { + return + } + default: + } + continue + } + + userData := cqe.UserData + res := cqe.Res + b.ring.seenCqe() + + // Shutdown NOP — exit once all real I/O has been drained. + if userData == sentinelUserData { + if b.pending.Load() <= 0 { + return + } + continue + } + + slot := uint32(userData) + b.pending.Add(-1) + + req := b.inflight[slot].Swap(nil) + b.freeSlots <- slot + + if req == nil { + continue + } + + if res < 0 { + req.done <- ReadResult{ + Err: fmt.Errorf("io_uring pread errno %d (%s), fd=%d off=%d len=%d", + -res, syscall.Errno(-res), req.fd, req.offset, len(req.buf)), + } + } else { + req.done <- ReadResult{N: int(res)} + } + } +} + +// ParallelBatchIoUringReader distributes pread requests across N independent +// BatchIoUringReader instances (each with its own io_uring ring and goroutines) +// using round-robin. +type ParallelBatchIoUringReader struct { + readers []*BatchIoUringReader + next atomic.Uint64 +} + +// NewParallelBatchIoUringReader creates numRings independent batch readers. +// Each ring gets its own io_uring instance and background goroutines. +func NewParallelBatchIoUringReader(cfg BatchIoUringConfig, numRings int) (*ParallelBatchIoUringReader, error) { + if numRings <= 0 { + numRings = 1 + } + readers := make([]*BatchIoUringReader, numRings) + for i := 0; i < numRings; i++ { + r, err := NewBatchIoUringReader(cfg) + if err != nil { + for j := 0; j < i; j++ { + readers[j].Close() + } + return nil, fmt.Errorf("parallel batch reader ring %d: %w", i, err) + } + readers[i] = r + } + return &ParallelBatchIoUringReader{readers: readers}, nil +} + +// Submit routes the pread to the next ring via round-robin. Thread-safe. +func (p *ParallelBatchIoUringReader) Submit(fd int, buf []byte, offset uint64) (int, error) { + idx := p.next.Add(1) % uint64(len(p.readers)) + return p.readers[idx].Submit(fd, buf, offset) +} + +// SubmitAsync routes the pread to the next ring via round-robin and returns +// immediately with a channel for the result. Thread-safe. +func (p *ParallelBatchIoUringReader) SubmitAsync(fd int, buf []byte, offset uint64) <-chan ReadResult { + idx := p.next.Add(1) % uint64(len(p.readers)) + return p.readers[idx].SubmitAsync(fd, buf, offset) +} + +// Close shuts down all underlying batch readers. +func (p *ParallelBatchIoUringReader) Close() { + for _, r := range p.readers { + r.Close() + } +} diff --git a/flashring/internal/iouring/iouring_test.go b/flashring/internal/iouring/iouring_test.go new file mode 100644 index 00000000..403f9d1e --- /dev/null +++ b/flashring/internal/iouring/iouring_test.go @@ -0,0 +1,103 @@ +//go:build linux +// +build linux + +package iouring + +import ( + "os" + "syscall" + "testing" + "unsafe" +) + +func TestIoUringBasicRead(t *testing.T) { + // 1. Create a temp file with known data + f, err := os.CreateTemp("", "iouring_test_*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + data := make([]byte, 4096) + for i := range data { + data[i] = byte(i % 251) // non-zero pattern + } + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + if err := f.Sync(); err != nil { + t.Fatal(err) + } + f.Close() + + // 2. Open with O_DIRECT | O_RDONLY + fd, err := syscall.Open(f.Name(), syscall.O_RDONLY|syscall.O_DIRECT, 0) + if err != nil { + t.Fatalf("open O_DIRECT: %v", err) + } + defer syscall.Close(fd) + + // 3. Create io_uring ring + ring, err := NewIoUring(32, 0) + if err != nil { + t.Fatalf("NewIoUring: %v", err) + } + defer ring.Close() + + // 4. Allocate aligned buffer + buf := alignedBlock(4096, 4096) + + // 5. Submit read via io_uring + n, err := ring.SubmitRead(fd, buf, 0) + if err != nil { + t.Fatalf("SubmitRead: %v", err) + } + if n != 4096 { + t.Fatalf("SubmitRead returned %d bytes, expected 4096", n) + } + + // 6. Verify data + for i := 0; i < 4096; i++ { + if buf[i] != data[i] { + t.Fatalf("data mismatch at byte %d: got %d, want %d", i, buf[i], data[i]) + } + } + t.Logf("io_uring read of 4096 bytes succeeded and data matches") + + // 7. Test a second read (to verify ring reuse works) + buf2 := alignedBlock(4096, 4096) + n2, err := ring.SubmitRead(fd, buf2, 0) + if err != nil { + t.Fatalf("SubmitRead #2: %v", err) + } + if n2 != 4096 { + t.Fatalf("SubmitRead #2 returned %d bytes, expected 4096", n2) + } + for i := 0; i < 4096; i++ { + if buf2[i] != data[i] { + t.Fatalf("data mismatch #2 at byte %d: got %d, want %d", i, buf2[i], data[i]) + } + } + t.Logf("io_uring second read also succeeded") + + // 8. Test multiple sequential reads to exercise ring cycling + for iter := 0; iter < 100; iter++ { + buf3 := alignedBlock(4096, 4096) + n3, err := ring.SubmitRead(fd, buf3, 0) + if err != nil { + t.Fatalf("SubmitRead iter %d: %v", iter, err) + } + if n3 != 4096 { + t.Fatalf("SubmitRead iter %d returned %d bytes, expected 4096", iter, n3) + } + } + t.Logf("100 sequential io_uring reads succeeded") +} + +// alignedBlock returns a block-aligned buffer. +func alignedBlock(size, alignment int) []byte { + raw := make([]byte, size+alignment) + addr := uintptr(unsafe.Pointer(&raw[0])) + off := (alignment - int(addr%uintptr(alignment))) % alignment + return raw[off : off+size] +} diff --git a/flashring/internal/iouring/iouring_writer.go b/flashring/internal/iouring/iouring_writer.go new file mode 100644 index 00000000..d9f7392b --- /dev/null +++ b/flashring/internal/iouring/iouring_writer.go @@ -0,0 +1,43 @@ +//go:build linux +// +build linux + +package iouring + +// IoUringWriter wraps a BatchIoUringWriter with decoupled submit/complete +// goroutines. The ring mutex is held only during SQE prep + io_uring_enter, +// not during CQE drain, allowing concurrent flush batches from different +// shards to interleave submission. +type IoUringWriter struct { + batch *BatchIoUringWriter +} + +// NewIoUringWriter creates an IoUringWriter backed by a decoupled batch writer. +func NewIoUringWriter(entries uint32, flags uint32) (*IoUringWriter, error) { + b, err := NewBatchIoUringWriter(BatchIoUringConfig{ + RingDepth: entries, + MaxBatch: int(entries), + MaxInflight: int(entries), + QueueSize: 1024, + }) + if err != nil { + return nil, err + } + return &IoUringWriter{batch: b}, nil +} + +// MaxBatchSize returns the maximum number of SQEs that can be submitted in +// a single SubmitWriteBatch call. +func (w *IoUringWriter) MaxBatchSize() int { + return w.batch.MaxBatchSize() +} + +// SubmitWriteBatch submits N pwrite operations and waits for all completions. +// Thread-safe. The ring mutex is NOT held during CQE drain. +func (w *IoUringWriter) SubmitWriteBatch(fd int, bufs [][]byte, offsets []uint64) ([]int, error) { + return w.batch.SubmitWriteBatch(fd, bufs, offsets) +} + +// Close releases the underlying io_uring ring and stops background goroutines. +func (w *IoUringWriter) Close() { + w.batch.Close() +} diff --git a/flashring/internal/maths/estimator.go b/flashring/internal/maths/estimator.go new file mode 100644 index 00000000..154298e1 --- /dev/null +++ b/flashring/internal/maths/estimator.go @@ -0,0 +1,185 @@ +// Package estimator implements online adaptive grid search for tuning +// weights (wFreq, wLA) to optimize cache rewrite decisions based on hit ratio. +package maths + +import ( + "math" + "time" + + "github.com/rs/zerolog/log" +) + +const ( + missBaseline = float64(1e-9) +) + +type WeightTuple struct { + WFreq float64 + WLA float64 +} + +type Stats struct { + HitRate float64 // averaged hit rate over time window + Trials int +} + +type GridSearchEstimator struct { + Tuples []WeightTuple + InitialTuples []WeightTuple + bestTuple WeightTuple + TupleStats map[WeightTuple]*Stats + CurrIndex int + StartTime time.Time + Duration time.Duration + LiveEstimator *Estimator + stopGridSearch bool + bestHitRate float64 + epsilon float64 +} + +type Estimator struct { + WFreq float64 + WLA float64 +} + +func NewGridSearchEstimator(duration time.Duration, initialTuples []WeightTuple, estimator *Estimator, epsilon float64) *GridSearchEstimator { + return &GridSearchEstimator{ + Tuples: initialTuples, + InitialTuples: initialTuples, + bestTuple: initialTuples[0], + TupleStats: make(map[WeightTuple]*Stats), + CurrIndex: 0, + StartTime: time.Now(), + Duration: duration, + LiveEstimator: estimator, + bestHitRate: 0, + stopGridSearch: false, + epsilon: epsilon, + } +} + +func (e *Estimator) CalculateRewriteScore(freq uint64, lastAccess uint64, keyMemId, activeMemId, maxMemTableCount uint32) float32 { + overWriteRisk := (activeMemId - keyMemId + maxMemTableCount) % maxMemTableCount + overWriteRiskScore := float32(overWriteRisk) / float32(maxMemTableCount) + + fScore := 1 - math.Exp(-e.WFreq*float64(freq)) + laScore := math.Exp(-e.WLA * float64(lastAccess)) + return float32(fScore+laScore) * overWriteRiskScore +} + +func (g *GridSearchEstimator) RecordHitRate(hitRate float64) { + if g.stopGridSearch { + tuple := g.bestTuple + if _, ok := g.TupleStats[tuple]; !ok { + g.TupleStats[tuple] = &Stats{} + } + stat := g.TupleStats[tuple] + stat.HitRate = (stat.HitRate*float64(stat.Trials) + hitRate) / float64(stat.Trials+1) + stat.Trials++ + if stat.HitRate < g.bestHitRate*0.9 { + log.Error().Msgf("GridSearchRestarted: hitRate %v bestHitRate %v", stat.HitRate, g.bestHitRate) + g.RestartGridSearch() + } + return + } + tuple := g.Tuples[g.CurrIndex] + if _, ok := g.TupleStats[tuple]; !ok { + g.TupleStats[tuple] = &Stats{} + } + stat := g.TupleStats[tuple] + stat.HitRate = (stat.HitRate*float64(stat.Trials) + hitRate) / float64(stat.Trials+1) + stat.Trials++ + + if time.Since(g.StartTime) < g.Duration { + return + } + // Advance to next tuple + g.CurrIndex = (g.CurrIndex + 1) % len(g.Tuples) + if g.CurrIndex == 0 { + ok := g.RefineGridAroundBest(2, 0.001) + if !ok { + g.stopGridSearch = true + return + } + } + g.StartTime = time.Now() + + // Update live estimator + next := g.Tuples[g.CurrIndex] + g.LiveEstimator.WFreq = next.WFreq + g.LiveEstimator.WLA = next.WLA +} + +func (g *GridSearchEstimator) BestTuple() WeightTuple { + + best := WeightTuple{} + bestScore := -1.0 + + for _, tup := range g.Tuples { + stat := g.TupleStats[tup] + if stat == nil || stat.Trials < 3 { + continue + } + if stat.HitRate > bestScore { + bestScore = stat.HitRate + best = tup + } + } + + return best +} + +func (g *GridSearchEstimator) GenerateRefinedGrid(base WeightTuple, steps int, delta float64) ([]WeightTuple, bool) { + refined := make([]WeightTuple, 0, (2*steps+1)*(2*steps+1)) + for i := -steps; i <= steps; i++ { + for j := -steps; j <= steps; j++ { + + if i == 0 && j == 0 { + continue + } + wf := base.WFreq + float64(i)*delta + la := base.WLA + float64(j)*delta + if math.Abs(wf-base.WFreq) < g.epsilon && math.Abs(la-base.WLA) < g.epsilon { + return refined, false + } + if wf > 0 && la > 0 { + refined = append(refined, WeightTuple{wf, la}) + } + } + } + return refined, true +} + +func (g *GridSearchEstimator) RefineGridAroundBest(steps int, delta float64) bool { + best := g.BestTuple() + refined, ok := g.GenerateRefinedGrid(best, steps, delta) + if !ok { + g.LiveEstimator.WFreq = best.WFreq + g.LiveEstimator.WLA = best.WLA + g.bestHitRate = g.TupleStats[best].HitRate + g.bestTuple = best + return false + } + g.Tuples = refined + g.CurrIndex = 0 + g.TupleStats = make(map[WeightTuple]*Stats) + g.LiveEstimator.WFreq = g.Tuples[0].WFreq + g.LiveEstimator.WLA = g.Tuples[0].WLA + g.StartTime = time.Now() + return true +} + +func (g *GridSearchEstimator) RestartGridSearch() { + g.stopGridSearch = false + g.Tuples = g.InitialTuples + g.CurrIndex = 0 + g.TupleStats = make(map[WeightTuple]*Stats) + g.LiveEstimator.WFreq = g.Tuples[0].WFreq + g.LiveEstimator.WLA = g.Tuples[0].WLA + g.StartTime = time.Now() + g.bestHitRate = 0 +} + +func (g *GridSearchEstimator) IsGridSearchActive() bool { + return !g.stopGridSearch +} diff --git a/flashring/internal/maths/freq.go b/flashring/internal/maths/freq.go new file mode 100644 index 00000000..3e554e6b --- /dev/null +++ b/flashring/internal/maths/freq.go @@ -0,0 +1,142 @@ +// freq.go +package maths + +/* +Package maths implements a binary Morris-style probabilistic counter +compressed into a single uint16. + +------------------------------------------------------------------------ +How the algorithm works +------------------------------------------------------------------------ + + 1. Layout (16 bits) + + ┌─ exponent (4 bits) ─┬─ mantissa (12 bits) ─┐ + │ e (0–15) │ m (0–4095) │ + └──────────────────────┴───────────────────────┘ + + The counter encodes the approximate value: m × 2ᵉ (equivalently m << e). + + 2. Increment rule + + On each key access, the counter is incremented probabilistically: + - Probability of increment = 1 / 2ᵉ. + - The caller supplies an external hash (hlo). We compare its lower + 32 bits against a precomputed threshold: th[e] = (2³² - 1) >> e. + - If uint32(hlo) < th[e] → hit → mantissa advances (m++). + - If uint32(hlo) >= th[e] → miss → counter unchanged. + + + 3. Mantissa overflow + + When m reaches 4096 (overflows 12 bits), we halve the mantissa + and bump the exponent: m = 2048, e++. + + This preserves the approximate decoded value across the transition: + Before: m=4095, e=0 → Value = 4095 × 1 = 4095 + After: m=2048, e=1 → Value = 2048 × 2 = 4096 + + At max exponent (e=15), the counter saturates at m=4095, e=15 + (decoded value = 4095 × 32768 = 134,184,960). + + 4. Decoding + + Value = m << e + + Examples: + Encoded (e=0, m=42) → 42 << 0 = 42 (exact) + Encoded (e=0, m=4000) → 4000 << 0 = 4000 (exact) + Encoded (e=1, m=2048) → 2048 << 1 = 4096 (step size = 2) + Encoded (e=2, m=2500) → 2500 << 2 = 10000 (step size = 4) + + 5. Resolution + + At exponent e, the step between consecutive representable values is 2ᵉ: + e=0: step 1, exact integers 0 – 4,095 + e=1: step 2, even numbers 4,096 – 8,190 + e=2: step 4 8,192 – 16,380 + ... + e=15: step 32,768 up to ~134 million + + For cache frequency tracking, most keys stay in e=0 (exact counts up + to 4095) or e=1 (step of 2), giving ~6,600 distinct values under 10K + compared to ~37 with the previous base-10 design. + + 6. Complexity & footprint + + State per key: 2 bytes (uint16), stored inline in the index entry. + Increment: 1 compare + a few bit-ops, no floating-point. + Thresholds: precomputed once in New() (16 entries). +*/ + +// 12-bit mantissa (0–4095). 4-bit exponent (0–15). +const ( + mBits = 12 + mMask = (1 << mBits) - 1 // 0x0FFF + eShift = mBits + mOverflow = 1 << mBits // 4096 +) + +type MorrisLogCounter struct { + th []uint32 // th[e] = (2^32 - 1) >> e; increment probability = 1/2^e + pow2 []uint64 // pow2[e] = 2^e; used for decoding + expClamp uint32 // maximum exponent, fixed at 15 +} + +// New creates a MorrisLogCounter with precomputed threshold and power tables. +// The 4-bit exponent field supports exponents 0–15. +func New() *MorrisLogCounter { + const maxExp = 15 + + th := make([]uint32, maxExp+1) + pow2 := make([]uint64, maxExp+1) + + max32 := uint64(^uint32(0)) // 2^32 - 1 + + for e := 0; e <= maxExp; e++ { + th[e] = uint32(max32 >> e) + pow2[e] = 1 << uint(e) + } + + return &MorrisLogCounter{ + th: th, + pow2: pow2, + expClamp: maxExp, + } +} + +// Inc probabilistically increments the counter. hlo is the lower 64 bits +// of the key's hash, used as the randomness source for the Bernoulli trial. +// Returns the (possibly updated) counter and whether an increment occurred. +func (c *MorrisLogCounter) Inc(v uint16, hlo uint64) (uint16, bool) { + m := v & mMask + e := v >> eShift + + // Bernoulli trial: increment with probability 1/2^e. + // At e=0 this is ~100% (th[0] = 0xFFFFFFFF). + // At e=1 this is ~50%, at e=2 ~25%, etc. + if uint32(hlo) >= c.th[e] { + return v, false + } + + m++ + if m == mOverflow { + // Mantissa overflowed 12 bits. Halve mantissa and bump exponent + // to keep the decoded value approximately continuous. + if e < 15 { + m = m >> 1 // 4096 → 2048 + e++ + } else { + // Saturate: can't increase exponent further. + m = mOverflow - 1 // clamp at 4095 + } + } + return (e << eShift) | (m & mMask), true +} + +// Value decodes the counter into an approximate frequency: m × 2^e. +func (c *MorrisLogCounter) Value(v uint16) uint64 { + m := uint64(v & mMask) + e := v >> eShift + return m << e +} diff --git a/flashring/internal/maths/freq_test.go b/flashring/internal/maths/freq_test.go new file mode 100644 index 00000000..67101765 --- /dev/null +++ b/flashring/internal/maths/freq_test.go @@ -0,0 +1,308 @@ +package maths + +import ( + "testing" +) + +func TestNew(t *testing.T) { + counter := New() + if counter == nil { + t.Fatal("New() returned nil") + } + if counter.expClamp != 15 { + t.Errorf("expClamp = %v, want 15", counter.expClamp) + } + if len(counter.th) != 16 { + t.Errorf("threshold table length = %v, want 16", len(counter.th)) + } + if len(counter.pow2) != 16 { + t.Errorf("pow2 table length = %v, want 16", len(counter.pow2)) + } +} + +func TestPow2Table(t *testing.T) { + counter := New() + + expected := []uint64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} + for i, exp := range expected { + if counter.pow2[i] != exp { + t.Errorf("pow2[%d] = %v, want %v", i, counter.pow2[i], exp) + } + } +} + +func TestThresholdTable(t *testing.T) { + counter := New() + + max32 := uint64(^uint32(0)) + + for e := uint32(0); e <= 15; e++ { + expected := uint32(max32 >> e) + if counter.th[e] != expected { + t.Errorf("th[%d] = %v, want %v", e, counter.th[e], expected) + } + } +} + +func TestValue(t *testing.T) { + counter := New() + + tests := []struct { + name string + v uint16 + expected uint64 + }{ + { + name: "mantissa 0, exponent 0", + v: 0, + expected: 0, + }, + { + name: "mantissa 5, exponent 0", + v: 5, + expected: 5, // 5 << 0 + }, + { + name: "mantissa 3, exponent 1", + v: (1 << eShift) | 3, + expected: 6, // 3 << 1 + }, + { + name: "mantissa 100, exponent 2", + v: (2 << eShift) | 100, + expected: 400, // 100 << 2 + }, + { + name: "mantissa 4095, exponent 0", + v: 4095, + expected: 4095, // 4095 << 0 + }, + { + name: "mantissa 2048, exponent 1", + v: (1 << eShift) | 2048, + expected: 4096, // 2048 << 1 + }, + { + name: "mantissa 2048, exponent 15", + v: (15 << eShift) | 2048, + expected: 2048 << 15, // 67108864 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := counter.Value(tt.v) + if result != tt.expected { + t.Errorf("Value(%v) = %v, want %v", tt.v, result, tt.expected) + } + }) + } +} + +func TestIncBasicBehavior(t *testing.T) { + counter := New() + + // With e=0, th[0] = 0xFFFFFFFF, so any hlo will hit (uint32(hlo) < th[0]) + v := uint16(5) // m=5, e=0 + newV, hit := counter.Inc(v, 0) + + if !hit { + t.Error("Inc() should always hit at e=0") + } + + expectedV := uint16(6) // m=6, e=0 + if newV != expectedV { + t.Errorf("Inc(%v) = %v, want %v", v, newV, expectedV) + } +} + +func TestIncMantissaOverflow(t *testing.T) { + counter := New() + + // m=4095 (mOverflow-1), e=0 -> increment should cause overflow + v := uint16(mOverflow - 1) // m=4095, e=0 + newV, hit := counter.Inc(v, 0) + + if !hit { + t.Error("Inc() should always hit at e=0") + } + + // On overflow: m becomes 4096>>1 = 2048, e becomes 1 + expectedM := uint16(mOverflow >> 1) // 2048 + expectedE := uint16(1) + expectedV := (expectedE << eShift) | expectedM + + if newV != expectedV { + t.Errorf("Inc(%v) = %v, want %v (m=2048, e=1)", v, newV, expectedV) + } + + // Verify the decoded value is reasonable + // Before: Value(4095) = 4095 << 0 = 4095 + // After: Value(newV) = 2048 << 1 = 4096 + valBefore := counter.Value(v) + valAfter := counter.Value(newV) + if valAfter <= valBefore { + t.Errorf("Value should increase after overflow: before=%v, after=%v", valBefore, valAfter) + } +} + +func TestIncExponentSaturation(t *testing.T) { + counter := New() + + // m=4095, e=15 (max exponent) -> should saturate + v := (uint16(15) << eShift) | uint16(mOverflow-1) // m=4095, e=15 + newV, hit := counter.Inc(v, 0) + + if !hit { + t.Error("Inc() should hit") + } + + // Should saturate: m stays at 4095, e stays at 15 + if newV != v { + t.Errorf("Inc(%v) = %v, want %v (saturated at max)", v, newV, v) + } +} + +func TestIncMissBehavior(t *testing.T) { + counter := New() + + // At e=1, th[1] = 0xFFFFFFFF >> 1 = 0x7FFFFFFF + // hlo with uint32 >= 0x7FFFFFFF should miss + v := (uint16(1) << eShift) | 5 // m=5, e=1 + hlo := uint64(0xFFFFFFFF) // uint32(hlo) = 0xFFFFFFFF >= th[1] + + newV, hit := counter.Inc(v, hlo) + + if hit { + t.Error("Inc() should miss when uint32(hlo) >= th[e]") + } + if newV != v { + t.Errorf("Inc() on miss should return original value: got %v, want %v", newV, v) + } +} + +func TestIncStatisticalBehavior(t *testing.T) { + if testing.Short() { + t.Skip("skipping statistical test in short mode") + } + + counter := New() + + // Test with e=0 (should hit ~100% of the time since th[0] = 0xFFFFFFFF) + v := uint16(5) + hits := 0 + trials := 1000 + + for i := 0; i < trials; i++ { + _, hit := counter.Inc(v, uint64(i)) + if hit { + hits++ + } + } + + hitRate := float64(hits) / float64(trials) + if hitRate < 0.99 { + t.Errorf("Hit rate for e=0 = %v, want ~1.0", hitRate) + } + + // Test with e=1 (should hit approximately 50% of the time) + // th[1] = 0x7FFFFFFF, so uint32(hlo) < th[1] means lower half hits + v = (1 << eShift) | 5 + hits = 0 + + for i := 0; i < trials; i++ { + // Use Knuth multiplicative hash to spread uint32 values evenly + hlo := uint64(uint32(i) * 2654435761) + _, hit := counter.Inc(v, hlo) + if hit { + hits++ + } + } + + hitRate = float64(hits) / float64(trials) + if hitRate < 0.35 || hitRate > 0.65 { + t.Errorf("Hit rate for e=1 = %v, want ~0.50 (0.35-0.65)", hitRate) + } +} + +func TestIntegrationCountingApproximation(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + counter := New() + + v := uint16(0) + totalEvents := 100000 + + for i := 0; i < totalEvents; i++ { + newV, hit := counter.Inc(v, uint64(i*2654435761)) // Knuth multiplicative hash for spread + if hit { + v = newV + } + } + + approxCount := counter.Value(v) + + // The approximation should be in the right ballpark + ratio := float64(approxCount) / float64(totalEvents) + if ratio < 0.1 || ratio > 10.0 { + t.Errorf("Approximation ratio = %v, totalEvents = %v, approxCount = %v", + ratio, totalEvents, approxCount) + } +} + +func TestBitPacking(t *testing.T) { + counter := New() + + tests := []struct { + mantissa uint16 + exponent uint16 + }{ + {0, 0}, + {4095, 0}, + {0, 15}, + {2048, 3}, + {100, 7}, + } + + for _, tt := range tests { + v := (tt.exponent << eShift) | (tt.mantissa & mMask) + + extractedM := v & mMask + extractedE := v >> eShift + + if extractedM != tt.mantissa&mMask { + t.Errorf("Mantissa packing: got %v, want %v", extractedM, tt.mantissa&mMask) + } + if extractedE != tt.exponent { + t.Errorf("Exponent packing: got %v, want %v", extractedE, tt.exponent) + } + + decoded := counter.Value(v) + expected := uint64(tt.mantissa&mMask) << tt.exponent + if decoded != expected { + t.Errorf("Value() = %v, want %v (m=%v, e=%v)", decoded, expected, tt.mantissa, tt.exponent) + } + } +} + +func BenchmarkInc(b *testing.B) { + counter := New() + v := uint16(123) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + v, _ = counter.Inc(v, uint64(i)) + } +} + +func BenchmarkValue(b *testing.B) { + counter := New() + v := uint16(123) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = counter.Value(v) + } +} diff --git a/flashring/internal/maths/predictor.go b/flashring/internal/maths/predictor.go new file mode 100644 index 00000000..86b70279 --- /dev/null +++ b/flashring/internal/maths/predictor.go @@ -0,0 +1,177 @@ +package maths + +import ( + "time" + + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" +) + +type Params struct { + Freq uint64 + LastAccess uint64 + KeyMemId uint32 + ActiveMemId uint32 +} +type Predictor struct { + Estimator *Estimator + GridSearchEstimator *GridSearchEstimator + ReWriteScoreThreshold float32 + MaxMemTableCount uint32 + freqBands FreqBands + recencyBands RecencyBands + hitRateCh chan float64 +} + +// FreqBands defines the upper bounds for frequency band labels. +// Keys with freq <= Cold are "cold", <= Warm are "warm", <= Hot are "hot", +// and anything above Hot is "very_hot". +type FreqBands struct { + Cold uint64 + Warm uint64 + Hot uint64 +} + +// RecencyBands defines upper-bound thresholds for recency band labels. +// lastAccess represents how long ago a key was accessed (higher = older). +// Keys with lastAccess <= Hot are "very_hot", <= Warm are "hot", +// <= Cold are "warm", and anything above Cold is "cold". +type RecencyBands struct { + Hot uint64 + Warm uint64 + Cold uint64 +} + +type PredictorConfig struct { + ReWriteScoreThreshold float32 + Weights []WeightTuple + SampleDuration time.Duration + MaxMemTableCount uint32 + GridSearchEpsilon float64 + FreqBands FreqBands + RecencyBands RecencyBands +} + +func NewPredictor(config PredictorConfig) *Predictor { + estimator := &Estimator{ + WFreq: config.Weights[0].WFreq, + WLA: config.Weights[0].WLA, + } + gridSearchEstimator := NewGridSearchEstimator(config.SampleDuration, config.Weights, estimator, config.GridSearchEpsilon) + fb := config.FreqBands + if fb.Cold == 0 && fb.Warm == 0 && fb.Hot == 0 { + fb = FreqBands{Cold: 1, Warm: 5, Hot: 20} + } + rb := config.RecencyBands + if rb.Hot == 0 && rb.Warm == 0 && rb.Cold == 0 { + rb = RecencyBands{Hot: 5, Warm: 50, Cold: 500} + } + p := &Predictor{ + Estimator: estimator, + GridSearchEstimator: gridSearchEstimator, + ReWriteScoreThreshold: config.ReWriteScoreThreshold, + MaxMemTableCount: config.MaxMemTableCount, + freqBands: fb, + recencyBands: rb, + hitRateCh: make(chan float64, 1024), + } + go func() { + for hitRate := range p.hitRateCh { + p.GridSearchEstimator.RecordHitRate(hitRate) + } + }() + return p +} + +func scoreBucket(score float32) string { + switch { + case score < 0.1: + return "0.0-0.1" + case score < 0.3: + return "0.1-0.3" + case score < 0.5: + return "0.3-0.5" + case score < 0.7: + return "0.5-0.7" + case score < 1.0: + return "0.7-1.0" + default: + return "1.0+" + } +} + +func ringZone(keyMemId, activeMemId, maxMemTableCount uint32) string { + risk := (activeMemId - keyMemId + maxMemTableCount) % maxMemTableCount + pct := float64(risk) / float64(maxMemTableCount) + switch { + case pct < 0.25: + return "0-25%" + case pct < 0.50: + return "25-50%" + case pct < 0.75: + return "50-75%" + default: + return "75-100%" + } +} + +func freqBand(freq uint64, fb FreqBands) string { + switch { + case freq <= fb.Cold: + return "cold" + case freq <= fb.Warm: + return "warm" + case freq <= fb.Hot: + return "hot" + default: + return "very_hot" + } +} + +func recencyBand(lastAccess uint64, rb RecencyBands) string { + switch { + case lastAccess <= rb.Hot: + return "very_hot" + case lastAccess <= rb.Warm: + return "hot" + case lastAccess <= rb.Cold: + return "warm" + default: + return "cold" + } +} + +func (p *Predictor) Predict(freq uint64, lastAccess uint64, keyMemId uint32, activeMemId uint32) bool { + score := p.Estimator.CalculateRewriteScore(freq, lastAccess, keyMemId, activeMemId, p.MaxMemTableCount) + rewrite := score > p.ReWriteScoreThreshold + + computeMetrics(keyMemId, activeMemId, p, freq, lastAccess, rewrite, score) + + return rewrite +} + +func computeMetrics(keyMemId uint32, activeMemId uint32, p *Predictor, freq uint64, lastAccess uint64, rewrite bool, score float32) { + zone := ringZone(keyMemId, activeMemId, p.MaxMemTableCount) + fBand := freqBand(freq, p.freqBands) + rBand := recencyBand(lastAccess, p.recencyBands) + decision := "skip" + if rewrite { + decision = "rewrite" + } + + metrics.Timing(metrics.KEY_ACCESS_FREQ, time.Duration(freq)*time.Millisecond, nil) + metrics.Timing(metrics.KEY_LAST_ACCESS, time.Duration(lastAccess)*time.Millisecond, nil) + metrics.Incr(metrics.KEY_REWRITE_SCORE, metrics.BuildTag(metrics.NewTag(metrics.TAG_SCORE_BUCKET, scoreBucket(score)))) + metrics.Incr(metrics.KEY_REWRITE_DECISION, metrics.BuildTag( + metrics.NewTag(metrics.TAG_DECISION, decision), + metrics.NewTag(metrics.TAG_RING_ZONE, zone), + metrics.NewTag(metrics.TAG_FREQ_BAND, fBand), + metrics.NewTag(metrics.TAG_RECENCY_BAND, rBand), + )) +} + +func (p *Predictor) Observe(hitRate float64) { + select { + case p.hitRateCh <- hitRate: + default: + } +} diff --git a/flashring/internal/maths/predictor_test.go b/flashring/internal/maths/predictor_test.go new file mode 100644 index 00000000..f9bdfa0e --- /dev/null +++ b/flashring/internal/maths/predictor_test.go @@ -0,0 +1,481 @@ +package maths + +import ( + "testing" + "time" +) + +func TestNewPredictor(t *testing.T) { + config := PredictorConfig{ + ReWriteScoreThreshold: 0.5, + Weights: []WeightTuple{ + {WFreq: 0.1, WLA: 0.2}, + {WFreq: 0.2, WLA: 0.3}, + }, + SampleDuration: 100 * time.Millisecond, + MaxMemTableCount: 10, + GridSearchEpsilon: 0.001, + } + + predictor := NewPredictor(config) + + // Verify predictor initialization + if predictor == nil { + t.Fatal("NewPredictor returned nil") + } + if predictor.ReWriteScoreThreshold != 0.5 { + t.Errorf("Expected ReWriteScoreThreshold 0.5, got %f", predictor.ReWriteScoreThreshold) + } + if predictor.MaxMemTableCount != 10 { + t.Errorf("Expected MaxMemTableCount 10, got %d", predictor.MaxMemTableCount) + } + + // Verify estimator initialization + if predictor.Estimator == nil { + t.Fatal("Estimator not initialized") + } + if predictor.Estimator.WFreq != 0.1 { + t.Errorf("Expected WFreq 0.1, got %f", predictor.Estimator.WFreq) + } + if predictor.Estimator.WLA != 0.2 { + t.Errorf("Expected WLA 0.2, got %f", predictor.Estimator.WLA) + } + + // Verify grid search estimator initialization + if predictor.GridSearchEstimator == nil { + t.Fatal("GridSearchEstimator not initialized") + } + + // Verify channel initialization + if predictor.hitRateCh == nil { + t.Fatal("hitRateCh not initialized") + } +} + +func TestPredictorPredict(t *testing.T) { + config := PredictorConfig{ + ReWriteScoreThreshold: 0.5, + Weights: []WeightTuple{ + {WFreq: 0.1, WLA: 0.2}, + }, + SampleDuration: 100 * time.Millisecond, + MaxMemTableCount: 10, + GridSearchEpsilon: 0.001, + } + + predictor := NewPredictor(config) + + tests := []struct { + name string + freq uint64 + lastAccess uint64 + keyMemId uint32 + activeMemId uint32 + expectRewrite bool + }{ + { + name: "high frequency, recent access, high overwrite risk", + freq: 100, + lastAccess: 1, + keyMemId: 0, + activeMemId: 8, + expectRewrite: true, + }, + { + name: "low frequency, old access, low overwrite risk", + freq: 1, + lastAccess: 1000, + keyMemId: 5, + activeMemId: 6, + expectRewrite: false, + }, + { + name: "medium frequency, medium access, medium overwrite risk", + freq: 10, + lastAccess: 50, + keyMemId: 3, + activeMemId: 7, + expectRewrite: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := predictor.Predict(tt.freq, tt.lastAccess, tt.keyMemId, tt.activeMemId) + if result != tt.expectRewrite { + score := predictor.Estimator.CalculateRewriteScore( + tt.freq, tt.lastAccess, tt.keyMemId, tt.activeMemId, predictor.MaxMemTableCount) + t.Errorf("Expected %v, got %v (score: %f, threshold: %f)", + tt.expectRewrite, result, score, predictor.ReWriteScoreThreshold) + } + }) + } +} + +func TestPredictorObserve(t *testing.T) { + config := PredictorConfig{ + ReWriteScoreThreshold: 0.5, + Weights: []WeightTuple{ + {WFreq: 0.1, WLA: 0.2}, + }, + SampleDuration: 10 * time.Millisecond, + MaxMemTableCount: 10, + GridSearchEpsilon: 0.001, + } + + predictor := NewPredictor(config) + + // Test observing hit rates + hitRates := []float64{0.8, 0.7, 0.9, 0.6} + + for _, hitRate := range hitRates { + predictor.Observe(hitRate) + } + + // Give some time for the goroutine to process + time.Sleep(50 * time.Millisecond) + + // Channel should not block on additional observations + for i := 0; i < 10; i++ { + predictor.Observe(0.5) + } +} + +func TestEstimatorCalculateRewriteScore(t *testing.T) { + estimator := &Estimator{ + WFreq: 0.1, + WLA: 0.2, + } + + tests := []struct { + name string + freq uint64 + lastAccess uint64 + keyMemId uint32 + activeMemId uint32 + maxMemTableCount uint32 + expectHighScore bool + }{ + { + name: "high frequency, recent access, high overwrite risk", + freq: 100, + lastAccess: 1, + keyMemId: 0, + activeMemId: 9, + maxMemTableCount: 10, + expectHighScore: true, + }, + { + name: "low frequency, old access, low overwrite risk", + freq: 1, + lastAccess: 1000, + keyMemId: 5, + activeMemId: 6, + maxMemTableCount: 10, + expectHighScore: false, + }, + { + name: "zero frequency should give low score", + freq: 0, + lastAccess: 0, + keyMemId: 0, + activeMemId: 0, + maxMemTableCount: 10, + expectHighScore: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + score := estimator.CalculateRewriteScore( + tt.freq, tt.lastAccess, tt.keyMemId, tt.activeMemId, tt.maxMemTableCount) + + if tt.expectHighScore && score < 0.1 { + t.Errorf("Expected high score, got %f", score) + } + if !tt.expectHighScore && score > 0.5 { + t.Errorf("Expected low score, got %f", score) + } + + // Score should always be non-negative + if score < 0 { + t.Errorf("Score should be non-negative, got %f", score) + } + }) + } +} + +func TestEstimatorScoreComponents(t *testing.T) { + estimator := &Estimator{ + WFreq: 0.1, + WLA: 0.2, + } + + // Test that frequency score increases with frequency + score1 := estimator.CalculateRewriteScore(1, 100, 0, 5, 10) + score2 := estimator.CalculateRewriteScore(10, 100, 0, 5, 10) + score3 := estimator.CalculateRewriteScore(100, 100, 0, 5, 10) + + if !(score1 < score2 && score2 < score3) { + t.Errorf("Score should increase with frequency: %f, %f, %f", score1, score2, score3) + } + + // Test that last access score decreases with time + score1 = estimator.CalculateRewriteScore(10, 1, 0, 5, 10) + score2 = estimator.CalculateRewriteScore(10, 10, 0, 5, 10) + score3 = estimator.CalculateRewriteScore(10, 100, 0, 5, 10) + + if !(score1 > score2 && score2 > score3) { + t.Errorf("Score should decrease with last access time: %f, %f, %f", score1, score2, score3) + } + + // Test overwrite risk calculation + score1 = estimator.CalculateRewriteScore(10, 10, 0, 1, 10) // low risk + score2 = estimator.CalculateRewriteScore(10, 10, 0, 5, 10) // medium risk + score3 = estimator.CalculateRewriteScore(10, 10, 0, 9, 10) // high risk + + if !(score1 < score2 && score2 < score3) { + t.Errorf("Score should increase with overwrite risk: %f, %f, %f", score1, score2, score3) + } +} + +func TestGridSearchEstimator(t *testing.T) { + initialTuples := []WeightTuple{ + {WFreq: 0.1, WLA: 0.1}, + {WFreq: 0.2, WLA: 0.2}, + {WFreq: 0.3, WLA: 0.3}, + } + + estimator := &Estimator{WFreq: 0.1, WLA: 0.1} + gridSearch := NewGridSearchEstimator( + 50*time.Millisecond, + initialTuples, + estimator, + 0.001, + ) + + // Test initialization + if len(gridSearch.Tuples) != 3 { + t.Errorf("Expected 3 tuples, got %d", len(gridSearch.Tuples)) + } + if gridSearch.CurrIndex != 0 { + t.Errorf("Expected CurrIndex 0, got %d", gridSearch.CurrIndex) + } + + // Test recording hit rates + hitRates := []float64{0.8, 0.7, 0.9} + for i, hitRate := range hitRates { + gridSearch.RecordHitRate(hitRate) + if i < len(hitRates)-1 { + time.Sleep(60 * time.Millisecond) // Wait for duration to pass + } + } + + // Verify stats are recorded + for _, tuple := range initialTuples { + if stat, ok := gridSearch.TupleStats[tuple]; ok && stat.Trials > 0 { + if stat.HitRate < 0 || stat.HitRate > 1 { + t.Errorf("Invalid hit rate %f for tuple %+v", stat.HitRate, tuple) + } + } + } +} + +func TestGridSearchBestTuple(t *testing.T) { + initialTuples := []WeightTuple{ + {WFreq: 0.1, WLA: 0.1}, + {WFreq: 0.2, WLA: 0.2}, + {WFreq: 0.3, WLA: 0.3}, + } + + estimator := &Estimator{WFreq: 0.1, WLA: 0.1} + gridSearch := NewGridSearchEstimator( + 10*time.Millisecond, + initialTuples, + estimator, + 0.001, + ) + + // Manually add stats + gridSearch.TupleStats[initialTuples[0]] = &Stats{HitRate: 0.7, Trials: 5} + gridSearch.TupleStats[initialTuples[1]] = &Stats{HitRate: 0.9, Trials: 5} + gridSearch.TupleStats[initialTuples[2]] = &Stats{HitRate: 0.6, Trials: 5} + + best := gridSearch.BestTuple() + expected := initialTuples[1] // Should be the one with 0.9 hit rate + + if best.WFreq != expected.WFreq || best.WLA != expected.WLA { + t.Errorf("Expected best tuple %+v, got %+v", expected, best) + } +} + +func TestGridSearchRefinement(t *testing.T) { + initialTuples := []WeightTuple{ + {WFreq: 0.2, WLA: 0.2}, + } + + estimator := &Estimator{WFreq: 0.2, WLA: 0.2} + + t.Run("delta > epsilon: refinement possible", func(t *testing.T) { + gs := NewGridSearchEstimator(10*time.Millisecond, initialTuples, estimator, 0.01) + base := WeightTuple{WFreq: 0.2, WLA: 0.2} + refined, ok := gs.GenerateRefinedGrid(base, 1, 0.1) + if !ok { + t.Error("Expected ok=true when delta > epsilon (refinement still useful)") + } + // 3x3 grid minus the center = 8 points; all have positive wf and la + // (base 0.2 ± 0.1 yields 0.1..0.3) + if len(refined) != 8 { + t.Errorf("Expected 8 refined tuples, got %d", len(refined)) + } + }) + + t.Run("delta < epsilon: convergence detected", func(t *testing.T) { + gs := NewGridSearchEstimator(10*time.Millisecond, initialTuples, estimator, 0.1) + base := WeightTuple{WFreq: 0.2, WLA: 0.2} + _, ok := gs.GenerateRefinedGrid(base, 1, 0.01) + if ok { + t.Error("Expected ok=false when delta < epsilon (converged)") + } + }) + + t.Run("larger grid with delta > epsilon", func(t *testing.T) { + gs := NewGridSearchEstimator(10*time.Millisecond, initialTuples, estimator, 0.001) + base := WeightTuple{WFreq: 0.5, WLA: 0.5} + refined, ok := gs.GenerateRefinedGrid(base, 2, 0.1) + if !ok { + t.Error("Expected ok=true when delta >> epsilon") + } + // 5x5 grid minus center = 24 points; all positive (0.3..0.7) + if len(refined) != 24 { + t.Errorf("Expected 24 refined tuples, got %d", len(refined)) + } + }) +} + +func TestGridSearchConvergence(t *testing.T) { + initialTuples := []WeightTuple{ + {WFreq: 0.1, WLA: 0.1}, + } + + estimator := &Estimator{WFreq: 0.1, WLA: 0.1} + gridSearch := NewGridSearchEstimator( + 1*time.Millisecond, + initialTuples, + estimator, + 0.1, // Large epsilon for quick convergence + ) + + // Test convergence with very small delta + base := WeightTuple{WFreq: 0.1, WLA: 0.1} + _, ok := gridSearch.GenerateRefinedGrid(base, 1, 0.01) // Small delta + + if ok { + t.Error("Grid refinement should fail when delta is smaller than epsilon") + } +} + +func BenchmarkEstimatorCalculateRewriteScore(b *testing.B) { + estimator := &Estimator{ + WFreq: 0.1, + WLA: 0.2, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + estimator.CalculateRewriteScore( + uint64(i%100+1), // freq + uint64(i%1000+1), // lastAccess + uint32(i%10), // keyMemId + uint32((i+5)%10), // activeMemId + 10, // maxMemTableCount + ) + } +} + +func BenchmarkPredictorPredict(b *testing.B) { + config := PredictorConfig{ + ReWriteScoreThreshold: 0.5, + Weights: []WeightTuple{ + {WFreq: 0.1, WLA: 0.2}, + }, + SampleDuration: 100 * time.Millisecond, + MaxMemTableCount: 10, + GridSearchEpsilon: 0.001, + } + + predictor := NewPredictor(config) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + predictor.Predict( + uint64(i%100+1), // freq + uint64(i%1000+1), // lastAccess + uint32(i%10), // keyMemId + uint32((i+5)%10), // activeMemId + ) + } +} + +// Integration test that simulates a realistic cache scenario +func TestPredictorIntegration(t *testing.T) { + config := PredictorConfig{ + ReWriteScoreThreshold: 0.3, + Weights: []WeightTuple{ + {WFreq: 0.1, WLA: 0.1}, + {WFreq: 0.2, WLA: 0.2}, + {WFreq: 0.3, WLA: 0.3}, + }, + SampleDuration: 20 * time.Millisecond, + MaxMemTableCount: 8, + GridSearchEpsilon: 0.01, + } + + predictor := NewPredictor(config) + + // Simulate cache operations + type cacheOp struct { + freq uint64 + lastAccess uint64 + keyMemId uint32 + activeMemId uint32 + } + + operations := []cacheOp{ + {freq: 100, lastAccess: 1, keyMemId: 0, activeMemId: 7}, // Should rewrite + {freq: 1, lastAccess: 1000, keyMemId: 6, activeMemId: 7}, // Should not rewrite + {freq: 50, lastAccess: 10, keyMemId: 2, activeMemId: 6}, // Maybe rewrite + {freq: 200, lastAccess: 5, keyMemId: 1, activeMemId: 7}, // Should rewrite + } + + rewriteCount := 0 + for i, op := range operations { + shouldRewrite := predictor.Predict(op.freq, op.lastAccess, op.keyMemId, op.activeMemId) + if shouldRewrite { + rewriteCount++ + } + + // Simulate hit rate feedback + var hitRate float64 + if shouldRewrite { + hitRate = 0.8 + 0.1*float64(i%3) // Simulated good hit rate for rewrites + } else { + hitRate = 0.6 + 0.1*float64(i%2) // Simulated moderate hit rate for no rewrites + } + + predictor.Observe(hitRate) + + // Small delay to allow processing + time.Sleep(5 * time.Millisecond) + } + + // Should have made some rewrite decisions + if rewriteCount == 0 { + t.Error("Expected at least some rewrite decisions") + } + if rewriteCount == len(operations) { + t.Error("Should not rewrite everything") + } + + t.Logf("Made %d rewrites out of %d operations", rewriteCount, len(operations)) +} diff --git a/flashring/internal/memtables/manager.go b/flashring/internal/memtables/manager.go new file mode 100644 index 00000000..3227c2fb --- /dev/null +++ b/flashring/internal/memtables/manager.go @@ -0,0 +1,123 @@ +package memtables + +import ( + "github.com/Meesho/BharatMLStack/flashring/internal/allocators" + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" + "github.com/rs/zerolog/log" +) + +type MemtableManager struct { + file *fs.WrapAppendFile + Capacity int32 + + memtable1 *Memtable + memtable2 *Memtable + activeMemtable *Memtable + nextFileOffset int64 + nextId uint32 + semaphore chan int +} + +// NewMemtableManager creates a double-buffered memtable pair. +// flushStaggerOffset advances the active memtable's write position so that +// different shards fill (and therefore flush) at staggered times, avoiding +// synchronized flush storms that compete with reads for NVMe bandwidth. +func NewMemtableManager(file *fs.WrapAppendFile, capacity int32, flushStaggerOffset int) (*MemtableManager, error) { + allocatorConfig := allocators.SlabAlignedPageAllocatorConfig{ + SizeClasses: []allocators.SizeClass{ + {Size: int(capacity), MinCount: 2}, + }, + } + allocator, err := allocators.NewSlabAlignedPageAllocator(allocatorConfig) + if err != nil { + return nil, err + } + page1 := allocator.Get(int(capacity)) + page2 := allocator.Get(int(capacity)) + memtable1, err := NewMemtable(MemtableConfig{ + capacity: int(capacity), + id: 0, + page: page1, + file: file, + }) + if err != nil { + return nil, err + } + memtable2, err := NewMemtable(MemtableConfig{ + capacity: int(capacity), + id: 1, + page: page2, + file: file, + }) + if err != nil { + return nil, err + } + // Pre-advance the active memtable so this shard's first flush happens + // earlier/later than its peers, spreading flush I/O over time. + memtable1.currentOffset = flushStaggerOffset + memtable1.flushStartOffset = flushStaggerOffset + + memtableManager := &MemtableManager{ + file: file, + Capacity: capacity, + memtable1: memtable1, + memtable2: memtable2, + activeMemtable: memtable1, + nextFileOffset: 2 * int64(capacity), + nextId: 2, + semaphore: make(chan int, 1), + } + return memtableManager, nil +} + +func (mm *MemtableManager) GetMemtable() (*Memtable, uint32, uint64) { + return mm.activeMemtable, mm.activeMemtable.Id, uint64(mm.activeMemtable.Id) * uint64(mm.Capacity) +} + +func (mm *MemtableManager) GetMemtableById(id uint32) *Memtable { + if mm.memtable1.Id == id { + return mm.memtable1 + } + if mm.memtable2.Id == id { + return mm.memtable2 + } + return nil +} + +func (mm *MemtableManager) flushConsumer(memtable *Memtable) { + n, fileOffset, err := memtable.Flush() + if n != int(mm.Capacity) { + log.Error().Msgf("Flush size mismatch: memId:%d fileOffset:%d nextFileOffset:%d n:%d err:%v", memtable.Id, fileOffset, mm.nextFileOffset, n, err) + } + if err != nil { + log.Error().Msgf("Failed to flush memtable: memId:%d fileOffset:%d nextFileOffset:%d n:%d err:%v", memtable.Id, fileOffset, mm.nextFileOffset, n, err) + } + memtable.Id = mm.nextId + mm.nextId++ + mm.nextFileOffset += int64(n) + metrics.Incr(metrics.KEY_MEMTABLE_FLUSH_COUNT, append(metrics.GetShardTag(memtable.ShardIdx), metrics.GetMemtableTag(memtable.Id)...)) +} +func (mm *MemtableManager) Flush() error { + + memtableToFlush := mm.activeMemtable + mm.semaphore <- 1 + + // Swap to the other memtable + if mm.activeMemtable == mm.memtable1 { + mm.activeMemtable = mm.memtable2 + } else { + mm.activeMemtable = mm.memtable1 + } + go func() { + defer func() { + <-mm.semaphore + if r := recover(); r != nil { + log.Error().Msgf("Recovered from panic in goroutine: %v", r) + } + }() + mm.flushConsumer(memtableToFlush) + }() + + return nil +} diff --git a/flashring/internal/memtables/manager_bench_test.go b/flashring/internal/memtables/manager_bench_test.go new file mode 100644 index 00000000..8e1b7406 --- /dev/null +++ b/flashring/internal/memtables/manager_bench_test.go @@ -0,0 +1,55 @@ +package memtables + +import ( + "fmt" + "testing" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" +) + +// Helper function to create a test file for benchmarks +func createManagerBenchmarkFile(b *testing.B) *fs.WrapAppendFile { + filename := fmt.Sprintf("/media/a0d00kc/freedom/tmp/bench_memtable_%d.dat", time.Now().UnixNano()) + + config := fs.FileConfig{ + Filename: filename, + MaxFileSize: 20 * 1024 * 1024 * 1024, // 20GB for benchmarks + FilePunchHoleSize: 1024 * 1024 * 1024, // 1GB + BlockSize: fs.BLOCK_SIZE, + } + + file, err := fs.NewWrapAppendFile(config) + if err != nil { + b.Fatalf("Failed to create benchmark file: %v", err) + } + return file +} + +func Benchmark_Puts(b *testing.B) { + file := createManagerBenchmarkFile(b) + + manager, err := NewMemtableManager(file, 1024*1024*1024, 0) + if err != nil { + b.Fatalf("Failed to create memtable manager: %v", err) + } + + buf16k := make([]byte, 16*1024) + for j := range buf16k { + buf16k[j] = byte(j % 256) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + memtable, _, _ := manager.GetMemtable() + _, _, readyForFlush := memtable.Put(buf16k) + if readyForFlush { + manager.Flush() + } + } + + // b.ReportMetric(float64(manager.stats.Flushes), "flushes") + b.ReportMetric(float64(b.N*16*1024)/1024/1024, "MB/s") + b.ReportAllocs() + +} diff --git a/flashring/internal/memtables/manager_test.go b/flashring/internal/memtables/manager_test.go new file mode 100644 index 00000000..4c931b9c --- /dev/null +++ b/flashring/internal/memtables/manager_test.go @@ -0,0 +1,389 @@ +package memtables + +import ( + "path/filepath" + "sync" + "testing" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/internal/iouring" +) + +// Helper function to create a mock file for testing +func createTestFileForManager(t *testing.T) *fs.WrapAppendFile { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_memtable_manager.dat") + + config := fs.FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, // 1MB + FilePunchHoleSize: 64 * 1024, // 64KB + BlockSize: fs.BLOCK_SIZE, + } + + file, err := fs.NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + writeRing, err := iouring.NewIoUringWriter(32, 0) + if err != nil { + t.Fatalf("Failed to create io_uring write ring: %v", err) + } + file.WriteRing = writeRing + return file +} + +func cleanupManagerFile(file *fs.WrapAppendFile) { + if file.WriteRing != nil { + file.WriteRing.Close() + } + file.Close() +} + +func TestNewMemtableManager_Success(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) // 8192 bytes + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Verify initial state + if manager.file != file { + t.Errorf("Expected file to be set correctly") + } + if manager.Capacity != capacity { + t.Errorf("Expected capacity %d, got %d", capacity, manager.Capacity) + } + if manager.memtable1 == nil { + t.Errorf("Expected memtable1 to be initialized") + } + if manager.memtable2 == nil { + t.Errorf("Expected memtable2 to be initialized") + } + if manager.activeMemtable != manager.memtable1 { + t.Errorf("Expected activeMemtable to be memtable1 initially") + } + if manager.nextFileOffset != 2*int64(capacity) { + t.Errorf("Expected nextFileOffset to be %d, got %d", 2*int64(capacity), manager.nextFileOffset) + } + if manager.nextId != 2 { + t.Errorf("Expected nextId to be 2, got %d", manager.nextId) + } + if cap(manager.semaphore) != 1 { + t.Errorf("Expected semaphore capacity to be 1, got %d", cap(manager.semaphore)) + } + + // Verify memtable initial IDs + if manager.memtable1.Id != 0 { + t.Errorf("Expected memtable1 ID to be 0, got %d", manager.memtable1.Id) + } + if manager.memtable2.Id != 1 { + t.Errorf("Expected memtable2 ID to be 1, got %d", manager.memtable2.Id) + } +} + +func TestNewMemtableManager_InvalidCapacity(t *testing.T) { + // Test with capacity not aligned to block size + capacity := int32(fs.BLOCK_SIZE + 1) // Should fail alignment check + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + _, err := NewMemtableManager(file, capacity, 0) + if err == nil { + t.Errorf("Expected NewMemtableManager to fail with invalid capacity") + } +} + +func TestNewMemtableManager_NilFile(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + + _, err := NewMemtableManager(nil, capacity, 0) + if err == nil { + t.Errorf("Expected NewMemtableManager to fail with nil file") + } +} + +func TestMemtableManager_GetMemtable(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + memtable, id, offset := manager.GetMemtable() + + // Initially should return memtable1 + if memtable != manager.memtable1 { + t.Errorf("Expected to get memtable1") + } + if id != 0 { + t.Errorf("Expected ID 0, got %d", id) + } + expectedOffset := uint64(0) * uint64(capacity) + if offset != expectedOffset { + t.Errorf("Expected offset %d, got %d", expectedOffset, offset) + } +} + +func TestMemtableManager_GetMemtableById(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Test getting memtable1 by ID + memtable := manager.GetMemtableById(0) + if memtable != manager.memtable1 { + t.Errorf("Expected to get memtable1 for ID 0") + } + + // Test getting memtable2 by ID + memtable = manager.GetMemtableById(1) + if memtable != manager.memtable2 { + t.Errorf("Expected to get memtable2 for ID 1") + } + + // Test getting non-existent memtable + memtable = manager.GetMemtableById(999) + if memtable != nil { + t.Errorf("Expected nil for non-existent ID, got %v", memtable) + } +} + +func TestMemtableManager_Flush(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Verify initial state + originalActive := manager.activeMemtable + originalNextId := manager.nextId + + // Perform flush + err = manager.Flush() + if err != nil { + t.Fatalf("Flush failed: %v", err) + } + + // Verify active memtable swapped + if manager.activeMemtable == originalActive { + t.Errorf("Expected active memtable to be swapped") + } + + // Active should now be the other memtable + if originalActive == manager.memtable1 { + if manager.activeMemtable != manager.memtable2 { + t.Errorf("Expected active memtable to be memtable2") + } + } else { + if manager.activeMemtable != manager.memtable1 { + t.Errorf("Expected active memtable to be memtable1") + } + } + + // Give time for background goroutine to complete + time.Sleep(100 * time.Millisecond) + + // Verify nextId was incremented (this happens in background) + if manager.nextId <= originalNextId { + t.Errorf("Expected nextId to be incremented, got %d, expected > %d", manager.nextId, originalNextId) + } +} + +func TestMemtableManager_FlushSwapsBetweenMemtables(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Initially active is memtable1 + if manager.activeMemtable != manager.memtable1 { + t.Fatalf("Expected initial active to be memtable1") + } + + // First flush - should swap to memtable2 + err = manager.Flush() + if err != nil { + t.Fatalf("First flush failed: %v", err) + } + if manager.activeMemtable != manager.memtable2 { + t.Errorf("Expected active to be memtable2 after first flush") + } + + // Second flush - should swap back to memtable1 + err = manager.Flush() + if err != nil { + t.Fatalf("Second flush failed: %v", err) + } + if manager.activeMemtable != manager.memtable1 { + t.Errorf("Expected active to be memtable1 after second flush") + } +} + +func TestMemtableManager_FlushConcurrency(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + const numConcurrentFlushes = 10 + var wg sync.WaitGroup + errors := make(chan error, numConcurrentFlushes) + + // Launch multiple concurrent flushes + for i := 0; i < numConcurrentFlushes; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := manager.Flush(); err != nil { + errors <- err + } + }() + } + + wg.Wait() + close(errors) + + // Check for errors + for err := range errors { + t.Errorf("Concurrent flush failed: %v", err) + } + + // Give time for all background operations to complete + time.Sleep(200 * time.Millisecond) + + // Verify manager is still in a valid state + memtable, id, offset := manager.GetMemtable() + if memtable == nil { + t.Errorf("Active memtable should not be nil") + } + if id != memtable.Id { + t.Errorf("Returned ID %d should match memtable ID %d", id, memtable.Id) + } + expectedOffset := uint64(memtable.Id) * uint64(capacity) + if offset != expectedOffset { + t.Errorf("Expected offset %d, got %d", expectedOffset, offset) + } +} + +func TestMemtableManager_GetMemtableAfterFlush(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Get initial memtable + initialMemtable, initialId, _ := manager.GetMemtable() + + // Perform flush + err = manager.Flush() + if err != nil { + t.Fatalf("Flush failed: %v", err) + } + + // Get memtable after flush + newMemtable, newId, newOffset := manager.GetMemtable() + + // Should be different memtable + if newMemtable == initialMemtable { + t.Errorf("Expected different memtable after flush") + } + if newId == initialId { + t.Errorf("Expected different ID after flush") + } + + // Offset calculation should be correct + expectedOffset := uint64(newId) * uint64(capacity) + if newOffset != expectedOffset { + t.Errorf("Expected offset %d, got %d", expectedOffset, newOffset) + } +} + +func TestMemtableManager_Integration(t *testing.T) { + capacity := int32(fs.BLOCK_SIZE * 2) + file := createTestFileForManager(t) + defer cleanupManagerFile(file) + + manager, err := NewMemtableManager(file, capacity, 0) + if err != nil { + t.Fatalf("NewMemtableManager failed: %v", err) + } + + // Test complete workflow: get memtable, put data, flush, repeat + testData := []byte("Hello, MemtableManager!") + + // Get initial memtable and put some data + memtable, id, _ := manager.GetMemtable() + offset, length, readyForFlush := memtable.Put(testData) + if readyForFlush { + t.Errorf("Memtable should not be ready for flush after small put") + } + + // Verify data can be retrieved + data, err := memtable.Get(offset, length) + if err != nil { + t.Fatalf("Failed to get data: %v", err) + } + if string(data) != string(testData) { + t.Errorf("Expected %s, got %s", testData, data) + } + + // Verify GetMemtableById works + retrievedMemtable := manager.GetMemtableById(id) + if retrievedMemtable != memtable { + t.Errorf("GetMemtableById should return the same memtable") + } + + // Perform flush and verify state changes + err = manager.Flush() + if err != nil { + t.Fatalf("Flush failed: %v", err) + } + + // Get new active memtable + newMemtable, newId, _ := manager.GetMemtable() + if newMemtable == memtable { + t.Errorf("Active memtable should have changed after flush") + } + if newId == id { + t.Errorf("Active memtable ID should have changed after flush") + } + + // Old memtable should still be retrievable by its original ID + oldMemtable := manager.GetMemtableById(id) + if oldMemtable != memtable { + t.Errorf("Should still be able to retrieve old memtable by ID") + } + + // Give background flush time to complete + time.Sleep(100 * time.Millisecond) +} diff --git a/flashring/internal/memtables/memtable.go b/flashring/internal/memtables/memtable.go new file mode 100644 index 00000000..7eb1ef49 --- /dev/null +++ b/flashring/internal/memtables/memtable.go @@ -0,0 +1,138 @@ +package memtables + +import ( + "errors" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" +) + +var ( + ErrCapacityNotAligned = errors.New("capacity must be aligned to block size") + ErrPageNotProvided = errors.New("page must be provided") + ErrFileNotProvided = errors.New("file must be provided") + ErrPageBufferCapacityMismatch = errors.New("page buffer must be provided and must be of size capacity") + ErrOffsetOutOfBounds = errors.New("offset out of bounds") + ErrMemtableNotReadyForFlush = errors.New("memtable not ready for flush") +) + +type Memtable struct { + Id uint32 + capacity int + currentOffset int + file *fs.WrapAppendFile + page *fs.AlignedPage + readyForFlush bool + next *Memtable + prev *Memtable + ShardIdx uint32 + + // flushStartOffset is the byte offset within the page where real data + // begins. On the first (staggered) memtable this equals the stagger + // offset so Flush() skips the uninitialized region. Reset to 0 after + // the first flush. + flushStartOffset int +} + +type MemtableConfig struct { + capacity int + id uint32 + page *fs.AlignedPage + file *fs.WrapAppendFile + shardIdx uint32 +} + +func NewMemtable(config MemtableConfig) (*Memtable, error) { + if config.capacity%fs.BLOCK_SIZE != 0 { + return nil, ErrCapacityNotAligned + } + if config.page == nil { + return nil, ErrPageNotProvided + } + if config.file == nil { + return nil, ErrFileNotProvided + } + if config.page.Buf == nil || len(config.page.Buf) != config.capacity { + return nil, ErrPageBufferCapacityMismatch + } + return &Memtable{ + Id: config.id, + ShardIdx: config.shardIdx, + capacity: config.capacity, + currentOffset: 0, + file: config.file, + page: config.page, + readyForFlush: false, + }, nil +} + +func (m *Memtable) Get(offset int, length uint16) ([]byte, error) { + if offset+int(length) > m.capacity { + return nil, ErrOffsetOutOfBounds + } + return m.page.Buf[offset : offset+int(length)], nil +} + +func (m *Memtable) Put(buf []byte) (offset int, length uint16, readyForFlush bool) { + offset = m.currentOffset + if offset+len(buf) > m.capacity { + m.readyForFlush = true + return -1, 0, true + } + copy(m.page.Buf[offset:], buf) + m.currentOffset += len(buf) + return offset, uint16(len(buf)), false +} + +// Efforts to make zero copy +func (m *Memtable) GetBufForAppend(size uint16) (bbuf []byte, offset int, length uint16, readyForFlush bool) { + offset = m.currentOffset + if offset+int(size) > m.capacity { + m.readyForFlush = true + return nil, -1, 0, true + } + bbuf = m.page.Buf[offset : offset+int(size)] + m.currentOffset += int(size) + return bbuf, offset, size, false +} + +func (m *Memtable) GetBufForRead(offset int, length uint16) (bbuf []byte, exists bool) { + if offset+int(length) > m.capacity { + return nil, false + } + return m.page.Buf[offset : offset+int(length)], true +} + +func (m *Memtable) Flush() (n int, fileOffset int64, err error) { + if !m.readyForFlush { + return 0, 0, ErrMemtableNotReadyForFlush + } + + chunkSize := fs.BLOCK_SIZE + + // When the memtable has a stagger offset (first cycle only), skip the + // uninitialized region: advance the file's write pointer past it, then + // write only the real data. Total file advancement = flushStartOffset + + // len(usedBuf) = capacity, preserving the memId*capacity layout. + startOff := m.flushStartOffset + if startOff > 0 { + m.file.AdvanceWriteOffset(int64(startOff)) + } + + buf := m.page.Buf[startOff:] + + // PwriteBatch submits all chunks via io_uring. + totalWritten, fileOffset, err := m.file.PwriteBatch(buf, chunkSize) + if err != nil { + return 0, 0, err + } + + m.currentOffset = 0 + m.readyForFlush = false + m.flushStartOffset = 0 // subsequent flushes write the full page + return startOff + totalWritten, fileOffset, nil +} + +func (m *Memtable) Discard() { + m.file = nil + m.page = nil +} diff --git a/flashring/internal/memtables/memtable_bench_test.go b/flashring/internal/memtables/memtable_bench_test.go new file mode 100644 index 00000000..40175e62 --- /dev/null +++ b/flashring/internal/memtables/memtable_bench_test.go @@ -0,0 +1,580 @@ +// Benchmark tests for Memtable operations optimized for single-threaded performance +// Uses 50GB max file size and 1GB memtable page size as specified +package memtables + +import ( + "crypto/rand" + "fmt" + "path/filepath" + "testing" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" +) + +const ( + // Configuration for single-threaded benchmarks + BENCH_MAX_FILE_SIZE = 50 * 1024 * 1024 * 1024 // 50GB max file size + BENCH_PAGE_SIZE = 1024 * 1024 * 1024 // 1GB memtable page size + BENCH_PUNCH_HOLE_SIZE = 64 * 1024 * 1024 // 64MB punch hole size + + // Data sizes for single-threaded performance testing + SMALL_DATA_SIZE = 256 // 256 bytes - typical small record + MEDIUM_DATA_SIZE = 4096 // 4KB - typical medium record + LARGE_DATA_SIZE = 64 * 1024 // 64KB - large record + VERY_LARGE_DATA_SIZE = 1024 * 1024 // 1MB - very large record +) + +// Helper function to create benchmark file +func createBenchmarkFile(b *testing.B) *fs.WrapAppendFile { + filename := filepath.Join("/media/a0d00kc/freedom/tmp/bench_memtable.dat") + + config := fs.FileConfig{ + Filename: filename, + MaxFileSize: BENCH_MAX_FILE_SIZE, + FilePunchHoleSize: BENCH_PUNCH_HOLE_SIZE, + BlockSize: fs.BLOCK_SIZE, + } + + file, err := fs.NewWrapAppendFile(config) + if err != nil { + b.Fatalf("Failed to create benchmark file: %v", err) + } + return file +} + +// Helper function to create benchmark page +func createBenchmarkPage() *fs.AlignedPage { + return fs.NewAlignedPage(BENCH_PAGE_SIZE) +} + +// Helper function to create benchmark memtable +func createBenchmarkMemtable(b *testing.B) (*Memtable, *fs.WrapAppendFile, *fs.AlignedPage) { + file := createBenchmarkFile(b) + page := createBenchmarkPage() + + config := MemtableConfig{ + capacity: BENCH_PAGE_SIZE, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + cleanup(file, page) + b.Fatalf("Failed to create benchmark memtable: %v", err) + } + + return memtable, file, page +} + +// Helper function to generate random data +func generateRandomData(size int) []byte { + data := make([]byte, size) + rand.Read(data) + return data +} + +// Benchmark Put operations with different data sizes +func BenchmarkMemtable_Put_Small(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(SMALL_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(SMALL_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + // Reset memtable for continued benchmarking + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + // Don't count flush operations in this benchmark + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } +} + +func BenchmarkMemtable_Put_Medium(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(MEDIUM_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(MEDIUM_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } +} + +func BenchmarkMemtable_Put_Large(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(LARGE_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(LARGE_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } +} + +func BenchmarkMemtable_Put_VeryLarge(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(VERY_LARGE_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(VERY_LARGE_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } +} + +// Benchmark Get operations +func BenchmarkMemtable_Get_Small(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Pre-populate memtable with data + data := generateRandomData(SMALL_DATA_SIZE) + numEntries := BENCH_PAGE_SIZE / SMALL_DATA_SIZE / 2 // Fill half the memtable + + offsets := make([]int, numEntries) + lengths := make([]uint16, numEntries) + + for i := 0; i < numEntries; i++ { + offset, length, _ := memtable.Put(data) + offsets[i] = offset + lengths[i] = length + } + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(SMALL_DATA_SIZE) + + for i := 0; i < b.N; i++ { + idx := i % numEntries + _, err := memtable.Get(offsets[idx], lengths[idx]) + if err != nil { + b.Fatalf("Get failed: %v", err) + } + } +} + +func BenchmarkMemtable_Get_Medium(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(MEDIUM_DATA_SIZE) + numEntries := BENCH_PAGE_SIZE / MEDIUM_DATA_SIZE / 2 + + offsets := make([]int, numEntries) + lengths := make([]uint16, numEntries) + + for i := 0; i < numEntries; i++ { + offset, length, _ := memtable.Put(data) + offsets[i] = offset + lengths[i] = length + } + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(MEDIUM_DATA_SIZE) + + for i := 0; i < b.N; i++ { + idx := i % numEntries + _, err := memtable.Get(offsets[idx], lengths[idx]) + if err != nil { + b.Fatalf("Get failed: %v", err) + } + } +} + +func BenchmarkMemtable_Get_Large(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(LARGE_DATA_SIZE) + numEntries := BENCH_PAGE_SIZE / LARGE_DATA_SIZE / 2 + + offsets := make([]int, numEntries) + lengths := make([]uint16, numEntries) + + for i := 0; i < numEntries; i++ { + offset, length, _ := memtable.Put(data) + offsets[i] = offset + lengths[i] = length + } + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(LARGE_DATA_SIZE) + + for i := 0; i < b.N; i++ { + idx := i % numEntries + _, err := memtable.Get(offsets[idx], lengths[idx]) + if err != nil { + b.Fatalf("Get failed: %v", err) + } + } +} + +// Benchmark Flush operations +func BenchmarkMemtable_Flush(b *testing.B) { + file := createBenchmarkFile(b) + defer cleanup(file, nil) + + // Create fresh memtable for each iteration + page := createBenchmarkPage() + config := MemtableConfig{ + capacity: BENCH_PAGE_SIZE, + id: uint32(0), + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + b.Fatalf("Failed to create memtable: %v", err) + } + + // Fill memtable to near capacity then trigger flush with overflow + fillData := generateRandomData(BENCH_PAGE_SIZE - 1000) + memtable.Put(fillData) + + // Now add data that will exceed capacity to trigger flush + overflowData := generateRandomData(2000) // This will exceed capacity + _, _, readyForFlush := memtable.Put(overflowData) + if !readyForFlush { + b.Fatalf("Failed to trigger flush - memtable should be ready for flush") + } + b.ReportAllocs() + b.SetBytes(BENCH_PAGE_SIZE) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + _, _, err = memtable.Flush() + if err != nil { + b.Fatalf("Flush failed: %v", err) + } + // Force re-flush same data in each iteration + memtable.readyForFlush = true + } + fs.Unmap(page) +} + +// Benchmark mixed operations (realistic usage pattern) +func BenchmarkMemtable_MixedOperations(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Pre-populate with some data + initialData := generateRandomData(MEDIUM_DATA_SIZE) + numInitial := 1000 + offsets := make([]int, numInitial) + lengths := make([]uint16, numInitial) + + for i := 0; i < numInitial; i++ { + offset, length, readyForFlush := memtable.Put(initialData) + if readyForFlush { + break + } + offsets[i] = offset + lengths[i] = length + } + + putData := generateRandomData(SMALL_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Mix of operations: 70% gets, 30% puts + if i%10 < 7 { + // Get operation + idx := i % len(offsets) + if idx < len(offsets) && lengths[idx] > 0 { + _, err := memtable.Get(offsets[idx], lengths[idx]) + if err != nil && err != ErrOffsetOutOfBounds { + b.Fatalf("Get failed: %v", err) + } + } + } else { + // Put operation + if memtable.readyForFlush { + // Reset for continued benchmarking + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + memtable.Put(putData) + } + } +} + +// Benchmark sequential writes to measure throughput +func BenchmarkMemtable_SequentialWrites(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + data := generateRandomData(MEDIUM_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(MEDIUM_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } +} + +// Benchmark random access patterns +func BenchmarkMemtable_RandomAccess(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Pre-populate memtable + data := generateRandomData(SMALL_DATA_SIZE) + numEntries := BENCH_PAGE_SIZE / SMALL_DATA_SIZE / 4 // Fill quarter of memtable + + offsets := make([]int, numEntries) + lengths := make([]uint16, numEntries) + + for i := 0; i < numEntries; i++ { + offset, length, _ := memtable.Put(data) + offsets[i] = offset + lengths[i] = length + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Random access pattern + idx := (i * 7919) % numEntries // Use prime number for better distribution + _, err := memtable.Get(offsets[idx], lengths[idx]) + if err != nil { + b.Fatalf("Get failed: %v", err) + } + } +} + +// Benchmark memory copying efficiency +func BenchmarkMemtable_MemoryCopy(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Test different copy sizes + sizes := []int{64, 256, 1024, 4096, 16384, 65536} + + for _, size := range sizes { + b.Run(fmt.Sprintf("Size%d", size), func(b *testing.B) { + data := generateRandomData(size) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + _, _, readyForFlush := memtable.Put(data) + if readyForFlush { + b.StopTimer() + memtable.currentOffset = 0 + memtable.readyForFlush = false + b.StartTimer() + } + } + }) + } +} + +// Benchmark full memtable lifecycle +func BenchmarkMemtable_FullLifecycle(b *testing.B) { + file := createBenchmarkFile(b) + defer cleanup(file, nil) + + entrySize := MEDIUM_DATA_SIZE + entriesPerMemtable := BENCH_PAGE_SIZE / entrySize + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(int64(entriesPerMemtable * entrySize)) + + for i := 0; i < b.N; i++ { + // Create memtable + page := createBenchmarkPage() + config := MemtableConfig{ + capacity: BENCH_PAGE_SIZE, + id: uint32(i), + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + b.Fatalf("Failed to create memtable: %v", err) + } + + // Fill memtable to near capacity then trigger flush with overflow + fillData := generateRandomData(BENCH_PAGE_SIZE - 1000) + memtable.Put(fillData) + + // Add data that will exceed capacity to trigger flush + overflowData := generateRandomData(2000) + _, _, readyForFlush := memtable.Put(overflowData) + if !readyForFlush { + b.Fatalf("Failed to trigger flush in lifecycle test") + } + + // Flush + _, _, err = memtable.Flush() + if err != nil { + b.Fatalf("Flush failed: %v", err) + } + + // Cleanup + memtable.Discard() + fs.Unmap(page) + } +} + +// Benchmark single-threaded workload patterns (read-heavy, write-heavy, mixed) +func BenchmarkMemtable_SingleThreadedWorkload(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Pre-populate with test data + data := generateRandomData(SMALL_DATA_SIZE) + numEntries := 10000 + offsets := make([]int, numEntries) + lengths := make([]uint16, numEntries) + validEntries := 0 + + for i := 0; i < numEntries; i++ { + offset, length, readyForFlush := memtable.Put(data) + if readyForFlush { + break + } + offsets[validEntries] = offset + lengths[validEntries] = length + validEntries++ + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Single-threaded workload pattern: 80% reads, 20% writes + if i%5 < 4 { + // Read operation (80%) + if validEntries > 0 { + idx := i % validEntries + memtable.Get(offsets[idx], lengths[idx]) + } + } else { + // Write operation (20%) - only if space available + if !memtable.readyForFlush { + memtable.Put(data) + } + } + } +} + +// Benchmark CPU-intensive single-threaded operations +func BenchmarkMemtable_CPUIntensive(b *testing.B) { + memtable, file, page := createBenchmarkMemtable(b) + defer cleanup(file, page) + + // Use medium-sized data for CPU-intensive operations + data := generateRandomData(MEDIUM_DATA_SIZE) + + b.ResetTimer() + b.ReportAllocs() + b.SetBytes(MEDIUM_DATA_SIZE) + + for i := 0; i < b.N; i++ { + if memtable.readyForFlush { + // Reset for continued benchmarking + memtable.currentOffset = 0 + memtable.readyForFlush = false + } + + // Perform put operation + offset, length, readyForFlush := memtable.Put(data) + if !readyForFlush { + // Immediately read back the data to stress CPU + _, err := memtable.Get(offset, length) + if err != nil { + b.Fatalf("Get failed: %v", err) + } + } + } +} diff --git a/flashring/internal/memtables/memtable_test.go b/flashring/internal/memtables/memtable_test.go new file mode 100644 index 00000000..4f60d07b --- /dev/null +++ b/flashring/internal/memtables/memtable_test.go @@ -0,0 +1,604 @@ +package memtables + +import ( + "path/filepath" + "testing" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/internal/iouring" +) + +// Helper function to create a mock file for testing +func createTestFile(t *testing.T) *fs.WrapAppendFile { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test_memtable.dat") + + config := fs.FileConfig{ + Filename: filename, + MaxFileSize: 1024 * 1024, // 1MB + FilePunchHoleSize: 64 * 1024, // 64KB + BlockSize: fs.BLOCK_SIZE, + } + + file, err := fs.NewWrapAppendFile(config) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + writeRing, err := iouring.NewIoUringWriter(32, 0) + if err != nil { + t.Fatalf("Failed to create io_uring write ring: %v", err) + } + file.WriteRing = writeRing + return file +} + +// Helper function to create a test page +func createTestPage(size int) *fs.AlignedPage { + return fs.NewAlignedPage(size) +} + +// Helper function to cleanup resources +func cleanup(file *fs.WrapAppendFile, page *fs.AlignedPage) { + if file != nil { + if file.WriteRing != nil { + file.WriteRing.Close() + } + file.Close() + } + if page != nil { + fs.Unmap(page) + } +} + +func TestNewMemtable_Success(t *testing.T) { + capacity := fs.BLOCK_SIZE * 2 // 8192 bytes + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + if memtable.Id != 1 { + t.Errorf("Expected Id 1, got %d", memtable.Id) + } + if memtable.capacity != capacity { + t.Errorf("Expected capacity %d, got %d", capacity, memtable.capacity) + } + if memtable.currentOffset != 0 { + t.Errorf("Expected currentOffset 0, got %d", memtable.currentOffset) + } + if memtable.readyForFlush != false { + t.Errorf("Expected readyForFlush false, got %v", memtable.readyForFlush) + } +} + +func TestNewMemtable_CapacityNotAligned(t *testing.T) { + capacity := fs.BLOCK_SIZE + 100 // Not aligned to block size + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + _, err := NewMemtable(config) + if err != ErrCapacityNotAligned { + t.Errorf("Expected ErrCapacityNotAligned, got %v", err) + } +} + +func TestNewMemtable_PageNotProvided(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + defer cleanup(file, nil) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: nil, + file: file, + } + + _, err := NewMemtable(config) + if err != ErrPageNotProvided { + t.Errorf("Expected ErrPageNotProvided, got %v", err) + } +} + +func TestNewMemtable_FileNotProvided(t *testing.T) { + capacity := fs.BLOCK_SIZE + page := createTestPage(capacity) + defer cleanup(nil, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: nil, + } + + _, err := NewMemtable(config) + if err != ErrFileNotProvided { + t.Errorf("Expected ErrFileNotProvided, got %v", err) + } +} + +func TestNewMemtable_PageBufferCapacityMismatch(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity * 2) // Wrong size + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + _, err := NewMemtable(config) + if err != ErrPageBufferCapacityMismatch { + t.Errorf("Expected ErrPageBufferCapacityMismatch, got %v", err) + } +} + +func TestNewMemtable_PageBufferNil(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + defer cleanup(file, nil) + + // Create page with nil buffer + page := &fs.AlignedPage{Buf: nil} + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + _, err := NewMemtable(config) + if err != ErrPageBufferCapacityMismatch { + t.Errorf("Expected ErrPageBufferCapacityMismatch, got %v", err) + } +} + +func TestMemtable_Get_Success(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Write some test data to the page buffer + testData := []byte("Hello, World!") + copy(page.Buf[:len(testData)], testData) + + // Get the data + result, err := memtable.Get(0, uint16(len(testData))) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if string(result) != string(testData) { + t.Errorf("Expected %s, got %s", testData, result) + } +} + +func TestMemtable_Get_OffsetOutOfBounds(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Try to get data beyond capacity + _, err = memtable.Get(capacity-10, 20) + if err != ErrOffsetOutOfBounds { + t.Errorf("Expected ErrOffsetOutOfBounds, got %v", err) + } +} + +func TestMemtable_Put_Success(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + testData := []byte("Hello, World!") + offset, length, readyForFlush := memtable.Put(testData) + + if offset != 0 { + t.Errorf("Expected offset 0, got %d", offset) + } + if length != uint16(len(testData)) { + t.Errorf("Expected length %d, got %d", len(testData), length) + } + if readyForFlush { + t.Errorf("Expected readyForFlush false, got %v", readyForFlush) + } + if memtable.currentOffset != len(testData) { + t.Errorf("Expected currentOffset %d, got %d", len(testData), memtable.currentOffset) + } + + // Verify data was written to buffer + result, err := memtable.Get(0, uint16(len(testData))) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if string(result) != string(testData) { + t.Errorf("Expected %s, got %s", testData, result) + } +} + +func TestMemtable_Put_ExceedsCapacity(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Fill the memtable to near capacity + testData := make([]byte, capacity-100) + _, _, _ = memtable.Put(testData) + + // Try to put data that exceeds capacity + largeData := make([]byte, 200) + offset, length, readyForFlush := memtable.Put(largeData) + + if offset != -1 { + t.Errorf("Expected offset -1, got %d", offset) + } + if length != 0 { + t.Errorf("Expected length 0, got %d", length) + } + if !readyForFlush { + t.Errorf("Expected readyForFlush true, got %v", readyForFlush) + } + if !memtable.readyForFlush { + t.Errorf("Expected memtable.readyForFlush true, got %v", memtable.readyForFlush) + } +} + +func TestMemtable_Put_MultiplePuts(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Put multiple pieces of data + data1 := []byte("First") + data2 := []byte("Second") + data3 := []byte("Third") + + offset1, length1, _ := memtable.Put(data1) + offset2, length2, _ := memtable.Put(data2) + offset3, length3, _ := memtable.Put(data3) + + if offset1 != 0 { + t.Errorf("Expected offset1 0, got %d", offset1) + } + if offset2 != len(data1) { + t.Errorf("Expected offset2 %d, got %d", len(data1), offset2) + } + if offset3 != len(data1)+len(data2) { + t.Errorf("Expected offset3 %d, got %d", len(data1)+len(data2), offset3) + } + + // Verify all data can be retrieved + result1, err := memtable.Get(offset1, length1) + if err != nil || string(result1) != string(data1) { + t.Errorf("Failed to retrieve data1: %v", err) + } + + result2, err := memtable.Get(offset2, length2) + if err != nil || string(result2) != string(data2) { + t.Errorf("Failed to retrieve data2: %v", err) + } + + result3, err := memtable.Get(offset3, length3) + if err != nil || string(result3) != string(data3) { + t.Errorf("Failed to retrieve data3: %v", err) + } +} + +func TestMemtable_Flush_Success(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Fill the memtable to trigger ready for flush + testData := make([]byte, capacity-100) + memtable.Put(testData) + + // Put data that exceeds capacity to trigger ready for flush + memtable.Put(make([]byte, 200)) + + if !memtable.readyForFlush { + t.Fatalf("Expected memtable to be ready for flush") + } + + n, fileOffset, err := memtable.Flush() + if err != nil { + t.Fatalf("Flush failed: %v", err) + } + + if n != len(page.Buf) { + t.Errorf("Expected n %d, got %d", len(page.Buf), n) + } + if fileOffset < 0 { + t.Errorf("Expected positive fileOffset, got %d", fileOffset) + } + if memtable.readyForFlush { + t.Errorf("Expected readyForFlush to be false after flush, got %v", memtable.readyForFlush) + } +} + +func TestMemtable_Flush_NotReadyForFlush(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Try to flush without being ready + _, _, err = memtable.Flush() + if err != ErrMemtableNotReadyForFlush { + t.Errorf("Expected ErrMemtableNotReadyForFlush, got %v", err) + } +} + +func TestMemtable_Discard(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + memtable.Discard() + + if memtable.file != nil { + t.Errorf("Expected file to be nil after discard") + } + if memtable.page != nil { + t.Errorf("Expected page to be nil after discard") + } +} + +func TestMemtable_Integration(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 42, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Test complete workflow: multiple puts, get, trigger flush, and flush + testCases := [][]byte{ + []byte("First entry"), + []byte("Second entry with more data"), + []byte("Third entry"), + } + + var offsets []int + var lengths []uint16 + + // Put multiple entries + for i, data := range testCases { + offset, length, readyForFlush := memtable.Put(data) + if readyForFlush { + t.Logf("Memtable ready for flush after entry %d", i) + break + } + offsets = append(offsets, offset) + lengths = append(lengths, length) + } + + // Verify all entries can be retrieved + for i := range offsets { + result, err := memtable.Get(offsets[i], lengths[i]) + if err != nil { + t.Fatalf("Get failed for entry %d: %v", i, err) + } + if string(result) != string(testCases[i]) { + t.Errorf("Entry %d mismatch: expected %s, got %s", i, testCases[i], result) + } + } + + // Fill up the memtable to trigger ready for flush + for !memtable.readyForFlush { + memtable.Put([]byte("filler")) + } + + // Test flush + n, fileOffset, err := memtable.Flush() + if err != nil { + t.Fatalf("Flush failed: %v", err) + } + + if n != capacity { + t.Errorf("Expected flush size %d, got %d", capacity, n) + } + if fileOffset <= 0 { + t.Errorf("Expected positive file offset, got %d", fileOffset) + } +} + +func TestMemtable_EdgeCases(t *testing.T) { + capacity := fs.BLOCK_SIZE + file := createTestFile(t) + page := createTestPage(capacity) + defer cleanup(file, page) + + config := MemtableConfig{ + capacity: capacity, + id: 1, + page: page, + file: file, + } + + memtable, err := NewMemtable(config) + if err != nil { + t.Fatalf("NewMemtable failed: %v", err) + } + + // Test zero-length put + offset, length, readyForFlush := memtable.Put([]byte{}) + if offset != 0 || length != 0 || readyForFlush { + t.Errorf("Zero-length put: offset=%d, length=%d, readyForFlush=%v", offset, length, readyForFlush) + } + + // Test zero-length get + result, err := memtable.Get(0, 0) + if err != nil { + t.Fatalf("Zero-length get failed: %v", err) + } + if len(result) != 0 { + t.Errorf("Expected zero-length result, got %d", len(result)) + } + + // Test get at exact capacity boundary with zero length (should succeed) + result, err = memtable.Get(capacity, 0) + if err != nil { + t.Errorf("Expected no error for boundary get with zero length, got %v", err) + } + if len(result) != 0 { + t.Errorf("Expected zero-length result for boundary get, got %d", len(result)) + } + + // Test get beyond capacity boundary + _, err = memtable.Get(capacity, 1) + if err != ErrOffsetOutOfBounds { + t.Errorf("Expected ErrOffsetOutOfBounds for beyond boundary get, got %v", err) + } + + // Test put that exactly fills capacity + exactData := make([]byte, capacity) + offset, length, readyForFlush = memtable.Put(exactData) + if offset != 0 || length != uint16(capacity) || readyForFlush { + t.Errorf("Exact capacity put: offset=%d, length=%d, readyForFlush=%v", offset, length, readyForFlush) + } + + // Next put should trigger ready for flush + offset, length, readyForFlush = memtable.Put([]byte("overflow")) + if offset != -1 || length != 0 || !readyForFlush { + t.Errorf("Overflow put: offset=%d, length=%d, readyForFlush=%v", offset, length, readyForFlush) + } +} diff --git a/flashring/internal/pools/leaky_pool.go b/flashring/internal/pools/leaky_pool.go new file mode 100644 index 00000000..81bb42dd --- /dev/null +++ b/flashring/internal/pools/leaky_pool.go @@ -0,0 +1,65 @@ +package pools + +import "sync" + +// LeakyPool is a bounded object pool. When all objects are in use, Get creates +// new ones via createFunc. When returned objects exceed capacity, the excess is +// dropped (optionally via a pre-deref hook for cleanup like unmapping pages). +type LeakyPool[T any] struct { + available []T + Meta any + createFunc func() T + preDrefHook func(obj T) + capacity int + usage int + idx int + mu sync.Mutex +} + +type LeakyPoolConfig[T any] struct { + Capacity int + Meta any + CreateFunc func() T +} + +func NewLeakyPool[T any](config LeakyPoolConfig[T]) *LeakyPool[T] { + return &LeakyPool[T]{ + available: make([]T, config.Capacity), + Meta: config.Meta, + capacity: config.Capacity, + createFunc: config.CreateFunc, + usage: 0, + idx: -1, + } +} + +func (p *LeakyPool[T]) RegisterPreDrefHook(hook func(obj T)) { + p.preDrefHook = hook +} + +func (p *LeakyPool[T]) Get() T { + p.mu.Lock() + defer p.mu.Unlock() + p.usage++ + if p.idx == -1 { + return p.createFunc() + } + o := p.available[p.idx] + p.idx-- + return o +} + +func (p *LeakyPool[T]) Put(obj T) { + p.mu.Lock() + defer p.mu.Unlock() + p.usage-- + p.idx++ + if p.idx == p.capacity { + if p.preDrefHook != nil { + p.preDrefHook(obj) + } + p.idx-- + return + } + p.available[p.idx] = obj +} diff --git a/flashring/internal/pools/pool.go b/flashring/internal/pools/pool.go new file mode 100644 index 00000000..fb62dff0 --- /dev/null +++ b/flashring/internal/pools/pool.go @@ -0,0 +1,7 @@ +package pools + +// Pool is a generic object pool that reuses pre-allocated objects. +type Pool[T any] interface { + Get() T + Put(obj T) +} diff --git a/flashring/internal/server/resp.go b/flashring/internal/server/resp.go new file mode 100644 index 00000000..3021e961 --- /dev/null +++ b/flashring/internal/server/resp.go @@ -0,0 +1,232 @@ +package server + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "strconv" + "time" +) + +// Cache is the minimal interface required by the RESP server. +// Implementations should be safe for concurrent use. +type Cache interface { + Put(key string, value []byte, ttl time.Duration) error + Get(key string) ([]byte, bool, bool) +} + +// ServeRESP starts a minimal RESP (Redis) protocol server over TCP supporting +// GET and SET only. It is optimized for low overhead and pipelined requests. +// +// Supported commands (case-insensitive): +// - GET key +// - SET key value [EX seconds] +func ServeRESP(addr string, cache Cache) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + for { + conn, err := ln.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + time.Sleep(50 * time.Millisecond) + continue + } + return err + } + if tc, ok := conn.(*net.TCPConn); ok { + _ = tc.SetNoDelay(true) + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(3 * time.Minute) + } + go handleConn(conn, cache) + } +} + +func handleConn(conn net.Conn, cache Cache) { + defer conn.Close() + r := bufio.NewReaderSize(conn, 64*1024) + w := bufio.NewWriterSize(conn, 64*1024) + + for { + cmd, args, perr := readRESPArray(r) + if perr != nil { + if perr == io.EOF || errors.Is(perr, net.ErrClosed) { + return + } + return + } + if len(cmd) == 0 { + continue + } + + switch { + case len(cmd) == 3 && (cmd[0]|0x20) == 'g' && (cmd[1]|0x20) == 'e' && (cmd[2]|0x20) == 't': + if len(args) != 1 { + writeError(w, "wrong number of arguments for 'get'") + } else { + val, found, expired := cache.Get(b2s(args[0])) + if !found || expired { + writeBulkNil(w) + } else { + writeBulk(w, val) + } + } + + case len(cmd) >= 3 && (cmd[0]|0x20) == 's' && (cmd[1]|0x20) == 'e' && (cmd[2]|0x20) == 't': + if len(args) != 2 && len(args) != 4 { + writeError(w, "wrong number of arguments for 'set'") + } else { + key := b2s(args[0]) + value := args[1] + var ttl time.Duration + if len(args) == 4 { + if !bytes.EqualFold(args[2], []byte("EX")) { + writeError(w, "only EX option is supported") + if w.Flush() != nil { + return + } + continue + } + secs, err := parseUint(args[3]) + if err != nil { + writeError(w, "invalid expire seconds") + if w.Flush() != nil { + return + } + continue + } + ttl = time.Duration(secs) * time.Second + } + _ = cache.Put(key, value, ttl) + writeSimpleString(w, "OK") + } + + default: + writeError(w, "unknown command") + } + + if w.Flush() != nil { + return + } + } +} + +func readRESPArray(r *bufio.Reader) (cmd []byte, args [][]byte, err error) { + b, err := r.ReadByte() + if err != nil { + return nil, nil, err + } + if b != '*' { + return nil, nil, io.ErrUnexpectedEOF + } + n, err := readIntCRLF(r) + if err != nil { + return nil, nil, err + } + if n <= 0 { + return nil, nil, nil + } + bs, err := readBulkString(r) + if err != nil { + return nil, nil, err + } + cmd = bs + if n > 1 { + args = make([][]byte, 0, n-1) + for i := 1; i < n; i++ { + bsi, err := readBulkString(r) + if err != nil { + return nil, nil, err + } + args = append(args, bsi) + } + } + return +} + +func readBulkString(r *bufio.Reader) ([]byte, error) { + b, err := r.ReadByte() + if err != nil { + return nil, err + } + if b != '$' { + return nil, io.ErrUnexpectedEOF + } + n, err := readIntCRLF(r) + if err != nil { + return nil, err + } + if n < 0 { + return nil, nil + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + if err := expectCRLF(r); err != nil { + return nil, err + } + return buf, nil +} + +func readIntCRLF(r *bufio.Reader) (int, error) { + line, err := r.ReadSlice('\r') + if err != nil { + return 0, err + } + if b, err := r.ReadByte(); err != nil || b != '\n' { + if err == nil { + err = io.ErrUnexpectedEOF + } + return 0, err + } + line = line[:len(line)-1] + return strconv.Atoi(b2s(line)) +} + +func expectCRLF(r *bufio.Reader) error { + c1, err := r.ReadByte() + if err != nil { + return err + } + c2, err := r.ReadByte() + if err != nil { + return err + } + if c1 != '\r' || c2 != '\n' { + return io.ErrUnexpectedEOF + } + return nil +} + +func writeSimpleString(w *bufio.Writer, s string) { + w.WriteByte('+') + w.WriteString(s) + w.WriteString("\r\n") +} + +func writeError(w *bufio.Writer, s string) { + w.WriteByte('-') + w.WriteString("ERR ") + w.WriteString(s) + w.WriteString("\r\n") +} + +func writeBulk(w *bufio.Writer, p []byte) { + w.WriteByte('$') + w.WriteString(strconv.Itoa(len(p))) + w.WriteString("\r\n") + w.Write(p) + w.WriteString("\r\n") +} + +func writeBulkNil(w *bufio.Writer) { + w.WriteString("$-1\r\n") +} + +func b2s(b []byte) string { return string(b) } +func parseUint(b []byte) (uint64, error) { return strconv.ParseUint(string(b), 10, 64) } diff --git a/flashring/internal/shard/shard_cache.go b/flashring/internal/shard/shard_cache.go new file mode 100644 index 00000000..ca501f6d --- /dev/null +++ b/flashring/internal/shard/shard_cache.go @@ -0,0 +1,417 @@ +package filecache + +import ( + "fmt" + "hash/crc32" + "sync" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/allocators" + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/internal/index" + "github.com/Meesho/BharatMLStack/flashring/internal/iouring" + "github.com/Meesho/BharatMLStack/flashring/internal/maths" + "github.com/Meesho/BharatMLStack/flashring/internal/memtables" + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" + "github.com/rs/zerolog/log" +) + +type ShardCache struct { + keyIndex *index.Index + file *fs.WrapAppendFile + iouringReader *iouring.ParallelBatchIoUringReader + mm *memtables.MemtableManager + readPageAllocator *allocators.SlabAlignedPageAllocator + dm *index.DeleteManager + predictor *maths.Predictor + startAt int64 + ShardIdx uint32 +} + +type ShardCacheConfig struct { + Rounds int + RbInitial int + RbMax int + DeleteAmortizedStep int + MemtableSize int32 + MaxFileSize int64 + BlockSize int + Directory string + Predictor *maths.Predictor + + // Global batched io_uring reader (shared across all shards). + IoUringReader *iouring.ParallelBatchIoUringReader + + // Dedicated io_uring writer for batched writes (shared across all shards). + IoUringWriter *iouring.IoUringWriter + + // FlushStaggerOffset pre-advances the first memtable so shards flush at + // staggered times instead of all at once. + FlushStaggerOffset int +} + +func NewShardCache(config ShardCacheConfig, sl *sync.RWMutex) (*ShardCache, error) { + filename := fmt.Sprintf("%s/%d.bin", config.Directory, time.Now().UnixNano()) + punchHoleSize := config.MemtableSize + fsConf := fs.FileConfig{ + Filename: filename, + MaxFileSize: config.MaxFileSize, + FilePunchHoleSize: int64(punchHoleSize), + BlockSize: config.BlockSize, + } + file, err := fs.NewWrapAppendFile(fsConf) + if err != nil { + return nil, fmt.Errorf("create shard file: %w", err) + } + memtableManager, err := memtables.NewMemtableManager(file, config.MemtableSize, config.FlushStaggerOffset) + if err != nil { + file.Close() + return nil, fmt.Errorf("create memtable manager: %w", err) + } + ki := index.NewIndex(0, config.RbInitial, config.RbMax, config.DeleteAmortizedStep, sl) + + sizeClasses := make([]allocators.SizeClass, 0) + i := fs.BLOCK_SIZE + minCount := 24 + iMax := (1 << 17) + for i < iMax { + sizeClasses = append(sizeClasses, allocators.SizeClass{Size: i, MinCount: minCount}) + i *= 2 + minCount /= 2 + } + readPageAllocator, err := allocators.NewSlabAlignedPageAllocator(allocators.SlabAlignedPageAllocatorConfig{SizeClasses: sizeClasses}) + if err != nil { + file.Close() + return nil, fmt.Errorf("create read page allocator: %w", err) + } + dm := index.NewDeleteManager(ki, file, config.DeleteAmortizedStep) + + file.WriteRing = config.IoUringWriter + + sc := &ShardCache{ + keyIndex: ki, + mm: memtableManager, + file: file, + readPageAllocator: readPageAllocator, + dm: dm, + predictor: config.Predictor, + startAt: time.Now().Unix(), + } + + if config.IoUringReader == nil { + file.Close() + return nil, fmt.Errorf("BatchIoUringReader is required") + } + sc.iouringReader = config.IoUringReader + + return sc, nil +} + +func (fc *ShardCache) Put(key string, value []byte, ttlMinutes uint16) error { + size := 4 + len(key) + len(value) + mt, mtId, _ := fc.mm.GetMemtable() + if err := fc.dm.ExecuteDeleteIfNeeded(); err != nil { + return err + } + buf, offset, length, readyForFlush := mt.GetBufForAppend(uint16(size)) + if readyForFlush { + fc.mm.Flush() + mt, mtId, _ = fc.mm.GetMemtable() + buf, offset, length, _ = mt.GetBufForAppend(uint16(size)) + } + copy(buf[4:], key) + copy(buf[4+len(key):], value) + crc := crc32.ChecksumIEEE(buf[4:]) + index.ByteOrder.PutUint32(buf[0:4], crc) + fc.keyIndex.Put(key, length, ttlMinutes, mtId, uint32(offset)) + fc.dm.IncMemtableKeyCount(mtId) + return nil +} + +func (fc *ShardCache) Get(key string) (bool, []byte, uint16, bool, bool) { + length, lastAccess, remainingTTL, freq, memId, offset, status := fc.keyIndex.Get(key) + if status == index.StatusNotFound { + metrics.Incr(metrics.KEY_KEY_NOT_FOUND_COUNT, []string{}) + return false, nil, 0, false, false + } + + metrics.Timing(metrics.KEY_DATA_LENGTH, time.Duration(length), []string{}) + + if status == index.StatusExpired { + metrics.Incr(metrics.KEY_KEY_EXPIRED_COUNT, []string{}) + return false, nil, 0, true, false + } + + _, currMemId, _ := fc.mm.GetMemtable() + shouldReWrite := fc.predictor.Predict(uint64(freq), uint64(lastAccess), memId, currMemId) + + var buf []byte + mt := fc.mm.GetMemtableById(memId) + if mt == nil { + metrics.Incr(metrics.KEY_MEMTABLE_MISS, []string{}) + buf = make([]byte, length) + fileOffset := uint64(memId)*uint64(fc.mm.Capacity) + uint64(offset) + n := fc.readFromDiskAsync(int64(fileOffset), length, buf) + if n != int(length) { + metrics.Incr(metrics.KEY_BAD_LENGTH_COUNT, []string{}) + return false, nil, 0, false, shouldReWrite + } + } else { + metrics.Incr(metrics.KEY_MEMTABLE_HIT, []string{}) + var exists bool + buf, exists = mt.GetBufForRead(int(offset), length) + if !exists { + return false, nil, 0, false, shouldReWrite + } + } + gotCR32 := index.ByteOrder.Uint32(buf[0:4]) + computedCR32 := crc32.ChecksumIEEE(buf[4:length]) + gotKey := string(buf[4 : 4+len(key)]) + if gotCR32 != computedCR32 { + metrics.Incr(metrics.KEY_BAD_CR32_COUNT, []string{}) + return false, nil, 0, false, shouldReWrite + } + if gotKey != key { + metrics.Incr(metrics.KEY_BAD_KEY_COUNT, []string{}) + return false, nil, 0, false, shouldReWrite + } + valLen := int(length) - 4 - len(key) + return true, buf[4+len(key) : 4+len(key)+valLen], remainingTTL, false, shouldReWrite +} + +func (fc *ShardCache) readFromDiskAsync(fileOffset int64, length uint16, buf []byte) int { + alignedStart, alignedSize := fs.AlignRange(fileOffset, int(length), fs.BLOCK_SIZE) + page := fc.readPageAllocator.Get(int(alignedSize)) + + readBuf := page.Buf[:alignedSize] + + var n int + var err error + var validOffset int64 + validOffset, err = fc.file.ValidateReadOffset(alignedStart, int(alignedSize)) + if err == nil { + n, err = fc.iouringReader.Submit(fc.file.ReadFd, readBuf, uint64(validOffset)) + } + + if err != nil || n != int(alignedSize) { + if err != nil && err != fs.ErrFileOffsetOutOfRange { + log.Warn().Err(err). + Int64("offset", alignedStart). + Int64("alignedReadSize", alignedSize). + Int("n", n). + Msg("io_uring pread failed") + } + fc.readPageAllocator.Put(page) + return 0 + } + + start := int(fileOffset - alignedStart) + copied := copy(buf, page.Buf[start:start+int(length)]) + fc.readPageAllocator.Put(page) + return copied +} + +func (fc *ShardCache) Flush() { + fc.mm.Flush() +} + +func (fc *ShardCache) Close() { + fc.file.Close() +} + +// DeleteKey removes the key from the index only. Debug use only. +func (fc *ShardCache) DeleteKey(key string) bool { + return fc.keyIndex.DeleteKey(key) +} + +func (fc *ShardCache) GetRingBufferActiveEntries() int { + return fc.keyIndex.GetRB().ActiveEntries() +} + +// --------------------------------------------------------------------------- +// MGet support — separate functions that duplicate parts of Get/readFromDiskAsync +// to allow the caller to split index lookups from disk I/O. +// --------------------------------------------------------------------------- + +// MGetMeta holds the result of an index lookup for batch gets. +type MGetMeta struct { + Found bool + Expired bool + ShouldReWrite bool + RemainingTTL uint16 + // Value is non-nil when the data was found in a memtable (no disk read needed). + Value []byte + NeedsDiskRead bool + Length uint16 + FileOffset int64 +} + +// PendingRead represents an in-flight async io_uring disk read. +type PendingRead struct { + done <-chan iouring.ReadResult + page *fs.AlignedPage + alignedSize int + pageOffset int + length uint16 +} + +// GetMetaForMGet performs an index lookup and memtable check for a single key +// without issuing any disk I/O. This is the first phase of an MGet operation. +func (fc *ShardCache) GetMetaForMGet(key string) MGetMeta { + length, lastAccess, remainingTTL, freq, memId, offset, status := fc.keyIndex.Get(key) + + if status == index.StatusNotFound { + metrics.Incr(metrics.KEY_KEY_NOT_FOUND_COUNT, []string{}) + return MGetMeta{} + } + + metrics.Timing(metrics.KEY_DATA_LENGTH, time.Duration(length), []string{}) + + if status == index.StatusExpired { + metrics.Incr(metrics.KEY_KEY_EXPIRED_COUNT, []string{}) + return MGetMeta{Expired: true} + } + + _, currMemId, _ := fc.mm.GetMemtable() + shouldReWrite := fc.predictor.Predict(uint64(freq), uint64(lastAccess), memId, currMemId) + + mt := fc.mm.GetMemtableById(memId) + if mt != nil { + metrics.Incr(metrics.KEY_MEMTABLE_HIT, []string{}) + buf, exists := mt.GetBufForRead(int(offset), length) + if !exists { + return MGetMeta{ShouldReWrite: shouldReWrite} + } + return MGetMeta{ + Found: true, + Value: buf, + Length: length, + RemainingTTL: remainingTTL, + ShouldReWrite: shouldReWrite, + } + } + + metrics.Incr(metrics.KEY_MEMTABLE_MISS, []string{}) + fileOffset := int64(uint64(memId)*uint64(fc.mm.Capacity) + uint64(offset)) + + return MGetMeta{ + Found: true, + NeedsDiskRead: true, + Length: length, + FileOffset: fileOffset, + RemainingTTL: remainingTTL, + ShouldReWrite: shouldReWrite, + } +} + +// SubmitDiskReadAsync enqueues an aligned disk read via io_uring without +// blocking for completion. Returns a PendingRead handle for CollectDiskRead. +func (fc *ShardCache) SubmitDiskReadAsync(fileOffset int64, length uint16) (*PendingRead, error) { + alignedStart, alignedSize := fs.AlignRange(fileOffset, int(length), fs.BLOCK_SIZE) + page := fc.readPageAllocator.Get(int(alignedSize)) + readBuf := page.Buf[:alignedSize] + + validOffset, err := fc.file.ValidateReadOffset(alignedStart, int(alignedSize)) + if err != nil { + fc.readPageAllocator.Put(page) + return nil, err + } + + done := fc.iouringReader.SubmitAsync(fc.file.ReadFd, readBuf, uint64(validOffset)) + + return &PendingRead{ + done: done, + page: page, + alignedSize: int(alignedSize), + pageOffset: int(fileOffset - alignedStart), + length: length, + }, nil +} + +// CollectDiskRead blocks until the pending io_uring read completes, copies +// the result into a new buffer, and frees the aligned page. Returns nil on failure. +func (fc *ShardCache) CollectDiskRead(pr *PendingRead) []byte { + result := <-pr.done + defer fc.readPageAllocator.Put(pr.page) + + if result.Err != nil || result.N != pr.alignedSize { + if result.Err != nil { + log.Warn().Err(result.Err).Msg("io_uring pread failed in MGet") + } + return nil + } + + buf := make([]byte, pr.length) + copy(buf, pr.page.Buf[pr.pageOffset:pr.pageOffset+int(pr.length)]) + return buf +} + +// CoalescedPendingRead represents an in-flight async io_uring disk read that +// covers a merged aligned region shared by multiple keys. +type CoalescedPendingRead struct { + done <-chan iouring.ReadResult + page *fs.AlignedPage + alignedSize int +} + +// SubmitCoalescedReadAsync enqueues a single aligned disk read that covers +// multiple keys whose file offsets fall within [alignedStart, alignedStart+alignedSize). +func (fc *ShardCache) SubmitCoalescedReadAsync(alignedStart int64, alignedSize int) (*CoalescedPendingRead, error) { + page := fc.readPageAllocator.Get(alignedSize) + readBuf := page.Buf[:alignedSize] + + validOffset, err := fc.file.ValidateReadOffset(alignedStart, alignedSize) + if err != nil { + fc.readPageAllocator.Put(page) + return nil, err + } + + done := fc.iouringReader.SubmitAsync(fc.file.ReadFd, readBuf, uint64(validOffset)) + + return &CoalescedPendingRead{ + done: done, + page: page, + alignedSize: alignedSize, + }, nil +} + +// CollectCoalescedRead blocks until the coalesced io_uring read completes and +// returns the full aligned buffer. The caller extracts individual key regions +// using each key's offset relative to the aligned start. +func (fc *ShardCache) CollectCoalescedRead(pr *CoalescedPendingRead) []byte { + result := <-pr.done + defer fc.readPageAllocator.Put(pr.page) + + if result.Err != nil || result.N != pr.alignedSize { + if result.Err != nil { + log.Warn().Err(result.Err).Msg("io_uring coalesced pread failed in MGet") + } + return nil + } + + buf := make([]byte, pr.alignedSize) + copy(buf, pr.page.Buf[:pr.alignedSize]) + return buf +} + +// ValidateAndExtract checks the CRC32 and key, then extracts the value from +// a raw data buffer. Used by MGet for both memtable and disk-read results. +func (fc *ShardCache) ValidateAndExtract(buf []byte, key string, length uint16) ([]byte, bool) { + if int(length) > len(buf) || length < 4 { + metrics.Incr(metrics.KEY_BAD_LENGTH_COUNT, []string{}) + return nil, false + } + gotCRC := index.ByteOrder.Uint32(buf[0:4]) + computedCRC := crc32.ChecksumIEEE(buf[4:length]) + if gotCRC != computedCRC { + metrics.Incr(metrics.KEY_BAD_CR32_COUNT, []string{}) + return nil, false + } + gotKey := string(buf[4 : 4+len(key)]) + if gotKey != key { + metrics.Incr(metrics.KEY_BAD_KEY_COUNT, []string{}) + return nil, false + } + valLen := int(length) - 4 - len(key) + return buf[4+len(key) : 4+len(key)+valLen], true +} diff --git a/flashring/main.go b/flashring/main.go new file mode 100644 index 00000000..66f4cfa9 --- /dev/null +++ b/flashring/main.go @@ -0,0 +1,412 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "runtime" + "strings" + "sync" + "syscall" + "time" + "unsafe" +) + +const ( + // Common page sizes (4KB is most common) + PageSize4K = 4 * 1024 + PageSize8K = 8 * 1024 + PageSize16K = 16 * 1024 + PageSize64K = 64 * 1024 + + // Test data sizes + SmallRecord = 128 // 128 bytes + MediumRecord = 1024 // 1KB + LargeRecord = 8192 // 8KB +) + +// PageAlignedBuffer provides page-aligned buffered writing +type PageAlignedBuffer struct { + file *os.File + buffer []byte + bufferSize int + writePos int + mu sync.Mutex +} + +// NewPageAlignedBuffer creates a new page-aligned buffer +func NewPageAlignedBuffer(filename string, bufferSize int) (*PageAlignedBuffer, error) { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, err + } + + // Align buffer to page boundary + buffer := make([]byte, bufferSize) + + return &PageAlignedBuffer{ + file: file, + buffer: buffer, + bufferSize: bufferSize, + writePos: 0, + }, nil +} + +// Write writes data to the buffer, flushing when page size is reached +func (pab *PageAlignedBuffer) Write(data []byte) error { + pab.mu.Lock() + defer pab.mu.Unlock() + + dataLen := len(data) + + // If data is larger than buffer, write directly + if dataLen > pab.bufferSize { + if pab.writePos > 0 { + if err := pab.flushUnsafe(); err != nil { + return err + } + } + _, err := pab.file.Write(data) + return err + } + + // If data doesn't fit in current buffer, flush first + if pab.writePos+dataLen > pab.bufferSize { + if err := pab.flushUnsafe(); err != nil { + return err + } + } + + // Copy data to buffer + copy(pab.buffer[pab.writePos:], data) + pab.writePos += dataLen + + return nil +} + +// Flush flushes the buffer to disk +func (pab *PageAlignedBuffer) Flush() error { + pab.mu.Lock() + defer pab.mu.Unlock() + return pab.flushUnsafe() +} + +func (pab *PageAlignedBuffer) flushUnsafe() error { + if pab.writePos == 0 { + return nil + } + + _, err := pab.file.Write(pab.buffer[:pab.writePos]) + if err != nil { + return err + } + + pab.writePos = 0 + return nil +} + +// Sync syncs the file to disk +func (pab *PageAlignedBuffer) Sync() error { + if err := pab.Flush(); err != nil { + return err + } + return pab.file.Sync() +} + +// Close closes the buffer and file +func (pab *PageAlignedBuffer) Close() error { + if err := pab.Flush(); err != nil { + return err + } + return pab.file.Close() +} + +// DirectWriter wraps direct file writing +type DirectWriter struct { + file *os.File +} + +func NewDirectWriter(filename string) (*DirectWriter, error) { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, err + } + return &DirectWriter{file: file}, nil +} + +func (dw *DirectWriter) Write(data []byte) error { + _, err := dw.file.Write(data) + return err +} + +func (dw *DirectWriter) Sync() error { + return dw.file.Sync() +} + +func (dw *DirectWriter) Close() error { + return dw.file.Close() +} + +// MemoryMappedWriter uses memory mapping for writing +type MemoryMappedWriter struct { + file *os.File + data []byte + size int64 + writePos int64 + mu sync.Mutex +} + +func NewMemoryMappedWriter(filename string, size int64) (*MemoryMappedWriter, error) { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + return nil, err + } + + // Truncate file to desired size + if err := file.Truncate(size); err != nil { + file.Close() + return nil, err + } + + // Memory map the file + data, err := syscall.Mmap(int(file.Fd()), 0, int(size), syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + file.Close() + return nil, err + } + + return &MemoryMappedWriter{ + file: file, + data: data, + size: size, + writePos: 0, + }, nil +} + +func (mmw *MemoryMappedWriter) Write(data []byte) error { + mmw.mu.Lock() + defer mmw.mu.Unlock() + + dataLen := int64(len(data)) + if mmw.writePos+dataLen > mmw.size { + return fmt.Errorf("write would exceed mapped region") + } + + copy(mmw.data[mmw.writePos:], data) + mmw.writePos += dataLen + + return nil +} + +func (mmw *MemoryMappedWriter) Sync() error { + // Use manual msync syscall since syscall.Msync might not be available on all platforms + _, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(&mmw.data[0])), uintptr(len(mmw.data)), uintptr(syscall.MS_SYNC)) + if errno != 0 { + return errno + } + return nil +} + +func (mmw *MemoryMappedWriter) Close() error { + if err := syscall.Munmap(mmw.data); err != nil { + return err + } + return mmw.file.Close() +} + +// Benchmark functions +func benchmarkPageAlignedBuffer(recordSize, numRecords, bufferSize int) time.Duration { + filename := fmt.Sprintf("test_page_aligned_%d_%d_%d.log", recordSize, numRecords, bufferSize) + defer os.Remove(filename) + + writer, err := NewPageAlignedBuffer(filename, bufferSize) + if err != nil { + panic(err) + } + defer writer.Close() + + data := make([]byte, recordSize) + for i := 0; i < recordSize; i++ { + data[i] = byte(i % 256) + } + + start := time.Now() + + for i := 0; i < numRecords; i++ { + if err := writer.Write(data); err != nil { + panic(err) + } + } + + if err := writer.Sync(); err != nil { + panic(err) + } + + return time.Since(start) +} + +func benchmarkDirectWrite(recordSize, numRecords int) time.Duration { + filename := fmt.Sprintf("test_direct_%d_%d.log", recordSize, numRecords) + defer os.Remove(filename) + + writer, err := NewDirectWriter(filename) + if err != nil { + panic(err) + } + defer writer.Close() + + data := make([]byte, recordSize) + for i := 0; i < recordSize; i++ { + data[i] = byte(i % 256) + } + + start := time.Now() + + for i := 0; i < numRecords; i++ { + if err := writer.Write(data); err != nil { + panic(err) + } + } + + if err := writer.Sync(); err != nil { + panic(err) + } + + return time.Since(start) +} + +func benchmarkBufferedWrite(recordSize, numRecords, bufferSize int) time.Duration { + filename := fmt.Sprintf("test_buffered_%d_%d_%d.log", recordSize, numRecords, bufferSize) + defer os.Remove(filename) + + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + defer file.Close() + + writer := bufio.NewWriterSize(file, bufferSize) + + data := make([]byte, recordSize) + for i := 0; i < recordSize; i++ { + data[i] = byte(i % 256) + } + + start := time.Now() + + for i := 0; i < numRecords; i++ { + if _, err := writer.Write(data); err != nil { + panic(err) + } + } + + if err := writer.Flush(); err != nil { + panic(err) + } + + if err := file.Sync(); err != nil { + panic(err) + } + + return time.Since(start) +} + +func benchmarkMemoryMapped(recordSize, numRecords int) time.Duration { + filename := fmt.Sprintf("test_mmap_%d_%d.log", recordSize, numRecords) + defer os.Remove(filename) + + totalSize := int64(recordSize * numRecords) + writer, err := NewMemoryMappedWriter(filename, totalSize) + if err != nil { + panic(err) + } + defer writer.Close() + + data := make([]byte, recordSize) + for i := 0; i < recordSize; i++ { + data[i] = byte(i % 256) + } + + start := time.Now() + + for i := 0; i < numRecords; i++ { + if err := writer.Write(data); err != nil { + panic(err) + } + } + + if err := writer.Sync(); err != nil { + panic(err) + } + + return time.Since(start) +} + +func printResults(name string, duration time.Duration, recordSize, numRecords int) { + totalBytes := int64(recordSize * numRecords) + throughputMBps := float64(totalBytes) / duration.Seconds() / (1024 * 1024) + recordsPerSec := float64(numRecords) / duration.Seconds() + + fmt.Printf("%-30s: %10s | %8.2f MB/s | %10.0f records/s | %8.2f MB total\n", + name, duration.Round(time.Microsecond), throughputMBps, recordsPerSec, float64(totalBytes)/(1024*1024)) +} + +func runBenchmarks() { + fmt.Println("=== Append-Only File Writing Benchmarks ===") + fmt.Printf("Go Version: %s, OS: %s, Arch: %s\n", runtime.Version(), runtime.GOOS, runtime.GOARCH) + fmt.Printf("CPUs: %d\n\n", runtime.NumCPU()) + + testCases := []struct { + recordSize int + numRecords int + name string + }{ + {SmallRecord, 100000, "Small Records (128B x 100K)"}, + {MediumRecord, 50000, "Medium Records (1KB x 50K)"}, + {LargeRecord, 10000, "Large Records (8KB x 10K)"}, + } + + bufferSizes := []int{PageSize4K, PageSize8K, PageSize16K, PageSize64K} + + for _, tc := range testCases { + fmt.Printf("\n=== %s ===\n", tc.name) + fmt.Printf("%-30s: %10s | %8s | %10s | %8s\n", "Method", "Duration", "MB/s", "Records/s", "Total MB") + fmt.Println(strings.Repeat("-", 80)) + + // Direct write benchmark + duration := benchmarkDirectWrite(tc.recordSize, tc.numRecords) + printResults("Direct Write", duration, tc.recordSize, tc.numRecords) + + // Buffered write benchmarks with different buffer sizes + for _, bufSize := range bufferSizes { + duration := benchmarkBufferedWrite(tc.recordSize, tc.numRecords, bufSize) + name := fmt.Sprintf("Buffered (%dK)", bufSize/1024) + printResults(name, duration, tc.recordSize, tc.numRecords) + } + + // Page-aligned buffer benchmarks + for _, bufSize := range bufferSizes { + duration := benchmarkPageAlignedBuffer(tc.recordSize, tc.numRecords, bufSize) + name := fmt.Sprintf("Page-Aligned (%dK)", bufSize/1024) + printResults(name, duration, tc.recordSize, tc.numRecords) + } + + // Memory-mapped benchmark (if total size is reasonable) + totalSize := int64(tc.recordSize * tc.numRecords) + if totalSize < 1024*1024*1024 { // Less than 1GB + duration := benchmarkMemoryMapped(tc.recordSize, tc.numRecords) + printResults("Memory Mapped", duration, tc.recordSize, tc.numRecords) + } + } + + fmt.Println("\n=== Recommendations ===") + fmt.Println("1. For high-throughput workloads: Use page-aligned buffers with 16KB-64KB buffer sizes") + fmt.Println("2. For low-latency workloads: Use smaller buffers (4KB-8KB) with frequent flushing") + fmt.Println("3. For large sequential writes: Consider memory-mapped files") + fmt.Println("4. Always align buffer sizes to page boundaries for optimal performance") + fmt.Println("5. Use fdatasync instead of fsync when metadata updates aren't critical") +} + +func main() { + runBenchmarks() +} diff --git a/flashring/pkg/cache/badger.go b/flashring/pkg/cache/badger.go new file mode 100644 index 00000000..f53f0c3e --- /dev/null +++ b/flashring/pkg/cache/badger.go @@ -0,0 +1,52 @@ +package cache + +import ( + "time" + + badger "github.com/dgraph-io/badger/v4" +) + +type Badger struct { + cache *badger.DB +} + +func NewBadger(config Config, dir string) (*Badger, error) { + options := badger.DefaultOptions(dir) + options.MetricsEnabled = false + options.BlockCacheSize = 1024 << 20 + options.IndexCacheSize = 512 << 20 + options.NumMemtables = 40 + options.MemTableSize = 1024 << 20 + options.ValueThreshold = 1024 + options.SyncWrites = false + + db, err := badger.Open(options) + if err != nil { + return nil, err + } + return &Badger{cache: db}, nil +} + +func (b *Badger) Put(key string, value []byte, ttl time.Duration) error { + return b.cache.Update(func(txn *badger.Txn) error { + entry := badger.NewEntry([]byte(key), value).WithTTL(ttl) + return txn.SetEntry(entry) + }) +} + +func (b *Badger) Get(key string) ([]byte, bool, bool) { + var val []byte + err := b.cache.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(key)) + if err != nil { + return err + } + val, err = item.ValueCopy(val) + return err + }) + return val, err != badger.ErrKeyNotFound, false +} + +func (b *Badger) Close() error { + return b.cache.Close() +} diff --git a/flashring/pkg/cache/cache.go b/flashring/pkg/cache/cache.go new file mode 100644 index 00000000..0bcb3826 --- /dev/null +++ b/flashring/pkg/cache/cache.go @@ -0,0 +1,457 @@ +package cache + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "sync" + "time" + + "github.com/Meesho/BharatMLStack/flashring/internal/fs" + "github.com/Meesho/BharatMLStack/flashring/internal/iouring" + "github.com/Meesho/BharatMLStack/flashring/internal/maths" + filecache "github.com/Meesho/BharatMLStack/flashring/internal/shard" + "github.com/cespare/xxhash/v2" + "github.com/rs/zerolog/log" + + "github.com/Meesho/BharatMLStack/flashring/pkg/metrics" +) + +const ( + rounds = 1 + maxKeysShard = (1 << 26) // 67M + blockSize = 4096 + maxCoalescedReadSz = 65536 // must match the largest slab allocator size class +) + +// Cache is the common interface for all cache backends. +type Cache interface { + Put(key string, value []byte, ttl time.Duration) error + Get(key string) (value []byte, found bool, expired bool) + Close() error +} + +// Config holds all parameters for creating a WrapCache. +type Config struct { + NumShards int + KeysPerShard int + FileSize int64 + MemtableSize int32 + ReWriteScoreThreshold float32 + GridSearchEpsilon float64 + SampleDuration time.Duration + FreqBands []int + RecencyBands []int +} + +var ( + ErrNumShardLessThan1 = fmt.Errorf("num shards must be greater than 0") + ErrKeysPerShardLessThan1 = fmt.Errorf("keys per shard must be greater than 0") + ErrKeysPerShardGreaterThan67M = fmt.Errorf("keys per shard must be less than 67M") + ErrMemtableSizeLessThan1 = fmt.Errorf("memtable size must be greater than 0") + ErrMemtableSizeGreaterThan1GB = fmt.Errorf("memtable size must be less than 1GB") + ErrMemtableSizeNotMultipleOf4KB = fmt.Errorf("memtable size must be a multiple of 4KB") + ErrFileSizeLessThan1 = fmt.Errorf("file size must be greater than 0") + ErrFileSizeNotMultipleOf4KB = fmt.Errorf("file size must be a multiple of 4KB") +) + +func (c *Config) validate() error { + checks := []struct { + cond bool + err error + }{ + {c.NumShards <= 0, ErrNumShardLessThan1}, + {c.KeysPerShard <= 0, ErrKeysPerShardLessThan1}, + {c.KeysPerShard > maxKeysShard, ErrKeysPerShardGreaterThan67M}, + {c.MemtableSize <= 0, ErrMemtableSizeLessThan1}, + {c.MemtableSize > 1<<30, ErrMemtableSizeGreaterThan1GB}, + {c.MemtableSize%blockSize != 0, ErrMemtableSizeNotMultipleOf4KB}, + {c.FileSize <= 0, ErrFileSizeLessThan1}, + {c.FileSize%blockSize != 0, ErrFileSizeNotMultipleOf4KB}, + } + for _, ch := range checks { + if ch.cond { + return ch.err + } + } + return nil +} + +// WrapCache is the primary disk-backed NVMe cache. +type WrapCache struct { + shards []*filecache.ShardCache + shardLocks []sync.RWMutex + predictor *maths.Predictor + iouringReader *iouring.ParallelBatchIoUringReader + iouringWriter *iouring.IoUringWriter + seed uint64 +} + +var defaultWeights = []maths.WeightTuple{ + {WFreq: 0.1, WLA: 0.1}, + {WFreq: 0.45, WLA: 0.1}, + {WFreq: 0.9, WLA: 0.1}, + {WFreq: 0.1, WLA: 0.45}, + {WFreq: 0.45, WLA: 0.45}, + {WFreq: 0.9, WLA: 0.45}, + {WFreq: 0.1, WLA: 0.9}, + {WFreq: 0.45, WLA: 0.9}, + {WFreq: 0.9, WLA: 0.9}, +} + +func NewWrapCache(config Config, mountPoint string) (*WrapCache, error) { + if err := config.validate(); err != nil { + return nil, err + } + + files, err := os.ReadDir(mountPoint) + if err != nil { + return nil, fmt.Errorf("read mount point: %w", err) + } + for _, file := range files { + os.Remove(filepath.Join(mountPoint, file.Name())) + } + + maxMemTableCount := config.FileSize / int64(config.MemtableSize) + predictor := maths.NewPredictor(maths.PredictorConfig{ + ReWriteScoreThreshold: config.ReWriteScoreThreshold, + Weights: defaultWeights, + SampleDuration: config.SampleDuration, + MaxMemTableCount: uint32(maxMemTableCount), + GridSearchEpsilon: config.GridSearchEpsilon, + FreqBands: maths.FreqBands{Cold: uint64(config.FreqBands[0]), Warm: uint64(config.FreqBands[1]), Hot: uint64(config.FreqBands[2])}, + RecencyBands: maths.RecencyBands{Hot: uint64(config.RecencyBands[0]), Warm: uint64(config.RecencyBands[1]), Cold: uint64(config.RecencyBands[2])}, + }) + + readRing, err := iouring.NewParallelBatchIoUringReader(iouring.BatchIoUringConfig{ + RingDepth: 512, + MaxBatch: 512, + MaxInflight: 512, + QueueSize: 2048, + Window: 0, + SQPoll: true, + }, 1) + if err != nil { + log.Panic().Err(err).Msg("Failed to create batched io_uring reader") + } + + writeRing, err := iouring.NewIoUringWriter(256, 0) + if err != nil { + log.Panic().Err(err).Msg("Failed to create io_uring write ring") + } + + seed := xxhash.Sum64String(strconv.Itoa(int(time.Now().UnixNano()))) + + metrics.BuildShardTags(config.NumShards) + shardLocks := make([]sync.RWMutex, config.NumShards) + shards := make([]*filecache.ShardCache, config.NumShards) + + // Stagger each shard's first memtable fill level so flushes are spread + // evenly over time instead of all firing at once. Shard i starts with + // i/N of its memtable already "used", so it fills sooner by that fraction. + // After the first cycle the stagger is self-sustaining. + staggerStep := (int(config.MemtableSize) / config.NumShards) &^ (blockSize - 1) // block-align + + for i := 0; i < config.NumShards; i++ { + shards[i], err = filecache.NewShardCache(filecache.ShardCacheConfig{ + MemtableSize: config.MemtableSize, + Rounds: rounds, + RbInitial: config.KeysPerShard, + RbMax: config.KeysPerShard, + DeleteAmortizedStep: 10000, + MaxFileSize: config.FileSize, + BlockSize: blockSize, + Directory: mountPoint, + Predictor: predictor, + IoUringReader: readRing, + IoUringWriter: writeRing, + FlushStaggerOffset: i * staggerStep, + }, &shardLocks[i]) + if err != nil { + for j := 0; j < i; j++ { + shards[j].Close() + } + readRing.Close() + writeRing.Close() + return nil, fmt.Errorf("create shard %d: %w", i, err) + } + } + + return &WrapCache{ + shards: shards, + shardLocks: shardLocks, + predictor: predictor, + iouringReader: readRing, + iouringWriter: writeRing, + seed: seed, + }, nil +} + +func (wc *WrapCache) Put(key string, value []byte, ttl time.Duration) error { + h32 := wc.hash(key) + shardIdx := h32 % uint32(len(wc.shards)) + + start := time.Now() + defer func() { + metrics.Timing(metrics.KEY_PUT_LATENCY, time.Since(start), metrics.GetShardTag(shardIdx)) + }() + + wc.shardLocks[shardIdx].Lock() + metrics.Timing(metrics.LATENCY_WLOCK, time.Since(start), []string{}) + defer wc.shardLocks[shardIdx].Unlock() + + ttlMinutes := uint16(ttl.Minutes()) + if ttlMinutes == 0 && ttl > 0 { + ttlMinutes = 1 + } + + if err := wc.shards[shardIdx].Put(key, value, ttlMinutes); err != nil { + return fmt.Errorf("put failed for key %s: %w", key, err) + } + metrics.Incr(metrics.KEY_PUTS, metrics.GetShardTag(shardIdx)) + if h32%100 < 10 { + metrics.Incr(metrics.KEY_RINGBUFFER_ACTIVE_ENTRIES, metrics.GetShardTag(shardIdx)) + } + return nil +} + +func (wc *WrapCache) Get(key string) ([]byte, bool, bool) { + h32 := wc.hash(key) + shardIdx := h32 % uint32(len(wc.shards)) + + start := time.Now() + defer func() { + metrics.Timing(metrics.KEY_GET_LATENCY, time.Since(start), metrics.GetShardTag(shardIdx)) + }() + + keyFound, val, remainingTTL, expired, shouldReWrite := wc.shards[shardIdx].Get(key) + + if keyFound && !expired { + metrics.Incr(metrics.KEY_HITS, metrics.GetShardTag(shardIdx)) + } + if expired { + metrics.Incr(metrics.KEY_EXPIRED_ENTRIES, metrics.GetShardTag(shardIdx)) + } + metrics.Incr(metrics.KEY_GETS, metrics.GetShardTag(shardIdx)) + + if shouldReWrite { + metrics.Incr(metrics.KEY_REWRITES, metrics.GetShardTag(shardIdx)) + valCopy := make([]byte, len(val)) + copy(valCopy, val) + go wc.rewrite(key, valCopy, remainingTTL) + } + + return val, keyFound, expired +} + +// Delete removes the key from the index only. The data remains on disk +// but becomes unreachable via Get. Debug use only. +func (wc *WrapCache) Delete(key string) bool { + h32 := wc.hash(key) + shardIdx := h32 % uint32(len(wc.shards)) + + wc.shardLocks[shardIdx].Lock() + defer wc.shardLocks[shardIdx].Unlock() + + return wc.shards[shardIdx].DeleteKey(key) +} + +// rewrite puts the value back into the cache asynchronously to move +// hot data closer to the write head. +func (wc *WrapCache) rewrite(key string, value []byte, remainingTTLMinutes uint16) { + wc.Put(key, value, time.Duration(remainingTTLMinutes)*time.Minute) +} + +func (wc *WrapCache) Close() error { + for i := range wc.shards { + wc.shardLocks[i].Lock() + wc.shards[i].Flush() + wc.shards[i].Close() + wc.shardLocks[i].Unlock() + } + wc.iouringReader.Close() + wc.iouringWriter.Close() + return nil +} + +// MGetResult holds the result for a single key in a batch get. +type MGetResult struct { + Value []byte + Found bool + Expired bool +} + +// MGet fetches multiple keys in a single call, batching disk I/O through +// io_uring for significantly lower and more consistent latency than issuing +// individual Get calls (even concurrently via goroutines). +// +// The operation runs in four phases on a single goroutine: +// 1. Index lookups + memtable checks for every key (in-memory, fast). +// 2. Coalesce: sort pending reads by (shard, offset), merge overlapping +// aligned ranges so nearby keys share a single io_uring pread. +// 3. Submit one pread per coalesced group in a tight loop. +// 4. Collect completions, scatter to individual keys, validate CRC32. +func (wc *WrapCache) MGet(keys []string) []MGetResult { + results := make([]MGetResult, len(keys)) + + type pendingEntry struct { + keyIdx int + key string + shardIdx uint32 + meta filecache.MGetMeta + } + + var diskReads []pendingEntry + + // ── Phase 1: index lookups + memtable checks (sequential, all in-memory) ── + for i, key := range keys { + h32 := wc.hash(key) + shardIdx := h32 % uint32(len(wc.shards)) + + meta := wc.shards[shardIdx].GetMetaForMGet(key) + metrics.Incr(metrics.KEY_GETS, metrics.GetShardTag(shardIdx)) + + if meta.Expired { + metrics.Incr(metrics.KEY_EXPIRED_ENTRIES, metrics.GetShardTag(shardIdx)) + results[i] = MGetResult{Expired: true} + continue + } + + if !meta.Found { + continue + } + + // Memtable hit — validate and return inline. + if meta.Value != nil { + val, ok := wc.shards[shardIdx].ValidateAndExtract(meta.Value, key, meta.Length) + if ok { + metrics.Incr(metrics.KEY_HITS, metrics.GetShardTag(shardIdx)) + results[i] = MGetResult{Value: val, Found: true} + } + if meta.ShouldReWrite && ok { + metrics.Incr(metrics.KEY_REWRITES, metrics.GetShardTag(shardIdx)) + valCopy := make([]byte, len(val)) + copy(valCopy, val) + go wc.rewrite(key, valCopy, meta.RemainingTTL) + } + continue + } + + // Needs disk read — collect for coalescing. + if meta.NeedsDiskRead { + diskReads = append(diskReads, pendingEntry{ + keyIdx: i, + key: key, + shardIdx: shardIdx, + meta: meta, + }) + } + } + + if len(diskReads) == 0 { + return results + } + + // ── Phase 2: coalesce nearby disk reads ── + // Sort by (shard, file offset) so keys that map to overlapping or + // adjacent 4KB-aligned blocks end up next to each other. + sort.Slice(diskReads, func(i, j int) bool { + if diskReads[i].shardIdx != diskReads[j].shardIdx { + return diskReads[i].shardIdx < diskReads[j].shardIdx + } + return diskReads[i].meta.FileOffset < diskReads[j].meta.FileOffset + }) + + type coalescedGroup struct { + shardIdx uint32 + alignedStart int64 + alignedEnd int64 // exclusive + members []int // indices into diskReads + pending *filecache.CoalescedPendingRead + } + + groups := make([]coalescedGroup, 0, len(diskReads)) + for i, dr := range diskReads { + aStart, aSize := fs.AlignRange(dr.meta.FileOffset, int(dr.meta.Length), fs.BLOCK_SIZE) + aEnd := aStart + aSize + + if len(groups) > 0 { + last := &groups[len(groups)-1] + // Merge if same shard, overlapping/adjacent, and the result + // still fits within the slab allocator's largest size class. + mergedEnd := last.alignedEnd + if aEnd > mergedEnd { + mergedEnd = aEnd + } + if dr.shardIdx == last.shardIdx && aStart <= last.alignedEnd && + mergedEnd-last.alignedStart <= maxCoalescedReadSz { + last.alignedEnd = mergedEnd + last.members = append(last.members, i) + continue + } + } + + groups = append(groups, coalescedGroup{ + shardIdx: dr.shardIdx, + alignedStart: aStart, + alignedEnd: aEnd, + members: []int{i}, + }) + } + + // ── Phase 3: submit one io_uring pread per coalesced group ── + for g := range groups { + size := int(groups[g].alignedEnd - groups[g].alignedStart) + pr, err := wc.shards[groups[g].shardIdx].SubmitCoalescedReadAsync( + groups[g].alignedStart, size) + if err != nil { + continue + } + groups[g].pending = pr + } + + // ── Phase 4: collect completions, scatter to individual keys ── + for _, grp := range groups { + if grp.pending == nil { + continue + } + + coalescedBuf := wc.shards[grp.shardIdx].CollectCoalescedRead(grp.pending) + if coalescedBuf == nil { + continue + } + + for _, memberIdx := range grp.members { + dr := diskReads[memberIdx] + bufOffset := int(dr.meta.FileOffset - grp.alignedStart) + if bufOffset < 0 || bufOffset+int(dr.meta.Length) > len(coalescedBuf) { + continue + } + + keyBuf := coalescedBuf[bufOffset : bufOffset+int(dr.meta.Length)] + val, ok := wc.shards[dr.shardIdx].ValidateAndExtract(keyBuf, dr.key, dr.meta.Length) + if ok { + metrics.Incr(metrics.KEY_HITS, metrics.GetShardTag(dr.shardIdx)) + results[dr.keyIdx] = MGetResult{Value: val, Found: true} + } + if dr.meta.ShouldReWrite && ok { + metrics.Incr(metrics.KEY_REWRITES, metrics.GetShardTag(dr.shardIdx)) + valCopy := make([]byte, len(val)) + copy(valCopy, val) + go wc.rewrite(dr.key, valCopy, dr.meta.RemainingTTL) + } + } + } + + return results +} + +func (wc *WrapCache) hash(key string) uint32 { + return uint32(xxhash.Sum64String(key) ^ wc.seed) +} + +func (wc *WrapCache) Hash(key string) uint32 { + return wc.hash(key) +} diff --git a/flashring/pkg/cache/freecache.go b/flashring/pkg/cache/freecache.go new file mode 100644 index 00000000..d9047153 --- /dev/null +++ b/flashring/pkg/cache/freecache.go @@ -0,0 +1,35 @@ +package cache + +import ( + "runtime/debug" + "time" + + "github.com/coocood/freecache" +) + +type Freecache struct { + cache *freecache.Cache +} + +func NewFreecache(sizeBytes int) (*Freecache, error) { + cache := freecache.NewCache(sizeBytes) + debug.SetGCPercent(20) + return &Freecache{cache: cache}, nil +} + +func (c *Freecache) Put(key string, value []byte, ttl time.Duration) error { + c.cache.Set([]byte(key), value, int(ttl.Seconds())) + return nil +} + +func (c *Freecache) Get(key string) ([]byte, bool, bool) { + val, err := c.cache.Get([]byte(key)) + if err != nil { + return nil, false, false + } + return val, true, false +} + +func (c *Freecache) Close() error { + return nil +} diff --git a/flashring/pkg/metrics/metric.go b/flashring/pkg/metrics/metric.go new file mode 100644 index 00000000..ab62bf0b --- /dev/null +++ b/flashring/pkg/metrics/metric.go @@ -0,0 +1,221 @@ +package metrics + +import ( + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/DataDog/datadog-go/v5/statsd" + "github.com/rs/zerolog/log" + "github.com/spf13/viper" +) + +// Flashring metric keys +const ( + KEY_GET_LATENCY = "flashring_get_latency" + KEY_PUT_LATENCY = "flashring_put_latency" + KEY_RTHROUGHPUT = "flashring_rthroughput" + KEY_WTHROUGHPUT = "flashring_wthroughput" + KEY_HITRATE = "flashring_hitrate" + KEY_ACTIVE_ENTRIES = "flashring_active_entries" + KEY_EXPIRED_ENTRIES = "flashring_expired_entries" + KEY_REWRITES = "flashring_rewrites" + KEY_GETS = "flashring_gets" + KEY_PUTS = "flashring_puts" + KEY_HITS = "flashring_hits" + + KEY_KEY_NOT_FOUND_COUNT = "flashring_key_not_found_count" + KEY_KEY_EXPIRED_COUNT = "flashring_key_expired_count" + KEY_BAD_DATA_COUNT = "flashring_bad_data_count" + KEY_BAD_LENGTH_COUNT = "flashring_bad_length_count" + KEY_BAD_CR32_COUNT = "flashring_bad_cr32_count" + KEY_BAD_KEY_COUNT = "flashring_bad_key_count" + KEY_DELETED_KEY_COUNT = "flashring_deleted_key_count" + + KEY_WRITE_COUNT = "flashring_write_count" + KEY_PUNCH_HOLE_COUNT = "flashring_punch_hole_count" + KEY_PREAD_COUNT = "flashring_pread_count" + + KEY_TRIM_HEAD_LATENCY = "flashring_wrap_file_trim_head_latency" + KEY_PREAD_LATENCY = "flashring_pread_latency" + KEY_PWRITE_LATENCY = "flashring_pwrite_latency" + + KEY_MEMTABLE_FLUSH_COUNT = "flashring_memtable_flush_count" + + LATENCY_RLOCK = "flashring_rlock_latency" + LATENCY_WLOCK = "flashring_wlock_latency" + + KEY_RINGBUFFER_ACTIVE_ENTRIES = "flashring_ringbuffer_active_entries" + KEY_MEMTABLE_ENTRY_COUNT = "flashring_memtable_entry_count" + KEY_MEMTABLE_HIT = "flashring_memtable_hit" + KEY_MEMTABLE_MISS = "flashring_memtable_miss" + KEY_DATA_LENGTH = "flashring_data_length" + KEY_IOURING_SIZE = "flashring_iouring_size" + KEY_REWRITE_SCORE = "flashring_rewrite_score" + KEY_REWRITE_DECISION = "flashring_rewrite_decision" + KEY_ACCESS_FREQ = "flashring_access_freq" + KEY_LAST_ACCESS = "flashring_last_access" +) + +// Rewrite predictor tag keys +const ( + TAG_SCORE_BUCKET = "score_bucket" + TAG_DECISION = "decision" + TAG_RING_ZONE = "ring_zone" + TAG_FREQ_BAND = "freq_band" + TAG_RECENCY_BAND = "recency_band" +) + +// Flashring tag keys +const ( + TAG_LATENCY_PERCENTILE = "latency_percentile" + TAG_VALUE_P25 = "p25" + TAG_VALUE_P50 = "p50" + TAG_VALUE_P99 = "p99" + TAG_SHARD_IDX = "shard_idx" + TAG_MEMTABLE_ID = "memtable_id" +) + +// Application-level metric keys +const ( + ApiRequestCount = "api_request_count" + ApiRequestLatency = "api_request_latency" + ExternalApiRequestCount = "external_api_request_count" + ExternalApiRequestLatency = "external_api_request_latency" + DBCallLatency = "db_call_latency" + DBCallCount = "db_call_count" + MethodLatency = "method_latency" + MethodCount = "method_count" +) + +var ( + statsDClient = getDefaultClient() + samplingRate = 0.1 + telegrafAddress = "localhost:8125" + initialized = false + once sync.Once + + // When false, all Timing/Count/Incr/Gauge calls are no-ops (zero allocations). + // Controlled by FLASHRING_METRICS_ENABLED env var ("true"/"1" to enable). + // Defaults to true for backward compatibility. + metricsEnabled = loadMetricsEnabled() + + shardTags []string +) + +func loadMetricsEnabled() bool { + v := os.Getenv("FLASHRING_METRICS_ENABLED") + if v == "" { + return false + } + return strings.EqualFold(v, "true") || v == "1" +} + +// Init initializes the metrics client +func Init() { + if initialized { + log.Debug().Msgf("Metrics already initialized!") + return + } + once.Do(func() { + var err error + samplingRate = viper.GetFloat64("APP_METRIC_SAMPLING_RATE") + globalTags := getGlobalTags() + + statsDClient, err = statsd.New( + telegrafAddress, + statsd.WithTags(globalTags), + ) + + if err != nil { + log.Panic().AnErr("StatsD client initialization failed", err) + } + log.Info().Msgf("Metrics client initialized with telegraf address - %s, global tags - %v, and "+ + "sampling rate - %f, flashring metrics enabled - %v", telegrafAddress, globalTags, samplingRate, metricsEnabled) + initialized = true + }) +} + +func getDefaultClient() *statsd.Client { + client, _ := statsd.New("localhost:8125") + return client +} + +func getGlobalTags() []string { + env := viper.GetString("APP_ENV") + if len(env) == 0 { + log.Warn().Msg("APP_ENV is not set") + } + service := viper.GetString("APP_NAME") + if len(service) == 0 { + log.Warn().Msg("APP_NAME is not set") + } + return []string{ + TagAsString(TagEnv, env), + TagAsString(TagService, service), + } +} + +// Timing sends timing information. No-op when metrics are disabled. +func Timing(name string, value time.Duration, tags []string) { + if !metricsEnabled { + return + } + err := statsDClient.Timing(name, value, tags, samplingRate) + if err != nil { + log.Warn().AnErr("Error occurred while doing statsd timing", err) + } +} + +// Count increases metric counter by value. No-op when metrics are disabled. +func Count(name string, value int64, tags []string) { + if !metricsEnabled { + return + } + err := statsDClient.Count(name, value, tags, samplingRate) + if err != nil { + log.Warn().AnErr("Error occurred while doing statsd count", err) + } +} + +// Incr increases metric counter by 1. No-op when metrics are disabled. +func Incr(name string, tags []string) { + if !metricsEnabled { + return + } + Count(name, 1, tags) +} + +// Gauge sets a gauge value. No-op when metrics are disabled. +func Gauge(name string, value float64, tags []string) { + if !metricsEnabled { + return + } + err := statsDClient.Gauge(name, value, tags, samplingRate) + if err != nil { + log.Warn().AnErr("Error occurred while doing statsd gauge", err) + } +} + +// Enabled returns whether flashring metrics are enabled. +func Enabled() bool { + return metricsEnabled +} + +func GetShardTag(shardIdx uint32) []string { + return shardTags[shardIdx : shardIdx+1] +} + +func GetMemtableTag(memtableId uint32) []string { + return BuildTag(NewTag(TAG_MEMTABLE_ID, strconv.Itoa(int(memtableId)))) +} + +func BuildShardTags(shardCount int) { + tags := make([]string, 0, shardCount) + for i := 0; i < shardCount; i++ { + tags = append(tags, BuildTag(NewTag(TAG_SHARD_IDX, strconv.Itoa(int(i))))...) + } + shardTags = tags +} diff --git a/flashring/pkg/metrics/tag.go b/flashring/pkg/metrics/tag.go new file mode 100644 index 00000000..d77ac38e --- /dev/null +++ b/flashring/pkg/metrics/tag.go @@ -0,0 +1,55 @@ +package metrics + +// Tag constants +const ( + TagEnv = "env" + TagService = "service" + TagPath = "path" + TagMethod = "method" + TagHttpStatusCode = "http_status_code" + TagGrpcStatusCode = "grpc_status_code" + TagExternalService = "external_service" + TagExternalServicePath = "external_service_path" + TagExternalServiceMethod = "external_service_method" + TagExternalServiceStatusCode = "external_service_status_code" + TagZkRealtimeTotalUpdateEvent = "zk_realtime_total_update_event" + TagZkRealtimeFailureEvent = "zk_realtime_failure_event" + TagZkRealtimeSuccessEvent = "zk_realtime_success_event" + TagZkRealtimeEventUpdateLatency = "zk_realtime_event_update_latency" + TagCommunicationProtocol = "communication_protocol" + TagUserContext = "user_context" + + TagValueCommunicationProtocolHttp = "http" + TagValueCommunicationProtocolGrpc = "grpc" +) + +type Tag struct { + Name string + Value string +} + +func NewTag(name, value string) Tag { + return Tag{ + Name: name, + Value: value, + } +} + +// BuildTag builds a tag from the given name and value +func BuildTag(tags ...Tag) []string { + allTags := make([]string, 0) + for _, tag := range tags { + allTags = append(allTags, TagAsString(tag.Name, tag.Value)) + } + return allTags +} + +func TagAsString(name string, value string) string { + return name + ":" + value +} + +func UpdateTags(tags *[]string, newTags ...Tag) { + for _, tag := range newTags { + *tags = append(*tags, TagAsString(tag.Name, tag.Value)) + } +} diff --git a/flashring/prep_ssd.sh b/flashring/prep_ssd.sh new file mode 100644 index 00000000..f8e33b3e --- /dev/null +++ b/flashring/prep_ssd.sh @@ -0,0 +1,202 @@ +#!/usr/bin/env bash +# Mount all non-root NVMe SSDs (/dev/nvme*n1) as ext4 under /mnt/localssd1, /mnt/localssd2, ... +# Uses hourly fstrim (systemd timer or cron fallback). Safe to re-run. +set -euo pipefail + +MOUNT_BASE="/mnt/localssd" + +log() { echo "[$(date +'%F %T')] $*"; } +trap 'log "ERROR: Command failed: $BASH_COMMAND (line $LINENO)"' ERR + +# ---------- Helpers ---------- +fs_type() { lsblk -ndo FSTYPE "$1" 2>/dev/null | tr -d ' '; } +is_mounted_anywhere() { findmnt -S "$1" >/dev/null 2>&1; } +current_mountpoint() { findmnt -S "$1" -no TARGET 2>/dev/null || true; } + +root_source() { findmnt -no SOURCE / 2>/dev/null || true; } +parent_of() { + local s="$1" + [[ "$s" =~ ^/dev/nvme[0-9]+n[0-9]+p[0-9]+$ ]] && { echo "${s%p*}"; return; } + [[ "$s" =~ ^/dev/sd[a-z][0-9]+$ ]] && { echo "${s%[0-9]}"; return; } + echo "$s" +} +is_boot_dev() { + local dev="$1" + local rsrc; rsrc="$(root_source)" + [[ -z "$rsrc" ]] && return 1 + local rparent; rparent="$(parent_of "$rsrc")" + [[ "$dev" == "$rsrc" || "$dev" == "$rparent" ]] +} + +next_mountpoint() { + local n=1 + while :; do + local mp="${MOUNT_BASE}${n}" + if ! mountpoint -q "$mp"; then + echo "$mp" + return 0 + fi + ((n+=1)) + done +} + +ensure_fstab_entry() { + local uuid="$1" mp="$2" + local line="UUID=${uuid} ${mp} ext4 defaults,nofail,noatime,nodiratime 0 2" + sed -i -E "/^UUID=${uuid}[[:space:]]/d" /etc/fstab 2>/dev/null || true + grep -q "UUID=${uuid} ${mp} ext4" /etc/fstab 2>/dev/null || echo "$line" >> /etc/fstab +} + +sanitize_fstab_discard() { + if grep -Eq '/mnt/localssd[0-9]+[[:space:]]+ext4' /etc/fstab 2>/dev/null; then + log "Sanitizing /etc/fstab to remove ',discard' on /mnt/localssd* entries" + sed -i -E '/\/mnt\/localssd[0-9]+[[:space:]]+ext4/ s/,?discard//g' /etc/fstab + fi +} + +remount_localssd_no_discard() { + mapfile -t MPS < <(findmnt -no TARGET | grep -E "^${MOUNT_BASE}[0-9]+$" || true) + for mp in "${MPS[@]:-}"; do + log "Remounting $mp without 'discard'" + mount -o remount,noatime,nodiratime "$mp" || true + done +} + +setup_fstrim_hourly() { + if command -v systemctl >/dev/null 2>&1 && command -v fstrim >/dev/null 2>&1; then + log "Configuring systemd fstrim.timer to run hourly" + mkdir -p /etc/systemd/system/fstrim.timer.d + cat >/etc/systemd/system/fstrim.timer.d/override.conf <<'EOF' +[Timer] +OnCalendar=hourly +Persistent=true +EOF + systemctl daemon-reload + systemctl enable --now fstrim.timer + systemctl status fstrim.timer --no-pager -l || true + else + if command -v fstrim >/dev/null 2>&1; then + log "Configuring cron.hourly for fstrim (systemd not available)" + mkdir -p /etc/cron.hourly + cat >/etc/cron.hourly/fstrim-localssd <<'EOF' +#!/bin/sh +/sbin/fstrim --all --quiet || /usr/sbin/fstrim --all --quiet || true +EOF + chmod +x /etc/cron.hourly/fstrim-localssd + else + log "WARN: fstrim not found; install util-linux to enable trimming." + fi + fi +} + +# ---------- Modes ---------- +umount_mode() { + log "Unmounting /mnt/localssd* and cleaning /etc/fstab entries" + mapfile -t MPS < <(findmnt -no TARGET | grep -E "^${MOUNT_BASE}[0-9]+$" || true) + for mp in "${MPS[@]:-}"; do + log "Umount $mp" + umount "$mp" || true + done + sed -i -E '/\/mnt\/localssd[0-9]+[[:space:]]+ext4/d' /etc/fstab || true + systemctl daemon-reload || true + log "Done. Re-run this script to mount afresh." + exit 0 +} + +status_mode() { + log "Current localssd mounts:" + findmnt -no TARGET,SOURCE,FSTYPE | grep -E "^${MOUNT_BASE}[0-9]+" || echo "None" + exit 0 +} + +usage() { + echo "Usage: $0 [--umount|--status]" + exit 1 +} + +case "${1:-}" in + --umount) umount_mode ;; + --status) status_mode ;; + "") ;; + *) usage ;; +esac + +# ---------- Preconditions ---------- +command -v lsblk >/dev/null || { echo "lsblk not found"; exit 1; } +command -v blkid >/dev/null || { echo "blkid not found"; exit 1; } +command -v mkfs.ext4 >/dev/null || { echo "mkfs.ext4 not found"; exit 1; } + +# Sync systemd with current fstab before mounts +systemctl daemon-reload || true + +# Enumerate all NVMe namespaces deterministically +mapfile -t NVME_DEVS < <(ls /dev/nvme*n1 2>/dev/null | sort || true) +if [[ ${#NVME_DEVS[@]} -eq 0 ]]; then + log "No NVMe namespaces (/dev/nvme*n1) found." + exit 0 +fi +log "Scanning devices: ${NVME_DEVS[*]}" + +processed=0 + +for dev in "${NVME_DEVS[@]}"; do + [[ -e "$dev" ]] || { log "$dev not found at runtime — skipping."; continue; } + + if is_boot_dev "$dev"; then + log "Skipping boot/root device: $dev" + continue + fi + + log "Found $dev" + + if is_mounted_anywhere "$dev"; then + mp_now="$(current_mountpoint "$dev")" + log "$dev already mounted at $mp_now — leaving as-is." + uuid="$(blkid -s UUID -o value "$dev" || true)" + [[ -n "$uuid" ]] && ensure_fstab_entry "$uuid" "$mp_now" + ((processed+=1)) + continue + fi + + fstype="$(fs_type "$dev")" + if [[ -z "$fstype" ]]; then + log "Formatting $dev as ext4" + mkfs.ext4 -F -m 0 -E lazy_itable_init=1,lazy_journal_init=1 "$dev" + fstype="ext4" + else + log "$dev already has filesystem: $fstype" + fi + + if [[ "$fstype" != "ext4" ]]; then + log "Skipping $dev (unsupported fs: $fstype)." + continue + fi + + mp="$(next_mountpoint)" + mkdir -p "$mp" + log "Mounting $dev at $mp (no 'discard')" + mount -o noatime,nodiratime "$dev" "$mp" + + uuid="$(blkid -s UUID -o value "$dev" || true)" + if [[ -n "$uuid" ]]; then + ensure_fstab_entry "$uuid" "$mp" + systemctl daemon-reload || true + else + log "WARN: Could not read UUID for $dev; skipping fstab entry." + fi + + ((processed+=1)) +done + +sanitize_fstab_discard +remount_localssd_no_discard + +systemctl daemon-reload || true +setup_fstrim_hourly + +if [[ "$processed" -eq 0 ]]; then + log "No devices processed. Existing mounts sanitized and hourly fstrim scheduled (if available)." +else + log "Done. Processed $processed device(s). Current mounts:" + findmnt -no TARGET,SOURCE | grep "$MOUNT_BASE" || true +fi