diff --git a/cmd/root.go b/cmd/root.go index 50f9a1c..b6f1942 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -23,6 +23,8 @@ func init() { rootCmd.Flags().StringP("migration-archive", "m", "", "Path to the migration archive Example: /path/to/migration-archive.tar.gz") rootCmd.MarkFlagRequired("migration-archive") + rootCmd.Flags().IntP("threads", "t", 0, "Number of parallel goroutines for metadata processing (default: number of CPUs)") + rootCmd.SilenceErrors = true rootCmd.SilenceUsage = true } @@ -94,7 +96,8 @@ var rootCmd = &cobra.Command{ pterm.DefaultSection.Println("Remap") remapSpinner, _ := pterm.DefaultSpinner.Start("Remapping SHAs...") - stats, err := commitremap.ProcessFiles(extractedDir, commitremap.DefaultPrefixes(), commitMap) + threads, _ := cmd.Flags().GetInt("threads") + stats, err := commitremap.ProcessFiles(extractedDir, commitremap.DefaultPrefixes(), commitMap, threads) if err != nil { remapSpinner.Fail("Remap failed") renderSummaryTable(stats, extractedDir) diff --git a/go.mod b/go.mod index 1526d78..7a45969 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/mona-actions/gh-commit-remap go 1.25 require ( + github.com/klauspost/pgzip v1.2.6 github.com/pterm/pterm v0.12.83 github.com/spf13/cobra v1.8.1 ) @@ -15,6 +16,7 @@ require ( github.com/containerd/console v1.0.5 // indirect github.com/gookit/color v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/klauspost/compress v1.18.6 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-runewidth v0.0.20 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index 1bdec4c..9f1ecd2 100644 --- a/go.sum +++ b/go.sum @@ -33,11 +33,15 @@ github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA= github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= +github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index baa2f9c..bd8bb91 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -8,7 +8,10 @@ import ( "io/fs" "os" "path/filepath" + "runtime" "strings" + + pgzip "github.com/klauspost/pgzip" ) // UnTar decompresses a .tar.gz file into destDir, returning the directory containing the extracted contents. @@ -147,7 +150,8 @@ func ReTarDir(srcDir, outPath string) (retErr error) { if err != nil { return fmt.Errorf("failed to create archive: %w", err) } - gzipWriter := gzip.NewWriter(outFile) + gzipWriter, _ := pgzip.NewWriterLevel(outFile, pgzip.BestSpeed) + gzipWriter.SetConcurrency(256<<10, runtime.NumCPU()) tarWriter := tar.NewWriter(gzipWriter) // The success path closes each writer explicitly (in tar -> gzip -> file order) diff --git a/pkg/commitremap/benchmark_test.go b/pkg/commitremap/benchmark_test.go new file mode 100644 index 0000000..697864c --- /dev/null +++ b/pkg/commitremap/benchmark_test.go @@ -0,0 +1,225 @@ +package commitremap + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" +) + +// ── Helpers ────────────────────────────────────────────────────────────────── + +// generateCommitMap creates a commit map with n entries of 40-char hex SHAs. +func generateCommitMap(n int) map[string]string { + m := make(map[string]string, n) + buf := make([]byte, 20) // 20 bytes = 40 hex chars + for i := 0; i < n; i++ { + rand.Read(buf) + old := hex.EncodeToString(buf) + rand.Read(buf) + new_ := hex.EncodeToString(buf) + m[old] = new_ + } + return m +} + +// generateJSONWithSHAs creates a JSON byte slice containing numObjects objects, +// each with shaFields fields containing SHAs from the commit map. +// hitRate controls what fraction of SHAs are in the commit map (0.0-1.0). +func generateJSONWithSHAs(commitMap map[string]string, numObjects, shaFields int, hitRate float64) []byte { + // Collect some real keys for hits + keys := make([]string, 0, len(commitMap)) + for k := range commitMap { + keys = append(keys, k) + if len(keys) >= numObjects*shaFields { + break + } + } + + buf := make([]byte, 20) + objects := make([]map[string]interface{}, numObjects) + hitCount := int(float64(numObjects*shaFields) * hitRate) + idx := 0 + for i := 0; i < numObjects; i++ { + obj := map[string]interface{}{ + "id": i, + "created_at": "2024-01-15T10:30:00Z", + "url": fmt.Sprintf("https://github.com/org/repo/pull/%d", i), + } + for f := 0; f < shaFields; f++ { + fieldName := fmt.Sprintf("sha_%d", f) + if idx < hitCount && len(keys) > 0 { + obj[fieldName] = keys[idx%len(keys)] + } else { + rand.Read(buf) + obj[fieldName] = hex.EncodeToString(buf) + } + idx++ + } + objects[i] = obj + } + + data, _ := json.Marshal(objects) + return data +} + +// writeJSONFixtureFiles creates numFiles JSON files in dir, each containing +// numObjects objects with SHA fields. +func writeJSONFixtureFiles(tb testing.TB, dir string, prefix string, numFiles, numObjects int, commitMap map[string]string) { + tb.Helper() + for i := 0; i < numFiles; i++ { + name := fmt.Sprintf("%s_%06d.json", prefix, i+1) + data := generateJSONWithSHAs(commitMap, numObjects, 3, 0.5) + if err := os.WriteFile(filepath.Join(dir, name), data, 0644); err != nil { + tb.Fatal(err) + } + } +} + +// ── Benchmarks ─────────────────────────────────────────────────────────────── + +// BenchmarkReplaceSHABytes measures the core byte-sliding window replacement. +// This is the innermost hot loop — called once per metadata file. +func BenchmarkReplaceSHABytes(b *testing.B) { + sizes := []struct { + name string + mapSize int + jsonObjs int + shaFields int + }{ + {"small-map/small-json", 100, 50, 2}, + {"large-map/small-json", 1_000_000, 50, 2}, + {"large-map/medium-json", 1_000_000, 500, 3}, + {"large-map/large-json", 1_000_000, 5000, 3}, + {"monorepo-scale", 1_834_000, 1000, 4}, + } + + for _, s := range sizes { + b.Run(s.name, func(b *testing.B) { + commitMap := generateCommitMap(s.mapSize) + data := generateJSONWithSHAs(commitMap, s.jsonObjs, s.shaFields, 0.5) + shaLen := 40 + + // Pre-allocate a working buffer to avoid measuring make+copy overhead + input := make([]byte, len(data)) + b.SetBytes(int64(len(data))) + for b.Loop() { + copy(input, data) + replaceSHABytes(input, commitMap, shaLen) + } + }) + } +} + +// BenchmarkReplaceSHABytes_NoHits measures scanning overhead when no SHAs match. +func BenchmarkReplaceSHABytes_NoHits(b *testing.B) { + commitMap := generateCommitMap(1_834_000) + data := generateJSONWithSHAs(commitMap, 1000, 4, 0.0) + differentMap := generateCommitMap(1_834_000) + shaLen := 40 + + input := make([]byte, len(data)) + b.SetBytes(int64(len(data))) + for b.Loop() { + copy(input, data) + replaceSHABytes(input, differentMap, shaLen) + } +} + +// BenchmarkUpdateMetadataFile measures single-file remap (read + replace + write). +func BenchmarkUpdateMetadataFile(b *testing.B) { + commitMap := generateCommitMap(1_834_000) + shaLen := 40 + data := generateJSONWithSHAs(commitMap, 500, 3, 0.5) + + dir := b.TempDir() + filePath := filepath.Join(dir, "test_000001.json") + + b.SetBytes(int64(len(data))) + for b.Loop() { + os.WriteFile(filePath, data, 0644) + updateMetadataFile(filePath, commitMap, shaLen) + } +} + +// BenchmarkParseCommitMap measures parsing a commit-map file. +func BenchmarkParseCommitMap(b *testing.B) { + sizes := []struct { + name string + n int + }{ + {"1K", 1_000}, + {"100K", 100_000}, + {"1M", 1_000_000}, + {"1.8M", 1_834_000}, + } + + for _, s := range sizes { + b.Run(s.name, func(b *testing.B) { + commitMap := generateCommitMap(s.n) + dir := b.TempDir() + filePath := filepath.Join(dir, "commit-map") + + // Write commit-map file + f, _ := os.Create(filePath) + fmt.Fprintln(f, "old new") + for old, new_ := range commitMap { + fmt.Fprintf(f, "%s %s\n", old, new_) + } + f.Close() + + fi, _ := os.Stat(filePath) + b.SetBytes(fi.Size()) + for b.Loop() { + ParseCommitMap(filePath) + } + }) + } +} + +// BenchmarkProcessFiles measures the full parallel pipeline. +func BenchmarkProcessFiles(b *testing.B) { + configs := []struct { + name string + mapSize int + numFiles int + objPerFile int + workers int + }{ + {"10-files/8-workers", 100_000, 10, 200, 8}, + {"100-files/8-workers", 100_000, 100, 200, 8}, + {"100-files/16-workers", 100_000, 100, 200, 16}, + } + + for _, c := range configs { + b.Run(c.name, func(b *testing.B) { + commitMap := generateCommitMap(c.mapSize) + + // Create fixture dir once + baseDir := b.TempDir() + writeJSONFixtureFiles(b, baseDir, "pull_requests", c.numFiles, c.objPerFile, commitMap) + + b.ResetTimer() + for b.Loop() { + b.StopTimer() + writeJSONFixtureFiles(b, baseDir, "pull_requests", c.numFiles, c.objPerFile, commitMap) + b.StartTimer() + + ProcessFiles(baseDir, []string{"pull_requests"}, commitMap, c.workers) + } + }) + } +} + +// BenchmarkIsHexByte measures the hex byte check (called per byte in the hot loop). +func BenchmarkIsHexByte(b *testing.B) { + inputs := []byte("0123456789abcdefABCDEF!@#$%^&*()ghijklmnopqrstuvwxyz") + for b.Loop() { + for _, c := range inputs { + isHexByte(c) + } + } +} diff --git a/pkg/commitremap/commitremap.go b/pkg/commitremap/commitremap.go index dd5eb01..31087d1 100644 --- a/pkg/commitremap/commitremap.go +++ b/pkg/commitremap/commitremap.go @@ -2,18 +2,28 @@ package commitremap import ( - "encoding/json" "fmt" "os" "path/filepath" + "runtime" "strings" + "sync" ) // DefaultPrefixes returns the set of archive metadata file prefixes that // gh-commit-remap rewrites by default. A fresh slice is returned on each // call so callers can mutate the result without affecting other callers. func DefaultPrefixes() []string { - return []string{"pull_requests", "issues", "issue_events"} + return []string{ + "issues", + "issue_events", + "issue_comments", + "pull_requests", + "pull_request_reviews", + "pull_request_review_comments", + "pull_request_review_threads", + "commit_comments", + } } type invalidCommitMapLineError struct { @@ -39,7 +49,7 @@ func ParseCommitMap(filePath string) (map[string]string, error) { return nil, fmt.Errorf("reading commit map %s: %w", filePath, err) } - for _, line := range strings.Split(string(content), "\n") { + for i, line := range strings.Split(string(content), "\n") { if strings.TrimSpace(line) == "" { continue } @@ -50,6 +60,11 @@ func ParseCommitMap(filePath string) (map[string]string, error) { return nil, fmt.Errorf("invalid commit map line: %w", lineErr) } + // Skip the header line produced by git-filter-repo ("old new") + if i == 0 && fields[0] == "old" && fields[1] == "new" { + continue + } + commitMap[fields[0]] = fields[1] } @@ -58,57 +73,89 @@ func ParseCommitMap(filePath string) (map[string]string, error) { // ProcessFiles rewrites SHAs in JSON metadata files matching _*.json inside archiveDir. // -// Each file is walked once, replacing string values that exactly match a key in -// commitMap. Only whole-string SHA values are replaced. SHAs embedded in URLs, -// markdown, or composite strings are not rewritten. -func ProcessFiles(archiveDir string, prefixes []string, commitMap map[string]string) (Stats, error) { +// Each file is scanned byte-by-byte using a sliding window that matches +// SHA-length hex sequences against the commit map. SHAs are replaced +// wherever they appear — including inside URLs, markdown, etc. +// +// numWorkers controls how many goroutines process files in parallel. +// If numWorkers <= 0, it defaults to runtime.NumCPU(). +func ProcessFiles(archiveDir string, prefixes []string, commitMap map[string]string, numWorkers int) (Stats, error) { stats := Stats{PerFile: make(map[string]int)} + shaLen, err := commitMapSHALen(commitMap) + if err != nil { + return stats, fmt.Errorf("validating commit map: %w", err) + } + + if numWorkers <= 0 { + numWorkers = runtime.NumCPU() + } + + // Collect all files to process + var allFiles []string for _, prefix := range prefixes { pattern := filepath.Join(archiveDir, prefix+"_*.json") files, err := filepath.Glob(pattern) if err != nil { return stats, fmt.Errorf("globbing %s: %w", pattern, err) } + allFiles = append(allFiles, files...) + } - for _, file := range files { - stats.FilesScanned++ - n, err := updateMetadataFile(file, commitMap) - if err != nil { - return stats, fmt.Errorf("updating metadata file %s: %w", file, err) - } - if n > 0 { - stats.PerFile[file] = n + stats.FilesScanned = len(allFiles) + + type fileResult struct { + file string + count int + err error + } + + results := make([]fileResult, len(allFiles)) + workCh := make(chan int, len(allFiles)) + for i := range allFiles { + workCh <- i + } + close(workCh) + + var wg sync.WaitGroup + for w := 0; w < numWorkers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for idx := range workCh { + n, err := updateMetadataFile(allFiles[idx], commitMap, shaLen) + results[idx] = fileResult{file: allFiles[idx], count: n, err: err} } + }() + } + wg.Wait() + + // Merge results in order, returning partial stats on first error + var firstErr error + for _, res := range results { + if res.count > 0 { + stats.PerFile[res.file] = res.count + } + if res.err != nil && firstErr == nil { + firstErr = fmt.Errorf("updating metadata file %s: %w", res.file, res.err) } } - return stats, nil + return stats, firstErr } -func updateMetadataFile(filePath string, commitMap map[string]string) (int, error) { +func updateMetadataFile(filePath string, commitMap map[string]string, shaLen int) (int, error) { data, err := os.ReadFile(filePath) if err != nil { return 0, fmt.Errorf("reading data: %w", err) } - var dataMap interface{} - err = json.Unmarshal(data, &dataMap) - if err != nil { - return 0, fmt.Errorf("unmarshaling data: %w", err) - } - - count := replaceSHA(dataMap, commitMap) + data, count := replaceSHABytes(data, commitMap, shaLen) if count == 0 { return 0, nil } - updatedData, err := json.MarshalIndent(dataMap, "", " ") - if err != nil { - return count, fmt.Errorf("marshaling updated data: %w", err) - } - - err = os.WriteFile(filePath, updatedData, 0644) + err = os.WriteFile(filePath, data, 0644) if err != nil { return count, fmt.Errorf("writing updated data: %w", err) } @@ -116,37 +163,87 @@ func updateMetadataFile(filePath string, commitMap map[string]string) (int, erro return count, nil } -// replaceSHA walks data in place, rewriting whole-string values that match a -// key in commitMap. It returns the number of replacements performed. -func replaceSHA(data interface{}, commitMap map[string]string) int { - count := 0 - switch v := data.(type) { - case map[string]interface{}: - for key, value := range v { - if str, ok := value.(string); ok { - if newSHA, hit := commitMap[str]; hit { - v[key] = newSHA - count++ - } - continue +var hexTable [256]bool + +// init initializes the hexTable with valid hexadecimal characters (valid sha1 and sha256 characters). +func init() { + for _, b := range []byte("0123456789abcdefABCDEF") { + hexTable[b] = true + } +} + +// isHexByte reports whether b is a valid hexadecimal byte (0-9, a-f, A-F). +// Uses a precomputed lookup table for branchless evaluation. +func isHexByte(b byte) bool { + return hexTable[b] +} + +// commitMapSHALen returns the SHA length common to every key in commitMap. +// It returns an error if the map is empty or if keys/values have different lengths. +func commitMapSHALen(commitMap map[string]string) (int, error) { + shaLen := 0 + for old, new_ := range commitMap { + if shaLen == 0 { + shaLen = len(old) + if shaLen == 0 { + return 0, fmt.Errorf("commit map contains an empty key") } + } + if len(old) != shaLen || len(new_) != shaLen { + return 0, fmt.Errorf("commit map SHAs have inconsistent lengths: expected %d, got key len %d / value len %d", shaLen, len(old), len(new_)) + } + } + if shaLen == 0 { + return 0, fmt.Errorf("commit map is empty") + } + return shaLen, nil +} - count += replaceSHA(value, commitMap) +// replaceSHABytes scans data byte-by-byte using a sliding window of shaLen. +// +// Algorithm: +// 1. Walk each byte, counting consecutive valid hex (SHA) bytes. +// 2. When a non-hex byte is hit, reset the counter, no SHA can span it. +// 3. Once we have shaLen consecutive hex bytes, extract that window and +// look it up in commitMap. +// 4. On match: replace in-place, skip past the replaced bytes. The next +// window starts fresh from the byte after the replacement, avoiding +// re-scanning the bytes we just wrote. +// 5. On no match: keep going. The counter grows past shaLen so the +// window slides forward by one byte each step, checking every +// overlapping shaLen-sized substring. For example with shaLen=40, +// if bytes 0–39 don't match, bytes 1–40 are checked next, etc. +// +// Returns the (potentially modified) byte slice and the replacement count. +func replaceSHABytes(data []byte, commitMap map[string]string, shaLen int) ([]byte, int) { + count := 0 + consecutiveHex := 0 + + for i := 0; i < len(data); i++ { + if isHexByte(data[i]) { + consecutiveHex++ + } else { + // Non-hex byte breaks any potential SHA sequence. + consecutiveHex = 0 + continue } - case []interface{}: - for i, value := range v { - if str, ok := value.(string); ok { - if newSHA, hit := commitMap[str]; hit { - v[i] = newSHA - count++ - } - continue - } - count += replaceSHA(value, commitMap) + // Once we have enough consecutive hex bytes, check if the last + // shaLen bytes match an entry in the commit map. + if consecutiveHex >= shaLen { + start := i - shaLen + 1 + candidate := string(data[start : i+1]) + if newSHA, ok := commitMap[candidate]; ok { + copy(data[start:i+1], newSHA) + count++ + consecutiveHex = 0 + } + // If no match, consecutiveHex keeps growing and the window + // slides forward on the next iteration. } } - return count + + return data, count } // summarize the work performed by a ProcessFiles call. diff --git a/pkg/commitremap/commitremap_test.go b/pkg/commitremap/commitremap_test.go index 7e269a4..f905613 100644 --- a/pkg/commitremap/commitremap_test.go +++ b/pkg/commitremap/commitremap_test.go @@ -91,6 +91,33 @@ func TestParseCommitMap(t *testing.T) { content: "oldSHA1 newSHA1\noldSHA2 newSHA2 extra\noldSHA3 newSHA3", errContains: "oldSHA2 newSHA2 extra", }, + { + name: "skips git-filter-repo header line", + content: "old new\n" + + "abc123 def456\n" + + "ghi789 jkl012", + expected: map[string]string{ + "abc123": "def456", + "ghi789": "jkl012", + }, + }, + { + name: "header with trailing CRLF", + content: "old new\r\n" + + "abc123 def456\r\n", + expected: map[string]string{ + "abc123": "def456", + }, + }, + { + name: "old new as data when not on first line", + content: "abc123 def456\n" + + "old new", + expected: map[string]string{ + "abc123": "def456", + "old": "new", + }, + }, } for _, tt := range tests { @@ -139,30 +166,181 @@ func TestParseCommitMap(t *testing.T) { }) } +func TestIsHexByte(t *testing.T) { + valid := "0123456789abcdefABCDEF" + for _, b := range []byte(valid) { + if !isHexByte(b) { + t.Fatalf("isHexByte(%q) = false, want true", b) + } + } + invalid := "ghijklGHIJKL!@#$%^&*() \t\n{}\"/:" + for _, b := range []byte(invalid) { + if isHexByte(b) { + t.Fatalf("isHexByte(%q) = true, want false", b) + } + } +} + +func TestCommitMapSHALen(t *testing.T) { + tests := []struct { + name string + commitMap map[string]string + wantLen int + errContains string + }{ + { + name: "40-char SHAs", + commitMap: map[string]string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}, + wantLen: 40, + }, + { + name: "empty map", + commitMap: map[string]string{}, + errContains: "empty", + }, + { + name: "inconsistent lengths", + commitMap: map[string]string{"aabb": "ccdd", "aabbcc": "ddeeff"}, + errContains: "inconsistent", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := commitMapSHALen(tt.commitMap) + if tt.errContains != "" { + if err == nil { + t.Fatalf("expected error containing %q", tt.errContains) + } + if !strings.Contains(err.Error(), tt.errContains) { + t.Fatalf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantLen { + t.Fatalf("commitMapSHALen = %d, want %d", got, tt.wantLen) + } + }) + } +} + +func TestReplaceSHABytes(t *testing.T) { + tests := []struct { + name string + input string + commitMap map[string]string + shaLen int + wantOut string + wantCount int + }{ + { + name: "exact SHA replaced", + input: `{"sha":"aabbccdd"}`, + commitMap: map[string]string{"aabbccdd": "11223344"}, + shaLen: 8, + wantOut: `{"sha":"11223344"}`, + wantCount: 1, + }, + { + name: "SHA in URL replaced", + input: `{"url":"https://example.com/commit/aabbccdd/details"}`, + commitMap: map[string]string{"aabbccdd": "11223344"}, + shaLen: 8, + wantOut: `{"url":"https://example.com/commit/11223344/details"}`, + wantCount: 1, + }, + { + name: "SHA in markdown replaced", + input: `{"body":"Fixed in aabbccdd, see also eeff0011"}`, + commitMap: map[string]string{"aabbccdd": "11223344", "eeff0011": "55667788"}, + shaLen: 8, + wantOut: `{"body":"Fixed in 11223344, see also 55667788"}`, + wantCount: 2, + }, + { + name: "non-hex byte breaks window", + input: `aabbXccdd`, + commitMap: map[string]string{"aabbccdd": "11223344"}, + shaLen: 8, + wantOut: `aabbXccdd`, + wantCount: 0, + }, + { + name: "no match leaves data unchanged", + input: `{"sha":"aabbccdd"}`, + commitMap: map[string]string{"11223344": "55667788"}, + shaLen: 8, + wantOut: `{"sha":"aabbccdd"}`, + wantCount: 0, + }, + { + name: "adjacent SHAs both replaced", + input: `aabbccddeeff0011`, + commitMap: map[string]string{"aabbccdd": "11111111", "eeff0011": "22222222"}, + shaLen: 8, + wantOut: `1111111122222222`, + wantCount: 2, + }, + { + name: "40-char SHA replacement", + input: `commit aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa done`, + commitMap: map[string]string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}, + shaLen: 40, + wantOut: `commit bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb done`, + wantCount: 1, + }, + { + name: "multiple occurrences of same SHA", + input: `aabbccdd and aabbccdd`, + commitMap: map[string]string{"aabbccdd": "11223344"}, + shaLen: 8, + wantOut: `11223344 and 11223344`, + wantCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := []byte(tt.input) + out, count := replaceSHABytes(data, tt.commitMap, tt.shaLen) + if string(out) != tt.wantOut { + t.Fatalf("output = %q, want %q", string(out), tt.wantOut) + } + if count != tt.wantCount { + t.Fatalf("count = %d, want %d", count, tt.wantCount) + } + }) + } +} + +// Use 40-char hex SHAs for ProcessFiles tests since commitMapSHALen validates. +var testCommitMap = map[string]string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "1111111111111111111111111111111111111111", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb": "2222222222222222222222222222222222222222", + "cccccccccccccccccccccccccccccccccccccccc": "3333333333333333333333333333333333333333", +} + func TestProcessFiles(t *testing.T) { t.Run("happy path", func(t *testing.T) { dir := t.TempDir() - commitMap := map[string]string{ - "oldSHA1": "newSHA1", - "oldSHA2": "newSHA2", - "oldSHA3": "newSHA3", - } fixtures := map[string]struct { input string want string }{ "pull_requests_000001.json": { - input: `{"sha":"oldSHA1","url":"https://example.invalid/oldSHA1","nested":[{"head":"oldSHA2","body":"mention oldSHA3"}],"untouched":"keep"}`, - want: `{"sha":"newSHA1","url":"https://example.invalid/oldSHA1","nested":[{"head":"newSHA2","body":"mention oldSHA3"}],"untouched":"keep"}`, + input: `{"sha":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","url":"https://example.invalid/bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","nested":[{"head":"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","body":"mention cccccccccccccccccccccccccccccccccccccccc"}],"untouched":"keep"}`, + want: `{"sha":"1111111111111111111111111111111111111111","url":"https://example.invalid/2222222222222222222222222222222222222222","nested":[{"head":"2222222222222222222222222222222222222222","body":"mention 3333333333333333333333333333333333333333"}],"untouched":"keep"}`, }, "issues_000001.json": { - input: `[{"events":[{"commit_id":"oldSHA2"},{"commit_id":"unknownSHA"}],"title":"oldSHA2 in title"}]`, - want: `[{"events":[{"commit_id":"newSHA2"},{"commit_id":"unknownSHA"}],"title":"oldSHA2 in title"}]`, + input: `[{"events":[{"commit_id":"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"},{"commit_id":"dddddddddddddddddddddddddddddddddddddddd"}],"title":"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb in title"}]`, + want: `[{"events":[{"commit_id":"2222222222222222222222222222222222222222"},{"commit_id":"dddddddddddddddddddddddddddddddddddddddd"}],"title":"2222222222222222222222222222222222222222 in title"}]`, }, "issue_events_000001.json": { - input: `{"items":[{"payload":{"before":"oldSHA3","after":"oldSHA1"}}],"count":1}`, - want: `{"items":[{"payload":{"before":"newSHA3","after":"newSHA1"}}],"count":1}`, + input: `{"items":[{"payload":{"before":"cccccccccccccccccccccccccccccccccccccccc","after":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}],"count":1}`, + want: `{"items":[{"payload":{"before":"3333333333333333333333333333333333333333","after":"1111111111111111111111111111111111111111"}}],"count":1}`, }, } @@ -173,13 +351,19 @@ func TestProcessFiles(t *testing.T) { writeFile(t, p, fixture.input) } - stats, err := ProcessFiles(dir, DefaultPrefixes(), commitMap) + stats, err := ProcessFiles(dir, DefaultPrefixes(), testCommitMap, 0) if err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } for name, fixture := range fixtures { - assertJSONFileEqual(t, filepath.Join(dir, name), fixture.want) + got, err := os.ReadFile(filepath.Join(dir, name)) + if err != nil { + t.Fatalf("read %s: %v", name, err) + } + if string(got) != fixture.want { + t.Fatalf("file %s = %q, want %q", name, string(got), fixture.want) + } } if got, want := stats.FilesScanned, 3; got != want { @@ -188,10 +372,10 @@ func TestProcessFiles(t *testing.T) { if got, want := stats.FilesChanged(), 3; got != want { t.Fatalf("FilesChanged() = %d, want %d", got, want) } - // pull_requests: sha=oldSHA1, nested[0].head=oldSHA2 -> 2 - // issues: events[0].commit_id=oldSHA2 -> 1 - // issue_events: payload.before=oldSHA3, payload.after=oldSHA1 -> 2 - if got, want := stats.TotalReplacements(), 5; got != want { + // pull_requests: sha + url + nested.head + nested.body = 4 + // issues: events[0].commit_id + title = 2 + // issue_events: payload.before + payload.after = 2 + if got, want := stats.TotalReplacements(), 8; got != want { t.Fatalf("TotalReplacements() = %d, want %d", got, want) } for name := range fixtures { @@ -204,18 +388,18 @@ func TestProcessFiles(t *testing.T) { t.Run("empty commit map", func(t *testing.T) { dir := t.TempDir() filePath := filepath.Join(dir, "pull_requests_000001.json") - want := `{"sha":"oldSHA1","nested":[{"sha":"oldSHA2"}]}` + want := `{"sha":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","nested":[{"sha":"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}]}` writeFile(t, filePath, want) - if _, err := ProcessFiles(dir, []string{"pull_requests"}, map[string]string{}); err != nil { - t.Fatalf("ProcessFiles returned error: %v", err) + _, err := ProcessFiles(dir, []string{"pull_requests"}, map[string]string{}, 0) + if err == nil { + t.Fatal("expected error for empty commit map") } - - assertJSONFileEqual(t, filePath, want) }) t.Run("no matching files", func(t *testing.T) { - if _, err := ProcessFiles(t.TempDir(), DefaultPrefixes(), map[string]string{"oldSHA1": "newSHA1"}); err != nil { + _, err := ProcessFiles(t.TempDir(), DefaultPrefixes(), testCommitMap, 0) + if err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } }) @@ -224,16 +408,22 @@ func TestProcessFiles(t *testing.T) { dir := t.TempDir() fooPath := filepath.Join(dir, "foo_000001.json") pullPath := filepath.Join(dir, "pull_requests_000001.json") - writeFile(t, fooPath, `{"sha":"oldSHA1"}`) - writeFile(t, pullPath, `{"sha":"oldSHA1"}`) + writeFile(t, fooPath, `{"sha":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`) + writeFile(t, pullPath, `{"sha":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`) - stats, err := ProcessFiles(dir, []string{"foo"}, map[string]string{"oldSHA1": "newSHA1"}) + stats, err := ProcessFiles(dir, []string{"foo"}, testCommitMap, 0) if err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } - assertJSONFileEqual(t, fooPath, `{"sha":"newSHA1"}`) - assertJSONFileEqual(t, pullPath, `{"sha":"oldSHA1"}`) + got, _ := os.ReadFile(fooPath) + if !strings.Contains(string(got), "1111111111111111111111111111111111111111") { + t.Fatalf("foo file should have SHA replaced, got %s", string(got)) + } + got2, _ := os.ReadFile(pullPath) + if strings.Contains(string(got2), "1111111111111111111111111111111111111111") { + t.Fatal("pull_requests file should NOT have been processed with custom prefix") + } if got, want := len(stats.PerFile), 1; got != want { t.Fatalf("len(stats.PerFile) = %d, want %d", got, want) @@ -249,37 +439,53 @@ func TestProcessFiles(t *testing.T) { t.Run("single-pass behavior remaps all keys", func(t *testing.T) { dir := t.TempDir() filePath := filepath.Join(dir, "pull_requests_000001.json") - writeFile(t, filePath, `{"items":[{"sha":"oldSHA1"},{"nested":{"sha":"oldSHA2","children":["oldSHA3","oldSHA4"]}}]}`) + writeFile(t, filePath, `{"items":[{"sha":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"},{"nested":{"sha":"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","children":["cccccccccccccccccccccccccccccccccccccccc"]}}]}`) - commitMap := map[string]string{ - "oldSHA1": "newSHA1", - "oldSHA2": "newSHA2", - "oldSHA3": "newSHA3", - "oldSHA4": "newSHA4", - } - if _, err := ProcessFiles(dir, []string{"pull_requests"}, commitMap); err != nil { + if _, err := ProcessFiles(dir, []string{"pull_requests"}, testCommitMap, 0); err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } - assertJSONFileEqual(t, filePath, `{"items":[{"sha":"newSHA1"},{"nested":{"sha":"newSHA2","children":["newSHA3","newSHA4"]}}]}`) + got, _ := os.ReadFile(filePath) + gotStr := string(got) + for _, expected := range []string{"1111111111111111111111111111111111111111", "2222222222222222222222222222222222222222", "3333333333333333333333333333333333333333"} { + if !strings.Contains(gotStr, expected) { + t.Fatalf("expected %s in output, got %s", expected, gotStr) + } + } }) - t.Run("non matching strings unchanged", func(t *testing.T) { + t.Run("SHAs in URLs and markdown are now replaced", func(t *testing.T) { dir := t.TempDir() filePath := filepath.Join(dir, "issues_000001.json") - original := `{"title":"no SHA here","body":"https://example.invalid/oldSHA1","labels":["bug","help wanted"]}` - writeFile(t, filePath, original) + input := `{"title":"no SHA here","body":"https://example.invalid/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","labels":["bug","help wanted"]}` + writeFile(t, filePath, input) - if _, err := ProcessFiles(dir, []string{"issues"}, map[string]string{"oldSHA1": "newSHA1"}); err != nil { + stats, err := ProcessFiles(dir, []string{"issues"}, testCommitMap, 0) + if err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } - assertJSONFileEqual(t, filePath, original) + got, _ := os.ReadFile(filePath) + if !strings.Contains(string(got), "1111111111111111111111111111111111111111") { + t.Fatalf("SHA in URL should be replaced, got %s", string(got)) + } + if stats.TotalReplacements() != 1 { + t.Fatalf("TotalReplacements() = %d, want 1", stats.TotalReplacements()) + } }) } func TestDefaultPrefixes(t *testing.T) { - want := []string{"pull_requests", "issues", "issue_events"} + want := []string{ + "issues", + "issue_events", + "issue_comments", + "pull_requests", + "pull_request_reviews", + "pull_request_review_comments", + "pull_request_review_threads", + "commit_comments", + } got := DefaultPrefixes() if !reflect.DeepEqual(got, want) { t.Fatalf("DefaultPrefixes() = %#v, want %#v", got, want) @@ -296,7 +502,7 @@ func TestDefaultPrefixes(t *testing.T) { func TestProcessFiles_SkipsWriteWhenNoReplacements(t *testing.T) { dir := t.TempDir() filePath := filepath.Join(dir, "pull_requests_000001.json") - original := []byte(`{"sha":"someSHA","nested":[{"sha":"otherSHA"}]}`) + original := []byte(`{"sha":"dddddddddddddddddddddddddddddddddddddddd","nested":[{"sha":"eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"}]}`) if err := os.WriteFile(filePath, original, 0644); err != nil { t.Fatalf("write fixture: %v", err) } @@ -305,7 +511,11 @@ func TestProcessFiles_SkipsWriteWhenNoReplacements(t *testing.T) { t.Fatalf("stat before: %v", err) } - stats, err := ProcessFiles(dir, []string{"pull_requests"}, map[string]string{"unrelated": "x"}) + // commitMap doesn't contain the SHAs in the file + noMatchMap := map[string]string{ + "ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00": "1100110011001100110011001100110011001100", + } + stats, err := ProcessFiles(dir, []string{"pull_requests"}, noMatchMap, 0) if err != nil { t.Fatalf("ProcessFiles returned error: %v", err) } @@ -337,29 +547,6 @@ func TestProcessFiles_SkipsWriteWhenNoReplacements(t *testing.T) { } } -func TestProcessFiles_ReturnsPartialStatsOnError(t *testing.T) { - dir := t.TempDir() - // filepath.Glob returns sorted results, so pull_requests_000001.json is processed before pull_requests_000002.json. - validPath := filepath.Join(dir, "pull_requests_000001.json") - badPath := filepath.Join(dir, "pull_requests_000002.json") - writeFile(t, validPath, `{"sha":"oldSHA1"}`) - writeFile(t, badPath, `{not valid json`) - - stats, err := ProcessFiles(dir, []string{"pull_requests"}, map[string]string{"oldSHA1": "newSHA1"}) - if err == nil { - t.Fatal("expected error from malformed JSON file") - } - if stats.FilesScanned != 2 { - t.Fatalf("FilesScanned = %d, want 2", stats.FilesScanned) - } - if len(stats.PerFile) < 1 { - t.Fatalf("len(PerFile) = %d, want >= 1", len(stats.PerFile)) - } - if _, ok := stats.PerFile[validPath]; !ok { - t.Fatalf("PerFile must contain validPath %q; got %#v", validPath, stats.PerFile) - } -} - func TestStats_HelperMethods(t *testing.T) { s := Stats{FilesScanned: 5, PerFile: map[string]int{"a": 2, "b": 3}} if got, want := s.FilesChanged(), 2; got != want { diff --git a/pkg/commitremap/stream.go b/pkg/commitremap/stream.go new file mode 100644 index 0000000..7edd54b --- /dev/null +++ b/pkg/commitremap/stream.go @@ -0,0 +1,224 @@ +package commitremap + +import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + + pgzip "github.com/klauspost/pgzip" +) + +const maxMatchedFileSize = 2 << 30 // 2 GB guard for matched entries + +// StreamRemap reads a .tar.gz migration archive, remaps SHAs in matching +// JSON metadata files in-flight, and writes the result to a new .tar.gz. +// This eliminates the extract→modify→retar cycle entirely. +// +// Files whose base name matches "_.json" for any prefix in +// prefixes are read into memory, processed by replaceSHABytes, and written +// back. All other entries are streamed through unchanged. +// +// The output archive uses parallel gzip (pgzip) at BestSpeed for fast +// compression. Tar entry order from the input is preserved. +func StreamRemap(inArchive, outArchive string, commitMap map[string]string, prefixes []string) (Stats, error) { + stats := Stats{PerFile: make(map[string]int)} + + if len(commitMap) == 0 { + return stats, fmt.Errorf("commit map is empty; nothing to remap") + } + + shaLen, err := commitMapSHALen(commitMap) + if err != nil { + return stats, fmt.Errorf("validating commit map: %w", err) + } + + prefixSet := make(map[string]bool, len(prefixes)) + for _, p := range prefixes { + prefixSet[p] = true + } + + // Open input tar.gz + inFile, err := os.Open(inArchive) + if err != nil { + return stats, fmt.Errorf("open input archive: %w", err) + } + defer inFile.Close() + + gzReader, err := gzip.NewReader(inFile) + if err != nil { + return stats, fmt.Errorf("create gzip reader: %w", err) + } + defer gzReader.Close() + + tarReader := tar.NewReader(gzReader) + + // Open output tar.gz + outFile, err := os.Create(outArchive) + if err != nil { + return stats, fmt.Errorf("create output archive: %w", err) + } + + gzWriter, _ := pgzip.NewWriterLevel(outFile, pgzip.BestSpeed) + gzWriter.SetConcurrency(256<<10, runtime.NumCPU()) + tarWriter := tar.NewWriter(gzWriter) + + var tarClosed, gzClosed, fileClosed bool + cleanup := func(retErr error) { + if !tarClosed { + _ = tarWriter.Close() + } + if !gzClosed { + _ = gzWriter.Close() + } + if !fileClosed { + _ = outFile.Close() + } + if retErr != nil { + _ = os.Remove(outArchive) + } + } + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + cleanup(err) + return stats, fmt.Errorf("reading tar entry: %w", err) + } + + // Safety validation (matching UnTar/ReTarDir behavior) + if filepath.IsAbs(header.Name) { + err := fmt.Errorf("archive entry %q has absolute path", header.Name) + cleanup(err) + return stats, err + } + if pathHasParentRef(header.Name) { + err := fmt.Errorf("archive entry %q escapes destination", header.Name) + cleanup(err) + return stats, err + } + if header.Typeflag != tar.TypeDir && header.Typeflag != tar.TypeReg { + err := fmt.Errorf("unsupported tar entry type %d for %q", header.Typeflag, header.Name) + cleanup(err) + return stats, err + } + + // Clone the header to avoid aliasing + hdr := *header + + if hdr.Typeflag == tar.TypeDir { + if err := tarWriter.WriteHeader(&hdr); err != nil { + cleanup(err) + return stats, fmt.Errorf("writing dir header %q: %w", hdr.Name, err) + } + continue + } + + // Check if this file matches a SHA-bearing prefix + if hdr.Size >= 0 && shouldRemap(hdr.Name, prefixSet) { + if hdr.Size > maxMatchedFileSize { + err := fmt.Errorf("matched file %q is too large (%d bytes, max %d)", hdr.Name, hdr.Size, maxMatchedFileSize) + cleanup(err) + return stats, err + } + + data, err := io.ReadAll(tarReader) + if err != nil { + cleanup(err) + return stats, fmt.Errorf("reading matched file %q: %w", hdr.Name, err) + } + + data, count := replaceSHABytes(data, commitMap, shaLen) + stats.FilesScanned++ + if count > 0 { + stats.PerFile[hdr.Name] = count + } + + // Size is unchanged (same-length SHA replacement), but set it + // from actual data length for correctness. + hdr.Size = int64(len(data)) + + if err := tarWriter.WriteHeader(&hdr); err != nil { + cleanup(err) + return stats, fmt.Errorf("writing header for %q: %w", hdr.Name, err) + } + if _, err := tarWriter.Write(data); err != nil { + cleanup(err) + return stats, fmt.Errorf("writing data for %q: %w", hdr.Name, err) + } + } else { + // Pass through unchanged + if err := tarWriter.WriteHeader(&hdr); err != nil { + cleanup(err) + return stats, fmt.Errorf("writing passthrough header %q: %w", hdr.Name, err) + } + if hdr.Size > 0 { + if _, err := io.Copy(tarWriter, tarReader); err != nil { + cleanup(err) + return stats, fmt.Errorf("copying passthrough file %q: %w", hdr.Name, err) + } + } + } + } + + // Close in order: tar → gzip → file + if err := tarWriter.Close(); err != nil { + cleanup(err) + return stats, fmt.Errorf("closing tar writer: %w", err) + } + tarClosed = true + if err := gzWriter.Close(); err != nil { + cleanup(err) + return stats, fmt.Errorf("closing gzip writer: %w", err) + } + gzClosed = true + if err := outFile.Close(); err != nil { + cleanup(err) + return stats, fmt.Errorf("closing output file: %w", err) + } + fileClosed = true + + return stats, nil +} + +// shouldRemap checks if a tar entry name matches "_.json" +// for any prefix in the set. +func shouldRemap(name string, prefixSet map[string]bool) bool { + base := filepath.Base(name) + if !strings.HasSuffix(base, ".json") { + return false + } + stem := strings.TrimSuffix(base, ".json") + idx := strings.LastIndex(stem, "_") + if idx <= 0 { + return false + } + suffix := stem[idx+1:] + if len(suffix) == 0 { + return false + } + for _, r := range suffix { + if r < '0' || r > '9' { + return false + } + } + prefix := stem[:idx] + return prefixSet[prefix] +} + +// pathHasParentRef checks for ".." path components. +func pathHasParentRef(name string) bool { + for _, part := range strings.Split(filepath.ToSlash(name), "/") { + if part == ".." { + return true + } + } + return false +} diff --git a/pkg/commitremap/stream_test.go b/pkg/commitremap/stream_test.go new file mode 100644 index 0000000..7ff2845 --- /dev/null +++ b/pkg/commitremap/stream_test.go @@ -0,0 +1,294 @@ +package commitremap + +import ( + "archive/tar" + "compress/gzip" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +// makeTarGz creates a .tar.gz at path with the given entries. +func makeTarGz(t *testing.T, path string, entries []tarEntry) { + t.Helper() + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + for _, e := range entries { + if e.isDir { + hdr := &tar.Header{Name: e.name, Typeflag: tar.TypeDir, Mode: 0o755} + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + continue + } + hdr := &tar.Header{ + Name: e.name, + Typeflag: tar.TypeReg, + Mode: 0o644, + Size: int64(len(e.data)), + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte(e.data)); err != nil { + t.Fatal(err) + } + } + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } +} + +// readTarGz returns all entries from a tar.gz. +func readTarGz(t *testing.T, path string) []tarEntry { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + gr, err := gzip.NewReader(f) + if err != nil { + t.Fatal(err) + } + defer gr.Close() + tr := tar.NewReader(gr) + + var entries []tarEntry + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + e := tarEntry{name: hdr.Name, isDir: hdr.Typeflag == tar.TypeDir} + if !e.isDir { + data, err := io.ReadAll(tr) + if err != nil { + t.Fatal(err) + } + e.data = string(data) + } + entries = append(entries, e) + } + return entries +} + +type tarEntry struct { + name string + data string + isDir bool +} + +func TestStreamRemap_RoundTrip(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + oldSHA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + newSHA := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + commitMap := map[string]string{oldSHA: newSHA} + + entries := []tarEntry{ + {name: "./", isDir: true}, + {name: "./pull_requests_000001.json", data: `{"sha":"` + oldSHA + `"}`}, + {name: "./issues_000002.json", data: `{"ref":"` + oldSHA + `","other":"value"}`}, + {name: "./users_000001.json", data: `{"name":"test"}`}, + {name: "./organizations_000001.json", data: `{"sha":"` + oldSHA + `"}`}, + } + makeTarGz(t, inPath, entries) + + stats, err := StreamRemap(inPath, outPath, commitMap, []string{"pull_requests", "issues"}) + if err != nil { + t.Fatalf("StreamRemap: %v", err) + } + + if stats.FilesScanned != 2 { + t.Errorf("FilesScanned = %d, want 2", stats.FilesScanned) + } + if stats.FilesChanged() != 2 { + t.Errorf("FilesChanged = %d, want 2", stats.FilesChanged()) + } + + result := readTarGz(t, outPath) + for _, e := range result { + switch { + case e.name == "./pull_requests_000001.json": + if !strings.Contains(e.data, newSHA) { + t.Errorf("pull_requests should contain new SHA, got: %s", e.data) + } + if strings.Contains(e.data, oldSHA) { + t.Errorf("pull_requests still contains old SHA") + } + case e.name == "./issues_000002.json": + if !strings.Contains(e.data, newSHA) { + t.Errorf("issues should contain new SHA, got: %s", e.data) + } + case e.name == "./users_000001.json": + if e.data != `{"name":"test"}` { + t.Errorf("users should be unchanged, got: %s", e.data) + } + case e.name == "./organizations_000001.json": + // Not in prefixes, should pass through with original SHA + if !strings.Contains(e.data, oldSHA) { + t.Errorf("organizations should still contain old SHA (not remapped), got: %s", e.data) + } + } + } +} + +func TestStreamRemap_NoMatchingFiles(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + commitMap := map[string]string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + } + + entries := []tarEntry{ + {name: "./", isDir: true}, + {name: "./users_000001.json", data: `{"name":"test"}`}, + } + makeTarGz(t, inPath, entries) + + stats, err := StreamRemap(inPath, outPath, commitMap, []string{"pull_requests"}) + if err != nil { + t.Fatalf("StreamRemap: %v", err) + } + if stats.FilesScanned != 0 { + t.Errorf("FilesScanned = %d, want 0", stats.FilesScanned) + } +} + +func TestStreamRemap_RejectsAbsolutePath(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + commitMap := map[string]string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + } + + entries := []tarEntry{ + {name: "/etc/passwd", data: "bad"}, + } + makeTarGz(t, inPath, entries) + + _, err := StreamRemap(inPath, outPath, commitMap, []string{"pull_requests"}) + if err == nil { + t.Fatal("expected error for absolute path, got nil") + } + if !strings.Contains(err.Error(), "absolute path") { + t.Errorf("unexpected error: %v", err) + } + // Output should be cleaned up + if _, statErr := os.Stat(outPath); statErr == nil { + t.Error("partial output should have been removed") + } +} + +func TestStreamRemap_RejectsParentTraversal(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + commitMap := map[string]string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + } + + entries := []tarEntry{ + {name: "../escape.json", data: "bad"}, + } + makeTarGz(t, inPath, entries) + + _, err := StreamRemap(inPath, outPath, commitMap, []string{"pull_requests"}) + if err == nil { + t.Fatal("expected error for parent traversal, got nil") + } + if !strings.Contains(err.Error(), "escapes") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStreamRemap_EmptyCommitMap(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + entries := []tarEntry{{name: "./", isDir: true}} + makeTarGz(t, inPath, entries) + + _, err := StreamRemap(inPath, outPath, map[string]string{}, []string{"pull_requests"}) + if err == nil { + t.Fatal("expected error for empty commit map") + } +} + +func TestStreamRemap_PreservesEntryOrder(t *testing.T) { + dir := t.TempDir() + inPath := filepath.Join(dir, "in.tar.gz") + outPath := filepath.Join(dir, "out.tar.gz") + + commitMap := map[string]string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + } + + names := []string{"./", "./z_000001.json", "./a_000001.json", "./m_000001.json"} + entries := []tarEntry{ + {name: names[0], isDir: true}, + {name: names[1], data: "z"}, + {name: names[2], data: "a"}, + {name: names[3], data: "m"}, + } + makeTarGz(t, inPath, entries) + + _, err := StreamRemap(inPath, outPath, commitMap, []string{}) + if err != nil { + t.Fatalf("StreamRemap: %v", err) + } + + result := readTarGz(t, outPath) + for i, e := range result { + if e.name != names[i] { + t.Errorf("entry %d: got name %q, want %q", i, e.name, names[i]) + } + } +} + +func TestShouldRemap(t *testing.T) { + prefixes := map[string]bool{"pull_requests": true, "issues": true} + tests := []struct { + name string + want bool + }{ + {"./pull_requests_000001.json", true}, + {"./issues_000002.json", true}, + {"./users_000001.json", false}, + {"./pull_requests.json", false}, // no _digits suffix + {"./pull_requests_abc.json", false}, // non-digit suffix + {"./subdir/pull_requests_1.json", true}, // nested + {"./readme.md", false}, + {"pull_requests_1.json", true}, + } + for _, tt := range tests { + if got := shouldRemap(tt.name, prefixes); got != tt.want { + t.Errorf("shouldRemap(%q) = %v, want %v", tt.name, got, tt.want) + } + } +}