diff --git a/Makefile b/Makefile index 2dd4c8a8a1..1fbc1bd73b 100644 --- a/Makefile +++ b/Makefile @@ -174,6 +174,9 @@ filter_helper: config-converter: $(GOBUILD) -ldflags '$(LDFLAGS)' -o bin/cdc_config_converter ./cmd/config-converter/main.go +multi-cluster-consistency-checker: + $(GOBUILD) -ldflags '$(LDFLAGS)' -o bin/multi-cluster-consistency-checker ./cmd/multi-cluster-consistency-checker + fmt: tools/bin/gofumports tools/bin/shfmt tools/bin/gci @echo "run gci (format imports)" tools/bin/gci write $(FILES) 2>&1 | $(FAIL_ON_STDOUT) diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go new file mode 100644 index 0000000000..f243403761 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go @@ -0,0 +1,314 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package advancer + +import ( + "context" + "maps" + "sync" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +type TimeWindowAdvancer struct { + // round is the current round of the time window + round uint64 + + // timeWindowTriplet is the triplet of adjacent time windows, mapping from cluster ID to the triplet + timeWindowTriplet map[string][3]types.TimeWindow + + // checkpointWatcher is the Active-Active checkpoint watcher for each cluster, + // mapping from local cluster ID to replicated cluster ID to the checkpoint watcher + checkpointWatcher map[string]map[string]watcher.Watcher + + // s3checkpointWatcher is the S3 checkpoint watcher for each cluster, mapping from cluster ID to the s3 checkpoint watcher + s3Watcher map[string]*watcher.S3Watcher + + // pdClients is the pd clients for each cluster, mapping from cluster ID to the pd client + pdClients map[string]pd.Client +} + +func NewTimeWindowAdvancer( + ctx context.Context, + checkpointWatchers map[string]map[string]watcher.Watcher, + s3Watchers map[string]*watcher.S3Watcher, + pdClients map[string]pd.Client, + checkpoint *recorder.Checkpoint, +) (*TimeWindowAdvancer, map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + timeWindowTriplet := make(map[string][3]types.TimeWindow) + for clusterID := range pdClients { + timeWindowTriplet[clusterID] = [3]types.TimeWindow{} + } + advancer := &TimeWindowAdvancer{ + round: 0, + timeWindowTriplet: timeWindowTriplet, + checkpointWatcher: checkpointWatchers, + s3Watcher: s3Watchers, + pdClients: pdClients, + } + newDataMap, err := advancer.initializeFromCheckpoint(ctx, checkpoint) + if err != nil { + return nil, nil, errors.Trace(err) + } + return advancer, newDataMap, nil +} + +func (t *TimeWindowAdvancer) initializeFromCheckpoint( + ctx context.Context, + checkpoint *recorder.Checkpoint, +) (map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + if checkpoint == nil { + return nil, nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil, nil + } + t.round = checkpoint.CheckpointItems[2].Round + 1 + for clusterID := range t.timeWindowTriplet { + newTimeWindows := [3]types.TimeWindow{} + clusterInfo, exists := checkpoint.CheckpointItems[2].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[2]", clusterID) + } + newTimeWindows[2] = clusterInfo.TimeWindow + if checkpoint.CheckpointItems[1] != nil { + clusterInfo, exists = checkpoint.CheckpointItems[1].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[1]", clusterID) + } + newTimeWindows[1] = clusterInfo.TimeWindow + } + if checkpoint.CheckpointItems[0] != nil { + clusterInfo, exists = checkpoint.CheckpointItems[0].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[0]", clusterID) + } + newTimeWindows[0] = clusterInfo.TimeWindow + } + t.timeWindowTriplet[clusterID] = newTimeWindows + } + + var mu sync.Mutex + newDataMap := make(map[string]map[cloudstorage.DmlPathKey]types.IncrementalData) + eg, egCtx := errgroup.WithContext(ctx) + for clusterID, s3Watcher := range t.s3Watcher { + eg.Go(func() error { + newData, err := s3Watcher.InitializeFromCheckpoint(egCtx, clusterID, checkpoint) + if err != nil { + return errors.Trace(err) + } + mu.Lock() + newDataMap[clusterID] = newData + mu.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return newDataMap, nil +} + +// AdvanceTimeWindow advances the time window for each cluster. Here is the steps: +// 1. Advance the checkpoint ts for each local-to-replicated changefeed. +// +// For any local-to-replicated changefeed, the checkpoint ts should be advanced to +// the maximum of pd timestamp after previous time window of the replicated cluster +// advanced and the right boundary of previous time window of every clusters. +// +// 2. Advance the right boundary for each cluster. +// +// For any cluster, the right boundary should be advanced to the maximum of pd timestamp of +// the cluster after the checkpoint ts of its local cluster advanced and the previous +// timewindow's checkpoint ts of changefeed where the cluster is the local or the replicated. +// +// 3. Update the time window for each cluster. +// +// For any cluster, the time window should be updated to the new time window. +func (t *TimeWindowAdvancer) AdvanceTimeWindow( + pctx context.Context, +) (map[string]types.TimeWindowData, error) { + log.Debug("advance time window", zap.Uint64("round", t.round)) + // mapping from local cluster ID to replicated cluster ID to the min checkpoint timestamp + minCheckpointTsMap := make(map[string]map[string]uint64) + maxTimeWindowRightBoundary := uint64(0) + for replicatedClusterID, triplet := range t.timeWindowTriplet { + for localClusterID, pdTimestampAfterTimeWindow := range triplet[2].PDTimestampAfterTimeWindow { + if _, ok := minCheckpointTsMap[localClusterID]; !ok { + minCheckpointTsMap[localClusterID] = make(map[string]uint64) + } + minCheckpointTsMap[localClusterID][replicatedClusterID] = max(minCheckpointTsMap[localClusterID][replicatedClusterID], pdTimestampAfterTimeWindow) + } + maxTimeWindowRightBoundary = max(maxTimeWindowRightBoundary, triplet[2].RightBoundary) + } + + var lock sync.Mutex + newTimeWindow := make(map[string]types.TimeWindow) + maxPDTimestampAfterCheckpointTs := make(map[string]uint64) + // for cluster ID, the max checkpoint timestamp is maximum of checkpoint from cluster to other clusters and checkpoint from other clusters to cluster + maxCheckpointTs := make(map[string]uint64) + // Advance the checkpoint ts for each cluster + eg, ctx := errgroup.WithContext(pctx) + for localClusterID, replicatedCheckpointWatcherMap := range t.checkpointWatcher { + for replicatedClusterID, checkpointWatcher := range replicatedCheckpointWatcherMap { + minCheckpointTs := max(minCheckpointTsMap[localClusterID][replicatedClusterID], maxTimeWindowRightBoundary) + eg.Go(func() error { + checkpointTs, err := checkpointWatcher.AdvanceCheckpointTs(ctx, minCheckpointTs) + if err != nil { + return errors.Trace(err) + } + // TODO: optimize this by getting pd ts in the end of all checkpoint ts advance + pdtsos, err := t.getPDTsFromOtherClusters(ctx, localClusterID) + if err != nil { + return errors.Trace(err) + } + lock.Lock() + timeWindow := newTimeWindow[localClusterID] + if timeWindow.CheckpointTs == nil { + timeWindow.CheckpointTs = make(map[string]uint64) + } + timeWindow.CheckpointTs[replicatedClusterID] = checkpointTs + newTimeWindow[localClusterID] = timeWindow + for otherClusterID, pdtso := range pdtsos { + maxPDTimestampAfterCheckpointTs[otherClusterID] = max(maxPDTimestampAfterCheckpointTs[otherClusterID], pdtso) + } + maxCheckpointTs[localClusterID] = max(maxCheckpointTs[localClusterID], checkpointTs) + maxCheckpointTs[replicatedClusterID] = max(maxCheckpointTs[replicatedClusterID], checkpointTs) + lock.Unlock() + return nil + }) + } + } + if err := eg.Wait(); err != nil { + return nil, errors.Annotate(err, "advance checkpoint timestamp failed") + } + + // Update the time window for each cluster + newDataMap := make(map[string]map[cloudstorage.DmlPathKey]types.IncrementalData) + maxVersionMap := make(map[string]map[types.SchemaTableKey]types.VersionKey) + eg, ctx = errgroup.WithContext(pctx) + for clusterID, triplet := range t.timeWindowTriplet { + minTimeWindowRightBoundary := max(maxCheckpointTs[clusterID], maxPDTimestampAfterCheckpointTs[clusterID], triplet[2].NextMinLeftBoundary) + s3Watcher := t.s3Watcher[clusterID] + eg.Go(func() error { + s3CheckpointTs, err := s3Watcher.AdvanceS3CheckpointTs(ctx, minTimeWindowRightBoundary) + if err != nil { + return errors.Trace(err) + } + newData, maxClusterVersionMap, err := s3Watcher.ConsumeNewFiles(ctx) + if err != nil { + return errors.Trace(err) + } + pdtso, err := t.getPDTsFromCluster(ctx, clusterID) + if err != nil { + return errors.Trace(err) + } + pdtsos, err := t.getPDTsFromOtherClusters(ctx, clusterID) + if err != nil { + return errors.Trace(err) + } + lock.Lock() + newDataMap[clusterID] = newData + maxVersionMap[clusterID] = maxClusterVersionMap + timeWindow := newTimeWindow[clusterID] + timeWindow.LeftBoundary = triplet[2].RightBoundary + timeWindow.RightBoundary = s3CheckpointTs + timeWindow.PDTimestampAfterTimeWindow = make(map[string]uint64) + timeWindow.NextMinLeftBoundary = pdtso + maps.Copy(timeWindow.PDTimestampAfterTimeWindow, pdtsos) + newTimeWindow[clusterID] = timeWindow + lock.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Annotate(err, "advance time window failed") + } + t.updateTimeWindow(newTimeWindow) + t.round++ + return newTimeWindowData(newTimeWindow, newDataMap, maxVersionMap), nil +} + +func (t *TimeWindowAdvancer) updateTimeWindow(newTimeWindow map[string]types.TimeWindow) { + for clusterID, timeWindow := range newTimeWindow { + triplet := t.timeWindowTriplet[clusterID] + triplet[0] = triplet[1] + triplet[1] = triplet[2] + triplet[2] = timeWindow + t.timeWindowTriplet[clusterID] = triplet + log.Debug("update time window", zap.String("clusterID", clusterID), zap.Any("timeWindow", timeWindow)) + } +} + +func (t *TimeWindowAdvancer) getPDTsFromCluster(ctx context.Context, clusterID string) (uint64, error) { + pdClient := t.pdClients[clusterID] + phyTs, logicTs, err := pdClient.GetTS(ctx) + if err != nil { + return 0, errors.Trace(err) + } + ts := oracle.ComposeTS(phyTs, logicTs) + return ts, nil +} + +func (t *TimeWindowAdvancer) getPDTsFromOtherClusters(pctx context.Context, clusterID string) (map[string]uint64, error) { + var lock sync.Mutex + pdtsos := make(map[string]uint64) + eg, ctx := errgroup.WithContext(pctx) + for otherClusterID := range t.pdClients { + if otherClusterID == clusterID { + continue + } + pdClient := t.pdClients[otherClusterID] + eg.Go(func() error { + phyTs, logicTs, err := pdClient.GetTS(ctx) + if err != nil { + return errors.Trace(err) + } + ts := oracle.ComposeTS(phyTs, logicTs) + lock.Lock() + pdtsos[otherClusterID] = ts + lock.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return pdtsos, nil +} + +func newTimeWindowData( + newTimeWindow map[string]types.TimeWindow, + newDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, + maxVersionMap map[string]map[types.SchemaTableKey]types.VersionKey, +) map[string]types.TimeWindowData { + timeWindowDatas := make(map[string]types.TimeWindowData) + for clusterID, timeWindow := range newTimeWindow { + timeWindowDatas[clusterID] = types.TimeWindowData{ + TimeWindow: timeWindow, + Data: newDataMap[clusterID], + MaxVersion: maxVersionMap[clusterID], + } + } + return timeWindowDatas +} diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go new file mode 100644 index 0000000000..c0d5a68429 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go @@ -0,0 +1,258 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package advancer + +import ( + "context" + "maps" + "sync" + "sync/atomic" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/stretchr/testify/require" + pd "github.com/tikv/pd/client" +) + +// mockPDClient mocks pd.Client for testing. +// Each call to GetTS returns a monotonically increasing TSO (physical part increases by 1000ms per call). +type mockPDClient struct { + pd.Client + seq int64 // accessed atomically +} + +func (m *mockPDClient) GetTS(ctx context.Context) (int64, int64, error) { + n := atomic.AddInt64(&m.seq, 1) + // Physical timestamp starts at 11000ms and increases by 1000ms per call. + // oracle.ComposeTS(physical, 0) = physical << 18, so each step is ~262 million. + return 10000 + n*1000, 0, nil +} + +func (m *mockPDClient) Close() {} + +// mockAdvancerWatcher mocks watcher.Watcher for testing. +// Returns minCheckpointTs + delta, ensuring the result is always > minCheckpointTs and monotonically increasing. +type mockAdvancerWatcher struct { + mu sync.Mutex + delta uint64 + history []uint64 +} + +func (m *mockAdvancerWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := minCheckpointTs + m.delta + m.history = append(m.history, result) + return result, nil +} + +func (m *mockAdvancerWatcher) Close() {} + +func (m *mockAdvancerWatcher) getHistory() []uint64 { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]uint64, len(m.history)) + copy(out, m.history) + return out +} + +func TestNewTimeWindowAdvancer(t *testing.T) { + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "cluster1": {}, + "cluster2": {}, + } + s3Watchers := map[string]*watcher.S3Watcher{ + "cluster1": nil, + "cluster2": nil, + } + pdClients := map[string]pd.Client{ + "cluster1": nil, + "cluster2": nil, + } + + advancer, _, err := NewTimeWindowAdvancer(context.Background(), checkpointWatchers, s3Watchers, pdClients, nil) + require.NoError(t, err) + require.NotNil(t, advancer) + require.Equal(t, uint64(0), advancer.round) + require.Len(t, advancer.timeWindowTriplet, 2) + require.Contains(t, advancer.timeWindowTriplet, "cluster1") + require.Contains(t, advancer.timeWindowTriplet, "cluster2") +} + +func TestNewTimeWindowAdvancerInitializeFromCheckpointMissingClusterInfo(t *testing.T) { + t.Parallel() + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "cluster1": {}, + "cluster2": {}, + } + s3Watchers := map[string]*watcher.S3Watcher{ + "cluster1": watcher.NewS3Watcher(&mockAdvancerWatcher{delta: 1}, storage.NewMemStorage(), nil), + "cluster2": watcher.NewS3Watcher(&mockAdvancerWatcher{delta: 1}, storage.NewMemStorage(), nil), + } + pdClients := map[string]pd.Client{ + "cluster1": &mockPDClient{}, + "cluster2": &mockPDClient{}, + } + + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "cluster1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + + _, _, err := NewTimeWindowAdvancer(context.Background(), checkpointWatchers, s3Watchers, pdClients, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster cluster2 not found in checkpoint item[2]") +} + +// TestTimeWindowAdvancer_AdvanceMultipleRounds simulates 4 rounds of AdvanceTimeWindow +// with 2 clusters (c1, c2) performing bidirectional replication. +// +// The test verifies: +// - Time windows advance correctly (LeftBoundary == previous RightBoundary) +// - RightBoundary > LeftBoundary for each time window +// - Checkpoint timestamps are monotonically increasing across rounds +// - PD TSOs are always greater than checkpoint timestamps +// - NextMinLeftBoundary (PD TSO) > RightBoundary (S3 checkpoint) +// - Mock watcher checkpoint histories are strictly increasing +func TestTimeWindowAdvancer_AdvanceMultipleRounds(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create mock PD clients for each cluster (monotonically increasing TSO) + pdC1 := &mockPDClient{} + pdC2 := &mockPDClient{} + pdClients := map[string]pd.Client{ + "c1": pdC1, + "c2": pdC2, + } + + // Create mock checkpoint watchers for bidirectional replication (c1->c2, c2->c1) + // Each returns minCheckpointTs + 100 + cpWatcherC1C2 := &mockAdvancerWatcher{delta: 100} + cpWatcherC2C1 := &mockAdvancerWatcher{delta: 100} + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "c1": {"c2": cpWatcherC1C2}, + "c2": {"c1": cpWatcherC2C1}, + } + + // Create S3 watchers with mock checkpoint watchers (returns minCheckpointTs + 50) + // and empty in-memory storage (no actual S3 data) + s3WatcherMockC1 := &mockAdvancerWatcher{delta: 50} + s3WatcherMockC2 := &mockAdvancerWatcher{delta: 50} + s3Watchers := map[string]*watcher.S3Watcher{ + "c1": watcher.NewS3Watcher(s3WatcherMockC1, storage.NewMemStorage(), nil), + "c2": watcher.NewS3Watcher(s3WatcherMockC2, storage.NewMemStorage(), nil), + } + + advancer, _, err := NewTimeWindowAdvancer(ctx, checkpointWatchers, s3Watchers, pdClients, nil) + require.NoError(t, err) + require.Equal(t, uint64(0), advancer.round) + + // Track previous round values for cross-round assertions + prevRightBoundaries := map[string]uint64{"c1": 0, "c2": 0} + prevCheckpointTs := map[string]map[string]uint64{ + "c1": {"c2": 0}, + "c2": {"c1": 0}, + } + prevRightBoundary := uint64(0) // max across all clusters + + for round := range 4 { + result, err := advancer.AdvanceTimeWindow(ctx) + require.NoError(t, err, "round %d", round) + require.Len(t, result, 2, "round %d: should have data for both clusters", round) + + for clusterID, twData := range result { + tw := twData.TimeWindow + + // 1. LeftBoundary == previous RightBoundary + require.Equal(t, prevRightBoundaries[clusterID], tw.LeftBoundary, + "round %d, cluster %s: LeftBoundary should equal previous RightBoundary", round, clusterID) + + // 2. RightBoundary > LeftBoundary (time window is non-empty) + require.Greater(t, tw.RightBoundary, tw.LeftBoundary, + "round %d, cluster %s: RightBoundary should be > LeftBoundary", round, clusterID) + + // 3. CheckpointTs should be populated and strictly increasing across rounds + require.NotEmpty(t, tw.CheckpointTs, + "round %d, cluster %s: CheckpointTs should be populated", round, clusterID) + for replicatedCluster, cpTs := range tw.CheckpointTs { + require.Greater(t, cpTs, prevCheckpointTs[clusterID][replicatedCluster], + "round %d, %s->%s: checkpoint should be strictly increasing", round, clusterID, replicatedCluster) + } + + // 4. PDTimestampAfterTimeWindow should be populated + require.NotEmpty(t, tw.PDTimestampAfterTimeWindow, + "round %d, cluster %s: PDTimestampAfterTimeWindow should be populated", round, clusterID) + + // 5. NextMinLeftBoundary > RightBoundary + // (PD TSO is obtained after S3 checkpoint, and PD TSO >> S3 checkpoint) + require.Greater(t, tw.NextMinLeftBoundary, tw.RightBoundary, + "round %d, cluster %s: NextMinLeftBoundary (PD TSO) should be > RightBoundary (S3 checkpoint)", round, clusterID) + + // 6. PD TSO values in PDTimestampAfterTimeWindow > all CheckpointTs values + // (PD TSOs are obtained after checkpoint advance) + for otherCluster, pdTs := range tw.PDTimestampAfterTimeWindow { + for replicatedCluster, cpTs := range tw.CheckpointTs { + require.Greater(t, pdTs, cpTs, + "round %d, cluster %s: PD TSO (from %s) should be > checkpoint (%s->%s)", + round, clusterID, otherCluster, clusterID, replicatedCluster) + } + } + + // 7. RightBoundary > previous round's max RightBoundary (time window advances) + require.Greater(t, tw.RightBoundary, prevRightBoundary, + "round %d, cluster %s: RightBoundary should be > previous max RightBoundary", round, clusterID) + } + + // Save current values for next round + maxRB := uint64(0) + for clusterID, twData := range result { + prevRightBoundaries[clusterID] = twData.TimeWindow.RightBoundary + if twData.TimeWindow.RightBoundary > maxRB { + maxRB = twData.TimeWindow.RightBoundary + } + maps.Copy(prevCheckpointTs[clusterID], twData.TimeWindow.CheckpointTs) + } + prevRightBoundary = maxRB + } + + // After 4 rounds, round counter should be 4 + require.Equal(t, uint64(4), advancer.round) + + // Verify all mock watcher checkpoint histories are strictly monotonically increasing + allWatchers := []*mockAdvancerWatcher{cpWatcherC1C2, cpWatcherC2C1, s3WatcherMockC1, s3WatcherMockC2} + watcherNames := []string{"cp c1->c2", "cp c2->c1", "s3 c1", "s3 c2"} + for idx, w := range allWatchers { + history := w.getHistory() + require.GreaterOrEqual(t, len(history), 4, + "%s: should have at least 4 checkpoint values (one per round)", watcherNames[idx]) + for i := 1; i < len(history); i++ { + require.Greater(t, history[i], history[i-1], + "%s: checkpoint values should be strictly increasing (index %d: %d -> %d)", + watcherNames[idx], i, history[i-1], history[i]) + } + } + + // Verify PD clients were called (monotonically increasing due to atomic counter) + require.Greater(t, atomic.LoadInt64(&pdC1.seq), int64(0), "pd-c1 should have been called") + require.Greater(t, atomic.LoadInt64(&pdC2.seq), int64(0), "pd-c2 should have been called") +} diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go new file mode 100644 index 0000000000..77972240ac --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -0,0 +1,722 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "sort" + + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "go.uber.org/zap" +) + +type versionCacheEntry struct { + previous int + cdcVersion types.CdcVersion +} + +type clusterViolationChecker struct { + clusterID string + twoPreviousTimeWindowKeyVersionCache map[string]map[types.PkType]versionCacheEntry +} + +func newClusterViolationChecker(clusterID string) *clusterViolationChecker { + return &clusterViolationChecker{ + clusterID: clusterID, + twoPreviousTimeWindowKeyVersionCache: make(map[string]map[types.PkType]versionCacheEntry), + } +} + +func (c *clusterViolationChecker) NewRecordFromCheckpoint(schemaKey string, record *decoder.Record, previous int) { + tableSchemaKeyVersionCache, exists := c.twoPreviousTimeWindowKeyVersionCache[schemaKey] + if !exists { + tableSchemaKeyVersionCache = make(map[types.PkType]versionCacheEntry) + c.twoPreviousTimeWindowKeyVersionCache[schemaKey] = tableSchemaKeyVersionCache + } + entry, exists := tableSchemaKeyVersionCache[record.Pk] + if !exists { + tableSchemaKeyVersionCache[record.Pk] = versionCacheEntry{ + previous: previous, + cdcVersion: record.CdcVersion, + } + return + } + entryCompareTs := entry.cdcVersion.GetCompareTs() + recordCompareTs := record.GetCompareTs() + if entryCompareTs < recordCompareTs { + tableSchemaKeyVersionCache[record.Pk] = versionCacheEntry{ + previous: previous, + cdcVersion: record.CdcVersion, + } + } +} + +func (c *clusterViolationChecker) Check(schemaKey string, r *decoder.Record, report *recorder.ClusterReport) { + tableSchemaKeyVersionCache, exists := c.twoPreviousTimeWindowKeyVersionCache[schemaKey] + if !exists { + tableSchemaKeyVersionCache = make(map[types.PkType]versionCacheEntry) + c.twoPreviousTimeWindowKeyVersionCache[schemaKey] = tableSchemaKeyVersionCache + } + entry, exists := tableSchemaKeyVersionCache[r.Pk] + if !exists { + tableSchemaKeyVersionCache[r.Pk] = versionCacheEntry{ + previous: 0, + cdcVersion: r.CdcVersion, + } + return + } + if entry.cdcVersion.CommitTs >= r.CommitTs { + // duplicated old version, just skip it + return + } + entryCompareTs := entry.cdcVersion.GetCompareTs() + recordCompareTs := r.GetCompareTs() + if entryCompareTs >= recordCompareTs { + // violation detected + log.Error("LWW violation detected", + zap.String("clusterID", c.clusterID), + zap.Any("entry", entry), + zap.String("pk", r.PkStr)) + report.AddLWWViolationItem(schemaKey, r.PkMap, r.PkStr, entry.cdcVersion.OriginTs, entry.cdcVersion.CommitTs, r.OriginTs, r.CommitTs) + return + } + tableSchemaKeyVersionCache[r.Pk] = versionCacheEntry{ + previous: 0, + cdcVersion: r.CdcVersion, + } +} + +func (c *clusterViolationChecker) UpdateCache() { + newTwoPreviousTimeWindowKeyVersionCache := make(map[string]map[types.PkType]versionCacheEntry) + for schemaKey, tableSchemaKeyVersionCache := range c.twoPreviousTimeWindowKeyVersionCache { + newTableSchemaKeyVersionCache := make(map[types.PkType]versionCacheEntry) + for primaryKey, entry := range tableSchemaKeyVersionCache { + if entry.previous >= 2 { + continue + } + newTableSchemaKeyVersionCache[primaryKey] = versionCacheEntry{ + previous: entry.previous + 1, + cdcVersion: entry.cdcVersion, + } + } + if len(newTableSchemaKeyVersionCache) > 0 { + newTwoPreviousTimeWindowKeyVersionCache[schemaKey] = newTableSchemaKeyVersionCache + } + } + c.twoPreviousTimeWindowKeyVersionCache = newTwoPreviousTimeWindowKeyVersionCache +} + +type RecordsMapWithMaxCompareTs struct { + records map[uint64]*decoder.Record + maxCompareTs uint64 +} + +type tableDataCache struct { + // localDataCache is a map of primary key to a map of commit ts to a record + localDataCache map[types.PkType]*RecordsMapWithMaxCompareTs + + // replicatedDataCache is a map of primary key to a map of origin ts to a record + replicatedDataCache map[types.PkType]*RecordsMapWithMaxCompareTs +} + +func newTableDataCache() *tableDataCache { + return &tableDataCache{ + localDataCache: make(map[types.PkType]*RecordsMapWithMaxCompareTs), + replicatedDataCache: make(map[types.PkType]*RecordsMapWithMaxCompareTs), + } +} + +func (tdc *tableDataCache) newLocalRecord(record *decoder.Record) { + recordsMap, exists := tdc.localDataCache[record.Pk] + if !exists { + recordsMap = &RecordsMapWithMaxCompareTs{ + records: make(map[uint64]*decoder.Record), + maxCompareTs: 0, + } + tdc.localDataCache[record.Pk] = recordsMap + } + recordsMap.records[record.CommitTs] = record + if record.CommitTs > recordsMap.maxCompareTs { + recordsMap.maxCompareTs = record.CommitTs + } +} + +func (tdc *tableDataCache) newReplicatedRecord(record *decoder.Record) { + recordsMap, exists := tdc.replicatedDataCache[record.Pk] + if !exists { + recordsMap = &RecordsMapWithMaxCompareTs{ + records: make(map[uint64]*decoder.Record), + maxCompareTs: 0, + } + tdc.replicatedDataCache[record.Pk] = recordsMap + } + recordsMap.records[record.OriginTs] = record + compareTs := record.GetCompareTs() + if compareTs > recordsMap.maxCompareTs { + recordsMap.maxCompareTs = compareTs + } +} + +type timeWindowDataCache struct { + tableDataCaches map[string]*tableDataCache + + leftBoundary uint64 + rightBoundary uint64 + checkpointTs map[string]uint64 +} + +func newTimeWindowDataCache(leftBoundary, rightBoundary uint64, checkpointTs map[string]uint64) timeWindowDataCache { + return timeWindowDataCache{ + tableDataCaches: make(map[string]*tableDataCache), + leftBoundary: leftBoundary, + rightBoundary: rightBoundary, + checkpointTs: checkpointTs, + } +} + +func (twdc *timeWindowDataCache) NewRecord(schemaKey string, record *decoder.Record) { + if record.CommitTs <= twdc.leftBoundary { + // record is before the left boundary, just skip it + return + } + tableDataCache, exists := twdc.tableDataCaches[schemaKey] + if !exists { + tableDataCache = newTableDataCache() + twdc.tableDataCaches[schemaKey] = tableDataCache + } + if record.OriginTs == 0 { + tableDataCache.newLocalRecord(record) + } else { + tableDataCache.newReplicatedRecord(record) + } +} + +type clusterDataChecker struct { + clusterID string + // true if more than 2 clusters are involved in the check + multiCluster bool + + thisRoundTimeWindow types.TimeWindow + + timeWindowDataCaches [3]timeWindowDataCache + + rightBoundary uint64 + + overDataCaches map[string][]*decoder.Record + + clusterViolationChecker *clusterViolationChecker + + report *recorder.ClusterReport + + lwwSkippedRecordsCount int + checkedRecordsCount int + newTimeWindowRecordsCount int +} + +func newClusterDataChecker(clusterID string, multiCluster bool) *clusterDataChecker { + return &clusterDataChecker{ + clusterID: clusterID, + multiCluster: multiCluster, + timeWindowDataCaches: [3]timeWindowDataCache{}, + rightBoundary: 0, + overDataCaches: make(map[string][]*decoder.Record), + clusterViolationChecker: newClusterViolationChecker(clusterID), + } +} + +func (cd *clusterDataChecker) InitializeFromCheckpoint( + ctx context.Context, + checkpointDataMap map[cloudstorage.DmlPathKey]types.IncrementalData, + checkpoint *recorder.Checkpoint, +) error { + if checkpoint == nil { + return nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil + } + clusterInfo, exists := checkpoint.CheckpointItems[2].ClusterInfo[cd.clusterID] + if !exists { + return errors.Errorf("cluster %s not found in checkpoint item[2]", cd.clusterID) + } + cd.rightBoundary = clusterInfo.TimeWindow.RightBoundary + cd.timeWindowDataCaches[2] = newTimeWindowDataCache( + clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) + if checkpoint.CheckpointItems[1] != nil { + clusterInfo, exists = checkpoint.CheckpointItems[1].ClusterInfo[cd.clusterID] + if !exists { + return errors.Errorf("cluster %s not found in checkpoint item[1]", cd.clusterID) + } + cd.timeWindowDataCaches[1] = newTimeWindowDataCache( + clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) + } + for schemaPathKey, incrementalData := range checkpointDataMap { + schemaKey := schemaPathKey.GetKey() + for _, contents := range incrementalData.DataContentSlices { + for _, content := range contents { + records, err := decoder.Decode(content, incrementalData.ColumnFieldTypes) + if err != nil { + return errors.Trace(err) + } + for _, record := range records { + cd.newRecordFromCheckpoint(schemaKey, record) + } + } + } + } + return nil +} + +func (cd *clusterDataChecker) newRecordFromCheckpoint(schemaKey string, record *decoder.Record) { + if record.CommitTs > cd.rightBoundary { + cd.overDataCaches[schemaKey] = append(cd.overDataCaches[schemaKey], record) + return + } + if cd.timeWindowDataCaches[2].leftBoundary < record.CommitTs { + cd.timeWindowDataCaches[2].NewRecord(schemaKey, record) + cd.clusterViolationChecker.NewRecordFromCheckpoint(schemaKey, record, 1) + + } else if cd.timeWindowDataCaches[1].leftBoundary < record.CommitTs { + cd.timeWindowDataCaches[1].NewRecord(schemaKey, record) + cd.clusterViolationChecker.NewRecordFromCheckpoint(schemaKey, record, 2) + } +} + +func (cd *clusterDataChecker) PrepareNextTimeWindowData(timeWindow types.TimeWindow) error { + if timeWindow.LeftBoundary != cd.rightBoundary { + return errors.Errorf("time window left boundary(%d) mismatch right boundary ts(%d)", timeWindow.LeftBoundary, cd.rightBoundary) + } + cd.timeWindowDataCaches[0] = cd.timeWindowDataCaches[1] + cd.timeWindowDataCaches[1] = cd.timeWindowDataCaches[2] + newTimeWindowDataCache := newTimeWindowDataCache(timeWindow.LeftBoundary, timeWindow.RightBoundary, timeWindow.CheckpointTs) + cd.rightBoundary = timeWindow.RightBoundary + newOverDataCache := make(map[string][]*decoder.Record) + for schemaKey, overRecords := range cd.overDataCaches { + newTableOverDataCache := make([]*decoder.Record, 0, len(overRecords)) + for _, overRecord := range overRecords { + if overRecord.CommitTs > timeWindow.RightBoundary { + newTableOverDataCache = append(newTableOverDataCache, overRecord) + } else { + newTimeWindowDataCache.NewRecord(schemaKey, overRecord) + } + } + newOverDataCache[schemaKey] = newTableOverDataCache + } + cd.timeWindowDataCaches[2] = newTimeWindowDataCache + cd.overDataCaches = newOverDataCache + cd.lwwSkippedRecordsCount = 0 + cd.checkedRecordsCount = 0 + cd.newTimeWindowRecordsCount = 0 + return nil +} + +func (cd *clusterDataChecker) NewRecord(schemaKey string, record *decoder.Record) { + if record.CommitTs > cd.rightBoundary { + cd.overDataCaches[schemaKey] = append(cd.overDataCaches[schemaKey], record) + return + } + cd.timeWindowDataCaches[2].NewRecord(schemaKey, record) +} + +func (cd *clusterDataChecker) findClusterReplicatedDataInTimeWindow(timeWindowIdx int, schemaKey string, pk types.PkType, originTs uint64) (*decoder.Record, bool) { + tableDataCache, exists := cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches[schemaKey] + if !exists { + return nil, false + } + if recordsMap, exists := tableDataCache.replicatedDataCache[pk]; exists { + if record, exists := recordsMap.records[originTs]; exists { + return record, false + } + if recordsMap.maxCompareTs > originTs { + return nil, true + } + } + // If no replicated record is found, a newer/equal local record with the + // same PK in the peer cluster also indicates this origin write can be + // considered overwritten by LWW semantics instead of hard data loss. + if recordsMap, exists := tableDataCache.localDataCache[pk]; exists { + if recordsMap.maxCompareTs > originTs { + return nil, true + } + } + return nil, false +} + +func (cd *clusterDataChecker) findClusterLocalDataInTimeWindow(timeWindowIdx int, schemaKey string, pk types.PkType, commitTs uint64) bool { + tableDataCache, exists := cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches[schemaKey] + if !exists { + return false + } + recordsMap, exists := tableDataCache.localDataCache[pk] + if !exists { + return false + } + _, exists = recordsMap.records[commitTs] + return exists +} + +// diffColumns compares column values between local written and replicated records +// and returns the list of inconsistent columns. +func diffColumns(local, replicated *decoder.Record) []recorder.InconsistentColumn { + var result []recorder.InconsistentColumn + for colName, localVal := range local.ColumnValues { + replicatedVal, ok := replicated.ColumnValues[colName] + if !ok { + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: localVal, + Replicated: nil, + }) + } else if localVal != replicatedVal { // safe: ColumnValues only holds comparable types (see decoder.go) + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: localVal, + Replicated: replicatedVal, + }) + } + } + for colName, replicatedVal := range replicated.ColumnValues { + if _, ok := local.ColumnValues[colName]; !ok { + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: nil, + Replicated: replicatedVal, + }) + } + } + sort.Slice(result, func(i, j int) bool { + return result[i].Column < result[j].Column + }) + return result +} + +// datalossDetection iterates through the local-written data cache [1] and [2] and filter out the records +// whose checkpoint ts falls within the (checkpoint[1], checkpoint[2]]. The record must be present +// in the replicated data cache [1] or [2] or another new record is present in the replicated data +// cache [1] or [2]. +func (cd *clusterDataChecker) dataLossDetection(checker *DataChecker) { + // Time window [1]: skip records whose commitTs <= checkpoint (already checked in previous round) + cd.checkLocalRecordsForDataLoss(1, func(commitTs, checkpointTs uint64) bool { + return commitTs <= checkpointTs + }, checker) + // Time window [2]: skip records whose commitTs > checkpoint (will be checked in next round) + cd.checkLocalRecordsForDataLoss(2, func(commitTs, checkpointTs uint64) bool { + return commitTs > checkpointTs + }, checker) +} + +// checkLocalRecordsForDataLoss iterates through the local-written data cache at timeWindowIdx +// and checks each record against the replicated data cache. Records for which shouldSkip returns +// true are skipped. This helper unifies the logic for time windows [1] and [2]. +func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( + timeWindowIdx int, + shouldSkip func(commitTs, checkpointTs uint64) bool, + checker *DataChecker, +) { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches { + for _, localDataCache := range tableDataCache.localDataCache { + for _, record := range localDataCache.records { + for replicatedClusterID, checkpointTs := range cd.timeWindowDataCaches[timeWindowIdx].checkpointTs { + if shouldSkip(record.CommitTs, checkpointTs) { + continue + } + cd.checkedRecordsCount++ + replicatedRecord, skipped := checker.FindClusterReplicatedData(replicatedClusterID, schemaKey, record.Pk, record.CommitTs, cd.multiCluster) + if skipped { + failpoint.Inject("multiClusterConsistencyCheckerLWWViolation", func() { + Write("multiClusterConsistencyCheckerLWWViolation", []RowRecord{ + { + CommitTs: record.CommitTs, + PrimaryKeys: record.PkMap, + }, + }) + }) + log.Debug("replicated record skipped by LWW", + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), + zap.String("schemaKey", schemaKey), + zap.String("pk", record.PkStr), + zap.Uint64("commitTs", record.CommitTs)) + cd.lwwSkippedRecordsCount++ + continue + } + if replicatedRecord == nil { + // data loss detected + log.Error("data loss detected", + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), + zap.Any("record", record)) + cd.report.AddDataLossItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, record.CommitTs) + } else if !record.EqualReplicatedRecord(replicatedRecord) { + // data inconsistent detected + log.Error("data inconsistent detected", + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), + zap.Any("record", record)) + cd.report.AddDataInconsistentItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, record.CommitTs, replicatedRecord.CommitTs, diffColumns(record, replicatedRecord)) + } + } + } + } + } +} + +// dataRedundantDetection iterates through the replicated data cache [1]. The record must be present +// in the source local data cache across recent windows. +func (cd *clusterDataChecker) dataRedundantDetection(checker *DataChecker) { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[1].tableDataCaches { + for _, replicatedDataCache := range tableDataCache.replicatedDataCache { + for _, record := range replicatedDataCache.records { + cd.checkedRecordsCount++ + // For replicated records, OriginTs is the local commit ts + if !checker.FindSourceLocalData(cd.clusterID, schemaKey, record.Pk, record.OriginTs) { + // data redundant detected + log.Error("data redundant detected", + zap.String("replicatedClusterID", cd.clusterID), + zap.Any("record", record)) + cd.report.AddDataRedundantItem(schemaKey, record.PkMap, record.PkStr, record.OriginTs, record.CommitTs) + } + } + } + } +} + +// lwwViolationDetection check the orderliness of the records +func (cd *clusterDataChecker) lwwViolationDetection() { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { + for pk, localRecords := range tableDataCache.localDataCache { + replicatedRecords := tableDataCache.replicatedDataCache[pk] + replicatedCount := 0 + if replicatedRecords != nil { + replicatedCount = len(replicatedRecords.records) + } + pkRecords := make([]*decoder.Record, 0, len(localRecords.records)+replicatedCount) + for _, localRecord := range localRecords.records { + pkRecords = append(pkRecords, localRecord) + } + if replicatedRecords != nil { + for _, replicatedRecord := range replicatedRecords.records { + pkRecords = append(pkRecords, replicatedRecord) + } + } + sort.Slice(pkRecords, func(i, j int) bool { + return pkRecords[i].CommitTs < pkRecords[j].CommitTs + }) + for _, record := range pkRecords { + cd.newTimeWindowRecordsCount++ + cd.clusterViolationChecker.Check(schemaKey, record, cd.report) + } + } + + for pk, replicatedRecords := range tableDataCache.replicatedDataCache { + if _, exists := tableDataCache.localDataCache[pk]; exists { + continue + } + pkRecords := make([]*decoder.Record, 0, len(replicatedRecords.records)) + for _, replicatedRecord := range replicatedRecords.records { + pkRecords = append(pkRecords, replicatedRecord) + } + sort.Slice(pkRecords, func(i, j int) bool { + return pkRecords[i].CommitTs < pkRecords[j].CommitTs + }) + for _, record := range pkRecords { + cd.newTimeWindowRecordsCount++ + cd.clusterViolationChecker.Check(schemaKey, record, cd.report) + } + } + } + + cd.clusterViolationChecker.UpdateCache() +} + +func (cd *clusterDataChecker) Check(checker *DataChecker, enableDataLoss, enableDataRedundant bool) { + cd.report = recorder.NewClusterReport(cd.clusterID, cd.thisRoundTimeWindow) + if enableDataLoss { + // CHECK 1 - Data Loss / Inconsistency Detection (round 2+) + // Needs [1] and [2] populated. + cd.dataLossDetection(checker) + } + if enableDataRedundant { + // CHECK 2 - Data Redundant Detection (round 3+) + // Needs [0], [1] and [2] all populated with real data; + // at round 2 [0] is still round 0 (empty), which would cause false positives. + cd.dataRedundantDetection(checker) + } + // CHECK 3 - LWW Violation Detection + // Always runs to keep the version cache up-to-date; meaningful results + // start from round 1 once the cache has been seeded. + cd.lwwViolationDetection() +} + +func (cd *clusterDataChecker) GetReport() *recorder.ClusterReport { + return cd.report +} + +type DataChecker struct { + round uint64 + checkableRound uint64 + clusterDataCheckers map[string]*clusterDataChecker +} + +func NewDataChecker(ctx context.Context, clusterConfig map[string]config.ClusterConfig, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) (*DataChecker, error) { + clusterDataChecker := make(map[string]*clusterDataChecker) + for clusterID := range clusterConfig { + clusterDataChecker[clusterID] = newClusterDataChecker(clusterID, len(clusterConfig) > 2) + } + checker := &DataChecker{ + round: 0, + checkableRound: 0, + clusterDataCheckers: clusterDataChecker, + } + if err := checker.initializeFromCheckpoint(ctx, checkpointDataMap, checkpoint); err != nil { + return nil, errors.Trace(err) + } + return checker, nil +} + +func (c *DataChecker) initializeFromCheckpoint(ctx context.Context, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) error { + if checkpoint == nil { + return nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil + } + c.round = checkpoint.CheckpointItems[2].Round + 1 + c.checkableRound = checkpoint.CheckpointItems[2].Round + 1 + for _, clusterDataChecker := range c.clusterDataCheckers { + if err := clusterDataChecker.InitializeFromCheckpoint(ctx, checkpointDataMap[clusterDataChecker.clusterID], checkpoint); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// FindClusterReplicatedData checks whether the record is present in the replicated data +// cache [1] or [2] or another new record is present in the replicated data cache [1] or [2]. +func (c *DataChecker) FindClusterReplicatedData(clusterID string, schemaKey string, pk types.PkType, originTs uint64, multiCluster bool) (*decoder.Record, bool) { + clusterDataChecker, exists := c.clusterDataCheckers[clusterID] + if !exists { + return nil, false + } + if multiCluster { + record, skipped := clusterDataChecker.findClusterReplicatedDataInTimeWindow(0, schemaKey, pk, originTs) + if skipped || record != nil { + return record, skipped + } + } + record, skipped := clusterDataChecker.findClusterReplicatedDataInTimeWindow(1, schemaKey, pk, originTs) + if skipped || record != nil { + return record, skipped + } + return clusterDataChecker.findClusterReplicatedDataInTimeWindow(2, schemaKey, pk, originTs) +} + +func (c *DataChecker) FindSourceLocalData(localClusterID string, schemaKey string, pk types.PkType, commitTs uint64) bool { + for _, clusterDataChecker := range c.clusterDataCheckers { + if clusterDataChecker.clusterID == localClusterID { + continue + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(0, schemaKey, pk, commitTs) { + return true + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(1, schemaKey, pk, commitTs) { + return true + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(2, schemaKey, pk, commitTs) { + return true + } + } + return false +} + +func (c *DataChecker) CheckInNextTimeWindow(newTimeWindowData map[string]types.TimeWindowData) (*recorder.Report, error) { + if err := c.decodeNewTimeWindowData(newTimeWindowData); err != nil { + log.Error("failed to decode new time window data", zap.Error(err)) + return nil, errors.Annotate(err, "failed to decode new time window data") + } + report := recorder.NewReport(c.round) + + // Round 0: seed the LWW cache (round 0 data is empty by convention). + // Round 1+: LWW violation detection produces meaningful results. + // Round 2+: data loss / inconsistency detection (needs [1] and [2]). + // Round 3+: data redundant detection (needs [0], [1] and [2] with real data). + enableDataLoss := c.checkableRound >= 3 || (len(c.clusterDataCheckers) > 2 && c.checkableRound >= 2) + enableDataRedundant := c.checkableRound >= 3 + + for clusterID, clusterDataChecker := range c.clusterDataCheckers { + clusterDataChecker.Check(c, enableDataLoss, enableDataRedundant) + log.Info("checked records count", + zap.String("clusterID", clusterID), + zap.Uint64("round", c.round), + zap.Bool("enableDataLoss", enableDataLoss), + zap.Bool("enableDataRedundant", enableDataRedundant), + zap.Int("checkedRecordsCount", clusterDataChecker.checkedRecordsCount), + zap.Int("newTimeWindowRecordsCount", clusterDataChecker.newTimeWindowRecordsCount), + zap.Int("lwwSkippedRecordsCount", clusterDataChecker.lwwSkippedRecordsCount)) + report.AddClusterReport(clusterID, clusterDataChecker.GetReport()) + } + + if c.checkableRound < 3 { + c.checkableRound++ + } + c.round++ + return report, nil +} + +func (c *DataChecker) decodeNewTimeWindowData(newTimeWindowData map[string]types.TimeWindowData) error { + if len(newTimeWindowData) != len(c.clusterDataCheckers) { + return errors.Errorf("number of clusters mismatch, expected %d, got %d", len(c.clusterDataCheckers), len(newTimeWindowData)) + } + for clusterID, timeWindowData := range newTimeWindowData { + clusterDataChecker, exists := c.clusterDataCheckers[clusterID] + if !exists { + return errors.Errorf("cluster %s not found", clusterID) + } + for replicatedClusterID := range timeWindowData.TimeWindow.CheckpointTs { + if replicatedClusterID == clusterID { + return errors.Errorf("cluster %s has invalid checkpoint ts target to itself", clusterID) + } + if _, ok := c.clusterDataCheckers[replicatedClusterID]; !ok { + return errors.Errorf("cluster %s has checkpoint ts target %s not found", clusterID, replicatedClusterID) + } + } + clusterDataChecker.thisRoundTimeWindow = timeWindowData.TimeWindow + if err := clusterDataChecker.PrepareNextTimeWindowData(timeWindowData.TimeWindow); err != nil { + return errors.Trace(err) + } + for schemaPathKey, incrementalData := range timeWindowData.Data { + schemaKey := schemaPathKey.GetKey() + for _, contents := range incrementalData.DataContentSlices { + for _, content := range contents { + records, err := decoder.Decode(content, incrementalData.ColumnFieldTypes) + if err != nil { + return errors.Trace(err) + } + for _, record := range records { + clusterDataChecker.NewRecord(schemaKey, record) + } + } + } + } + } + + return nil +} diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go new file mode 100644 index 0000000000..0d940e5ee0 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -0,0 +1,961 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/stretchr/testify/require" +) + +func TestNewDataChecker(t *testing.T) { + t.Parallel() + + t.Run("create data checker", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "cluster1": { + PDAddrs: []string{"127.0.0.1:2379"}, + S3SinkURI: "s3://bucket/cluster1/", + S3ChangefeedID: "s3-cf-1", + }, + "cluster2": { + PDAddrs: []string{"127.0.0.1:2479"}, + S3SinkURI: "s3://bucket/cluster2/", + S3ChangefeedID: "s3-cf-2", + }, + } + + checker, initErr := NewDataChecker(context.Background(), clusterConfig, nil, nil) + require.NoError(t, initErr) + require.NotNil(t, checker) + require.Equal(t, uint64(0), checker.round) + require.Len(t, checker.clusterDataCheckers, 2) + require.Contains(t, checker.clusterDataCheckers, "cluster1") + require.Contains(t, checker.clusterDataCheckers, "cluster2") + }) +} + +func TestNewDataCheckerInitializeFromCheckpointError(t *testing.T) { + t.Parallel() + + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + invalidContent := []byte(`{"pkNames":["id"],"isDdl":false,"type":"INSERT","mysqlType":{"id":"int"},"data":[{"val":"x"}],"_tidb":{"commitTs":1}}`) + checkpointDataMap := map[string]map[cloudstorage.DmlPathKey]types.IncrementalData{ + "c1": { + {}: { + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {}: {invalidContent}, + }, + }, + }, + } + + _, err := NewDataChecker(context.Background(), clusterConfig, checkpointDataMap, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "column value of column id not found") +} + +func TestNewDataCheckerInitializeFromCheckpointMissingClusterInfo(t *testing.T) { + t.Parallel() + + t.Run("cluster missing in checkpoint item[2]", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + "c2": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + + _, err := NewDataChecker(context.Background(), clusterConfig, nil, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c2 not found in checkpoint item[2]") + }) + + t.Run("cluster missing in checkpoint item[1]", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + "c2": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + }, + }, + "c2": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + }, + }, + }) + + _, err := NewDataChecker(context.Background(), clusterConfig, nil, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c2 not found in checkpoint item[1]") + }) +} + +func TestNewClusterDataChecker(t *testing.T) { + t.Parallel() + + t.Run("create cluster data checker", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1", false) + require.NotNil(t, checker) + require.Equal(t, "cluster1", checker.clusterID) + require.Equal(t, uint64(0), checker.rightBoundary) + require.NotNil(t, checker.timeWindowDataCaches) + require.NotNil(t, checker.overDataCaches) + require.NotNil(t, checker.clusterViolationChecker) + }) +} + +func TestNewClusterViolationChecker(t *testing.T) { + t.Parallel() + + t.Run("create cluster violation checker", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + require.NotNil(t, checker) + require.Equal(t, "cluster1", checker.clusterID) + require.NotNil(t, checker.twoPreviousTimeWindowKeyVersionCache) + }) +} + +func TestClusterViolationChecker_Check(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("check new record", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record, report) + require.Empty(t, report.TableFailureItems) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache, schemaKey) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache[schemaKey], record.Pk) + }) + + t.Run("check duplicate old version", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record1 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + record2 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 50, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record1, report) + checker.Check(schemaKey, record2, report) + require.Empty(t, report.TableFailureItems) // Should skip duplicate old version + }) + + t.Run("check lww violation", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record1 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + record2 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 50, // OriginTs is less than record1's CommitTs, causing violation + }, + } + + checker.Check(schemaKey, record1, report) + checker.Check(schemaKey, record2, report) + require.Len(t, report.TableFailureItems, 1) + require.Contains(t, report.TableFailureItems, schemaKey) + tableItems := report.TableFailureItems[schemaKey] + require.Len(t, tableItems.LWWViolationItems, 1) + require.Equal(t, map[string]any{"id": "1"}, tableItems.LWWViolationItems[0].PK) + require.Equal(t, uint64(0), tableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(100), tableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(50), tableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(150), tableItems.LWWViolationItems[0].CommitTS) + }) +} + +func TestClusterViolationChecker_UpdateCache(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("update cache", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record, report) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache, schemaKey) + entry := checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 0, entry.previous) + + checker.UpdateCache() + entry = checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 1, entry.previous) + + checker.UpdateCache() + entry = checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 2, entry.previous) + + checker.UpdateCache() + // Entry should be removed after 2 updates + _, exists := checker.twoPreviousTimeWindowKeyVersionCache[schemaKey] + require.False(t, exists) + }) +} + +func TestNewTimeWindowDataCache(t *testing.T) { + t.Parallel() + + t.Run("create time window data cache", func(t *testing.T) { + t.Parallel() + leftBoundary := uint64(100) + rightBoundary := uint64(200) + checkpointTs := map[string]uint64{ + "cluster2": 150, + } + + cache := newTimeWindowDataCache(leftBoundary, rightBoundary, checkpointTs) + require.Equal(t, leftBoundary, cache.leftBoundary) + require.Equal(t, rightBoundary, cache.rightBoundary) + require.Equal(t, checkpointTs, cache.checkpointTs) + require.NotNil(t, cache.tableDataCaches) + }) +} + +func TestTimeWindowDataCache_NewRecord(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("add local record", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 0, + }, + } + + cache.NewRecord(schemaKey, record) + require.Contains(t, cache.tableDataCaches, schemaKey) + require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache, record.Pk) + require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache[record.Pk].records, record.CommitTs) + }) + + t.Run("add replicated record", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 100, + }, + } + + cache.NewRecord(schemaKey, record) + require.Contains(t, cache.tableDataCaches, schemaKey) + require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache, record.Pk) + require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache[record.Pk].records, record.OriginTs) + }) + + t.Run("skip record before left boundary", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 50, + OriginTs: 0, + }, + } + + cache.NewRecord(schemaKey, record) + require.NotContains(t, cache.tableDataCaches, schemaKey) + }) +} + +func TestClusterDataChecker_PrepareNextTimeWindowData(t *testing.T) { + t.Parallel() + + t.Run("prepare next time window data", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1", false) + checker.rightBoundary = 100 + + timeWindow := types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"cluster2": 150}, + } + + err := checker.PrepareNextTimeWindowData(timeWindow) + require.NoError(t, err) + require.Equal(t, uint64(200), checker.rightBoundary) + }) + + t.Run("mismatch left boundary", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1", false) + checker.rightBoundary = 100 + + timeWindow := types.TimeWindow{ + LeftBoundary: 150, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"cluster2": 150}, + } + + err := checker.PrepareNextTimeWindowData(timeWindow) + require.Error(t, err) + require.Contains(t, err.Error(), "mismatch") + }) +} + +// makeCanalJSON builds a canal-JSON formatted record for testing. +// pkID is the primary key value, commitTs is the TiDB commit timestamp, +// originTs is the origin timestamp (0 for locally-written records, non-zero for replicated records), +// val is a varchar column value. +func makeCanalJSON(pkID int, commitTs uint64, originTs uint64, val string) string { + originTsVal := "null" + if originTs > 0 { + originTsVal = fmt.Sprintf(`"%d"`, originTs) + } + return fmt.Sprintf( + `{"id":0,"database":"test","table":"t1","pkNames":["id"],"isDdl":false,"type":"INSERT",`+ + `"es":0,"ts":0,"sql":"","sqlType":{"id":4,"val":12,"_tidb_origin_ts":-5},`+ + `"mysqlType":{"id":"int","val":"varchar","_tidb_origin_ts":"bigint"},`+ + `"old":null,"data":[{"id":"%d","val":"%s","_tidb_origin_ts":%s}],`+ + `"_tidb":{"commitTs":%d}}`, + pkID, val, originTsVal, commitTs) +} + +// makeContent combines canal-JSON records with CRLF terminator. +func makeContent(records ...string) []byte { + return []byte(strings.Join(records, "\r\n")) +} + +// makeTWData builds a TimeWindowData for testing. +func makeTWData(left, right uint64, checkpointTs map[string]uint64, content []byte) types.TimeWindowData { + data := map[cloudstorage.DmlPathKey]types.IncrementalData{} + if content != nil { + data[cloudstorage.DmlPathKey{}] = types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {}: {content}, + }, + } + } + return types.TimeWindowData{ + TimeWindow: types.TimeWindow{ + LeftBoundary: left, + RightBoundary: right, + CheckpointTs: checkpointTs, + }, + Data: data, + } +} + +// defaultSchemaKey is the schema key produced by DmlPathKey{}.GetKey() +// which is QuoteSchema("", "") = "“.“" +var defaultSchemaKey = (&cloudstorage.DmlPathKey{}).GetKey() + +// TestDataChecker_FourRoundsCheck simulates 4 rounds with increasing data and verifies check results. +// Setup: 2 clusters (c1 locally-written, c2 replicated from c1). +// - Round 0: LWW cache is seeded (data is empty by convention). +// - Round 1+: LWW violation detection is active. +// - Round 2+: data loss / inconsistent detection is active. +// - Round 3+: data redundant detection is also active (needs [0],[1],[2] all populated). +func TestDataChecker_FourRoundsCheck(t *testing.T) { + t.Parallel() + ctx := context.Background() + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + // makeBaseRounds creates shared rounds 0 and 1 data for all subtests. + // c1 produces locally-written data, c2 receives matching replicated data from c1. + makeBaseRounds := func() [2]map[string]types.TimeWindowData { + return [2]map[string]types.TimeWindowData{ + // Round 0: [0, 100] + { + "c1": makeTWData(0, 100, map[string]uint64{"c2": 80}, + makeContent(makeCanalJSON(1, 50, 0, "a"))), + "c2": makeTWData(0, 100, nil, + makeContent(makeCanalJSON(1, 60, 50, "a"))), + }, + // Round 1: [100, 200] + { + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, + makeContent(makeCanalJSON(2, 150, 0, "b"))), + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(2, 160, 150, "b"))), + }, + } + } + + t.Run("all consistent", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "c"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + require.Equal(t, uint64(i), report.Round) + // Every round now produces cluster reports (LWW always runs). + require.Len(t, report.ClusterReports, 2, "round %d should have 2 cluster reports", i) + require.False(t, report.NeedFlush(), "round %d should not need flush (all consistent)", i) + for clusterID, cr := range report.ClusterReports { + require.Empty(t, cr.TableFailureItems, "round %d cluster %s should have no table failure items", i, clusterID) + } + } + }) + + t.Run("data loss detected", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + // Round 2: c1 has locally-written pk=3 but c2 has NO matching replicated data + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, nil), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + // c1 should detect data loss: pk=3 (commitTs=250) missing in c2's replicated data + c1Report := lastReport.ClusterReports["c1"] + require.NotNil(t, c1Report) + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, uint64(250), tableItems.DataLossItems[0].LocalCommitTS) + // c2 should have no issues + c2Report := lastReport.ClusterReports["c2"] + require.Empty(t, c2Report.TableFailureItems) + }) + + t.Run("no data loss when peer local newer version exists for same pk", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + // Round 2: c1 and c2 both write the same PK locally. + // c2 does not contain a replicated row with originTs=250, but has a newer + // local row (commitTs=270) for the same PK. This should be treated as LWW + // overwrite/skip instead of data loss. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c1-local"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 270, 0, "c2-local"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + c1Report := lastReport.ClusterReports["c1"] + if tableItems, ok := c1Report.TableFailureItems[defaultSchemaKey]; ok { + require.Empty(t, tableItems.DataLossItems) + } + }) + + t.Run("data inconsistent detected", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + // Round 2: c2 has replicated data for pk=3 but with wrong column value + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "WRONG"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + c1Report := lastReport.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Empty(t, tableItems.DataLossItems) + require.Len(t, tableItems.DataInconsistentItems, 1) + require.Equal(t, "c2", tableItems.DataInconsistentItems[0].PeerClusterID) + require.Equal(t, uint64(250), tableItems.DataInconsistentItems[0].LocalCommitTS) + require.Equal(t, uint64(260), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) + require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) + require.Equal(t, "val", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Column) + require.Equal(t, "c", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Local) + require.Equal(t, "WRONG", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Replicated) + }) + + t.Run("data redundant detected", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent( + makeCanalJSON(3, 260, 250, "c"), + makeCanalJSON(99, 240, 230, "x"), + )), + } + // Round 3: keep normal data so data redundant detection at round 3 + // checks round2 ([1]) and catches the orphan from round2. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + // c1 should have no data loss + c1Report := lastReport.ClusterReports["c1"] + require.Empty(t, c1Report.TableFailureItems) + // c2 should detect data redundant: pk=99 has no matching locally-written record in c1 + c2Report := lastReport.ClusterReports["c2"] + require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) + tableItems := c2Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataRedundantItems, 1) + require.Equal(t, uint64(230), tableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(240), tableItems.DataRedundantItems[0].ReplicatedCommitTS) + }) + + t.Run("lww violation detected", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "c"))), + } + // Round 3: c1 has locally-written pk=5 (commitTs=350, compareTs=350) and + // replicated pk=5 from c2 (commitTs=370, originTs=310, compareTs=310). + // Since 350 >= 310 with commitTs 350 < 370, this is an LWW violation. + // c2 also has matching records to avoid data loss/redundant noise. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent( + makeCanalJSON(5, 350, 0, "e"), + makeCanalJSON(5, 370, 310, "e"), + )), + "c2": makeTWData(300, 400, nil, + makeContent( + makeCanalJSON(5, 310, 0, "e"), + makeCanalJSON(5, 360, 350, "e"), + )), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + c1Report := lastReport.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + c1TableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c1TableItems.LWWViolationItems, 1) + require.Equal(t, uint64(0), c1TableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(350), c1TableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(310), c1TableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(370), c1TableItems.LWWViolationItems[0].CommitTS) + // c2 should have no LWW violation (its records are ordered correctly: + // locally-written commitTs=310 compareTs=310, replicated commitTs=360 compareTs=350, 310 < 350) + c2Report := lastReport.ClusterReports["c2"] + if c2TableItems, ok := c2Report.TableFailureItems[defaultSchemaKey]; ok { + require.Empty(t, c2TableItems.LWWViolationItems) + } + }) + + // lww violation detected at round 1: LWW is active from round 1, + // so a violation introduced in round 1 data should surface immediately. + t.Run("lww violation detected at round 1", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + // Round 0: [0, 100] — c1 writes pk=1 (commitTs=50, compareTs=50) + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, + makeContent(makeCanalJSON(1, 50, 0, "a"))), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: [100, 200] — c1 writes pk=1 again: + // locally-written pk=1 (commitTs=150, originTs=0, compareTs=150) + // replicated pk=1 (commitTs=180, originTs=120, compareTs=120) + // The LWW cache already has pk=1 compareTs=50 from round 0. + // Record order: commitTs=150 (compareTs=150 > cached 50 → update cache), + // commitTs=180 (compareTs=120 < cached 150 → VIOLATION) + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, nil, + makeContent( + makeCanalJSON(1, 150, 0, "b"), + makeCanalJSON(1, 180, 120, "b"), + )), + "c2": makeTWData(100, 200, nil, nil), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush(), "round 0 should not need flush") + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + + // LWW violation should be detected at round 1 + require.True(t, report1.NeedFlush(), "round 1 should detect LWW violation") + c1Report := report1.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + c1TableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c1TableItems.LWWViolationItems, 1) + require.Equal(t, uint64(0), c1TableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(150), c1TableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(120), c1TableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(180), c1TableItems.LWWViolationItems[0].CommitTS) + }) + + // data loss detected at round 3 for 2-cluster mode: + // enableDataLoss is active from round 3 when only 2 clusters are involved. + // A record in round 2 whose commitTs > checkpointTs enters [1] at round 3, + // and if the replicated counterpart is missing, data loss is detected at round 3. + t.Run("data loss detected at round 3", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, nil), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: consistent data. + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, + makeContent(makeCanalJSON(1, 150, 0, "a"))), + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(1, 160, 150, "a"))), + } + // Round 2: c1 writes pk=2 (commitTs=250), checkpointTs["c2"]=240. + // Since 250 > 240, this record requires replication checking. + // c2 has no matching replicated data in this round. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(2, 250, 0, "b"))), + "c2": makeTWData(200, 300, nil, nil), + } + // Round 3: data loss detection is enabled in 2-cluster mode. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(3, 350, 0, "c"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(3, 360, 350, "c"))), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush()) + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + require.False(t, report1.NeedFlush(), "round 1 should not detect data loss yet") + + report2, err := checker.CheckInNextTimeWindow(round2) + require.NoError(t, err) + require.False(t, report2.NeedFlush(), "round 2 should not detect data loss yet") + + report3, err := checker.CheckInNextTimeWindow(round3) + require.NoError(t, err) + require.True(t, report3.NeedFlush(), "round 3 should detect data loss") + c1Report := report3.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, uint64(250), tableItems.DataLossItems[0].LocalCommitTS) + }) + + // data redundant detected at round 3 (not round 2): + // dataRedundantDetection checks timeWindowDataCaches[1]. + // At round 2 [0]=round 0 (empty) so FindSourceLocalData may miss data in + // that window → enableDataRedundant is false to avoid false positives. + // At round 3 [0]=round 1, [1]=round 2, [2]=round 3 are all populated + // with real data, so enableDataRedundant=true and an orphan in [1] is caught. + // + // This test puts an orphan pk=99 in round 2 only: + // - Round 2: orphan in [2] but enableDataRedundant=false → NOT flagged. + // - Round 3: orphan moved to [1] and enableDataRedundant=true → flagged. + t.Run("data redundant detected at round 3 not round 2", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, nil), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: normal consistent data. + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, + makeContent(makeCanalJSON(1, 150, 0, "a"))), + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(1, 160, 150, "a"))), + } + // Round 2: c2 has orphan replicated pk=99 (originTs=230) in [2]. + // enableDataRedundant=false at round 2, so it must NOT be flagged. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 280}, + makeContent(makeCanalJSON(2, 250, 0, "b"))), + "c2": makeTWData(200, 300, nil, + makeContent( + makeCanalJSON(2, 260, 250, "b"), + makeCanalJSON(99, 240, 230, "x"), // orphan replicated + )), + } + // Round 3: no new orphan in [2]; enableDataRedundant=true at round 3 should + // catch the orphan that is now in [1] (from round 2). + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(3, 350, 0, "c"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(3, 360, 350, "c"))), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush(), "round 0 should not need flush") + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + require.False(t, report1.NeedFlush(), "round 1 should not need flush") + + report2, err := checker.CheckInNextTimeWindow(round2) + require.NoError(t, err) + // Round 2: redundant detection is NOT enabled; the orphan pk=99 should NOT be flagged. + require.False(t, report2.NeedFlush(), "round 2 should not flag data redundant yet") + + report3, err := checker.CheckInNextTimeWindow(round3) + require.NoError(t, err) + // Round 3: redundant detection is enabled; the orphan pk=99 in [1] (round 2) + // is now caught. + require.True(t, report3.NeedFlush(), "round 3 should detect data redundant") + c2Report := report3.ClusterReports["c2"] + require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) + c2TableItems := c2Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c2TableItems.DataRedundantItems, 1) + require.Equal(t, uint64(230), c2TableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(240), c2TableItems.DataRedundantItems[0].ReplicatedCommitTS) + }) +} + +func TestDataChecker_CheckInNextTimeWindowInvalidCheckpointTarget(t *testing.T) { + t.Parallel() + ctx := context.Background() + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + t.Run("unknown target cluster", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, map[string]uint64{"c3": 80}, nil), + "c2": makeTWData(0, 100, nil, nil), + } + _, err := checker.CheckInNextTimeWindow(round0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c1 has checkpoint ts target c3 not found") + }) + + t.Run("self target cluster", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, map[string]uint64{"c1": 80}, nil), + "c2": makeTWData(0, 100, nil, nil), + } + _, err := checker.CheckInNextTimeWindow(round0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c1 has invalid checkpoint ts target to itself") + }) +} diff --git a/cmd/multi-cluster-consistency-checker/checker/failpoint.go b/cmd/multi-cluster-consistency-checker/checker/failpoint.go new file mode 100644 index 0000000000..0339f9c389 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/failpoint.go @@ -0,0 +1,99 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "encoding/json" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// envKey is the environment variable that controls the output file path. +const envKey = "TICDC_MULTI_CLUSTER_CONSISTENCY_CHECKER_FAILPOINT_RECORD_FILE" + +// RowRecord captures the essential identity of a single affected row. +type RowRecord struct { + CommitTs uint64 `json:"commitTs"` + OriginTs uint64 `json:"originTs"` + PrimaryKeys map[string]any `json:"primaryKeys"` +} + +// Record is one line written to the JSONL file. +type Record struct { + Time string `json:"time"` + Failpoint string `json:"failpoint"` + Rows []RowRecord `json:"rows"` +} + +var ( + initOnce sync.Once + mu sync.Mutex + file *os.File + disabled atomic.Bool +) + +func ensureFile() { + initOnce.Do(func() { + path := os.Getenv(envKey) + if path == "" { + disabled.Store(true) + return + } + var err error + file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + log.Warn("failed to open failpoint record file, recording disabled", + zap.String("path", path), zap.Error(err)) + disabled.Store(true) + return + } + log.Info("failpoint record file opened", zap.String("path", path)) + }) +} + +// Write persists one failpoint event to the JSONL file. +// It is safe for concurrent use. +// If the env var is not set the call is a no-op (zero allocation). +func Write(failpoint string, rows []RowRecord) { + if disabled.Load() { + return + } + ensureFile() + if file == nil { + return + } + + rec := Record{ + Time: time.Now().UTC().Format(time.RFC3339Nano), + Failpoint: failpoint, + Rows: rows, + } + data, err := json.Marshal(rec) + if err != nil { + log.Warn("failed to marshal failpoint record", zap.Error(err)) + return + } + data = append(data, '\n') + + mu.Lock() + defer mu.Unlock() + if _, err := file.Write(data); err != nil { + log.Warn("failed to write failpoint record", zap.Error(err)) + } +} diff --git a/cmd/multi-cluster-consistency-checker/config/config.example.toml b/cmd/multi-cluster-consistency-checker/config/config.example.toml new file mode 100644 index 0000000000..b5430add8e --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config.example.toml @@ -0,0 +1,46 @@ +# Example configuration file for multi-cluster consistency checker + +# Global configuration +[global] + +# Log level: debug, info, warn, error, fatal, panic +log-level = "info" + +# Data directory configuration, contains report and checkpoint data +data-dir = "/tmp/multi-cluster-consistency-checker-data" + + # Tables configuration + [global.tables] + schema1 = ["table1", "table2"] + schema2 = ["table1", "table2"] + +# Cluster configurations +[clusters] + # First cluster configuration + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket-name/cluster1/" + s3-changefeed-id = "s3-changefeed-id-1" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "active-active-changefeed-id-from-cluster1-to-cluster2" } + + # Second cluster configuration + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket-name/cluster2/" + s3-changefeed-id = "s3-changefeed-id-2" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "active-active-changefeed-id-from-cluster2-to-cluster1" } + + # Third cluster configuration (optional) + # [clusters.cluster3] + # pd-addrs = ["127.0.0.1:2579"] + # cdc-addr = "127.0.0.1:8500" + # s3-sink-uri = "s3://bucket-name/cluster3/" + # s3-changefeed-id = "s3-changefeed-id-3" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + # [clusters.cluster3.peer-cluster-changefeed-config] + # cluster1 = { changefeed-id = "active-active-changefeed-id-from-cluster3-to-cluster1" } + # cluster2 = { changefeed-id = "active-active-changefeed-id-from-cluster3-to-cluster2" } diff --git a/cmd/multi-cluster-consistency-checker/config/config.go b/cmd/multi-cluster-consistency-checker/config/config.go new file mode 100644 index 0000000000..7875c4a6af --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config.go @@ -0,0 +1,153 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + + "github.com/BurntSushi/toml" + "github.com/pingcap/ticdc/pkg/security" +) + +// Config represents the configuration for multi-cluster consistency checker +type Config struct { + // GlobalConfig contains global settings (reserved for future use) + GlobalConfig GlobalConfig `toml:"global" json:"global"` + + // Clusters contains configurations for multiple clusters + Clusters map[string]ClusterConfig `toml:"clusters" json:"clusters"` +} + +const DefaultMaxReportFiles = 1000 + +// GlobalConfig contains global configuration settings +type GlobalConfig struct { + LogLevel string `toml:"log-level" json:"log-level"` + DataDir string `toml:"data-dir" json:"data-dir"` + MaxReportFiles int `toml:"max-report-files" json:"max-report-files"` + Tables map[string][]string `toml:"tables" json:"tables"` +} + +type PeerClusterChangefeedConfig struct { + // ChangefeedID is the changefeed ID for the changefeed + ChangefeedID string `toml:"changefeed-id" json:"changefeed-id"` +} + +// ClusterConfig represents configuration for a single cluster +type ClusterConfig struct { + // PDAddrs is the addresses of the PD (Placement Driver) servers + PDAddrs []string `toml:"pd-addrs" json:"pd-addrs"` + + // S3SinkURI is the S3 sink URI for this cluster + S3SinkURI string `toml:"s3-sink-uri" json:"s3-sink-uri"` + + // S3ChangefeedID is the changefeed ID for the S3 changefeed + S3ChangefeedID string `toml:"s3-changefeed-id" json:"s3-changefeed-id"` + + // SecurityConfig is the security configuration for the cluster + SecurityConfig security.Credential `toml:"security-config" json:"security-config"` + + // PeerClusterChangefeedConfig is the configuration for the changefeed of the peer cluster + // mapping from peer cluster ID to the changefeed configuration + PeerClusterChangefeedConfig map[string]PeerClusterChangefeedConfig `toml:"peer-cluster-changefeed-config" json:"peer-cluster-changefeed-config"` +} + +// loadConfig loads the configuration from a TOML file +func LoadConfig(path string) (*Config, error) { + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil, fmt.Errorf("config file does not exist: %s", path) + } + + cfg := &Config{ + Clusters: make(map[string]ClusterConfig), + } + + meta, err := toml.DecodeFile(path, cfg) + if err != nil { + return nil, fmt.Errorf("failed to decode config file: %w", err) + } + + // Apply defaults + if cfg.GlobalConfig.MaxReportFiles <= 0 { + cfg.GlobalConfig.MaxReportFiles = DefaultMaxReportFiles + } + + // Validate DataDir + if cfg.GlobalConfig.DataDir == "" { + return nil, fmt.Errorf("global: data-dir is required") + } + + // Validate Tables + if len(cfg.GlobalConfig.Tables) == 0 { + return nil, fmt.Errorf("global: at least one schema must be configured in tables") + } + for schema, tables := range cfg.GlobalConfig.Tables { + if len(tables) == 0 { + return nil, fmt.Errorf("global: tables[%s]: at least one table must be configured", schema) + } + } + + // Validate that at least two clusters are configured. + // Single-cluster mode cannot provide cross-cluster consistency checks. + if len(cfg.Clusters) < 2 { + return nil, fmt.Errorf("at least two clusters must be configured") + } + + // Validate cluster configurations + for name, cluster := range cfg.Clusters { + if len(cluster.PDAddrs) == 0 { + return nil, fmt.Errorf("cluster '%s': pd-addrs is required", name) + } + if cluster.S3SinkURI == "" { + return nil, fmt.Errorf("cluster '%s': s3-sink-uri is required", name) + } + if cluster.S3ChangefeedID == "" { + return nil, fmt.Errorf("cluster '%s': s3-changefeed-id is required", name) + } + if len(cluster.PeerClusterChangefeedConfig) != len(cfg.Clusters)-1 { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config is not entirely configured", name) + } + for peerClusterID, peerClusterChangefeedConfig := range cluster.PeerClusterChangefeedConfig { + if peerClusterID == name { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config references itself", name) + } + if _, ok := cfg.Clusters[peerClusterID]; !ok { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config references unknown cluster '%s'", name, peerClusterID) + } + if peerClusterChangefeedConfig.ChangefeedID == "" { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config[%s]: changefeed-id is required", name, peerClusterID) + } + } + } + + // Check for unknown configuration keys + if undecoded := meta.Undecoded(); len(undecoded) > 0 { + // Filter out keys under [global] and [clusters] sections + var unknownKeys []string + for _, key := range undecoded { + keyStr := key.String() + // Only warn about keys that are not in the expected sections + if keyStr != "global" && keyStr != "clusters" { + unknownKeys = append(unknownKeys, keyStr) + } + } + if len(unknownKeys) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unknown configuration keys found: %v\n", unknownKeys) + } + } + + return cfg, nil +} diff --git a/cmd/multi-cluster-consistency-checker/config/config_test.go b/cmd/multi-cluster-consistency-checker/config/config_test.go new file mode 100644 index 0000000000..8a1e03c9b1 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config_test.go @@ -0,0 +1,446 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadConfig(t *testing.T) { + t.Parallel() + + t.Run("valid config", func(t *testing.T) { + t.Parallel() + // Create a temporary config file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1", "table2"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.NoError(t, err) + require.NotNil(t, cfg) + require.Equal(t, "info", cfg.GlobalConfig.LogLevel) + require.Equal(t, "/tmp/data", cfg.GlobalConfig.DataDir) + require.Len(t, cfg.Clusters, 2) + require.Contains(t, cfg.Clusters, "cluster1") + require.Contains(t, cfg.Clusters, "cluster2") + require.Equal(t, []string{"127.0.0.1:2379"}, cfg.Clusters["cluster1"].PDAddrs) + require.Equal(t, "s3://bucket/cluster1/", cfg.Clusters["cluster1"].S3SinkURI) + require.Equal(t, "s3-cf-1", cfg.Clusters["cluster1"].S3ChangefeedID) + require.Len(t, cfg.Clusters["cluster1"].PeerClusterChangefeedConfig, 1) + require.Equal(t, "cf-1-to-2", cfg.Clusters["cluster1"].PeerClusterChangefeedConfig["cluster2"].ChangefeedID) + // max-report-files not set, should default to DefaultMaxReportFiles + require.Equal(t, DefaultMaxReportFiles, cfg.GlobalConfig.MaxReportFiles) + }) + + t.Run("custom max-report-files", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" +max-report-files = 50 + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.NoError(t, err) + require.Equal(t, 50, cfg.GlobalConfig.MaxReportFiles) + }) + + t.Run("file not exists", func(t *testing.T) { + t.Parallel() + cfg, err := LoadConfig("/nonexistent/path/config.toml") + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "config file does not exist") + }) + + t.Run("invalid toml", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := `invalid toml content [` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "failed to decode config file") + }) + + t.Run("missing data-dir", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "data-dir is required") + }) + + t.Run("missing tables", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least one schema must be configured in tables") + }) + + t.Run("empty table list in schema", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = [] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least one table must be configured") + }) + + t.Run("no clusters", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least two clusters must be configured") + }) + + t.Run("single cluster is invalid", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least two clusters must be configured") + }) + + t.Run("missing pd-addrs", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "pd-addrs is required") + }) + + t.Run("missing s3-sink-uri", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "s3-sink-uri is required") + }) + + t.Run("missing s3-changefeed-id", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "s3-changefeed-id is required") + }) + + t.Run("incomplete replicated cluster changefeed config", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "peer-cluster-changefeed-config is not entirely configured") + }) + + t.Run("missing changefeed-id in peer cluster config", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = {} + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "changefeed-id is required") + }) +} diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer.go b/cmd/multi-cluster-consistency-checker/consumer/consumer.go new file mode 100644 index 0000000000..458db8c225 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer.go @@ -0,0 +1,874 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumer + +import ( + "context" + "encoding/json" + "fmt" + "path" + "strings" + "sync" + + perrors "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +type ( + fileIndexRange map[cloudstorage.FileIndexKey]indexRange + fileIndexKeyMap map[cloudstorage.FileIndexKey]uint64 +) + +type indexRange struct { + start uint64 + end uint64 +} + +const defaultGlobalReadConcurrencyLimit = 128 +const defaultGlobalWalkConcurrencyLimit = 64 +const defaultTableWorkerConcurrencyLimit = 256 + +func updateTableDMLIdxMap( + tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, + dmlkey cloudstorage.DmlPathKey, + fileIdx *cloudstorage.FileIndex, +) { + m, ok := tableDMLIdxMap[dmlkey] + if !ok { + tableDMLIdxMap[dmlkey] = fileIndexKeyMap{ + fileIdx.FileIndexKey: fileIdx.Idx, + } + } else if fileIdx.Idx > m[fileIdx.FileIndexKey] { + m[fileIdx.FileIndexKey] = fileIdx.Idx + } +} + +type schemaDefinition struct { + path string + columnFieldTypes map[string]*ptypes.FieldType +} + +type schemaKey struct { + schema string + table string +} + +var ErrWalkDirEnd = perrors.Normalize("walk dir end", perrors.RFCCodeText("CDC:ErrWalkDirEnd")) + +type CurrentTableVersion struct { + mu sync.RWMutex + currentTableVersionMap map[schemaKey]types.VersionKey +} + +func NewCurrentTableVersion() *CurrentTableVersion { + return &CurrentTableVersion{ + currentTableVersionMap: make(map[schemaKey]types.VersionKey), + } +} + +// GetCurrentTableVersion returns the current table version for a given schema and table +func (cvt *CurrentTableVersion) GetCurrentTableVersion(schema, table string) types.VersionKey { + cvt.mu.RLock() + defer cvt.mu.RUnlock() + return cvt.currentTableVersionMap[schemaKey{schema: schema, table: table}] +} + +// UpdateCurrentTableVersion updates the current table version for a given schema and table +func (cvt *CurrentTableVersion) UpdateCurrentTableVersion(schema, table string, version types.VersionKey) { + cvt.mu.Lock() + defer cvt.mu.Unlock() + cvt.currentTableVersionMap[schemaKey{schema: schema, table: table}] = version +} + +type SchemaDefinitions struct { + mu sync.RWMutex + schemaDefinitionMap map[cloudstorage.SchemaPathKey]schemaDefinition +} + +func NewSchemaDefinitions() *SchemaDefinitions { + return &SchemaDefinitions{ + schemaDefinitionMap: make(map[cloudstorage.SchemaPathKey]schemaDefinition), + } +} + +// GetColumnFieldTypes returns the pre-parsed column field types for a given schema and table version +func (sp *SchemaDefinitions) GetColumnFieldTypes(schema, table string, version uint64) (map[string]*ptypes.FieldType, error) { + schemaPathKey := cloudstorage.SchemaPathKey{ + Schema: schema, + Table: table, + TableVersion: version, + } + sp.mu.RLock() + schemaDefinition, ok := sp.schemaDefinitionMap[schemaPathKey] + sp.mu.RUnlock() + if !ok { + return nil, errors.Errorf("schema definition not found for schema: %s, table: %s, version: %d", schema, table, version) + } + return schemaDefinition.columnFieldTypes, nil +} + +// SetSchemaDefinition sets the schema definition for a given schema and table version. +// It pre-parses the column field types from the table definition for later use by the decoder. +func (sp *SchemaDefinitions) SetSchemaDefinition(schemaPathKey cloudstorage.SchemaPathKey, filePath string, tableDefinition *cloudstorage.TableDefinition) error { + columnFieldTypes := make(map[string]*ptypes.FieldType) + if tableDefinition != nil { + for i, col := range tableDefinition.Columns { + colInfo, err := col.ToTiColumnInfo(int64(i)) + if err != nil { + return errors.Annotatef(err, "failed to convert column %s to FieldType", col.Name) + } + columnFieldTypes[col.Name] = &colInfo.FieldType + } + } + sp.mu.Lock() + sp.schemaDefinitionMap[schemaPathKey] = schemaDefinition{ + path: filePath, + columnFieldTypes: columnFieldTypes, + } + sp.mu.Unlock() + return nil +} + +// RemoveSchemaDefinitionWithCondition removes the schema definition for a given condition +func (sp *SchemaDefinitions) RemoveSchemaDefinitionWithCondition(condition func(schemaPathKey cloudstorage.SchemaPathKey) bool) { + sp.mu.Lock() + for schemaPathkey := range sp.schemaDefinitionMap { + if condition(schemaPathkey) { + delete(sp.schemaDefinitionMap, schemaPathkey) + } + } + sp.mu.Unlock() +} + +type TableDMLIdx struct { + mu sync.Mutex + tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap +} + +func NewTableDMLIdx() *TableDMLIdx { + return &TableDMLIdx{ + tableDMLIdxMap: make(map[cloudstorage.DmlPathKey]fileIndexKeyMap), + } +} + +func (t *TableDMLIdx) UpdateDMLIdxMapByStartPath(dmlkey cloudstorage.DmlPathKey, fileIdx *cloudstorage.FileIndex) { + t.mu.Lock() + defer t.mu.Unlock() + if originalFileIndexKeyMap, ok := t.tableDMLIdxMap[dmlkey]; !ok { + t.tableDMLIdxMap[dmlkey] = fileIndexKeyMap{ + fileIdx.FileIndexKey: fileIdx.Idx, + } + } else { + if fileIdx.Idx > originalFileIndexKeyMap[fileIdx.FileIndexKey] { + originalFileIndexKeyMap[fileIdx.FileIndexKey] = fileIdx.Idx + } + } +} + +func (t *TableDMLIdx) DiffNewTableDMLIdxMap( + newTableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, +) map[cloudstorage.DmlPathKey]fileIndexRange { + resMap := make(map[cloudstorage.DmlPathKey]fileIndexRange) + t.mu.Lock() + defer t.mu.Unlock() + for newDMLPathKey, newFileIndexKeyMap := range newTableDMLIdxMap { + origFileIndexKeyMap, ok := t.tableDMLIdxMap[newDMLPathKey] + if !ok { + t.tableDMLIdxMap[newDMLPathKey] = newFileIndexKeyMap + resMap[newDMLPathKey] = make(fileIndexRange) + for indexKey, newEndVal := range newFileIndexKeyMap { + resMap[newDMLPathKey][indexKey] = indexRange{ + start: 1, + end: newEndVal, + } + } + continue + } + for indexKey, newEndVal := range newFileIndexKeyMap { + origEndVal := origFileIndexKeyMap[indexKey] + if newEndVal > origEndVal { + origFileIndexKeyMap[indexKey] = newEndVal + if _, ok := resMap[newDMLPathKey]; !ok { + resMap[newDMLPathKey] = make(fileIndexRange) + } + resMap[newDMLPathKey][indexKey] = indexRange{ + start: origEndVal + 1, + end: newEndVal, + } + } + } + } + return resMap +} + +type S3Consumer struct { + s3Storage storage.ExternalStorage + fileExtension string + dateSeparator string + fileIndexWidth int + tables map[string][]string + // readLimiter limits the total number of concurrent ReadFile calls. + readLimiter chan struct{} + // walkLimiter limits the total number of concurrent WalkDir calls. + walkLimiter chan struct{} + // tableWorkerConcurrencyLimit limits table-level goroutines in top-level flows. + tableWorkerConcurrencyLimit int + + // skip the first round data download + skipDownloadData bool + + currentTableVersion *CurrentTableVersion + tableDMLIdx *TableDMLIdx + schemaDefinitions *SchemaDefinitions +} + +func NewS3Consumer( + s3Storage storage.ExternalStorage, + tables map[string][]string, +) *S3Consumer { + return &S3Consumer{ + s3Storage: s3Storage, + fileExtension: ".json", + dateSeparator: config.DateSeparatorDay.String(), + fileIndexWidth: config.DefaultFileIndexWidth, + tables: tables, + readLimiter: make( + chan struct{}, + defaultGlobalReadConcurrencyLimit, + ), + walkLimiter: make( + chan struct{}, + defaultGlobalWalkConcurrencyLimit, + ), + tableWorkerConcurrencyLimit: defaultTableWorkerConcurrencyLimit, + + skipDownloadData: true, + + currentTableVersion: NewCurrentTableVersion(), + tableDMLIdx: NewTableDMLIdx(), + schemaDefinitions: NewSchemaDefinitions(), + } +} + +func (c *S3Consumer) acquireReadSlot(ctx context.Context) error { + if c.readLimiter == nil { + return nil + } + select { + case c.readLimiter <- struct{}{}: + return nil + case <-ctx.Done(): + return errors.Trace(ctx.Err()) + } +} + +func (c *S3Consumer) releaseReadSlot() { + if c.readLimiter == nil { + return + } + <-c.readLimiter +} + +func (c *S3Consumer) acquireWalkSlot(ctx context.Context) error { + if c.walkLimiter == nil { + return nil + } + select { + case c.walkLimiter <- struct{}{}: + return nil + case <-ctx.Done(): + return errors.Trace(ctx.Err()) + } +} + +func (c *S3Consumer) releaseWalkSlot() { + if c.walkLimiter == nil { + return + } + <-c.walkLimiter +} + +func (c *S3Consumer) InitializeFromCheckpoint( + ctx context.Context, clusterID string, checkpoint *recorder.Checkpoint, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + if checkpoint == nil { + return nil, nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil, nil + } + c.skipDownloadData = false + scanRanges, err := checkpoint.ToScanRange(clusterID) + if err != nil { + return nil, errors.Trace(err) + } + var mu sync.Mutex + // Combine DML data and schema data into result + result := make(map[cloudstorage.DmlPathKey]types.IncrementalData) + eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(c.tableWorkerConcurrencyLimit) + for schemaTableKey, scanRange := range scanRanges { + eg.Go(func() error { + scanVersions, err := c.downloadSchemaFilesWithScanRange( + egCtx, schemaTableKey.Schema, schemaTableKey.Table, scanRange.StartVersionKey, scanRange.EndVersionKey, scanRange.EndDataPath) + if err != nil { + return errors.Trace(err) + } + err = c.downloadDataFilesWithScanRange( + egCtx, schemaTableKey.Schema, schemaTableKey.Table, scanVersions, scanRange, + func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ) { + mu.Lock() + result[dmlPathKey] = types.IncrementalData{ + DataContentSlices: dmlSlices, + ColumnFieldTypes: columnFieldTypes, + } + mu.Unlock() + }, + ) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +func (c *S3Consumer) downloadSchemaFilesWithScanRange( + ctx context.Context, + schema, table string, + startVersionKey string, + endVersionKey string, + endDataPath string, +) ([]types.VersionKey, error) { + metaSubDir := fmt.Sprintf("%s/%s/meta/", schema, table) + opt := &storage.WalkOption{ + SubDir: metaSubDir, + ObjPrefix: "schema_", + // TODO: StartAfter: startVersionKey, + } + + var startSchemaKey, endSchemaKey cloudstorage.SchemaPathKey + _, err := startSchemaKey.ParseSchemaFilePath(startVersionKey) + if err != nil { + return nil, errors.Trace(err) + } + _, err = endSchemaKey.ParseSchemaFilePath(endVersionKey) + if err != nil { + return nil, errors.Trace(err) + } + + var scanVersions []types.VersionKey + newVersionPaths := make(map[cloudstorage.SchemaPathKey]string) + scanVersions = append(scanVersions, types.VersionKey{ + Version: startSchemaKey.TableVersion, + VersionPath: startVersionKey, + }) + newVersionPaths[startSchemaKey] = startVersionKey + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endVersionKey < filePath { + return ErrWalkDirEnd + } + if !cloudstorage.IsSchemaFile(filePath) { + return nil + } + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if schemaKey.TableVersion > startSchemaKey.TableVersion { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: schemaKey.TableVersion, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }) + }(); err != nil && !errors.Is(err, ErrWalkDirEnd) { + return nil, errors.Trace(err) + } + + if err := c.downloadSchemaFiles(ctx, newVersionPaths); err != nil { + return nil, errors.Trace(err) + } + + c.currentTableVersion.UpdateCurrentTableVersion(schema, table, types.VersionKey{ + Version: endSchemaKey.TableVersion, + VersionPath: endVersionKey, + DataPath: endDataPath, + }) + + return scanVersions, nil +} + +// downloadDataFilesWithScanRange downloads data files for a given scan range. +// consumeFunc is called from multiple goroutines concurrently and must be goroutine-safe. +func (c *S3Consumer) downloadDataFilesWithScanRange( + ctx context.Context, + schema, table string, + scanVersions []types.VersionKey, + scanRange *recorder.ScanRange, + consumeFunc func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ), +) error { + eg, egCtx := errgroup.WithContext(ctx) + for _, version := range scanVersions { + eg.Go(func() error { + newFiles, err := c.getNewFilesForSchemaPathKeyWithEndPath(egCtx, schema, table, version.Version, scanRange.StartDataPath, scanRange.EndDataPath) + if err != nil { + return errors.Trace(err) + } + dmlData, err := c.downloadDMLFiles(egCtx, newFiles) + if err != nil { + return errors.Trace(err) + } + columnFieldTypes, err := c.schemaDefinitions.GetColumnFieldTypes(schema, table, version.Version) + if err != nil { + return errors.Trace(err) + } + for dmlPathKey, dmlSlices := range dmlData { + consumeFunc(dmlPathKey, dmlSlices, columnFieldTypes) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +func (c *S3Consumer) getNewFilesForSchemaPathKeyWithEndPath( + ctx context.Context, + schema, table string, + version uint64, + startDataPath string, + endDataPath string, +) (map[cloudstorage.DmlPathKey]fileIndexRange, error) { + schemaPrefix := path.Join(schema, table, fmt.Sprintf("%d", version)) + opt := &storage.WalkOption{ + SubDir: schemaPrefix, + // TODO: StartAfter: startDataPath, + } + newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endDataPath < filePath { + return ErrWalkDirEnd + } + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if filePath == startDataPath { + c.tableDMLIdx.UpdateDMLIdxMapByStartPath(dmlkey, fileIdx) + } else { + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + } + } + return nil + }) + }(); err != nil && !errors.Is(err, ErrWalkDirEnd) { + return nil, errors.Trace(err) + } + return c.tableDMLIdx.DiffNewTableDMLIdxMap(newTableDMLIdxMap), nil +} + +// downloadSchemaFiles downloads schema files concurrently for given schema path keys +func (c *S3Consumer) downloadSchemaFiles( + ctx context.Context, + newVersionPaths map[cloudstorage.SchemaPathKey]string, +) error { + eg, egCtx := errgroup.WithContext(ctx) + + log.Debug("starting concurrent schema file download", zap.Int("totalSchemas", len(newVersionPaths))) + for schemaPathKey, filePath := range newVersionPaths { + eg.Go(func() error { + if err := c.readAndParseSchemaFile(egCtx, schemaPathKey, filePath); err != nil { + return errors.Trace(err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +// readAndParseSchemaFile reads a schema file with global read concurrency control. +func (c *S3Consumer) readAndParseSchemaFile( + ctx context.Context, + schemaPathKey cloudstorage.SchemaPathKey, + filePath string, +) error { + if err := c.acquireReadSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseReadSlot() + + content, err := c.s3Storage.ReadFile(ctx, filePath) + if err != nil { + return errors.Annotatef(err, "failed to read schema file: %s", filePath) + } + + tableDefinition := &cloudstorage.TableDefinition{} + if err := json.Unmarshal(content, tableDefinition); err != nil { + return errors.Annotatef(err, "failed to unmarshal schema file: %s", filePath) + } + if err := c.schemaDefinitions.SetSchemaDefinition(schemaPathKey, filePath, tableDefinition); err != nil { + return errors.Trace(err) + } + return nil +} + +func (c *S3Consumer) discoverAndDownloadNewTableVersions( + ctx context.Context, + schema, table string, +) ([]types.VersionKey, error) { + currentVersion := c.currentTableVersion.GetCurrentTableVersion(schema, table) + metaSubDir := fmt.Sprintf("%s/%s/meta/", schema, table) + opt := &storage.WalkOption{ + SubDir: metaSubDir, + ObjPrefix: "schema_", + // TODO: StartAfter: currentVersion.versionPath, + } + + var scanVersions []types.VersionKey + newVersionPaths := make(map[cloudstorage.SchemaPathKey]string) + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if !cloudstorage.IsSchemaFile(filePath) { + return nil + } + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + version := schemaKey.TableVersion + if version > currentVersion.Version { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: version, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }) + }(); err != nil { + return nil, errors.Trace(err) + } + + // download new version schema files concurrently + if err := c.downloadSchemaFiles(ctx, newVersionPaths); err != nil { + return nil, errors.Trace(err) + } + + if currentVersion.Version > 0 { + scanVersions = append(scanVersions, currentVersion) + } + return scanVersions, nil +} + +func (c *S3Consumer) getNewFilesForSchemaPathKey( + ctx context.Context, + schema, table string, + version *types.VersionKey, +) (map[cloudstorage.DmlPathKey]fileIndexRange, error) { + schemaPrefix := path.Join(schema, table, fmt.Sprintf("%d", version.Version)) + opt := &storage.WalkOption{ + SubDir: schemaPrefix, + // TODO: StartAfter: version.dataPath, + } + + newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + maxFilePath := "" + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + maxFilePath = filePath + } + return nil + }) + }(); err != nil { + return nil, errors.Trace(err) + } + + version.DataPath = maxFilePath + return c.tableDMLIdx.DiffNewTableDMLIdxMap(newTableDMLIdxMap), nil +} + +func (c *S3Consumer) downloadDMLFiles( + ctx context.Context, + newFiles map[cloudstorage.DmlPathKey]fileIndexRange, +) (map[cloudstorage.DmlPathKey]map[cloudstorage.FileIndexKey][][]byte, error) { + if len(newFiles) == 0 || c.skipDownloadData { + return nil, nil + } + + result := make(map[cloudstorage.DmlPathKey]map[cloudstorage.FileIndexKey][][]byte) + type downloadTask struct { + dmlPathKey cloudstorage.DmlPathKey + fileIndex cloudstorage.FileIndex + } + + var tasks []downloadTask + for dmlPathKey, fileRange := range newFiles { + for indexKey, indexRange := range fileRange { + log.Debug("prepare to download new dml file in index range", + zap.String("schema", dmlPathKey.Schema), + zap.String("table", dmlPathKey.Table), + zap.Uint64("version", dmlPathKey.TableVersion), + zap.Int64("partitionNum", dmlPathKey.PartitionNum), + zap.String("date", dmlPathKey.Date), + zap.String("dispatcherID", indexKey.DispatcherID), + zap.Bool("enableTableAcrossNodes", indexKey.EnableTableAcrossNodes), + zap.Uint64("startIndex", indexRange.start), + zap.Uint64("endIndex", indexRange.end)) + for i := indexRange.start; i <= indexRange.end; i++ { + tasks = append(tasks, downloadTask{ + dmlPathKey: dmlPathKey, + fileIndex: cloudstorage.FileIndex{ + FileIndexKey: indexKey, + Idx: i, + }, + }) + } + } + } + + log.Debug("starting concurrent DML file download", zap.Int("totalFiles", len(tasks))) + + // Concurrently download files + type fileContent struct { + dmlPathKey cloudstorage.DmlPathKey + indexKey cloudstorage.FileIndexKey + idx uint64 + content []byte + } + + fileContents := make(chan fileContent, len(tasks)) + eg, egCtx := errgroup.WithContext(ctx) + for _, task := range tasks { + eg.Go(func() error { + if err := c.acquireReadSlot(egCtx); err != nil { + return errors.Trace(err) + } + defer c.releaseReadSlot() + + filePath := task.dmlPathKey.GenerateDMLFilePath( + &task.fileIndex, + c.fileExtension, + c.fileIndexWidth, + ) + + content, err := c.s3Storage.ReadFile(egCtx, filePath) + if err != nil { + return errors.Annotatef(err, "failed to read file: %s", filePath) + } + + // Channel writes are thread-safe, no mutex needed + fileContents <- fileContent{ + dmlPathKey: task.dmlPathKey, + indexKey: task.fileIndex.FileIndexKey, + idx: task.fileIndex.Idx, + content: content, + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + + // Close the channel to signal no more writes + close(fileContents) + + // Process the downloaded file contents + for fc := range fileContents { + if result[fc.dmlPathKey] == nil { + result[fc.dmlPathKey] = make(map[cloudstorage.FileIndexKey][][]byte) + } + result[fc.dmlPathKey][fc.indexKey] = append( + result[fc.dmlPathKey][fc.indexKey], + fc.content, + ) + } + + return result, nil +} + +// downloadNewFilesWithVersions downloads new files for given schema versions. +// consumeFunc is called from multiple goroutines concurrently and must be goroutine-safe. +func (c *S3Consumer) downloadNewFilesWithVersions( + ctx context.Context, + schema, table string, + scanVersions []types.VersionKey, + consumeFunc func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ), +) (*types.VersionKey, error) { + var maxVersion *types.VersionKey + eg, egCtx := errgroup.WithContext(ctx) + for _, version := range scanVersions { + versionp := &version + if maxVersion == nil || maxVersion.Version < version.Version { + maxVersion = versionp + } + eg.Go(func() error { + newFiles, err := c.getNewFilesForSchemaPathKey(egCtx, schema, table, versionp) + if err != nil { + return errors.Trace(err) + } + dmlData, err := c.downloadDMLFiles(egCtx, newFiles) + if err != nil { + return errors.Trace(err) + } + columnFieldTypes, err := c.schemaDefinitions.GetColumnFieldTypes(schema, table, versionp.Version) + if err != nil { + return errors.Trace(err) + } + for dmlPathKey, dmlSlices := range dmlData { + consumeFunc(dmlPathKey, dmlSlices, columnFieldTypes) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + if maxVersion != nil { + c.currentTableVersion.UpdateCurrentTableVersion(schema, table, *maxVersion) + } + return maxVersion, nil +} + +func (c *S3Consumer) ConsumeNewFiles( + ctx context.Context, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, map[types.SchemaTableKey]types.VersionKey, error) { + var mu sync.Mutex + // Combine DML data and schema data into result + result := make(map[cloudstorage.DmlPathKey]types.IncrementalData) + var versionMu sync.Mutex + maxVersionMap := make(map[types.SchemaTableKey]types.VersionKey) + eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(c.tableWorkerConcurrencyLimit) + for schema, tables := range c.tables { + for _, table := range tables { + eg.Go(func() error { + scanVersions, err := c.discoverAndDownloadNewTableVersions(egCtx, schema, table) + if err != nil { + return errors.Trace(err) + } + maxVersion, err := c.downloadNewFilesWithVersions( + egCtx, schema, table, scanVersions, + func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ) { + mu.Lock() + result[dmlPathKey] = types.IncrementalData{ + DataContentSlices: dmlSlices, + ColumnFieldTypes: columnFieldTypes, + } + mu.Unlock() + }, + ) + if err != nil { + return errors.Trace(err) + } + if maxVersion != nil { + versionMu.Lock() + maxVersionMap[types.SchemaTableKey{Schema: schema, Table: table}] = *maxVersion + versionMu.Unlock() + } + return nil + }) + } + } + + if err := eg.Wait(); err != nil { + return nil, nil, errors.Trace(err) + } + c.skipDownloadData = false + return result, maxVersionMap, nil +} diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go new file mode 100644 index 0000000000..28328c7597 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go @@ -0,0 +1,804 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumer + +import ( + "bytes" + "context" + "path" + "slices" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "github.com/stretchr/testify/require" +) + +func TestUpdateTableDMLIdxMap(t *testing.T) { + t.Parallel() + + t.Run("insert new entry", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + fileIdx := &cloudstorage.FileIndex{ + FileIndexKey: cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}, + Idx: 5, + } + + updateTableDMLIdxMap(m, dmlKey, fileIdx) + require.Len(t, m, 1) + require.Equal(t, uint64(5), m[dmlKey][fileIdx.FileIndexKey]) + }) + + t.Run("update with higher index", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + fileIdx1 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 3} + fileIdx2 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 7} + + updateTableDMLIdxMap(m, dmlKey, fileIdx1) + updateTableDMLIdxMap(m, dmlKey, fileIdx2) + require.Equal(t, uint64(7), m[dmlKey][indexKey]) + }) + + t.Run("skip lower index", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + fileIdx1 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 10} + fileIdx2 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 5} + + updateTableDMLIdxMap(m, dmlKey, fileIdx1) + updateTableDMLIdxMap(m, dmlKey, fileIdx2) + require.Equal(t, uint64(10), m[dmlKey][indexKey]) + }) +} + +func TestCurrentTableVersion(t *testing.T) { + t.Parallel() + + t.Run("get returns zero value for missing key", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + v := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, types.VersionKey{}, v) + }) + + t.Run("update and get", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk := types.VersionKey{Version: 100, VersionPath: "db/tbl/meta/schema_100_0000000000.json"} + cvt.UpdateCurrentTableVersion("db", "tbl", vk) + got := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, vk, got) + }) + + t.Run("update overwrites previous value", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk1 := types.VersionKey{Version: 1} + vk2 := types.VersionKey{Version: 2} + cvt.UpdateCurrentTableVersion("db", "tbl", vk1) + cvt.UpdateCurrentTableVersion("db", "tbl", vk2) + got := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, vk2, got) + }) + + t.Run("different tables are independent", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk1 := types.VersionKey{Version: 10} + vk2 := types.VersionKey{Version: 20} + cvt.UpdateCurrentTableVersion("db", "tbl1", vk1) + cvt.UpdateCurrentTableVersion("db", "tbl2", vk2) + require.Equal(t, vk1, cvt.GetCurrentTableVersion("db", "tbl1")) + require.Equal(t, vk2, cvt.GetCurrentTableVersion("db", "tbl2")) + }) +} + +func TestSchemaDefinitions(t *testing.T) { + t.Parallel() + + t.Run("get returns error for missing key", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + _, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.Error(t, err) + require.Contains(t, err.Error(), "schema definition not found") + }) + + t.Run("set and get empty table definition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{} + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.NoError(t, err) + + got, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.NoError(t, err) + require.Equal(t, map[string]*ptypes.FieldType{}, got) + }) + + t.Run("set and get with columns parses field types correctly", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{ + Table: "tbl", + Schema: "db", + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "name", Tp: "VARCHAR", Precision: "255"}, + {Name: "score", Tp: "DECIMAL", Precision: "10", Scale: "2"}, + {Name: "duration", Tp: "TIME", Scale: "3"}, + {Name: "created_at", Tp: "TIMESTAMP", Scale: "6"}, + {Name: "big_id", Tp: "BIGINT UNSIGNED", Precision: "20"}, + }, + TotalColumns: 6, + } + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.NoError(t, err) + + got, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.NoError(t, err) + require.Len(t, got, 6) + + // INT PK + require.Equal(t, mysql.TypeLong, got["id"].GetType()) + require.True(t, mysql.HasPriKeyFlag(got["id"].GetFlag())) + require.Equal(t, 11, got["id"].GetFlen()) + + // VARCHAR(255) + require.Equal(t, mysql.TypeVarchar, got["name"].GetType()) + require.Equal(t, 255, got["name"].GetFlen()) + + // DECIMAL(10,2) + require.Equal(t, mysql.TypeNewDecimal, got["score"].GetType()) + require.Equal(t, 10, got["score"].GetFlen()) + require.Equal(t, 2, got["score"].GetDecimal()) + + // TIME(3) — decimal stores FSP + require.Equal(t, mysql.TypeDuration, got["duration"].GetType()) + require.Equal(t, 3, got["duration"].GetDecimal()) + + // TIMESTAMP(6) — decimal stores FSP + require.Equal(t, mysql.TypeTimestamp, got["created_at"].GetType()) + require.Equal(t, 6, got["created_at"].GetDecimal()) + + // BIGINT UNSIGNED + require.Equal(t, mysql.TypeLonglong, got["big_id"].GetType()) + require.True(t, mysql.HasUnsignedFlag(got["big_id"].GetFlag())) + require.Equal(t, 20, got["big_id"].GetFlen()) + }) + + t.Run("set returns error for invalid column definition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{ + Table: "tbl", + Schema: "db", + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", Precision: "not_a_number"}, + }, + TotalColumns: 1, + } + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to convert column id to FieldType") + + // Verify the definition was NOT stored + _, err = sp.GetColumnFieldTypes("db", "tbl", 1) + require.Error(t, err) + require.Contains(t, err.Error(), "schema definition not found") + }) + + t.Run("remove with condition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key1 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl1", TableVersion: 1} + key2 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 2} + require.NoError(t, sp.SetSchemaDefinition(key1, "/path1", nil)) + require.NoError(t, sp.SetSchemaDefinition(key2, "/path2", nil)) + + // Remove only entries for tbl1 + sp.RemoveSchemaDefinitionWithCondition(func(k cloudstorage.SchemaPathKey) bool { + return k.Table == "tbl1" + }) + + _, err := sp.GetColumnFieldTypes("db", "tbl1", 1) + require.Error(t, err) + + _, err = sp.GetColumnFieldTypes("db", "tbl2", 2) + require.NoError(t, err) + }) + + t.Run("remove with condition matching all", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key1 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl1", TableVersion: 1} + key2 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 2} + require.NoError(t, sp.SetSchemaDefinition(key1, "/path1", nil)) + require.NoError(t, sp.SetSchemaDefinition(key2, "/path2", nil)) + + sp.RemoveSchemaDefinitionWithCondition(func(k cloudstorage.SchemaPathKey) bool { + return true + }) + + _, err := sp.GetColumnFieldTypes("db", "tbl1", 1) + require.Error(t, err) + _, err = sp.GetColumnFieldTypes("db", "tbl2", 2) + require.Error(t, err) + }) +} + +func TestTableDMLIdx_DiffNewTableDMLIdxMap(t *testing.T) { + t.Parallel() + + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + + t.Run("new entry starts from 1", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey][indexKey]) + }) + + t.Run("existing entry increments from previous end + 1", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + // First call: set initial state + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + // Second call: new end is 7, should get range [4, 7] + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 7}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 4, end: 7}, result[dmlKey][indexKey]) + }) + + t.Run("same end value returns no diff", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Empty(t, result) + }) + + t.Run("lower end value returns no diff", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 10}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Empty(t, result) + }) + + t.Run("empty new map returns empty result", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + result := idx.DiffNewTableDMLIdxMap(map[cloudstorage.DmlPathKey]fileIndexKeyMap{}) + require.Empty(t, result) + }) + + t.Run("multiple keys", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + dmlKey2 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 1}, + Date: "2026-01-02", + } + + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3}, + dmlKey2: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 2) + require.Equal(t, indexRange{start: 1, end: 3}, result[dmlKey][indexKey]) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey2][indexKey]) + }) + + t.Run("multiple index keys for same dml path", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + indexKey2 := cloudstorage.FileIndexKey{DispatcherID: "dispatcher1", EnableTableAcrossNodes: true} + + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3, indexKey2: 5}, + } + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 1, end: 3}, result[dmlKey][indexKey]) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey][indexKey2]) + }) +} + +type mockFile struct { + name string + content []byte +} + +type mockS3Storage struct { + storage.ExternalStorage + + fileOffset map[string]int + sortedFiles []mockFile +} + +type trackingMockS3Storage struct { + *mockS3Storage + currentConcurrent int64 + maxConcurrent int64 + delay time.Duration +} + +func NewMockS3Storage(sortedFiles []mockFile) *mockS3Storage { + s3Storage := &mockS3Storage{} + s3Storage.UpdateFiles(sortedFiles) + return s3Storage +} + +func NewTrackingMockS3Storage(sortedFiles []mockFile, delay time.Duration) *trackingMockS3Storage { + return &trackingMockS3Storage{ + mockS3Storage: NewMockS3Storage(sortedFiles), + delay: delay, + } +} + +func (m *mockS3Storage) ReadFile(ctx context.Context, name string) ([]byte, error) { + return m.sortedFiles[m.fileOffset[name]].content, nil +} + +func (m *trackingMockS3Storage) ReadFile(ctx context.Context, name string) ([]byte, error) { + current := atomic.AddInt64(&m.currentConcurrent, 1) + for { + max := atomic.LoadInt64(&m.maxConcurrent) + if current <= max || atomic.CompareAndSwapInt64(&m.maxConcurrent, max, current) { + break + } + } + defer atomic.AddInt64(&m.currentConcurrent, -1) + + timer := time.NewTimer(m.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + return m.mockS3Storage.ReadFile(ctx, name) +} + +func (m *mockS3Storage) WalkDir(ctx context.Context, opt *storage.WalkOption, fn func(path string, size int64) error) error { + filenamePrefix := path.Join(opt.SubDir, opt.ObjPrefix) + for _, file := range m.sortedFiles { + if strings.HasPrefix(file.name, filenamePrefix) { + if err := fn(file.name, 0); err != nil { + return err + } + } + } + return nil +} + +func (m *mockS3Storage) UpdateFiles(sortedFiles []mockFile) { + fileOffset := make(map[string]int) + for i, file := range sortedFiles { + fileOffset[file.name] = i + } + m.fileOffset = fileOffset + m.sortedFiles = sortedFiles +} + +func (m *trackingMockS3Storage) MaxConcurrent() int64 { + return atomic.LoadInt64(&m.maxConcurrent) +} + +func TestDownloadDMLFilesGlobalConcurrencyLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + dmlPathKey1 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + Date: "2026-01-01", + } + dmlPathKey2 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + Date: "2026-01-02", + } + + files := []mockFile{ + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("f1")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("f2")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000003.json", content: []byte("f3")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("f4")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000002.json", content: []byte("f5")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000003.json", content: []byte("f6")}, + } + s3Storage := NewTrackingMockS3Storage(files, 40*time.Millisecond) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + s3Consumer.skipDownloadData = false + s3Consumer.readLimiter = make(chan struct{}, 2) + + newFiles1 := map[cloudstorage.DmlPathKey]fileIndexRange{ + dmlPathKey1: {indexKey: {start: 1, end: 3}}, + } + newFiles2 := map[cloudstorage.DmlPathKey]fileIndexRange{ + dmlPathKey2: {indexKey: {start: 1, end: 3}}, + } + + var wg sync.WaitGroup + errCh := make(chan error, 2) + wg.Add(2) + go func() { + defer wg.Done() + _, err := s3Consumer.downloadDMLFiles(ctx, newFiles1) + errCh <- err + }() + go func() { + defer wg.Done() + _, err := s3Consumer.downloadDMLFiles(ctx, newFiles2) + errCh <- err + }() + wg.Wait() + close(errCh) + for err := range errCh { + require.NoError(t, err) + } + + require.Greater(t, s3Storage.MaxConcurrent(), int64(1)) + require.LessOrEqual(t, s3Storage.MaxConcurrent(), int64(2)) +} + +func TestS3Consumer(t *testing.T) { + t.Parallel() + ctx := context.Background() + round1Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + } + round1TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 1, + VersionPath: "test/t1/meta/schema_1_0000000001.json", + DataPath: "test/t1/1/2026-01-01/CDC00000000000000000001.json", + }, + }, + } + expectedMaxVersionMap1 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 1, VersionPath: "test/t1/meta/schema_1_0000000001.json", DataPath: "test/t1/1/2026-01-01/CDC00000000000000000001.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + round2Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("1_2026-01-01_2.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("1_2026-01-02_1.json")}, + } + round2TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 10, RightBoundary: 20}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 1, + VersionPath: "test/t1/meta/schema_1_0000000001.json", + DataPath: "test/t1/1/2026-01-02/CDC00000000000000000001.json", + }, + }, + } + expectedNewData2 := func(newData map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, newData, 2) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-01_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-01", + }]) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + } + expectedMaxVersionMap2 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 1, VersionPath: "test/t1/meta/schema_1_0000000001.json", DataPath: "test/t1/1/2026-01-02/CDC00000000000000000001.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + round3Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/meta/schema_2_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("1_2026-01-01_2.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("1_2026-01-02_1.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000002.json", content: []byte("1_2026-01-02_2.json")}, + {name: "test/t1/2/2026-01-02/CDC00000000000000000001.json", content: []byte("2_2026-01-02_1.json")}, + {name: "test/t1/2/2026-01-03/CDC00000000000000000001.json", content: []byte("2_2026-01-03_1.json")}, + {name: "test/t1/2/2026-01-03/CDC00000000000000000002.json", content: []byte("2_2026-01-03_2.json")}, + } + round3TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 20, RightBoundary: 30}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 2, + VersionPath: "test/t1/meta/schema_2_0000000001.json", + DataPath: "test/t1/2/2026-01-03/CDC00000000000000000002.json", + }, + }, + } + expectedNewData3 := func(newData map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, newData, 3) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-02_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("2_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + newDataContent := newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-03", + }] + require.Len(t, newDataContent.DataContentSlices, 1) + contents := newDataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("2_2026-01-03_1.json"), []byte("2_2026-01-03_2.json")}, contents) + } + expectedMaxVersionMap3 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 2, VersionPath: "test/t1/meta/schema_2_0000000001.json", DataPath: "test/t1/2/2026-01-03/CDC00000000000000000002.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + expectedCheckpoint23 := func(data map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, data, 4) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-01_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-01", + }]) + dataContent := data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }] + require.Len(t, dataContent.DataContentSlices, 1) + contents := dataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("1_2026-01-02_1.json"), []byte("1_2026-01-02_2.json")}, contents) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("2_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + dataContent = data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-03", + }] + require.Len(t, dataContent.DataContentSlices, 1) + contents = dataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("2_2026-01-03_1.json"), []byte("2_2026-01-03_2.json")}, contents) + } + + t.Run("checkpoint with nil items returns nil", func(t *testing.T) { + t.Parallel() + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "test", nil) + require.NoError(t, err) + require.Empty(t, data) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + require.Empty(t, newData) + expectedMaxVersionMap1(maxVersionMap) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with empty items returns nil", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "test", checkpoint) + require.NoError(t, err) + require.Empty(t, data) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + require.Empty(t, newData) + expectedMaxVersionMap1(maxVersionMap) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 1 item", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + require.Empty(t, data) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 2 items", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "clusterX": round2TimeWindowData, + }) + s3Storage := NewMockS3Storage(round2Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + expectedNewData2(data) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 3 items", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "clusterX": round2TimeWindowData, + }) + checkpoint.NewTimeWindowData(2, map[string]types.TimeWindowData{ + "clusterX": round3TimeWindowData, + }) + s3Storage := NewMockS3Storage(round3Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + expectedCheckpoint23(data) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder.go b/cmd/multi-cluster-consistency-checker/decoder/decoder.go new file mode 100644 index 0000000000..a35ebeeebe --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder.go @@ -0,0 +1,471 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "fmt" + "slices" + "strconv" + "strings" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/codec/common" + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + tiTypes "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "go.uber.org/zap" + "golang.org/x/text/encoding/charmap" +) + +const tidbWaterMarkType = "TIDB_WATERMARK" + +type canalValueDecoderJSONMessage struct { + PkNames []string `json:"pkNames"` + IsDDL bool `json:"isDdl"` + EventType string `json:"type"` + MySQLType map[string]string `json:"mysqlType"` + Data []map[string]any `json:"data"` +} + +func (c *canalValueDecoderJSONMessage) messageType() common.MessageType { + if c.IsDDL { + return common.MessageTypeDDL + } + + if c.EventType == tidbWaterMarkType { + return common.MessageTypeResolved + } + + return common.MessageTypeRow +} + +type TiDBCommitTsExtension struct { + CommitTs uint64 `json:"commitTs"` +} + +type canalValueDecoderJSONMessageWithTiDBExtension struct { + canalValueDecoderJSONMessage + + TiDBCommitTsExtension *TiDBCommitTsExtension `json:"_tidb"` +} + +func defaultCanalJSONCodecConfig() *common.Config { + codecConfig := common.NewConfig(config.ProtocolCanalJSON) + // Always enable tidb extension for canal-json protocol + // because we need to get the commit ts from the extension field. + codecConfig.EnableTiDBExtension = true + codecConfig.Terminator = config.CRLF + return codecConfig +} + +type Record struct { + types.CdcVersion + Pk types.PkType + PkStr string + PkMap map[string]any + ColumnValues map[string]any +} + +func (r *Record) EqualReplicatedRecord(replicatedRecord *Record) bool { + if replicatedRecord == nil { + return false + } + if r.CommitTs != replicatedRecord.OriginTs { + return false + } + if r.Pk != replicatedRecord.Pk { + return false + } + if len(r.ColumnValues) != len(replicatedRecord.ColumnValues) { + return false + } + for columnName, columnValue := range r.ColumnValues { + replicatedColumnValue, ok := replicatedRecord.ColumnValues[columnName] + if !ok { + return false + } + // NOTE: This comparison is safe because ColumnValues only holds comparable + // types (nil, string, int64, float64, etc.) as produced by the canal-json + // decoder. If a non-comparable type (e.g. []byte or map) were ever stored, + // the != operator would panic at runtime. + if columnValue != replicatedColumnValue { + return false + } + } + return true +} + +type columnValueDecoder struct { + data []byte + config *common.Config + + msg *canalValueDecoderJSONMessageWithTiDBExtension + columnFieldTypes map[string]*ptypes.FieldType +} + +func newColumnValueDecoder(data []byte) (*columnValueDecoder, error) { + config := defaultCanalJSONCodecConfig() + data, err := common.Decompress(config.LargeMessageHandle.LargeMessageHandleCompression, data) + if err != nil { + log.Error("decompress data failed", + zap.String("compression", config.LargeMessageHandle.LargeMessageHandleCompression), + zap.Any("data", data), + zap.Error(err)) + return nil, errors.Annotatef(err, "decompress data failed") + } + return &columnValueDecoder{ + config: config, + data: data, + }, nil +} + +func Decode(data []byte, columnFieldTypes map[string]*ptypes.FieldType) ([]*Record, error) { + decoder, err := newColumnValueDecoder(data) + if err != nil { + return nil, errors.Trace(err) + } + + decoder.columnFieldTypes = columnFieldTypes + + records := make([]*Record, 0) + for { + msgType, hasNext := decoder.tryNext() + if !hasNext { + break + } + if msgType == common.MessageTypeRow { + record, err := decoder.decodeNext() + if err != nil { + return nil, errors.Trace(err) + } + records = append(records, record) + } + } + + return records, nil +} + +func (d *columnValueDecoder) tryNext() (common.MessageType, bool) { + if d.data == nil { + return common.MessageTypeUnknown, false + } + var ( + msg = &canalValueDecoderJSONMessageWithTiDBExtension{} + encodedData []byte + ) + + idx := bytes.IndexAny(d.data, d.config.Terminator) + if idx >= 0 { + encodedData = d.data[:idx] + d.data = d.data[idx+len(d.config.Terminator):] + } else { + encodedData = d.data + d.data = nil + } + + if len(encodedData) == 0 { + return common.MessageTypeUnknown, false + } + + if err := json.Unmarshal(encodedData, msg); err != nil { + log.Error("canal json decoder unmarshal data failed", + zap.Error(err), zap.ByteString("data", encodedData)) + d.msg = nil + return common.MessageTypeUnknown, true + } + d.msg = msg + return d.msg.messageType(), true +} + +func (d *columnValueDecoder) decodeNext() (*Record, error) { + if d.msg == nil || len(d.msg.Data) == 0 || d.msg.messageType() != common.MessageTypeRow { + log.Error("invalid message", zap.Any("msg", d.msg)) + return nil, errors.New("invalid message") + } + + var pkStrBuilder strings.Builder + pkStrBuilder.WriteString("[") + pkValues := make([]tiTypes.Datum, 0, len(d.msg.PkNames)) + pkMap := make(map[string]any, len(d.msg.PkNames)) + slices.Sort(d.msg.PkNames) + for i, pkName := range d.msg.PkNames { + columnValue, ok := d.msg.Data[0][pkName] + if !ok { + log.Error("column value not found", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("column value of column %s not found", pkName) + } + if i > 0 { + pkStrBuilder.WriteString(", ") + } + fmt.Fprintf(&pkStrBuilder, "%s: %v", pkName, columnValue) + pkMap[pkName] = columnValue + ft := d.getColumnFieldType(pkName) + if ft == nil { + log.Error("field type not found", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("field type of column %s not found", pkName) + } + datum, err := safeValueToDatum(columnValue, ft) + if err != nil { + log.Error("failed to convert primary key column value", + zap.String("pkName", pkName), + zap.Any("columnValue", columnValue), + zap.Error(err)) + return nil, errors.Annotatef(err, "failed to convert primary key column %s", pkName) + } + if datum.IsNull() { + log.Error("column value is null", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("column value of column %s is null", pkName) + } + pkValues = append(pkValues, *datum) + delete(d.msg.Data[0], pkName) + } + pkStrBuilder.WriteString("]") + pkEncoded, err := codec.EncodeKey(time.UTC, nil, pkValues...) + if err != nil { + return nil, errors.Annotate(err, "failed to encode primary key") + } + pk := hex.EncodeToString(pkEncoded) + originTs := uint64(0) + columnValues := make(map[string]any) + for columnName, columnValue := range d.msg.Data[0] { + if columnName == event.OriginTsColumn { + if columnValue != nil { + originTs, err = strconv.ParseUint(columnValue.(string), 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + } + } else { + columnValues[columnName] = columnValue + } + } + commitTs := d.msg.TiDBCommitTsExtension.CommitTs + d.msg = nil + return &Record{ + Pk: types.PkType(pk), + PkStr: pkStrBuilder.String(), + PkMap: pkMap, + ColumnValues: columnValues, + CdcVersion: types.CdcVersion{ + CommitTs: commitTs, + OriginTs: originTs, + }, + }, nil +} + +func safeValueToDatum(value any, ft *ptypes.FieldType) (datum *tiTypes.Datum, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.Errorf("value to datum conversion panic: %v", r) + datum = nil + } + }() + return valueToDatum(value, ft), nil +} + +// getColumnFieldType returns the FieldType for a column. +// It first looks up from the tableDefinition-based columnFieldTypes map, +// then falls back to parsing the MySQLType string from the canal-json message. +func (d *columnValueDecoder) getColumnFieldType(columnName string) *ptypes.FieldType { + if d.columnFieldTypes != nil { + if ft, ok := d.columnFieldTypes[columnName]; ok { + return ft + } + } + // Fallback: parse from MySQLType in the canal-json message + mysqlType, ok := d.msg.MySQLType[columnName] + if !ok { + return nil + } + return newPKColumnFieldTypeFromMysqlType(mysqlType) +} + +func newPKColumnFieldTypeFromMysqlType(mysqlType string) *ptypes.FieldType { + tp := ptypes.NewFieldType(common.ExtractBasicMySQLType(mysqlType)) + if common.IsBinaryMySQLType(mysqlType) { + tp.AddFlag(mysql.BinaryFlag) + tp.SetCharset("binary") + tp.SetCollate("binary") + } + if strings.HasPrefix(mysqlType, "char") || + strings.HasPrefix(mysqlType, "varchar") || + strings.Contains(mysqlType, "text") || + strings.Contains(mysqlType, "enum") || + strings.Contains(mysqlType, "set") { + tp.SetCharset("utf8mb4") + tp.SetCollate("utf8mb4_bin") + } + + if common.IsUnsignedMySQLType(mysqlType) { + tp.AddFlag(mysql.UnsignedFlag) + } + + flen, decimal := common.ExtractFlenDecimal(mysqlType, tp.GetType()) + tp.SetFlen(flen) + tp.SetDecimal(decimal) + switch tp.GetType() { + case mysql.TypeEnum, mysql.TypeSet: + tp.SetElems(common.ExtractElements(mysqlType)) + case mysql.TypeDuration: + decimal = common.ExtractDecimal(mysqlType) + tp.SetDecimal(decimal) + default: + } + return tp +} + +func valueToDatum(value any, ft *ptypes.FieldType) *tiTypes.Datum { + d := &tiTypes.Datum{} + if value == nil { + d.SetNull() + return d + } + rawValue, ok := value.(string) + if !ok { + log.Panic("canal-json encoded message should have type in `string`") + } + if mysql.HasBinaryFlag(ft.GetFlag()) { + // when encoding the `JavaSQLTypeBLOB`, use `IS08859_1` decoder, now reverse it back. + result, err := charmap.ISO8859_1.NewEncoder().String(rawValue) + if err != nil { + log.Panic("invalid column value, please report a bug", zap.Any("rawValue", rawValue), zap.Error(err)) + } + rawValue = result + } + + switch ft.GetType() { + case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: + if mysql.HasUnsignedFlag(ft.GetFlag()) { + data, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for unsigned integer", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetUint64(data) + return d + } + data, err := strconv.ParseInt(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for integer", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetInt64(data) + return d + case mysql.TypeYear: + data, err := strconv.ParseInt(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for year", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetInt64(data) + return d + case mysql.TypeFloat: + data, err := strconv.ParseFloat(rawValue, 32) + if err != nil { + log.Panic("invalid column value for float", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetFloat32(float32(data)) + return d + case mysql.TypeDouble: + data, err := strconv.ParseFloat(rawValue, 64) + if err != nil { + log.Panic("invalid column value for double", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetFloat64(data) + return d + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, + mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + d.SetString(rawValue, ft.GetCollate()) + return d + case mysql.TypeNewDecimal: + data := new(tiTypes.MyDecimal) + err := data.FromString([]byte(rawValue)) + if err != nil { + log.Panic("invalid column value for decimal", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlDecimal(data) + d.SetLength(ft.GetFlen()) + if ft.GetDecimal() == tiTypes.UnspecifiedLength { + d.SetFrac(int(data.GetDigitsFrac())) + } else { + d.SetFrac(ft.GetDecimal()) + } + return d + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + data, err := tiTypes.ParseTime(tiTypes.DefaultStmtNoWarningContext, rawValue, ft.GetType(), ft.GetDecimal()) + if err != nil { + log.Panic("invalid column value for time", zap.Any("rawValue", rawValue), + zap.Int("flen", ft.GetFlen()), zap.Int("decimal", ft.GetDecimal()), + zap.Error(err)) + } + d.SetMysqlTime(data) + return d + case mysql.TypeDuration: + data, _, err := tiTypes.ParseDuration(tiTypes.DefaultStmtNoWarningContext, rawValue, ft.GetDecimal()) + if err != nil { + log.Panic("invalid column value for duration", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlDuration(data) + return d + case mysql.TypeEnum: + enumValue, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for enum", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlEnum(tiTypes.Enum{ + Name: "", + Value: enumValue, + }, ft.GetCollate()) + return d + case mysql.TypeSet: + setValue, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for set", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlSet(tiTypes.Set{ + Name: "", + Value: setValue, + }, ft.GetCollate()) + return d + case mysql.TypeBit: + data, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for bit", zap.Any("rawValue", rawValue), zap.Error(err)) + } + byteSize := (ft.GetFlen() + 7) >> 3 + d.SetMysqlBit(tiTypes.NewBinaryLiteralFromUint(data, byteSize)) + return d + case mysql.TypeJSON: + data, err := tiTypes.ParseBinaryJSONFromString(rawValue) + if err != nil { + log.Panic("invalid column value for json", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlJSON(data) + return d + case mysql.TypeTiDBVectorFloat32: + data, err := tiTypes.ParseVectorFloat32(rawValue) + if err != nil { + log.Panic("cannot parse vector32 value from string", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetVectorFloat32(data) + return d + } + return d +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go new file mode 100644 index 0000000000..ba04e4ef3d --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go @@ -0,0 +1,326 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder_test + +import ( + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "github.com/stretchr/testify/require" +) + +// buildColumnFieldTypes converts a TableDefinition into a map of column name → FieldType, +// mimicking what SchemaDefinitions.SetSchemaDefinition does. +func buildColumnFieldTypes(t *testing.T, td *cloudstorage.TableDefinition) map[string]*ptypes.FieldType { + t.Helper() + result := make(map[string]*ptypes.FieldType, len(td.Columns)) + for i, col := range td.Columns { + colInfo, err := col.ToTiColumnInfo(int64(i)) + require.NoError(t, err) + result[col.Name] = &colInfo.FieldType + } + return result +} + +// DataContent uses CRLF (\r\n) as line terminator to match the codec config +const DataContent1 string = "" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770184540709,"ts":1770184542274,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar"},"old":null,"data":[{"id":"20","first_name":"t","last_name":"TT","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464043256649875456}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770184540709,"ts":1770184542274,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"21","first_name":"u","last_name":"UU","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464043256649875456}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770301693150,"ts":1770301693833,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int"},"old":null,"data":[{"id":"5","first_name":"e","last_name":"E","_tidb_origin_ts":"464073966942421014","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464073967049113629}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770301693150,"ts":1770301693833,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint"},"old":null,"data":[{"id":"6","first_name":"f","last_name":"F","_tidb_origin_ts":"464073966942421014","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464073967049113629}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770303499850,"ts":1770303500498,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":"464074440387592202","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464074440664678441}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"UPDATE","es":1770303520951,"ts":1770303522531,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar"},"old":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":"464074440387592202","_tidb_softdelete_time":null}],"data":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":null,"_tidb_softdelete_time":"2026-02-05 22:58:40.992217"}],"_tidb":{"commitTs":464074446196178963}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770303498793,"ts":1770303499864,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464074440387592202}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"UPDATE","es":1770303522494,"ts":1770303523900,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"data":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":"464074446196178963","_tidb_softdelete_time":"2026-02-05 22:58:40.992217"}],"_tidb":{"commitTs":464074446600667164}}` + +var ExpectedRecords1 = []decoder.Record{ + {CdcVersion: types.CdcVersion{CommitTs: 464043256649875456, OriginTs: 0}, Pk: "038000000000000014", PkStr: "[id: 20]", ColumnValues: map[string]any{"first_name": "t", "last_name": "TT", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464043256649875456, OriginTs: 0}, Pk: "038000000000000015", PkStr: "[id: 21]", ColumnValues: map[string]any{"first_name": "u", "last_name": "UU", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464073967049113629, OriginTs: 464073966942421014}, Pk: "038000000000000005", PkStr: "[id: 5]", ColumnValues: map[string]any{"first_name": "e", "last_name": "E", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464073967049113629, OriginTs: 464073966942421014}, Pk: "038000000000000006", PkStr: "[id: 6]", ColumnValues: map[string]any{"first_name": "f", "last_name": "F", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074440664678441, OriginTs: 464074440387592202}, Pk: "038000000000000007", PkStr: "[id: 7]", ColumnValues: map[string]any{"first_name": "g", "last_name": "G", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074446196178963, OriginTs: 0}, Pk: "038000000000000007", PkStr: "[id: 7]", ColumnValues: map[string]any{"first_name": "g", "last_name": "G", "_tidb_softdelete_time": "2026-02-05 22:58:40.992217"}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074440387592202, OriginTs: 0}, Pk: "038000000000000008", PkStr: "[id: 8]", ColumnValues: map[string]any{"first_name": "h", "last_name": "H", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074446600667164, OriginTs: 464074446196178963}, Pk: "038000000000000008", PkStr: "[id: 8]", ColumnValues: map[string]any{"first_name": "h", "last_name": "H", "_tidb_softdelete_time": "2026-02-05 22:58:40.992217"}}, +} + +// tableDefinition1 describes the "message" table: id(INT PK), first_name(VARCHAR), last_name(VARCHAR), +// _tidb_origin_ts(BIGINT), _tidb_softdelete_time(TIMESTAMP) +var tableDefinition1 = &cloudstorage.TableDefinition{ + Table: "message", + Schema: "test_active", + Version: 1, + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "first_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "last_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "_tidb_origin_ts", Tp: "BIGINT", Precision: "20"}, + {Name: "_tidb_softdelete_time", Tp: "TIMESTAMP"}, + }, + TotalColumns: 5, +} + +func TestCanalJSONDecoder1(t *testing.T) { + records, err := decoder.Decode([]byte(DataContent1), buildColumnFieldTypes(t, tableDefinition1)) + require.NoError(t, err) + require.Len(t, records, 8) + for i, actualRecord := range records { + expectedRecord := ExpectedRecords1[i] + require.Equal(t, actualRecord.Pk, expectedRecord.Pk) + require.Equal(t, actualRecord.PkStr, expectedRecord.PkStr) + require.Equal(t, actualRecord.ColumnValues, expectedRecord.ColumnValues) + require.Equal(t, actualRecord.CdcVersion.CommitTs, expectedRecord.CdcVersion.CommitTs) + require.Equal(t, actualRecord.CdcVersion.OriginTs, expectedRecord.CdcVersion.OriginTs) + } +} + +const DataContent2 string = "" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344412751,"ts":1770344413749,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"100","first_name":"a","last_name":"A","_tidb_origin_ts":"464085165262503958","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085165736198159}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344427851,"ts":1770344429772,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"101","first_name":"b","last_name":"B","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085169694572575}}` + "\r\n" + +var ExpectedRecords2 = []decoder.Record{ + {CdcVersion: types.CdcVersion{CommitTs: 464085165736198159, OriginTs: 464085165262503958}, Pk: "016100000000000000f8038000000000000064", PkStr: "[first_name: a, id: 100]", ColumnValues: map[string]any{"last_name": "A", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464085169694572575, OriginTs: 0}, Pk: "016200000000000000f8038000000000000065", PkStr: "[first_name: b, id: 101]", ColumnValues: map[string]any{"last_name": "B", "_tidb_softdelete_time": nil}}, +} + +// tableDefinition2 describes the "message2" table: id(INT PK), first_name(VARCHAR PK), last_name(VARCHAR), +// _tidb_origin_ts(BIGINT), _tidb_softdelete_time(TIMESTAMP) +var tableDefinition2 = &cloudstorage.TableDefinition{ + Table: "message2", + Schema: "test_active", + Version: 1, + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "first_name", Tp: "VARCHAR", IsPK: "true", Precision: "255"}, + {Name: "last_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "_tidb_origin_ts", Tp: "BIGINT", Precision: "20"}, + {Name: "_tidb_softdelete_time", Tp: "TIMESTAMP"}, + }, + TotalColumns: 5, +} + +func TestCanalJSONDecoder2(t *testing.T) { + records, err := decoder.Decode([]byte(DataContent2), buildColumnFieldTypes(t, tableDefinition2)) + require.NoError(t, err) + require.Len(t, records, 2) + for i, actualRecord := range records { + expectedRecord := ExpectedRecords2[i] + require.Equal(t, actualRecord.Pk, expectedRecord.Pk) + require.Equal(t, actualRecord.PkStr, expectedRecord.PkStr) + require.Equal(t, actualRecord.ColumnValues, expectedRecord.ColumnValues) + require.Equal(t, actualRecord.CdcVersion.CommitTs, expectedRecord.CdcVersion.CommitTs) + require.Equal(t, actualRecord.CdcVersion.OriginTs, expectedRecord.CdcVersion.OriginTs) + } +} + +// TestCanalJSONDecoderWithInvalidMessage verifies that when a malformed message appears in +// the data stream, it is skipped gracefully and subsequent valid messages are still decoded. +// This covers the fix where d.msg is cleared to nil on unmarshal failure to prevent stale +// message data from leaking into decodeNext. +func TestCanalJSONDecoderWithInvalidMessage(t *testing.T) { + // First line is invalid JSON, second line is a valid message. + dataWithInvalidLine := `{invalid json}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344412751,"ts":1770344413749,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"100","first_name":"a","last_name":"A","_tidb_origin_ts":"464085165262503958","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085165736198159}}` + "\r\n" + + records, err := decoder.Decode([]byte(dataWithInvalidLine), buildColumnFieldTypes(t, tableDefinition2)) + require.NoError(t, err) + // The invalid line should be skipped, only the valid record should be returned. + require.Len(t, records, 1) + require.Equal(t, ExpectedRecords2[0].Pk, records[0].Pk) + require.Equal(t, ExpectedRecords2[0].PkStr, records[0].PkStr) + require.Equal(t, ExpectedRecords2[0].CdcVersion.CommitTs, records[0].CdcVersion.CommitTs) + require.Equal(t, ExpectedRecords2[0].CdcVersion.OriginTs, records[0].CdcVersion.OriginTs) +} + +// TestCanalJSONDecoderAllInvalidMessages verifies that when all messages are malformed, +// the decoder returns an empty result without errors. +func TestCanalJSONDecoderAllInvalidMessages(t *testing.T) { + allInvalid := `{broken}` + "\r\n" + `{also broken}` + "\r\n" + records, err := decoder.Decode([]byte(allInvalid), nil) + require.NoError(t, err) + require.Empty(t, records) +} + +func TestCanalJSONDecoderInvalidPrimaryKeyValueNoPanic(t *testing.T) { + // id is intentionally encoded as a JSON number (not string), which would + // trigger a panic inside valueToDatum without recover protection. + invalidPKType := `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1,"ts":1,"sql":"","sqlType":{"id":4},"mysqlType":{"id":"int"},"old":null,"data":[{"id":20}],"_tidb":{"commitTs":100}}` + "\r\n" + + var ( + records []*decoder.Record + err error + ) + require.NotPanics(t, func() { + records, err = decoder.Decode([]byte(invalidPKType), buildColumnFieldTypes(t, tableDefinition1)) + }) + require.Error(t, err) + require.Empty(t, records) + require.Contains(t, err.Error(), "failed to convert primary key column id") +} + +func TestRecord_EqualReplicatedRecord(t *testing.T) { + tests := []struct { + name string + local *decoder.Record + replicated *decoder.Record + expectedEqual bool + }{ + { + name: "equal records", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": 42, + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": 42, + }, + }, + expectedEqual: true, + }, + { + name: "replicated is nil", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: nil, + expectedEqual: false, + }, + { + name: "different CommitTs and OriginTs", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 200}, + Pk: "pk1", + }, + expectedEqual: false, + }, + { + name: "different primary keys", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk2", + }, + expectedEqual: false, + }, + { + name: "different column count", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": "value2", + }, + }, + expectedEqual: false, + }, + { + name: "different column names", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col2": "value1", + }, + }, + expectedEqual: false, + }, + { + name: "different column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value2", + }, + }, + expectedEqual: false, + }, + { + name: "empty column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{}, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{}, + }, + expectedEqual: true, + }, + { + name: "nil column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: nil, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: nil, + }, + expectedEqual: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.local.EqualReplicatedRecord(tt.replicated) + require.Equal(t, tt.expectedEqual, result) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go b/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go new file mode 100644 index 0000000000..95bc8d43e9 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go @@ -0,0 +1,898 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder + +import ( + "math" + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + tiTypes "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestValueToDatum_NilValue(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + d := valueToDatum(nil, ft) + require.True(t, d.IsNull()) +} + +func TestValueToDatum_NonStringPanics(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + require.Panics(t, func() { + valueToDatum(123, ft) + }) +} + +func TestValueToDatum_SignedIntegers(t *testing.T) { + t.Parallel() + + intTypes := []struct { + name string + tp byte + }{ + {"TypeTiny", mysql.TypeTiny}, + {"TypeShort", mysql.TypeShort}, + {"TypeInt24", mysql.TypeInt24}, + {"TypeLong", mysql.TypeLong}, + {"TypeLonglong", mysql.TypeLonglong}, + } + + for _, it := range intTypes { + t.Run(it.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(it.tp) + + t.Run("positive", func(t *testing.T) { + t.Parallel() + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(42), d.GetInt64()) + }) + + t.Run("zero", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(0), d.GetInt64()) + }) + + t.Run("negative", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-100", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(-100), d.GetInt64()) + }) + + t.Run("max int64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("9223372036854775807", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MaxInt64), d.GetInt64()) + }) + + t.Run("min int64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-9223372036854775808", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MinInt64), d.GetInt64()) + }) + }) + } +} + +func TestValueToDatum_UnsignedIntegers(t *testing.T) { + t.Parallel() + + intTypes := []struct { + name string + tp byte + }{ + {"TypeTiny", mysql.TypeTiny}, + {"TypeShort", mysql.TypeShort}, + {"TypeInt24", mysql.TypeInt24}, + {"TypeLong", mysql.TypeLong}, + {"TypeLonglong", mysql.TypeLonglong}, + } + + for _, it := range intTypes { + t.Run(it.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(it.tp) + ft.AddFlag(mysql.UnsignedFlag) + + t.Run("positive", func(t *testing.T) { + t.Parallel() + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(42), d.GetUint64()) + }) + + t.Run("zero", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(0), d.GetUint64()) + }) + + t.Run("max uint64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(math.MaxUint64), d.GetUint64()) + }) + }) + } +} + +func TestValueToDatum_InvalidIntegerPanics(t *testing.T) { + t.Parallel() + + t.Run("signed invalid", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + require.Panics(t, func() { + valueToDatum("not_a_number", ft) + }) + }) + + t.Run("unsigned invalid", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + ft.AddFlag(mysql.UnsignedFlag) + require.Panics(t, func() { + valueToDatum("not_a_number", ft) + }) + }) +} + +func TestValueToDatum_Year(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeYear) + + t.Run("normal year", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2026", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(2026), d.GetInt64()) + }) + + t.Run("zero year", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(0), d.GetInt64()) + }) + + t.Run("invalid year panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("abc", ft) + }) + }) +} + +func TestValueToDatum_Float(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeFloat) + + t.Run("positive float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3.14", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(3.14), d.GetFloat32(), 0.001) + }) + + t.Run("negative float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-2.5", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(-2.5), d.GetFloat32(), 0.001) + }) + + t.Run("zero float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.Equal(t, float32(0), d.GetFloat32()) + }) + + t.Run("invalid float panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_float", ft) + }) + }) +} + +func TestValueToDatum_Double(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeDouble) + + t.Run("positive double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3.141592653589793", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, 3.141592653589793, d.GetFloat64(), 1e-15) + }) + + t.Run("negative double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-1.23456789", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, -1.23456789, d.GetFloat64(), 1e-9) + }) + + t.Run("zero double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.Equal(t, float64(0), d.GetFloat64()) + }) + + t.Run("very large double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1.7976931348623157e+308", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, math.MaxFloat64, d.GetFloat64(), 1e+293) + }) + + t.Run("invalid double panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_double", ft) + }) + }) +} + +func TestValueToDatum_StringTypes(t *testing.T) { + t.Parallel() + + stringTypes := []struct { + name string + tp byte + }{ + {"TypeVarString", mysql.TypeVarString}, + {"TypeVarchar", mysql.TypeVarchar}, + {"TypeString", mysql.TypeString}, + {"TypeBlob", mysql.TypeBlob}, + {"TypeTinyBlob", mysql.TypeTinyBlob}, + {"TypeMediumBlob", mysql.TypeMediumBlob}, + {"TypeLongBlob", mysql.TypeLongBlob}, + } + + for _, st := range stringTypes { + t.Run(st.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(st.tp) + ft.SetCollate("utf8mb4_bin") + + t.Run("normal string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("hello world", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "hello world", d.GetString()) + }) + + t.Run("empty string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "", d.GetString()) + }) + + t.Run("unicode string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("你好世界🌍", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "你好世界🌍", d.GetString()) + }) + }) + } +} + +func TestValueToDatum_BinaryFlag(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeString) + ft.AddFlag(mysql.BinaryFlag) + ft.SetCharset("binary") + ft.SetCollate("binary") + + t.Run("ascii content", func(t *testing.T) { + t.Parallel() + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "abc", d.GetString()) + }) + + t.Run("empty binary", func(t *testing.T) { + t.Parallel() + d := valueToDatum("", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "", d.GetString()) + }) +} + +func TestValueToDatum_Decimal(t *testing.T) { + t.Parallel() + + t.Run("simple decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(2) + + d := valueToDatum("123.45", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "123.45", d.GetMysqlDecimal().String()) + require.Equal(t, 10, d.Length()) + require.Equal(t, 2, d.Frac()) + }) + + t.Run("negative decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(3) + + d := valueToDatum("-99.999", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "-99.999", d.GetMysqlDecimal().String()) + require.Equal(t, 3, d.Frac()) + }) + + t.Run("zero decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(0) + + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "0", d.GetMysqlDecimal().String()) + }) + + t.Run("large decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(65) + ft.SetDecimal(30) + + d := valueToDatum("12345678901234567890.123456789012345678", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, 65, d.Length()) + require.Equal(t, 30, d.Frac()) + }) + + t.Run("unspecified decimal uses actual frac", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(tiTypes.UnspecifiedLength) + + d := valueToDatum("12.345", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, 3, d.Frac()) // actual digits frac from the value + }) + + t.Run("invalid decimal panics", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(2) + require.Panics(t, func() { + valueToDatum("not_decimal", ft) + }) + }) +} + +func TestValueToDatum_Date(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeDate) + ft.SetDecimal(0) + + t.Run("normal date", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2026-02-11", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11", d.GetMysqlTime().String()) + }) + + t.Run("zero date", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0000-00-00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "0000-00-00", d.GetMysqlTime().String()) + }) +} + +func TestValueToDatum_Datetime(t *testing.T) { + t.Parallel() + + t.Run("datetime without fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDatetime) + ft.SetDecimal(0) + + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00", d.GetMysqlTime().String()) + }) + + t.Run("datetime with fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDatetime) + ft.SetDecimal(6) + + d := valueToDatum("2026-02-11 10:30:00.123456", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00.123456", d.GetMysqlTime().String()) + }) +} + +func TestValueToDatum_Timestamp(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeTimestamp) + ft.SetDecimal(0) + + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00", d.GetMysqlTime().String()) +} + +func TestValueToDatum_Duration(t *testing.T) { + t.Parallel() + + t.Run("positive duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("12:30:45", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "12:30:45", d.GetMysqlDuration().String()) + }) + + t.Run("negative duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("-01:00:00", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "-01:00:00", d.GetMysqlDuration().String()) + }) + + t.Run("duration with fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(3) + + d := valueToDatum("10:20:30.123", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "10:20:30.123", d.GetMysqlDuration().String()) + }) + + t.Run("zero duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("00:00:00", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "00:00:00", d.GetMysqlDuration().String()) + }) +} + +func TestValueToDatum_Enum(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeEnum) + ft.SetCharset("utf8mb4") + ft.SetCollate("utf8mb4_bin") + ft.SetElems([]string{"a", "b", "c"}) + + t.Run("valid enum value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(1), d.GetMysqlEnum().Value) + }) + + t.Run("enum value 2", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(2), d.GetMysqlEnum().Value) + }) + + t.Run("enum value 0", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(0), d.GetMysqlEnum().Value) + }) + + t.Run("invalid enum panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("abc", ft) + }) + }) +} + +func TestValueToDatum_Set(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeSet) + ft.SetCharset("utf8mb4") + ft.SetCollate("utf8mb4_bin") + ft.SetElems([]string{"a", "b", "c"}) + + t.Run("single set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(1), d.GetMysqlSet().Value) + }) + + t.Run("combined set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3", ft) // a,b + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(3), d.GetMysqlSet().Value) + }) + + t.Run("zero set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(0), d.GetMysqlSet().Value) + }) + + t.Run("invalid set panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("xyz", ft) + }) + }) +} + +func TestValueToDatum_Bit(t *testing.T) { + t.Parallel() + + t.Run("bit(1) value 1", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(1) + + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit(8) value 255", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + + d := valueToDatum("255", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit(64) large value", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(64) + + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit zero", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("invalid bit panics", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + require.Panics(t, func() { + valueToDatum("not_a_bit", ft) + }) + }) +} + +func TestValueToDatum_JSON(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeJSON) + + t.Run("json object", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`{"key": "value"}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + require.Contains(t, d.GetMysqlJSON().String(), "key") + require.Contains(t, d.GetMysqlJSON().String(), "value") + }) + + t.Run("json array", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`[1, 2, 3]`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json string", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`"hello"`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json number", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`42`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json null", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`null`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json boolean", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`true`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("nested json", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`{"a": [1, {"b": "c"}], "d": null}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("invalid json panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum(`{invalid`, ft) + }) + }) +} + +func TestValueToDatum_VectorFloat32(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeTiDBVectorFloat32) + + t.Run("simple vector", func(t *testing.T) { + t.Parallel() + d := valueToDatum("[1,2,3]", ft) + require.False(t, d.IsNull()) + }) + + t.Run("single element vector", func(t *testing.T) { + t.Parallel() + d := valueToDatum("[0.5]", ft) + require.False(t, d.IsNull()) + }) + + t.Run("invalid vector panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_vector", ft) + }) + }) +} + +func TestValueToDatum_UnknownType(t *testing.T) { + t.Parallel() + // Use a type that doesn't match any case in the switch (TypeGeometry). + // The default datum returned is a zero-value datum, which is null. + ft := ptypes.NewFieldType(mysql.TypeGeometry) + d := valueToDatum("some_value", ft) + require.True(t, d.IsNull()) +} + +func TestValueToDatum_ViaNewPKColumnFieldType(t *testing.T) { + t.Parallel() + // Test valueToDatum using FieldTypes produced by newPKColumnFieldTypeFromMysqlType, + // which is the real caller in production code. + + t.Run("int", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("int") + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(42), d.GetInt64()) + }) + + t.Run("int unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("int unsigned") + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(42), d.GetUint64()) + }) + + t.Run("bigint", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bigint") + d := valueToDatum("9223372036854775807", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MaxInt64), d.GetInt64()) + }) + + t.Run("bigint unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bigint unsigned") + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(math.MaxUint64), d.GetUint64()) + }) + + t.Run("varchar", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("varchar") + d := valueToDatum("hello", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "hello", d.GetString()) + }) + + t.Run("char", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("char") + d := valueToDatum("x", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "x", d.GetString()) + }) + + t.Run("decimal(10,2)", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("decimal(10,2)") + d := valueToDatum("123.45", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "123.45", d.GetMysqlDecimal().String()) + require.Equal(t, 10, d.Length()) + require.Equal(t, 2, d.Frac()) + }) + + t.Run("float", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("float") + d := valueToDatum("3.14", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(3.14), d.GetFloat32(), 0.001) + }) + + t.Run("double", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("double") + d := valueToDatum("3.141592653589793", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, 3.141592653589793, d.GetFloat64(), 1e-15) + }) + + t.Run("binary", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("binary") + require.True(t, mysql.HasBinaryFlag(ft.GetFlag())) + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + }) + + t.Run("varbinary", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("varbinary") + require.True(t, mysql.HasBinaryFlag(ft.GetFlag())) + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + }) + + t.Run("tinyint", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("tinyint") + d := valueToDatum("127", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(127), d.GetInt64()) + }) + + t.Run("smallint unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("smallint unsigned") + d := valueToDatum("65535", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(65535), d.GetUint64()) + }) + + t.Run("date", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("date") + d := valueToDatum("2026-02-11", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11", d.GetMysqlTime().String()) + }) + + t.Run("datetime", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("datetime") + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + }) + + t.Run("timestamp", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("timestamp") + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + }) + + t.Run("time", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("time") + d := valueToDatum("12:30:45", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + }) + + t.Run("year", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("year") + d := valueToDatum("2026", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(2026), d.GetInt64()) + }) + + t.Run("enum('a','b','c')", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("enum('a','b','c')") + d := valueToDatum("2", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(2), d.GetMysqlEnum().Value) + }) + + t.Run("set('x','y','z')", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("set('x','y','z')") + d := valueToDatum("5", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(5), d.GetMysqlSet().Value) + }) + + t.Run("bit(8)", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bit(8)") + d := valueToDatum("255", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("json", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("json") + d := valueToDatum(`{"key":"value"}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/integration/integration_test.go b/cmd/multi-cluster-consistency-checker/integration/integration_test.go new file mode 100644 index 0000000000..9a6dbc78be --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/integration_test.go @@ -0,0 +1,733 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/advancer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/stretchr/testify/require" +) + +// schemaKey is the schema key for data stored via S3 path "test/t1/1/...". +// It equals QuoteSchema("test", "t1") = "`test`.`t1`". +var schemaKey = (&cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, +}).GetKey() + +// testEnv holds the initialized test environment. +type testEnv struct { + ctx context.Context + mc *MockMultiCluster + advancer *advancer.TimeWindowAdvancer + checker *checker.DataChecker +} + +// setupEnv creates a test environment with 2 clusters (c1, c2), both +// replicating to each other, with in-memory S3 storage and mock PD/watchers. +func setupEnv(t *testing.T) *testEnv { + t.Helper() + ctx := context.Background() + tables := map[string][]string{"test": {"t1"}} + + mc := NewMockMultiCluster( + []string{"c1", "c2"}, + tables, + 0, // pdBase: start physical time at 0ms + 100, // pdStep: 100ms per PD GetTS call + 100, // cpDelta: checkpoint = minCheckpointTs + 100 + 50, // s3Delta: s3 checkpoint = minCheckpointTs + 50 + ) + + require.NoError(t, mc.InitSchemaFiles(ctx)) + + twa, _, err := advancer.NewTimeWindowAdvancer( + ctx, mc.CPWatchers, mc.S3Watchers, mc.GetPDClients(), nil, + ) + require.NoError(t, err) + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + dc, err := checker.NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, err) + + return &testEnv{ctx: ctx, mc: mc, advancer: twa, checker: dc} +} + +// roundResult holds the output of a single round. +type roundResult struct { + report *recorder.Report + twData map[string]types.TimeWindowData +} + +// executeRound writes data to clusters' S3 storage, advances the time window, +// and runs the checker for one round. +func (e *testEnv) executeRound(t *testing.T, c1Content, c2Content []byte) roundResult { + t.Helper() + if c1Content != nil { + require.NoError(t, e.mc.WriteDMLFile(e.ctx, "c1", c1Content)) + } + if c2Content != nil { + require.NoError(t, e.mc.WriteDMLFile(e.ctx, "c2", c2Content)) + } + + twData, err := e.advancer.AdvanceTimeWindow(e.ctx) + require.NoError(t, err) + + report, err := e.checker.CheckInNextTimeWindow(twData) + require.NoError(t, err) + + return roundResult{report: report, twData: twData} +} + +// maxRightBoundary returns the maximum RightBoundary across all clusters. +func maxRightBoundary(twData map[string]types.TimeWindowData) uint64 { + maxRB := uint64(0) + for _, tw := range twData { + if tw.TimeWindow.RightBoundary > maxRB { + maxRB = tw.TimeWindow.RightBoundary + } + } + return maxRB +} + +// The test architecture simulates a 2-cluster active-active setup: +// +// c1 (locally-written records) ──CDC──> c2 (replicated records) +// c2 (locally-written records) ──CDC──> c1 (replicated records) +// +// Each cluster writes locally-written records (originTs=0) and receives replicated +// records from the other cluster (originTs>0). +// +// The checker needs 3 warm-up rounds before it starts checking (checkableRound >= 3). +// Data written in round 0 is tracked by the S3 consumer but not downloaded +// (skipDownloadData=true for the first round). From round 1 onwards, only +// NEW files (with higher indices) are downloaded. +// +// Data commitTs is set to prevMaxRightBoundary+1 to ensure records fall +// within the current time window (leftBoundary, rightBoundary]. +// +// TestIntegration_AllConsistent verifies that no errors are reported +// when all locally-written records have matching replicated records. +func TestIntegration_AllConsistent(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + // c1: locally-written records write (originTs=0) + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + // c2: replicated records replicated from c1 (originTs = c1's commitTs) + c2 := MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: c1 TW=[%d, %d], c2 TW=[%d, %d], commitTs=%d", + round, + result.twData["c1"].TimeWindow.LeftBoundary, result.twData["c1"].TimeWindow.RightBoundary, + result.twData["c2"].TimeWindow.LeftBoundary, result.twData["c2"].TimeWindow.RightBoundary, + cts) + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: all data should be consistent, no report needed", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_AllConsistent_CrossRoundReplicatedRecords verifies that the checker +// treats data as consistent when a locally-written record's commitTs exceeds the +// round's checkpointTs, and the matching replicated records only appears in the next +// round. +// +// This occurs when locally-written records commitTs happen late in the time window, after +// the checkpoint has already been determined. For TW[2], records with +// commitTs > checkpointTs are deferred (skipped). In the next round they +// become TW[1], where the check condition is commitTs > checkpointTs (checked), +// and the replicated records are searched in TW[1] + TW[2] — finding the match in +// the current round's TW[2]. +func TestIntegration_AllConsistent_CrossRoundReplicatedRecords(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + // Offset to place commitTs between checkpointTs and rightBoundary. + // With pdStep=100 and 2 clusters, each round's time window spans + // approximately ComposeTS(300, 0) = 78643200, and checkpointTs sits + // at roughly ComposeTS(200, 0) from leftBoundary. + // Using ComposeTS(250, 0) = 65536000 lands safely between them. + crossRoundOffset := uint64(250 << 18) // ComposeTS(250, 0) = 65536000 + + var lateLocallyWrittenRecordsCommitTs uint64 + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Round N: c1 local has two records: + // pk=round+1 normal commitTs (checked in this round's TW[2]) + // pk=100 large commitTs > checkpointTs + // (deferred in TW[2], checked via TW[1] next round) + // c2 replicated only matches pk=round+1. + lateLocallyWrittenRecordsCommitTs = prevMaxRB + crossRoundOffset + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, lateLocallyWrittenRecordsCommitTs, 0, "late"), + ) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + case 5: + // Round N+1: c2 now includes the replicated record for pk=100. + // The checker evaluates TW[1] (= round 4), finds pk=100 with + // commitTs > checkpointTs, and searches c2's TW[1] + TW[2]. + // pk=100's matching replicated record is in c2's TW[2] (this round). + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, cts+2, lateLocallyWrittenRecordsCommitTs, "late"), + ) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: c1 TW=[%d, %d], cpTs=%v, commitTs=%d", + round, + result.twData["c1"].TimeWindow.LeftBoundary, + result.twData["c1"].TimeWindow.RightBoundary, + result.twData["c1"].TimeWindow.CheckpointTs, + cts) + + if round == 4 { + // Verify the late commitTs falls between checkpointTs and rightBoundary. + c1TW := result.twData["c1"].TimeWindow + cpTs := c1TW.CheckpointTs["c2"] + require.Greater(t, lateLocallyWrittenRecordsCommitTs, cpTs, + "lateLocallyWrittenRecordsCommitTs must be > checkpointTs for cross-round detection") + require.LessOrEqual(t, lateLocallyWrittenRecordsCommitTs, c1TW.RightBoundary, + "lateLocallyWrittenRecordsCommitTs must be <= rightBoundary to stay in this time window") + t.Logf("Round 4 verification: lateCommitTs=%d, checkpointTs=%d, rightBoundary=%d", + lateLocallyWrittenRecordsCommitTs, cpTs, c1TW.RightBoundary) + } + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: data should be consistent (cross-round matching should work)", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_AllConsistent_LWWSkippedReplicatedRecords verifies that no errors +// are reported when a replicated record is "LWW-skipped" during data-loss +// detection, combined with cross-time-window matching. +// +// pk=100: single-cluster overwrite (c1 writes old+new, c2 only has newer replicated records) +// +// Round N: c1 locally-written records pk=100 × 2 (commitTs=A, B; both > checkpointTs) +// c2 has NO replicated records for pk=100 +// Round N+1: c2 replicated records pk=100 (originTs=B, matches newer locally-written records only) +// → old locally-written records LWW-skipped (c2 replicated records compareTs=B >= A) +// +// pk=200: bidirectional write (c1 and c2 both write the same pk) +// +// Round N: c1 locally-written records pk=200 (commitTs=A, deferred) +// Round N+1: c1 locally-written records pk=200 (commitTs=E, newer), c1 replicated records pk=200 (originTs=D, from c2) +// c2 replicated records pk=200 (commitTs=D, D < E), c2 replicated records pk=200 (originTs=E, from c1) +// +// Key constraint: c1 local commitTs (E) > c2 local commitTs (D). +// This ensures that on c2, the replicated (compareTs=E) > local (compareTs=D), +// so the LWW violation checker sees monotonically increasing compareTs. +// +// c1 data loss for old pk=200 (commitTs=A): +// → c2 replicated has originTs=E, compareTs=E >= A → LWW-skipped ✓ +// c1 data loss for new pk=200 (commitTs=E): +// → c2 replicated has originTs=E → exact match ✓ +// c2 data loss for c2 local pk=200 (commitTs=D): +// → c1 replicated has originTs=D → exact match ✓ +func TestIntegration_AllConsistent_LWWSkippedReplicatedRecords(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + // Place both commitTs values between checkpointTs and rightBoundary. + // With pdStep=100 and 2 clusters: + // window width ≈ ComposeTS(300, 0), checkpointTs ≈ leftBoundary + ComposeTS(200, 0) + // Using ComposeTS(250, 0) puts us safely in the gap. + crossRoundOffset := uint64(250 << 18) // ComposeTS(250, 0) = 65536000 + + var oldCommitTs, newCommitTs uint64 + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Round N: c1 local writes pk=100 twice + pk=200 once, all > checkpointTs. + // c2 has NO replicated record for pk=100 or pk=200; they arrive next round. + oldCommitTs = prevMaxRB + crossRoundOffset + newCommitTs = oldCommitTs + 5 + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, oldCommitTs, 0, "old_write"), + MakeCanalJSON(100, newCommitTs, 0, "new_write"), + MakeCanalJSON(200, oldCommitTs, 0, "old_write"), + ) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + ) + + case 5: + // Round N+1: replicated data arrives for both pk=100 and pk=200. + // + // pk=200 bidirectional: c1 local at cts+5 (> c2 local at cts+2) + // ensures c2's LWW check sees increasing compareTs. + // c1: replicated(commitTs=cts+4, originTs=cts+2) then local(commitTs=cts+5) + // → compareTs order: cts+2 < cts+5 ✓ + // c2: local(commitTs=cts+2) then replicated(commitTs=cts+6, originTs=cts+5) + // → compareTs order: cts+2 < cts+5 ✓ + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(200, cts+4, cts+2, "pk200_c2"), // c1 replicated pk=200 from c2 + MakeCanalJSON(200, cts+5, 0, "pk200_c1"), // c1 local pk=200 (newer) + ) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, cts+2, newCommitTs, "new_write"), + MakeCanalJSON(200, cts+2, 0, "pk200_c2"), // c2 local pk=200 + MakeCanalJSON(200, cts+6, cts+5, "pk200_c1"), // c2 replicated pk=200 from c1 + ) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round == 4 { + // Verify both commitTs fall between checkpointTs and rightBoundary. + c1TW := result.twData["c1"].TimeWindow + cpTs := c1TW.CheckpointTs["c2"] + require.Greater(t, oldCommitTs, cpTs, + "oldCommitTs must be > checkpointTs for cross-round deferral") + require.LessOrEqual(t, newCommitTs, c1TW.RightBoundary, + "newCommitTs must be <= rightBoundary to stay in this time window") + t.Logf("Round 4 verification: oldCommitTs=%d, newCommitTs=%d, checkpointTs=%d, rightBoundary=%d", + oldCommitTs, newCommitTs, cpTs, c1TW.RightBoundary) + } + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: cross-round LWW-skipped replicated should not cause errors", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_DataLoss verifies that the checker detects data loss +// when a locally-written record has no matching replicated record in the other cluster. +func TestIntegration_DataLoss(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + dataLossDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + // c1 always produces local data + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + + var c2 []byte + if round == 4 { + // Round 4: c2 has NO matching replicated record → data loss expected + // (round 4's data is checked in the same round since checkableRound >= 3) + c2 = nil + } else { + // Normal: c2 has matching replicated record + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.DataLossItems) > 0 { + t.Logf("Round %d: detected data loss: %+v", round, items.DataLossItems) + dataLossDetected = true + // Verify the data loss item + for _, item := range items.DataLossItems { + require.Equal(t, "c2", item.PeerClusterID) + } + } + } + } + } + } + + require.True(t, dataLossDetected, "data loss should have been detected") +} + +// TestIntegration_DataInconsistent verifies that the checker detects data +// inconsistency when a replicated record has different column values +// from the locally-written record. +func TestIntegration_DataInconsistent(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + inconsistentDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + + var c2 []byte + if round == 4 { + // Round 4: c2 has replicated record with WRONG column value + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, "WRONG_VALUE")) + } else { + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + for _, item := range items.DataInconsistentItems { + t.Logf("Round %d: detected data inconsistency: %+v", round, item) + inconsistentDetected = true + require.Equal(t, "c2", item.PeerClusterID) + } + } + } + } + } + + require.True(t, inconsistentDetected, "data inconsistency should have been detected") +} + +// TestIntegration_DataRedundant verifies that the checker detects redundant +// replicated data that has no matching locally-written record. +func TestIntegration_DataRedundant(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + redundantDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 := MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + if round == 4 { + // Round 4: c2 has an EXTRA replicated record (pk=999) with a fake + // originTs that doesn't match any c1 local commitTs. + fakeOriginTs := cts - 5 // Doesn't match any c1 local commitTs + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(999, cts+2, fakeOriginTs, "extra"), + ) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c2Report := result.report.ClusterReports["c2"] + if c2Report != nil { + if items, ok := c2Report.TableFailureItems[schemaKey]; ok { + if len(items.DataRedundantItems) > 0 { + t.Logf("Round %d: detected data redundant: %+v", round, items.DataRedundantItems) + redundantDetected = true + } + } + } + } + } + + require.True(t, redundantDetected, "data redundancy should have been detected") +} + +// TestIntegration_LWWViolation verifies that the checker detects Last Write Wins +// violations when records for the same primary key have non-monotonic origin timestamps. +func TestIntegration_LWWViolation(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + lwwViolationDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + if round == 4 { + // Round 4: inject LWW violation in c1. + // Record A: pk=5, commitTs=cts, originTs=0 → compareTs = cts + // Record B: pk=5, commitTs=cts+2, originTs=cts-10 → compareTs = cts-10 + // Since A's compareTs (cts) >= B's compareTs (cts-10) and A's commitTs < B's commitTs, + // this is a Last Write Wins violation. + c1 = MakeContent( + MakeCanalJSON(5, cts, 0, "original"), + MakeCanalJSON(5, cts+2, cts-10, "replicated"), + ) + // c2: provide matching replicated record to avoid data loss noise + c2 = MakeContent( + MakeCanalJSON(5, cts+1, cts, "original"), + ) + } else { + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.LWWViolationItems) > 0 { + t.Logf("Round %d: detected LWW violation: %+v", round, items.LWWViolationItems) + lwwViolationDetected = true + } + } + } + } + } + + require.True(t, lwwViolationDetected, "LWW violation should have been detected") +} + +// TestIntegration_LWWViolation_AcrossRounds verifies that the checker detects +// LWW violations when conflicting records for the same pk appear in rounds N +// and N+2, with no data for that pk in round N+1. +// +// The clusterViolationChecker keeps cache entries for up to 3 rounds +// (previous: 0 → 1 → 2). Since Check runs before UpdateCache, an entry +// created in round N (previous=0) is still available at previous=2 when +// round N+2 runs. +// +// Timeline: +// +// Round N: c1 local pk=50 (originTs=0, compareTs=A) → cached +// Round N+1: no pk=50 data → cache ages (prev 1→2) +// Round N+2: c1 replicated pk=50 (originTs=B= new.compareTs + // → LWW violation across 2-round gap. + violatingOriginTs := firstRecordCommitTs - 10 + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(50, cts+2, violatingOriginTs, "second"), + ) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.LWWViolationItems) > 0 { + t.Logf("Round %d: LWW violation across rounds: %+v", + round, items.LWWViolationItems) + lwwViolationDetected = true + // Verify the violation details + item := items.LWWViolationItems[0] + require.Equal(t, uint64(0), item.ExistingOriginTS, + "existing record should be local (originTs=0)") + require.Equal(t, firstRecordCommitTs, item.ExistingCommitTS, + "existing record should be from round N") + } + } + } + } + } + + require.True(t, lwwViolationDetected, + "LWW violation across round N and N+2 should have been detected") +} + +// TestIntegration_MultipleErrorTypes verifies that the checker can detect +// multiple error types simultaneously across different clusters and rounds. +func TestIntegration_MultipleErrorTypes(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + dataLossDetected := false + redundantDetected := false + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Data loss: c1 local pk=5, c2 has NO replicated record + c1 = MakeContent(MakeCanalJSON(5, cts, 0, "lost")) + c2 = nil + case 5: + // Data redundant: c2 has extra replicated pk=888 + c1 = MakeContent(MakeCanalJSON(6, cts, 0, "normal")) + fakeOriginTs := cts - 3 + c2 = MakeContent( + MakeCanalJSON(6, cts+1, cts, "normal"), + MakeCanalJSON(888, cts+2, fakeOriginTs, "ghost"), + ) + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d, ClusterReports=%d", + round, result.report.NeedFlush(), cts, len(result.report.ClusterReports)) + + if round >= 3 && result.report.NeedFlush() { + // Check c1 for data loss + if c1Report := result.report.ClusterReports["c1"]; c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.DataLossItems) > 0 { + dataLossDetected = true + t.Logf("Round %d: data loss detected in c1: %d items", + round, len(items.DataLossItems)) + } + } + } + // Check c2 for data redundant + if c2Report := result.report.ClusterReports["c2"]; c2Report != nil { + if items, ok := c2Report.TableFailureItems[schemaKey]; ok { + if len(items.DataRedundantItems) > 0 { + redundantDetected = true + t.Logf("Round %d: data redundant detected in c2: %d items", + round, len(items.DataRedundantItems)) + } + } + } + } + } + + require.True(t, dataLossDetected, "data loss should have been detected") + require.True(t, redundantDetected, "data redundancy should have been detected") +} diff --git a/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go b/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go new file mode 100644 index 0000000000..96216e09b7 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go @@ -0,0 +1,205 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/tidb/br/pkg/storage" + pd "github.com/tikv/pd/client" +) + +// mockPDClient simulates a PD server's TSO service. +// Each GetTS call returns a monotonically increasing physical timestamp. +type mockPDClient struct { + pd.Client + seq atomic.Int64 + base int64 // base physical time in milliseconds + step int64 // increment per call in milliseconds +} + +func (m *mockPDClient) GetTS(_ context.Context) (int64, int64, error) { + n := m.seq.Add(1) + return m.base + n*m.step, 0, nil +} + +func (m *mockPDClient) Close() {} + +// ---------- Mock Checkpoint Watcher ---------- + +// mockWatcher simulates a checkpoint watcher. +// It always returns minCheckpointTs + delta, ensuring the result exceeds the minimum. +type mockWatcher struct { + delta uint64 +} + +func (m *mockWatcher) AdvanceCheckpointTs(_ context.Context, minCheckpointTs uint64) (uint64, error) { + return minCheckpointTs + m.delta, nil +} + +func (m *mockWatcher) Close() {} + +// MockMultiCluster manages the mock infrastructure for simulating multiple +// TiCDC clusters. It provides: +// - Mock PD clients for TSO generation +// - Mock checkpoint watchers for inter-cluster replication checkpoints +// - Mock S3 checkpoint watchers and S3 watchers for cloud storage +// - In-memory S3 storage for each cluster +// - Helpers to write canal-JSON formatted data files +type MockMultiCluster struct { + ClusterIDs []string + Tables map[string][]string // schema -> table names + + S3Storages map[string]storage.ExternalStorage + pdClients map[string]*mockPDClient + CPWatchers map[string]map[string]watcher.Watcher + S3Watchers map[string]*watcher.S3Watcher + + // fileCounters tracks the next DML file index per cluster. + // Files are written with monotonically increasing indices so the + // S3Consumer discovers only new files in each round. + fileCounters map[string]uint64 + + date string // fixed date used in all DML file paths +} + +// NewMockMultiCluster creates a new mock multi-cluster environment. +// +// Parameters: +// - clusterIDs: identifiers for the clusters (e.g. ["c1", "c2"]) +// - tables: schema -> table names mapping (e.g. {"test": ["t1"]}) +// - pdBase: base physical time (ms) for mock PD TSO generation +// - pdStep: physical time increment (ms) per GetTS call +// - cpDelta: checkpoint watcher returns minCheckpointTs + cpDelta +// - s3Delta: S3 checkpoint watcher returns minCheckpointTs + s3Delta +func NewMockMultiCluster( + clusterIDs []string, + tables map[string][]string, + pdBase, pdStep int64, + cpDelta, s3Delta uint64, +) *MockMultiCluster { + mc := &MockMultiCluster{ + ClusterIDs: clusterIDs, + Tables: tables, + S3Storages: make(map[string]storage.ExternalStorage), + pdClients: make(map[string]*mockPDClient), + CPWatchers: make(map[string]map[string]watcher.Watcher), + S3Watchers: make(map[string]*watcher.S3Watcher), + fileCounters: make(map[string]uint64), + date: "2026-02-11", + } + + for _, id := range clusterIDs { + mc.S3Storages[id] = storage.NewMemStorage() + mc.pdClients[id] = &mockPDClient{base: pdBase, step: pdStep} + + // Checkpoint watchers: one per replicated cluster + watchers := make(map[string]watcher.Watcher) + for _, other := range clusterIDs { + if other != id { + watchers[other] = &mockWatcher{delta: cpDelta} + } + } + mc.CPWatchers[id] = watchers + + // S3 watcher: uses in-memory storage + mock checkpoint watcher + s3CpWatcher := &mockWatcher{delta: s3Delta} + mc.S3Watchers[id] = watcher.NewS3Watcher( + s3CpWatcher, + mc.S3Storages[id], + tables, + ) + } + + return mc +} + +// InitSchemaFiles writes initial schema files for all tables in all clusters. +// The schema file content is empty (parser is nil in the current implementation). +func (mc *MockMultiCluster) InitSchemaFiles(ctx context.Context) error { + for _, s3 := range mc.S3Storages { + for schema, tableList := range mc.Tables { + for _, table := range tableList { + path := fmt.Sprintf("%s/%s/meta/schema_1_0000000000.json", schema, table) + if err := s3.WriteFile(ctx, path, []byte("{}")); err != nil { + return err + } + } + } + } + return nil +} + +// WriteDMLFile writes a canal-JSON DML file to a cluster's S3 storage. +// Each call increments the file index for that cluster, ensuring the +// S3Consumer discovers it as a new file. +func (mc *MockMultiCluster) WriteDMLFile(ctx context.Context, clusterID string, content []byte) error { + mc.fileCounters[clusterID]++ + idx := mc.fileCounters[clusterID] + for schema, tableList := range mc.Tables { + for _, table := range tableList { + path := fmt.Sprintf("%s/%s/1/%s/CDC%020d.json", schema, table, mc.date, idx) + if err := mc.S3Storages[clusterID].WriteFile(ctx, path, content); err != nil { + return err + } + } + } + return nil +} + +// GetPDClients returns mock PD clients as the pd.Client interface. +func (mc *MockMultiCluster) GetPDClients() map[string]pd.Client { + clients := make(map[string]pd.Client) + for id, c := range mc.pdClients { + clients[id] = c + } + return clients +} + +// Close closes all S3 watchers. +func (mc *MockMultiCluster) Close() { + for _, sw := range mc.S3Watchers { + sw.Close() + } +} + +// MakeCanalJSON builds a canal-JSON formatted record for testing. +// +// Parameters: +// - pkID: primary key value (int column "id") +// - commitTs: TiDB commit timestamp +// - originTs: origin timestamp (0 for locally-written records, non-zero for replicated records) +// - val: value for the "val" varchar column +func MakeCanalJSON(pkID int, commitTs uint64, originTs uint64, val string) string { + originTsVal := "null" + if originTs > 0 { + originTsVal = fmt.Sprintf(`"%d"`, originTs) + } + return fmt.Sprintf( + `{"id":0,"database":"test","table":"t1","pkNames":["id"],"isDdl":false,"type":"INSERT",`+ + `"es":0,"ts":0,"sql":"","sqlType":{"id":4,"val":12,"_tidb_origin_ts":-5},`+ + `"mysqlType":{"id":"int","val":"varchar","_tidb_origin_ts":"bigint"},`+ + `"old":null,"data":[{"id":"%d","val":"%s","_tidb_origin_ts":%s}],`+ + `"_tidb":{"commitTs":%d}}`, + pkID, val, originTsVal, commitTs) +} + +// MakeContent combines canal-JSON records with CRLF terminator (matching codec config). +func MakeContent(records ...string) []byte { + return []byte(strings.Join(records, "\r\n")) +} diff --git a/cmd/multi-cluster-consistency-checker/integration/validation_test.go b/cmd/multi-cluster-consistency-checker/integration/validation_test.go new file mode 100644 index 0000000000..97a6ef94d8 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/validation_test.go @@ -0,0 +1,315 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/stretchr/testify/require" +) + +const ( + envKeyCDCOutput = "ACTIVE_ACTIVE_FAILPOINT_CDC_OUTPUT" + + envKeyCheckerOutput = "ACTIVE_ACTIVE_FAILPOINT_CHECKER_OUTPUT" + + envKeyReportDir = "ACTIVE_ACTIVE_FAILPOINT_REPORT_DIR" +) + +func readRecordJSONL(path string) ([]checker.Record, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var records []checker.Record + scanner := bufio.NewScanner(file) + // JSONL lines may contain large row sets; increase scanner limit to avoid + // "bufio.Scanner: token too long". + scanner.Buffer(make([]byte, 0, 1024*1024), 64*1024*1024) + lineNo := 0 + for scanner.Scan() { + lineNo++ + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var record checker.Record + err := json.Unmarshal([]byte(line), &record) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal jsonl record in %s line %d: %w", path, lineNo, err) + } + records = append(records, record) + } + + err = scanner.Err() + if err != nil { + return nil, fmt.Errorf("failed to scan jsonl file %s: %w", path, err) + } + return records, nil +} + +func getCdcRecords(t *testing.T, cdcOutputPath string) []checker.Record { + records, err := readRecordJSONL(cdcOutputPath) + require.NoError(t, err) + return records +} + +func getCheckerRecords(t *testing.T, checkerOutputPath string) []checker.Record { + records, err := readRecordJSONL(checkerOutputPath) + require.NoError(t, err) + return records +} + +func readReports(t *testing.T, reportDir string) []recorder.Report { + entries, err := os.ReadDir(reportDir) + require.NoError(t, err) + + var reportFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + if strings.HasSuffix(entry.Name(), ".json") { + reportFiles = append(reportFiles, filepath.Join(reportDir, entry.Name())) + } + } + sort.Strings(reportFiles) + + var reports []recorder.Report + for _, path := range reportFiles { + content, err := os.ReadFile(path) + require.NoError(t, err) + + var report recorder.Report + err = json.Unmarshal(content, &report) + require.NoError(t, err) + reports = append(reports, report) + } + return reports +} + +type dataLossKey struct { + PK string + LocalCommitTS uint64 +} + +type dataInconsistentKey struct { + PK string + LocalCommitTS uint64 +} + +type dataRedundantKey struct { + PK string + OriginTS uint64 +} + +type lwwViolationKey struct { + PK string + OriginTS uint64 +} + +func pkMapToKey(pk map[string]any) string { + keys := make([]string, 0, len(pk)) + for key := range pk { + keys = append(keys, key) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, fmt.Sprintf("%s=%s", key, normalizePKValue(pk[key]))) + } + return strings.Join(parts, ",") +} + +func normalizePKValue(v any) string { + switch value := v.(type) { + case string: + if u, err := strconv.ParseUint(value, 10, 64); err == nil { + return strconv.FormatUint(u, 10) + } + if i, err := strconv.ParseInt(value, 10, 64); err == nil { + return strconv.FormatInt(i, 10) + } + return value + case json.Number: + return value.String() + case float64: + return strconv.FormatFloat(value, 'f', -1, 64) + case float32: + return strconv.FormatFloat(float64(value), 'f', -1, 32) + case int: + return strconv.FormatInt(int64(value), 10) + case int8: + return strconv.FormatInt(int64(value), 10) + case int16: + return strconv.FormatInt(int64(value), 10) + case int32: + return strconv.FormatInt(int64(value), 10) + case int64: + return strconv.FormatInt(value, 10) + case uint: + return strconv.FormatUint(uint64(value), 10) + case uint8: + return strconv.FormatUint(uint64(value), 10) + case uint16: + return strconv.FormatUint(uint64(value), 10) + case uint32: + return strconv.FormatUint(uint64(value), 10) + case uint64: + return strconv.FormatUint(value, 10) + default: + return fmt.Sprintf("%v", value) + } +} + +func validate(t *testing.T, + cdcRecords, checkerRecords []checker.Record, + reports []recorder.Report, +) { + dataLossItems := make(map[dataLossKey]struct{}) + dataInconsistentItems := make(map[dataInconsistentKey]struct{}) + dataRedundantItems := make(map[dataRedundantKey]struct{}) + lwwViolationItems := make(map[lwwViolationKey]struct{}) + for _, report := range reports { + for _, clusterReport := range report.ClusterReports { + for _, tableFailureItems := range clusterReport.TableFailureItems { + for _, dataLossItem := range tableFailureItems.DataLossItems { + key := dataLossKey{ + PK: pkMapToKey(dataLossItem.PK), + LocalCommitTS: dataLossItem.LocalCommitTS, + } + dataLossItems[key] = struct{}{} + } + for _, dataInconsistentItem := range tableFailureItems.DataInconsistentItems { + key := dataInconsistentKey{ + PK: pkMapToKey(dataInconsistentItem.PK), + LocalCommitTS: dataInconsistentItem.LocalCommitTS, + } + dataInconsistentItems[key] = struct{}{} + } + for _, dataRedundantItem := range tableFailureItems.DataRedundantItems { + key := dataRedundantKey{ + PK: pkMapToKey(dataRedundantItem.PK), + OriginTS: dataRedundantItem.OriginTS, + } + dataRedundantItems[key] = struct{}{} + } + for _, lwwViolationItem := range tableFailureItems.LWWViolationItems { + key := lwwViolationKey{ + PK: pkMapToKey(lwwViolationItem.PK), + OriginTS: lwwViolationItem.OriginTS, + } + lwwViolationItems[key] = struct{}{} + } + } + } + } + skippedRecords := make(map[dataLossKey]struct{}) + for _, record := range checkerRecords { + for _, row := range record.Rows { + skippedRecords[dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs}] = struct{}{} + } + } + for _, record := range cdcRecords { + for _, row := range record.Rows { + switch record.Failpoint { + case "cloudStorageSinkDropMessage": + if row.OriginTs > 0 { + key := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.OriginTs} + if _, ok := skippedRecords[key]; !ok { + _, ok := dataLossItems[key] + require.True(t, ok) + delete(dataLossItems, key) + } + } else { + // the replicated record maybe skipped by LWW + key := dataRedundantKey{PK: pkMapToKey(row.PrimaryKeys), OriginTS: row.CommitTs} + _, ok := dataRedundantItems[key] + if !ok { + t.Log("replicated record maybe skipped by LWW", key) + } + delete(dataRedundantItems, key) + } + case "cloudStorageSinkMutateValue": + keyLoss := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs} + if _, skipped := skippedRecords[keyLoss]; !skipped { + key := dataInconsistentKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs} + _, ok := dataInconsistentItems[key] + require.True(t, ok) + delete(dataInconsistentItems, key) + } + case "cloudStorageSinkMutateValueTidbOriginTs": + keyLoss := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.OriginTs} + if _, ok := skippedRecords[keyLoss]; !ok { + _, ok := dataLossItems[keyLoss] + require.True(t, ok) + delete(dataLossItems, keyLoss) + } + + keyRedundant := dataRedundantKey{PK: pkMapToKey(row.PrimaryKeys), OriginTS: row.OriginTs + 1} + _, ok := dataRedundantItems[keyRedundant] + require.True(t, ok) + delete(dataRedundantItems, keyRedundant) + } + } + } + require.Empty(t, dataLossItems) + require.Empty(t, dataInconsistentItems) + require.Empty(t, dataRedundantItems) + require.Empty(t, lwwViolationItems) + require.Fail(t, "success") +} + +func TestValidation(t *testing.T) { + var ( + cdcOutputPath string + checkerOutputPath string + reportDir string + ) + + cdcOutputPath = os.Getenv(envKeyCDCOutput) + if cdcOutputPath == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_CDC_OUTPUT is not set") + return + } + checkerOutputPath = os.Getenv(envKeyCheckerOutput) + if checkerOutputPath == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_CHECKER_OUTPUT is not set") + return + } + reportDir = os.Getenv(envKeyReportDir) + if reportDir == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_REPORT_DIR is not set") + return + } + + cdcRecords := getCdcRecords(t, cdcOutputPath) + checkerRecords := getCheckerRecords(t, checkerOutputPath) + reports := readReports(t, reportDir) + + validate(t, cdcRecords, checkerRecords, reports) +} diff --git a/cmd/multi-cluster-consistency-checker/main.go b/cmd/multi-cluster-consistency-checker/main.go new file mode 100644 index 0000000000..7f2390f99e --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/main.go @@ -0,0 +1,160 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/logger" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +var ( + cfgPath string + dryRun bool +) + +// Exit codes for multi-cluster-consistency-checker. +// +// 0 – clean shutdown (normal exit or graceful signal handling) +// 1 – transient error, safe to restart (network, I/O, temporary failures) +// 2 – invalid configuration (missing required flags / fields) +// 3 – configuration decode failure (malformed config file) +// 4 – checkpoint corruption, requires manual intervention +// 5 – unrecoverable internal error +const ( + ExitCodeTransient = 1 + ExitCodeInvalidConfig = 2 + ExitCodeDecodeConfigFailed = 3 + ExitCodeCheckpointCorruption = 4 + ExitCodeUnrecoverable = 5 +) + +// ExitError wraps an error with a process exit code so that callers higher in +// the stack can translate domain errors into the correct exit status. +type ExitError struct { + Code int + Err error +} + +func (e *ExitError) Error() string { return e.Err.Error() } +func (e *ExitError) Unwrap() error { return e.Err } + +// exitCodeFromError extracts the exit code from an error. +// If the error is an *ExitError the embedded code is returned; +// otherwise the fallback code is returned. +func exitCodeFromError(err error, fallback int) int { + var ee *ExitError + if errors.As(err, &ee) { + return ee.Code + } + return fallback +} + +const ( + FlagConfig = "config" + FlagDryRun = "dry-run" +) + +func main() { + rootCmd := &cobra.Command{ + Use: "multi-cluster-consistency-checker", + Short: "A tool to check consistency across multiple TiCDC clusters", + Long: "A tool to check consistency across multiple TiCDC clusters by comparing data from different clusters' S3 sink locations", + Run: run, + } + + rootCmd.Flags().StringVarP(&cfgPath, FlagConfig, "c", "", "configuration file path (required)") + rootCmd.MarkFlagRequired(FlagConfig) + rootCmd.Flags().BoolVar(&dryRun, FlagDryRun, false, "validate config and connectivity without running the checker") + + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(ExitCodeUnrecoverable) + } +} + +func run(cmd *cobra.Command, args []string) { + if cfgPath == "" { + fmt.Fprintln(os.Stderr, "error: --config flag is required") + os.Exit(ExitCodeInvalidConfig) + } + + cfg, err := config.LoadConfig(cfgPath) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) + os.Exit(ExitCodeDecodeConfigFailed) + } + + // Initialize logger with configured log level + logLevel := cfg.GlobalConfig.LogLevel + if logLevel == "" { + logLevel = "info" // default log level + } + loggerConfig := &logger.Config{ + Level: logLevel, + } + err = logger.InitLogger(loggerConfig) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err) + os.Exit(ExitCodeUnrecoverable) + } + log.Info("Logger initialized", zap.String("level", logLevel)) + + // Create a context that can be cancelled by signals + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + // Start the task in a goroutine + errChan := make(chan error, 1) + go func() { + err := runTask(ctx, cfg, dryRun) + if err != nil { + log.Error("task error", zap.Error(err)) + } + errChan <- err + }() + + // Wait for either a signal or task completion + select { + case sig := <-sigChan: + fmt.Fprintf(os.Stdout, "\nReceived signal: %v, shutting down gracefully...\n", sig) + cancel() + // Wait for the task to finish + if err := <-errChan; err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "task error during shutdown: %v\n", err) + code := exitCodeFromError(err, ExitCodeTransient) + os.Exit(code) + } + fmt.Fprintf(os.Stdout, "Shutdown complete\n") + case err := <-errChan: + if err != nil { + fmt.Fprintf(os.Stderr, "failed to run task: %v\n", err) + code := exitCodeFromError(err, ExitCodeTransient) + os.Exit(code) + } + } +} diff --git a/cmd/multi-cluster-consistency-checker/main_test.go b/cmd/multi-cluster-consistency-checker/main_test.go new file mode 100644 index 0000000000..ee3fc71f23 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/main_test.go @@ -0,0 +1,192 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "testing" + + "github.com/pingcap/ticdc/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestExitError_Error(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("something went wrong") + ee := &ExitError{Code: ExitCodeTransient, Err: inner} + require.Equal(t, "something went wrong", ee.Error()) +} + +func TestExitError_Unwrap(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("root cause") + ee := &ExitError{Code: ExitCodeCheckpointCorruption, Err: inner} + require.ErrorIs(t, ee, inner) + require.Equal(t, inner, ee.Unwrap()) +} + +func TestExitError_Unwrap_deep(t *testing.T) { + t.Parallel() + root := errors.New("root") + wrapped := fmt.Errorf("layer1: %w", root) + ee := &ExitError{Code: ExitCodeTransient, Err: wrapped} + require.ErrorIs(t, ee, root) +} + +func TestValidateS3BucketPrefix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + changefeedURI string + configURI string + wantErr bool + errContains string + }{ + { + name: "exact match", + changefeedURI: "s3://my-bucket/cluster1/", + configURI: "s3://my-bucket/cluster1/", + wantErr: false, + }, + { + name: "match ignoring trailing slash", + changefeedURI: "s3://my-bucket/cluster1", + configURI: "s3://my-bucket/cluster1/", + wantErr: false, + }, + { + name: "match with query params in changefeed URI", + changefeedURI: "s3://my-bucket/prefix/?protocol=canal-json&date-separator=day", + configURI: "s3://my-bucket/prefix/", + wantErr: false, + }, + { + name: "bucket mismatch", + changefeedURI: "s3://bucket-a/prefix/", + configURI: "s3://bucket-b/prefix/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "prefix mismatch", + changefeedURI: "s3://my-bucket/cluster1/", + configURI: "s3://my-bucket/cluster2/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "scheme mismatch", + changefeedURI: "gcs://my-bucket/prefix/", + configURI: "s3://my-bucket/prefix/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "deeper prefix mismatch", + changefeedURI: "s3://my-bucket/a/b/c/", + configURI: "s3://my-bucket/a/b/d/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "empty config URI", + changefeedURI: "s3://my-bucket/prefix/", + configURI: "", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateS3BucketPrefix(tt.changefeedURI, tt.configURI, "test-cluster", "cf-1") + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestExitCodeFromError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + fallback int + expected int + }{ + { + name: "ExitError with transient code", + err: &ExitError{Code: ExitCodeTransient, Err: fmt.Errorf("timeout")}, + fallback: ExitCodeUnrecoverable, + expected: ExitCodeTransient, + }, + { + name: "ExitError with checkpoint corruption code", + err: &ExitError{Code: ExitCodeCheckpointCorruption, Err: fmt.Errorf("bad checkpoint")}, + fallback: ExitCodeTransient, + expected: ExitCodeCheckpointCorruption, + }, + { + name: "ExitError with unrecoverable code", + err: &ExitError{Code: ExitCodeUnrecoverable, Err: fmt.Errorf("fatal")}, + fallback: ExitCodeTransient, + expected: ExitCodeUnrecoverable, + }, + { + name: "ExitError with invalid config code", + err: &ExitError{Code: ExitCodeInvalidConfig, Err: fmt.Errorf("missing field")}, + fallback: ExitCodeTransient, + expected: ExitCodeInvalidConfig, + }, + { + name: "ExitError with decode config failed code", + err: &ExitError{Code: ExitCodeDecodeConfigFailed, Err: fmt.Errorf("bad toml")}, + fallback: ExitCodeTransient, + expected: ExitCodeDecodeConfigFailed, + }, + { + name: "plain error returns fallback", + err: fmt.Errorf("some plain error"), + fallback: ExitCodeTransient, + expected: ExitCodeTransient, + }, + { + name: "plain error returns different fallback", + err: fmt.Errorf("another plain error"), + fallback: ExitCodeUnrecoverable, + expected: ExitCodeUnrecoverable, + }, + { + name: "wrapped ExitError is still extracted", + err: fmt.Errorf("outer: %w", &ExitError{Code: ExitCodeCheckpointCorruption, Err: fmt.Errorf("inner")}), + fallback: ExitCodeTransient, + expected: ExitCodeCheckpointCorruption, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + code := exitCodeFromError(tt.err, tt.fallback) + require.Equal(t, tt.expected, code) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder.go b/cmd/multi-cluster-consistency-checker/recorder/recorder.go new file mode 100644 index 0000000000..06fb7c69bd --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder.go @@ -0,0 +1,270 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "go.uber.org/zap" +) + +// ErrCheckpointCorruption is a sentinel error indicating that the persisted +// checkpoint data is corrupted and requires manual intervention to fix. +var ErrCheckpointCorruption = errors.New("checkpoint corruption") + +type Recorder struct { + reportDir string + checkpointDir string + maxReportFiles int + + // reportFiles caches the sorted list of report file names in reportDir. + // Updated in-memory after flush and cleanup to avoid repeated os.ReadDir calls. + reportFiles []string + + checkpoint *Checkpoint +} + +func NewRecorder(dataDir string, clusters map[string]config.ClusterConfig, maxReportFiles int) (*Recorder, error) { + if err := os.MkdirAll(filepath.Join(dataDir, "report"), 0o755); err != nil { + return nil, errors.Trace(err) + } + if err := os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755); err != nil { + return nil, errors.Trace(err) + } + if maxReportFiles <= 0 { + maxReportFiles = config.DefaultMaxReportFiles + } + + // Read existing report files once at startup + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + if err != nil { + return nil, errors.Trace(err) + } + reportFiles := make([]string, 0, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + reportFiles = append(reportFiles, entry.Name()) + } + } + sort.Strings(reportFiles) + + r := &Recorder{ + reportDir: filepath.Join(dataDir, "report"), + checkpointDir: filepath.Join(dataDir, "checkpoint"), + maxReportFiles: maxReportFiles, + reportFiles: reportFiles, + + checkpoint: NewCheckpoint(), + } + if err := r.initializeCheckpoint(); err != nil { + return nil, errors.Trace(err) + } + for _, item := range r.checkpoint.CheckpointItems { + if item == nil { + continue + } + if len(item.ClusterInfo) != len(clusters) { + return nil, errors.Annotatef(ErrCheckpointCorruption, "checkpoint item (round %d) cluster info length mismatch, expected %d, got %d", item.Round, len(clusters), len(item.ClusterInfo)) + } + for clusterID := range clusters { + if _, ok := item.ClusterInfo[clusterID]; !ok { + return nil, errors.Annotatef(ErrCheckpointCorruption, "checkpoint item (round %d) cluster info missing for cluster %s", item.Round, clusterID) + } + } + } + + return r, nil +} + +func (r *Recorder) GetCheckpoint() *Checkpoint { + return r.checkpoint +} + +func (r *Recorder) initializeCheckpoint() error { + checkpointFile := filepath.Join(r.checkpointDir, "checkpoint.json") + bakFile := filepath.Join(r.checkpointDir, "checkpoint.json.bak") + + // If checkpoint.json exists, use it directly. + if _, err := os.Stat(checkpointFile); err == nil { + data, err := os.ReadFile(checkpointFile) + if err != nil { + return errors.Trace(err) // transient I/O error + } + if err := json.Unmarshal(data, r.checkpoint); err != nil { + return errors.Annotatef(ErrCheckpointCorruption, "failed to unmarshal checkpoint.json: %v", err) + } + return nil + } else if !os.IsNotExist(err) { + return errors.Annotatef(ErrCheckpointCorruption, "failed to stat checkpoint.json: %v", err) + } + + // checkpoint.json is missing — try recovering from the backup. + // This can happen when the process crashed after rename but before the + // new temp file was renamed into place. + if _, err := os.Stat(bakFile); err == nil { + log.Warn("checkpoint.json not found, recovering from checkpoint.json.bak") + data, err := os.ReadFile(bakFile) + if err != nil { + return errors.Trace(err) // transient I/O error + } + if err := json.Unmarshal(data, r.checkpoint); err != nil { + return errors.Annotatef(ErrCheckpointCorruption, "failed to unmarshal checkpoint.json.bak: %v", err) + } + // Restore the backup as the primary file + if err := os.Rename(bakFile, checkpointFile); err != nil { + return errors.Trace(err) // transient I/O error + } + return nil + } else if !os.IsNotExist(err) { + return errors.Annotatef(ErrCheckpointCorruption, "failed to stat checkpoint.json.bak: %v", err) + } + + // Neither file exists — fresh start. + return nil +} + +func (r *Recorder) RecordTimeWindow(timeWindowData map[string]types.TimeWindowData, report *Report) error { + for clusterID, timeWindow := range timeWindowData { + log.Info("time window advanced", + zap.Uint64("round", report.Round), + zap.String("clusterID", clusterID), + zap.Uint64("windowLeftBoundary", timeWindow.LeftBoundary), + zap.Uint64("windowRightBoundary", timeWindow.RightBoundary), + zap.Any("checkpointTs", timeWindow.CheckpointTs)) + } + if report.NeedFlush() { + if err := r.flushReport(report); err != nil { + return errors.Trace(err) + } + r.cleanupOldReports() + } + if err := r.flushCheckpoint(report.Round, timeWindowData); err != nil { + return errors.Trace(err) + } + return nil +} + +func (r *Recorder) flushReport(report *Report) error { + reportName := fmt.Sprintf("report-%d.report", report.Round) + if err := atomicWriteFile(filepath.Join(r.reportDir, reportName), []byte(report.MarshalReport())); err != nil { + return errors.Trace(err) + } + + jsonName := fmt.Sprintf("report-%d.json", report.Round) + dataBytes, err := json.Marshal(report) + if err != nil { + return errors.Trace(err) + } + if err := atomicWriteFile(filepath.Join(r.reportDir, jsonName), dataBytes); err != nil { + return errors.Trace(err) + } + + // Append new file names to the cache (they are always the latest, so append at end) + r.reportFiles = append(r.reportFiles, reportName, jsonName) + return nil +} + +// atomicWriteFile writes data to a temporary file, fsyncs it to ensure +// durability, and then atomically renames it to the target path. +// This prevents partial / corrupt files on crash. +func atomicWriteFile(targetPath string, data []byte) error { + tempPath := targetPath + ".tmp" + if err := syncWriteFile(tempPath, data); err != nil { + return errors.Trace(err) + } + if err := os.Rename(tempPath, targetPath); err != nil { + return errors.Trace(err) + } + return nil +} + +// syncWriteFile writes data to a file and fsyncs it before returning, +// guaranteeing that the content is durable on disk. +func syncWriteFile(path string, data []byte) error { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return errors.Trace(err) + } + if _, err := f.Write(data); err != nil { + f.Close() + return errors.Trace(err) + } + if err := f.Sync(); err != nil { + f.Close() + return errors.Trace(err) + } + return errors.Trace(f.Close()) +} + +// cleanupOldReports removes the oldest report files from the in-memory cache +// when the total number exceeds maxReportFiles * 2 (each round produces .report + .json). +func (r *Recorder) cleanupOldReports() { + if len(r.reportFiles) <= r.maxReportFiles*2 { + return + } + + toDelete := len(r.reportFiles) - r.maxReportFiles*2 + for i := 0; i < toDelete; i++ { + path := filepath.Join(r.reportDir, r.reportFiles[i]) + if err := os.Remove(path); err != nil { + log.Warn("failed to remove old report file", + zap.String("path", path), + zap.Error(err)) + } else { + log.Info("removed old report file", zap.String("path", path)) + } + } + r.reportFiles = r.reportFiles[toDelete:] +} + +func (r *Recorder) flushCheckpoint(round uint64, timeWindowData map[string]types.TimeWindowData) error { + r.checkpoint.NewTimeWindowData(round, timeWindowData) + + checkpointFile := filepath.Join(r.checkpointDir, "checkpoint.json") + bakFile := filepath.Join(r.checkpointDir, "checkpoint.json.bak") + tempFile := filepath.Join(r.checkpointDir, "checkpoint_temp.json") + + data, err := json.Marshal(r.checkpoint) + if err != nil { + return errors.Trace(err) + } + + // 1. Write the new content to a temp file first and fsync it. + if err := syncWriteFile(tempFile, data); err != nil { + return errors.Trace(err) + } + + // 2. Rename the existing checkpoint to .bak (ignore error if it doesn't exist yet). + if err := os.Rename(checkpointFile, bakFile); err != nil && !os.IsNotExist(err) { + return errors.Trace(err) + } + + // 3. Rename the temp file to be the new checkpoint. + if err := os.Rename(tempFile, checkpointFile); err != nil { + return errors.Trace(err) + } + + // 4. Remove the backup — no longer needed. + _ = os.Remove(bakFile) + + return nil +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go new file mode 100644 index 0000000000..75f1209b9a --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go @@ -0,0 +1,561 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestNewRecorder(t *testing.T) { + t.Parallel() + + t.Run("creates directories", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + require.NotNil(t, r) + + // Verify directories exist + info, err := os.Stat(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.True(t, info.IsDir()) + + info, err = os.Stat(filepath.Join(dataDir, "checkpoint")) + require.NoError(t, err) + require.True(t, info.IsDir()) + }) + + t.Run("checkpoint is initialized empty", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r.GetCheckpoint() + require.NotNil(t, cp) + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.Nil(t, cp.CheckpointItems[2]) + }) + + t.Run("loads existing checkpoint on startup", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // First recorder: write a checkpoint + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 1, VersionPath: "vp1", DataPath: "dp1"}, + }, + }, + } + report := NewReport(0) + err = r1.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // Second recorder: should load the checkpoint + r2, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r2.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[2].Round) + info := cp.CheckpointItems[2].ClusterInfo["c1"] + require.Equal(t, uint64(1), info.TimeWindow.LeftBoundary) + require.Equal(t, uint64(10), info.TimeWindow.RightBoundary) + }) + + t.Run("cluster count mismatch rejects checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a checkpoint with 2 clusters + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Try to load with only 1 cluster — should fail + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster info length mismatch") + }) + + t.Run("cluster ID missing rejects checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a checkpoint with clusters c1 and c2 + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Try to load with c1 and c3 (same count, different ID) — should fail + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c3": {}}, 0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster info missing for cluster c3") + }) + + t.Run("matching clusters loads checkpoint successfully", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + clusters := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + // Write checkpoint across 3 rounds so all 3 slots are filled + r1, err := NewRecorder(dataDir, clusters, 0) + require.NoError(t, err) + for i := range 3 { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: uint64(i * 10), RightBoundary: uint64((i + 1) * 10)}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: uint64(i * 10), RightBoundary: uint64((i + 1) * 10)}}, + } + err = r1.RecordTimeWindow(twData, NewReport(uint64(i))) + require.NoError(t, err) + } + + // Reload with the same clusters — should succeed + r2, err := NewRecorder(dataDir, clusters, 0) + require.NoError(t, err) + cp := r2.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + }) + + t.Run("nil checkpoint items are skipped during validation", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write only 1 round — items[0] and items[1] stay nil + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload — should succeed even with a different cluster count since + // only the non-nil item[2] is validated + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + }) + + t.Run("no checkpoint file skips validation", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Fresh start with any cluster config — should always succeed + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}, "c3": {}}, 0) + require.NoError(t, err) + require.NotNil(t, r) + + cp := r.GetCheckpoint() + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.Nil(t, cp.CheckpointItems[2]) + }) +} + +func TestRecorder_RecordTimeWindow(t *testing.T) { + t.Parallel() + + t.Run("without report flush writes only checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + } + report := NewReport(0) // needFlush = false + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // checkpoint.json should exist + _, err = os.Stat(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + // No report files + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Empty(t, entries) + }) + + t.Run("with report flush writes both checkpoint and report", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + } + report := NewReport(5) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, 200) + report.AddClusterReport("c1", cr) + require.True(t, report.NeedFlush()) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // checkpoint.json should exist + _, err = os.Stat(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + // Report files should exist + _, err = os.Stat(filepath.Join(dataDir, "report", "report-5.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-5.json")) + require.NoError(t, err) + + // Verify report content + reportData, err := os.ReadFile(filepath.Join(dataDir, "report", "report-5.report")) + require.NoError(t, err) + require.Contains(t, string(reportData), "round: 5") + require.Contains(t, string(reportData), `[id: 1]`) + + // Verify json report is valid JSON + jsonData, err := os.ReadFile(filepath.Join(dataDir, "report", "report-5.json")) + require.NoError(t, err) + var parsed Report + err = json.Unmarshal(jsonData, &parsed) + require.NoError(t, err) + require.Equal(t, uint64(5), parsed.Round) + }) + + t.Run("multiple rounds advance checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + for i := uint64(0); i < 4; i++ { + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: i + 1}, + }, + }, + } + report := NewReport(i) + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // After 4 rounds, checkpoint should have rounds 1, 2, 3 (oldest evicted) + cp := r.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(1), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(3), cp.CheckpointItems[2].Round) + }) +} + +func TestRecorder_CheckpointPersistence(t *testing.T) { + t.Parallel() + + t.Run("checkpoint survives restart", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + stk := types.SchemaTableKey{Schema: "db", Table: "tbl"} + + // Simulate 3 rounds + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + for i := uint64(0); i < 3; i++ { + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: i + 1, VersionPath: fmt.Sprintf("vp%d", i), DataPath: fmt.Sprintf("dp%d", i)}, + }, + }, + } + report := NewReport(i) + err = r1.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // Restart: new recorder from same dir + r2, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r2.GetCheckpoint() + require.Equal(t, uint64(0), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[2].Round) + + // Verify ToScanRange works after restart + scanRange, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, scanRange, 1) + sr := scanRange[stk] + require.Equal(t, "vp0", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp0", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("old report files are cleaned up when exceeding max", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + maxReportFiles := 3 + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, maxReportFiles) + require.NoError(t, err) + + // Write 5 reports (each creates 2 files: .report and .json) + for i := uint64(0); i < 5; i++ { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + } + report := NewReport(i) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, i+1) + report.AddClusterReport("c1", cr) + require.True(t, report.NeedFlush()) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // Should have at most maxReportFiles * 2 = 6 files (rounds 2, 3, 4) + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Equal(t, maxReportFiles*2, len(entries)) + + // Oldest files (round 0 and 1) should be deleted + _, err = os.Stat(filepath.Join(dataDir, "report", "report-0.report")) + require.True(t, os.IsNotExist(err)) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-1.report")) + require.True(t, os.IsNotExist(err)) + + // Newest files should still exist + _, err = os.Stat(filepath.Join(dataDir, "report", "report-2.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-3.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-4.report")) + require.NoError(t, err) + }) + + t.Run("no cleanup when under max report files", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + maxReportFiles := 10 + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, maxReportFiles) + require.NoError(t, err) + + // Write 3 reports + for i := uint64(0); i < 3; i++ { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + } + report := NewReport(i) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, i+1) + report.AddClusterReport("c1", cr) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // All 6 files should exist (3 rounds * 2 files each) + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Equal(t, 6, len(entries)) + }) + + t.Run("checkpoint json is valid", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"c2": 150}, + }, + }, + } + report := NewReport(0) + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // Read and parse checkpoint.json + data, err := os.ReadFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + var cp Checkpoint + err = json.Unmarshal(data, &cp) + require.NoError(t, err) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(100), cp.CheckpointItems[2].ClusterInfo["c1"].TimeWindow.LeftBoundary) + require.Equal(t, uint64(200), cp.CheckpointItems[2].ClusterInfo["c1"].TimeWindow.RightBoundary) + }) +} + +func TestErrCheckpointCorruption(t *testing.T) { + t.Parallel() + + t.Run("corrupted checkpoint file returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create report and checkpoint directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) + + // Write invalid JSON to checkpoint.json + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), []byte("{bad json"), 0o600) + require.NoError(t, err) + + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("cluster count mismatch returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a valid checkpoint with 2 clusters + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload with 1 cluster — should be ErrCheckpointCorruption + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("missing cluster ID returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a valid checkpoint with clusters c1 and c2 + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload with c1 and c3 — should be ErrCheckpointCorruption + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c3": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("fresh start does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + }) + + t.Run("unreadable checkpoint file does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) + + // Create checkpoint.json as a directory so ReadFile fails with a non-corruption I/O error + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), 0o755)) + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.False(t, errors.Is(err, ErrCheckpointCorruption), + "I/O errors should NOT be classified as ErrCheckpointCorruption, got: %v", err) + }) + + t.Run("unreadable backup checkpoint does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) + + // Make checkpoint.json.bak a directory so ReadFile fails with an I/O error + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), 0o755)) + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.False(t, errors.Is(err, ErrCheckpointCorruption), + "I/O errors should NOT be classified as ErrCheckpointCorruption, got: %v", err) + }) + + t.Run("corrupted backup checkpoint returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) + + // Write invalid JSON to checkpoint.json.bak (simulate crash recovery with corrupted backup) + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), []byte("not valid json"), 0o600) + require.NoError(t, err) + + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/types.go b/cmd/multi-cluster-consistency-checker/recorder/types.go new file mode 100644 index 0000000000..a3c8cbd234 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/types.go @@ -0,0 +1,408 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "fmt" + "sort" + "strings" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" +) + +type DataLossItem struct { + PeerClusterID string `json:"peer_cluster_id"` + PK map[string]any `json:"pk"` + LocalCommitTS uint64 `json:"local_commit_ts"` + + PKStr string `json:"-"` +} + +func (item *DataLossItem) String() string { + return fmt.Sprintf("peer cluster: %s, pk: %s, local commit ts: %d", item.PeerClusterID, item.PKStr, item.LocalCommitTS) +} + +type InconsistentColumn struct { + Column string `json:"column"` + Local any `json:"local"` + Replicated any `json:"replicated"` +} + +func (c *InconsistentColumn) String() string { + return fmt.Sprintf("column: %s, local: %v, replicated: %v", c.Column, c.Local, c.Replicated) +} + +type DataInconsistentItem struct { + PeerClusterID string `json:"peer_cluster_id"` + PK map[string]any `json:"pk"` + LocalCommitTS uint64 `json:"local_commit_ts"` + ReplicatedCommitTS uint64 `json:"replicated_commit_ts"` + InconsistentColumns []InconsistentColumn `json:"inconsistent_columns,omitempty"` + + PKStr string `json:"-"` +} + +func (item *DataInconsistentItem) String() string { + var sb strings.Builder + fmt.Fprintf(&sb, "peer cluster: %s, pk: %s, local commit ts: %d, replicated commit ts: %d", + item.PeerClusterID, item.PKStr, item.LocalCommitTS, item.ReplicatedCommitTS) + if len(item.InconsistentColumns) > 0 { + sb.WriteString(", inconsistent columns: [") + for i, col := range item.InconsistentColumns { + if i > 0 { + sb.WriteString("; ") + } + sb.WriteString(col.String()) + } + sb.WriteString("]") + } + return sb.String() +} + +type DataRedundantItem struct { + PK map[string]any `json:"pk"` + OriginTS uint64 `json:"origin_ts"` + ReplicatedCommitTS uint64 `json:"replicated_commit_ts"` + + PKStr string `json:"-"` +} + +func (item *DataRedundantItem) String() string { + return fmt.Sprintf("pk: %s, origin ts: %d, replicated commit ts: %d", item.PKStr, item.OriginTS, item.ReplicatedCommitTS) +} + +type LWWViolationItem struct { + PK map[string]any `json:"pk"` + ExistingOriginTS uint64 `json:"existing_origin_ts"` + ExistingCommitTS uint64 `json:"existing_commit_ts"` + OriginTS uint64 `json:"origin_ts"` + CommitTS uint64 `json:"commit_ts"` + + PKStr string `json:"-"` +} + +func (item *LWWViolationItem) String() string { + return fmt.Sprintf( + "pk: %s, existing origin ts: %d, existing commit ts: %d, origin ts: %d, commit ts: %d", + item.PKStr, item.ExistingOriginTS, item.ExistingCommitTS, item.OriginTS, item.CommitTS) +} + +type TableFailureItems struct { + DataLossItems []DataLossItem `json:"data_loss_items"` // data loss items + DataInconsistentItems []DataInconsistentItem `json:"data_inconsistent_items"` // data inconsistent items + DataRedundantItems []DataRedundantItem `json:"data_redundant_items"` // data redundant items + LWWViolationItems []LWWViolationItem `json:"lww_violation_items"` // lww violation items +} + +func NewTableFailureItems() *TableFailureItems { + return &TableFailureItems{ + DataLossItems: make([]DataLossItem, 0), + DataInconsistentItems: make([]DataInconsistentItem, 0), + DataRedundantItems: make([]DataRedundantItem, 0), + LWWViolationItems: make([]LWWViolationItem, 0), + } +} + +type ClusterReport struct { + ClusterID string `json:"cluster_id"` + + TimeWindow types.TimeWindow `json:"time_window"` + + TableFailureItems map[string]*TableFailureItems `json:"table_failure_items"` // table failure items + + needFlush bool `json:"-"` +} + +func NewClusterReport(clusterID string, timeWindow types.TimeWindow) *ClusterReport { + return &ClusterReport{ + ClusterID: clusterID, + TimeWindow: timeWindow, + TableFailureItems: make(map[string]*TableFailureItems), + needFlush: false, + } +} + +func (r *ClusterReport) AddDataLossItem( + peerClusterID, schemaKey string, + pk map[string]any, + pkStr string, + localCommitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataLossItems = append(tableFailureItems.DataLossItems, DataLossItem{ + PeerClusterID: peerClusterID, + PK: pk, + LocalCommitTS: localCommitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddDataInconsistentItem( + peerClusterID, schemaKey string, + pk map[string]any, + pkStr string, + localCommitTS, replicatedCommitTS uint64, + inconsistentColumns []InconsistentColumn, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataInconsistentItems = append(tableFailureItems.DataInconsistentItems, DataInconsistentItem{ + PeerClusterID: peerClusterID, + PK: pk, + LocalCommitTS: localCommitTS, + ReplicatedCommitTS: replicatedCommitTS, + InconsistentColumns: inconsistentColumns, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddDataRedundantItem( + schemaKey string, + pk map[string]any, + pkStr string, + originTS, replicatedCommitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataRedundantItems = append(tableFailureItems.DataRedundantItems, DataRedundantItem{ + PK: pk, + OriginTS: originTS, + ReplicatedCommitTS: replicatedCommitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddLWWViolationItem( + schemaKey string, + pk map[string]any, + pkStr string, + existingOriginTS, existingCommitTS uint64, + originTS, commitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.LWWViolationItems = append(tableFailureItems.LWWViolationItems, LWWViolationItem{ + PK: pk, + ExistingOriginTS: existingOriginTS, + ExistingCommitTS: existingCommitTS, + OriginTS: originTS, + CommitTS: commitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +type Report struct { + Round uint64 `json:"round"` + ClusterReports map[string]*ClusterReport `json:"cluster_reports"` + needFlush bool `json:"-"` +} + +func NewReport(round uint64) *Report { + return &Report{ + Round: round, + ClusterReports: make(map[string]*ClusterReport), + needFlush: false, + } +} + +func (r *Report) AddClusterReport(clusterID string, clusterReport *ClusterReport) { + r.ClusterReports[clusterID] = clusterReport + r.needFlush = r.needFlush || clusterReport.needFlush +} + +func (r *Report) MarshalReport() string { + var reportMsg strings.Builder + fmt.Fprintf(&reportMsg, "round: %d\n", r.Round) + + // Sort cluster IDs for deterministic output + clusterIDs := make([]string, 0, len(r.ClusterReports)) + for clusterID := range r.ClusterReports { + clusterIDs = append(clusterIDs, clusterID) + } + sort.Strings(clusterIDs) + + for _, clusterID := range clusterIDs { + clusterReport := r.ClusterReports[clusterID] + if !clusterReport.needFlush { + continue + } + fmt.Fprintf(&reportMsg, "\n[cluster: %s]\n", clusterID) + fmt.Fprintf(&reportMsg, "time window: %s\n", clusterReport.TimeWindow.String()) + + // Sort schema keys for deterministic output + schemaKeys := make([]string, 0, len(clusterReport.TableFailureItems)) + for schemaKey := range clusterReport.TableFailureItems { + schemaKeys = append(schemaKeys, schemaKey) + } + sort.Strings(schemaKeys) + + for _, schemaKey := range schemaKeys { + tableFailureItems := clusterReport.TableFailureItems[schemaKey] + fmt.Fprintf(&reportMsg, " - [table name: %s]\n", schemaKey) + if len(tableFailureItems.DataLossItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data loss items: %d]\n", len(tableFailureItems.DataLossItems)) + for _, dataLossItem := range tableFailureItems.DataLossItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataLossItem.String()) + } + } + if len(tableFailureItems.DataInconsistentItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data inconsistent items: %d]\n", len(tableFailureItems.DataInconsistentItems)) + for _, dataInconsistentItem := range tableFailureItems.DataInconsistentItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataInconsistentItem.String()) + } + } + if len(tableFailureItems.DataRedundantItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data redundant items: %d]\n", len(tableFailureItems.DataRedundantItems)) + for _, dataRedundantItem := range tableFailureItems.DataRedundantItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataRedundantItem.String()) + } + } + if len(tableFailureItems.LWWViolationItems) > 0 { + fmt.Fprintf(&reportMsg, " - [lww violation items: %d]\n", len(tableFailureItems.LWWViolationItems)) + for _, lwwViolationItem := range tableFailureItems.LWWViolationItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", lwwViolationItem.String()) + } + } + } + } + reportMsg.WriteString("\n") + return reportMsg.String() +} + +func (r *Report) NeedFlush() bool { + return r.needFlush +} + +type SchemaTableVersionKey struct { + types.SchemaTableKey + types.VersionKey +} + +func NewSchemaTableVersionKeyFromVersionKeyMap(versionKeyMap map[types.SchemaTableKey]types.VersionKey) []SchemaTableVersionKey { + result := make([]SchemaTableVersionKey, 0, len(versionKeyMap)) + for schemaTableKey, versionKey := range versionKeyMap { + result = append(result, SchemaTableVersionKey{ + SchemaTableKey: schemaTableKey, + VersionKey: versionKey, + }) + } + return result +} + +type CheckpointClusterInfo struct { + TimeWindow types.TimeWindow `json:"time_window"` + MaxVersion []SchemaTableVersionKey `json:"max_version"` +} + +type CheckpointItem struct { + Round uint64 `json:"round"` + ClusterInfo map[string]CheckpointClusterInfo `json:"cluster_info"` +} + +type Checkpoint struct { + CheckpointItems [3]*CheckpointItem `json:"checkpoint_items"` +} + +func NewCheckpoint() *Checkpoint { + return &Checkpoint{ + CheckpointItems: [3]*CheckpointItem{ + nil, + nil, + nil, + }, + } +} + +func (c *Checkpoint) NewTimeWindowData(round uint64, timeWindowData map[string]types.TimeWindowData) { + newCheckpointItem := CheckpointItem{ + Round: round, + ClusterInfo: make(map[string]CheckpointClusterInfo), + } + for clusterID, timeWindow := range timeWindowData { + newCheckpointItem.ClusterInfo[clusterID] = CheckpointClusterInfo{ + TimeWindow: timeWindow.TimeWindow, + MaxVersion: NewSchemaTableVersionKeyFromVersionKeyMap(timeWindow.MaxVersion), + } + } + c.CheckpointItems[0] = c.CheckpointItems[1] + c.CheckpointItems[1] = c.CheckpointItems[2] + c.CheckpointItems[2] = &newCheckpointItem +} + +type ScanRange struct { + StartVersionKey string + EndVersionKey string + StartDataPath string + EndDataPath string +} + +func (c *Checkpoint) ToScanRange(clusterID string) (map[types.SchemaTableKey]*ScanRange, error) { + result := make(map[types.SchemaTableKey]*ScanRange) + if c.CheckpointItems[2] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[2].ClusterInfo[clusterID].MaxVersion { + result[versionKey.SchemaTableKey] = &ScanRange{ + StartVersionKey: versionKey.VersionPath, + EndVersionKey: versionKey.VersionPath, + StartDataPath: versionKey.DataPath, + EndDataPath: versionKey.DataPath, + } + } + if c.CheckpointItems[1] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[1].ClusterInfo[clusterID].MaxVersion { + scanRange, ok := result[versionKey.SchemaTableKey] + if !ok { + return nil, errors.Errorf("schema table key %s.%s not found in result", versionKey.Schema, versionKey.Table) + } + scanRange.StartVersionKey = versionKey.VersionPath + scanRange.StartDataPath = versionKey.DataPath + } + if c.CheckpointItems[0] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[0].ClusterInfo[clusterID].MaxVersion { + scanRange, ok := result[versionKey.SchemaTableKey] + if !ok { + return nil, errors.Errorf("schema table key %s.%s not found in result", versionKey.Schema, versionKey.Table) + } + scanRange.StartVersionKey = versionKey.VersionPath + scanRange.StartDataPath = versionKey.DataPath + } + return result, nil +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/types_test.go b/cmd/multi-cluster-consistency-checker/recorder/types_test.go new file mode 100644 index 0000000000..02c23908d6 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/types_test.go @@ -0,0 +1,640 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "fmt" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/stretchr/testify/require" +) + +func TestDataLossItem_String(t *testing.T) { + t.Parallel() + item := &DataLossItem{ + PeerClusterID: "cluster-2", + PK: map[string]any{"id": "1"}, + LocalCommitTS: 200, + PKStr: `[id: 1]`, + } + s := item.String() + require.Equal(t, `peer cluster: cluster-2, pk: [id: 1], local commit ts: 200`, s) +} + +func TestDataInconsistentItem_String(t *testing.T) { + t.Parallel() + + t.Run("without inconsistent columns", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + } + s := item.String() + require.Equal(t, `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410`, s) + }) + + t.Run("with inconsistent columns", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + InconsistentColumns: []InconsistentColumn{ + {Column: "col1", Local: "val_a", Replicated: "val_b"}, + {Column: "col2", Local: 100, Replicated: 200}, + }, + } + s := item.String() + require.Equal(t, + `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410, `+ + "inconsistent columns: [column: col1, local: val_a, replicated: val_b; column: col2, local: 100, replicated: 200]", + s) + }) + + t.Run("with missing column in replicated", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + InconsistentColumns: []InconsistentColumn{ + {Column: "col1", Local: "val_a", Replicated: nil}, + }, + } + s := item.String() + require.Equal(t, + `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410, `+ + "inconsistent columns: [column: col1, local: val_a, replicated: ]", + s) + }) +} + +func TestDataRedundantItem_String(t *testing.T) { + t.Parallel() + item := &DataRedundantItem{ + PK: map[string]any{"id": "x"}, + PKStr: `[id: x]`, + OriginTS: 10, + ReplicatedCommitTS: 20, + } + s := item.String() + require.Equal(t, `pk: [id: x], origin ts: 10, replicated commit ts: 20`, s) +} + +func TestLWWViolationItem_String(t *testing.T) { + t.Parallel() + item := &LWWViolationItem{ + PK: map[string]any{"id": "y"}, + PKStr: `[id: y]`, + ExistingOriginTS: 1, + ExistingCommitTS: 2, + OriginTS: 3, + CommitTS: 4, + } + s := item.String() + require.Equal(t, `pk: [id: y], existing origin ts: 1, existing commit ts: 2, origin ts: 3, commit ts: 4`, s) +} + +const testSchemaKey = "test_table" + +func TestClusterReport(t *testing.T) { + t.Parallel() + + t.Run("new cluster report is empty and does not need flush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + require.Equal(t, "c1", cr.ClusterID) + require.Empty(t, cr.TableFailureItems) + require.False(t, cr.needFlush) + }) + + t.Run("add data loss item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("peer-cluster-1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 200) + require.Len(t, cr.TableFailureItems, 1) + require.Contains(t, cr.TableFailureItems, testSchemaKey) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, "peer-cluster-1", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, map[string]any{"id": "1"}, tableItems.DataLossItems[0].PK) + require.Equal(t, uint64(200), tableItems.DataLossItems[0].LocalCommitTS) + }) + + t.Run("add data inconsistent item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cols := []InconsistentColumn{ + {Column: "val", Local: "a", Replicated: "b"}, + } + cr.AddDataInconsistentItem("peer-cluster-2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 400, 410, cols) + require.Len(t, cr.TableFailureItems, 1) + require.Contains(t, cr.TableFailureItems, testSchemaKey) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataInconsistentItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, "peer-cluster-2", tableItems.DataInconsistentItems[0].PeerClusterID) + require.Equal(t, map[string]any{"id": "2"}, tableItems.DataInconsistentItems[0].PK) + require.Equal(t, uint64(400), tableItems.DataInconsistentItems[0].LocalCommitTS) + require.Equal(t, uint64(410), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) + require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) + require.Equal(t, "val", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Column) + require.Equal(t, "a", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Local) + require.Equal(t, "b", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Replicated) + }) + + t.Run("add data redundant item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "2"}, `id: 2`, 300, 400) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataRedundantItems, 1) + require.True(t, cr.needFlush) + }) + + t.Run("add lww violation item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "3"}, `id: 3`, 1, 2, 3, 4) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.LWWViolationItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, uint64(1), tableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(2), tableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(3), tableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(4), tableItems.LWWViolationItems[0].CommitTS) + }) + + t.Run("add multiple items", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `id: 1`, 2) + cr.AddDataInconsistentItem("d2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 4, 5, nil) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "3"}, `[id: 3]`, 5, 6) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "4"}, `[id: 4]`, 7, 8, 9, 10) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Len(t, tableItems.DataInconsistentItems, 1) + require.Len(t, tableItems.DataRedundantItems, 1) + require.Len(t, tableItems.LWWViolationItems, 1) + }) +} + +func TestReport(t *testing.T) { + t.Parallel() + + t.Run("new report does not need flush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + require.Equal(t, uint64(1), r.Round) + require.Empty(t, r.ClusterReports) + require.False(t, r.NeedFlush()) + }) + + t.Run("add empty cluster report does not set needFlush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", types.TimeWindow{}) + r.AddClusterReport("c1", cr) + require.Len(t, r.ClusterReports, 1) + require.False(t, r.NeedFlush()) + }) + + t.Run("add non-empty cluster report sets needFlush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2) + r.AddClusterReport("c1", cr) + require.True(t, r.NeedFlush()) + }) + + t.Run("needFlush propagates from any cluster report", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr1 := NewClusterReport("c1", types.TimeWindow{}) + cr2 := NewClusterReport("c2", types.TimeWindow{}) + cr2.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 1, 2) + r.AddClusterReport("c1", cr1) + r.AddClusterReport("c2", cr2) + require.True(t, r.NeedFlush()) + }) +} + +func TestReport_MarshalReport(t *testing.T) { + t.Parallel() + + tw := types.TimeWindow{LeftBoundary: 0, RightBoundary: 0} + twStr := tw.String() + + t.Run("empty report", func(t *testing.T) { + t.Parallel() + r := NewReport(5) + s := r.MarshalReport() + require.Equal(t, "round: 5\n\n", s) + }) + + t.Run("report with data loss items", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", tw) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 200) + r.AddClusterReport("c1", cr) + s := r.MarshalReport() + require.Equal(t, "round: 1\n\n"+ + "[cluster: c1]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 200]`+"\n\n", + s) + }) + + t.Run("report with data redundant items", func(t *testing.T) { + t.Parallel() + r := NewReport(2) + cr := NewClusterReport("c2", tw) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "r"}, `[id: r]`, 10, 20) + r.AddClusterReport("c2", cr) + s := r.MarshalReport() + require.Equal(t, "round: 2\n\n"+ + "[cluster: c2]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data redundant items: 1]\n"+ + ` - [pk: [id: r], origin ts: 10, replicated commit ts: 20]`+"\n\n", + s) + }) + + t.Run("report with lww violation items", func(t *testing.T) { + t.Parallel() + r := NewReport(3) + cr := NewClusterReport("c3", tw) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "v"}, `[id: v]`, 1, 2, 3, 4) + r.AddClusterReport("c3", cr) + s := r.MarshalReport() + require.Equal(t, "round: 3\n\n"+ + "[cluster: c3]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [lww violation items: 1]\n"+ + ` - [pk: [id: v], existing origin ts: 1, existing commit ts: 2, origin ts: 3, commit ts: 4]`+"\n\n", + s) + }) + + t.Run("skips cluster reports that do not need flush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + crEmpty := NewClusterReport("empty-cluster", tw) + crFull := NewClusterReport("full-cluster", tw) + crFull.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2) + r.AddClusterReport("empty-cluster", crEmpty) + r.AddClusterReport("full-cluster", crFull) + s := r.MarshalReport() + require.Equal(t, "round: 1\n\n"+ + "[cluster: full-cluster]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 2]`+"\n\n", + s) + }) + + t.Run("report with mixed items", func(t *testing.T) { + t.Parallel() + r := NewReport(10) + cr := NewClusterReport("c1", tw) + cr.AddDataLossItem("d0", testSchemaKey, map[string]any{"id": "0"}, `[id: 0]`, 1) + cr.AddDataInconsistentItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2, 3, []InconsistentColumn{ + {Column: "val", Local: "x", Replicated: "y"}, + }) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 3, 4) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "3"}, `[id: 3]`, 5, 6, 7, 8) + r.AddClusterReport("c1", cr) + s := r.MarshalReport() + require.Equal(t, "round: 10\n\n"+ + "[cluster: c1]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d0, pk: [id: 0], local commit ts: 1]`+"\n"+ + " - [data inconsistent items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 2, replicated commit ts: 3, inconsistent columns: [column: val, local: x, replicated: y]]`+"\n"+ + " - [data redundant items: 1]\n"+ + ` - [pk: [id: 2], origin ts: 3, replicated commit ts: 4]`+"\n"+ + " - [lww violation items: 1]\n"+ + ` - [pk: [id: 3], existing origin ts: 5, existing commit ts: 6, origin ts: 7, commit ts: 8]`+"\n\n", + s) + }) +} + +func TestNewSchemaTableVersionKeyFromVersionKeyMap(t *testing.T) { + t.Parallel() + + t.Run("empty map", func(t *testing.T) { + t.Parallel() + result := NewSchemaTableVersionKeyFromVersionKeyMap(nil) + require.Empty(t, result) + }) + + t.Run("single entry", func(t *testing.T) { + t.Parallel() + m := map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 1, VersionPath: "path1"}, + } + result := NewSchemaTableVersionKeyFromVersionKeyMap(m) + require.Len(t, result, 1) + require.Equal(t, "db", result[0].Schema) + require.Equal(t, "tbl", result[0].Table) + require.Equal(t, uint64(1), result[0].Version) + require.Equal(t, "path1", result[0].VersionPath) + }) + + t.Run("multiple entries", func(t *testing.T) { + t.Parallel() + m := map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db1", Table: "t1"}: {Version: 1}, + {Schema: "db2", Table: "t2"}: {Version: 2}, + } + result := NewSchemaTableVersionKeyFromVersionKeyMap(m) + require.Len(t, result, 2) + }) +} + +func TestCheckpoint_NewTimeWindowData(t *testing.T) { + t.Parallel() + + t.Run("first call populates slot 2", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + }) + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[2].Round) + }) + + t.Run("second call shifts slots", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 10, RightBoundary: 20}}, + }) + require.Nil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[2].Round) + }) + + t.Run("third call fills all slots", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 3; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + }) + } + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[2].Round) + }) + + t.Run("fourth call evicts oldest", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 4; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + }) + } + require.Equal(t, uint64(1), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(3), cp.CheckpointItems[2].Round) + }) + + t.Run("stores max version from time window data", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 5, VersionPath: "vp", DataPath: "dp"}, + }, + }, + }) + info := cp.CheckpointItems[2].ClusterInfo["c1"] + require.Len(t, info.MaxVersion, 1) + require.Equal(t, uint64(5), info.MaxVersion[0].Version) + require.Equal(t, "vp", info.MaxVersion[0].VersionPath) + require.Equal(t, "dp", info.MaxVersion[0].DataPath) + }) +} + +func TestCheckpoint_ToScanRange(t *testing.T) { + t.Parallel() + + stk := types.SchemaTableKey{Schema: "db", Table: "tbl"} + + t.Run("empty checkpoint returns empty", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Empty(t, result) + }) + + t.Run("only item[2] set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2", DataPath: "dp2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // With only item[2], Start and End are both from item[2] + require.Equal(t, "vp2", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp2", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("items[1] and items[2] set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1", DataPath: "dp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2", DataPath: "dp2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // End comes from item[2], Start overridden by item[1] + require.Equal(t, "vp1", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp1", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("all three items set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 3; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: { + Version: i + 1, + VersionPath: fmt.Sprintf("vp%d", i), + DataPath: fmt.Sprintf("dp%d", i), + }, + }, + }, + }) + } + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // End from item[2], Start overridden by item[0] (oldest) + require.Equal(t, "vp0", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp0", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("missing key in item[1] returns error", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "other", Table: "other"}: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2"}, + }, + }, + }) + _, err := cp.ToScanRange("c1") + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("missing key in item[0] returns error", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "other", Table: "other"}: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2"}, + }, + }, + }) + cp.NewTimeWindowData(2, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 3, VersionPath: "vp3"}, + }, + }, + }) + _, err := cp.ToScanRange("c1") + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("unknown cluster returns empty", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + result, err := cp.ToScanRange("unknown-cluster") + require.NoError(t, err) + require.Empty(t, result) + }) + + t.Run("multiple tables", func(t *testing.T) { + t.Parallel() + stk2 := types.SchemaTableKey{Schema: "db2", Table: "tbl2"} + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1-t1", DataPath: "dp1-t1"}, + stk2: {Version: 1, VersionPath: "vp1-t2", DataPath: "dp1-t2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, stk) + require.Contains(t, result, stk2) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/task.go b/cmd/multi-cluster-consistency-checker/task.go new file mode 100644 index 0000000000..fa93ffcd55 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/task.go @@ -0,0 +1,333 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/advancer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/ticdc/pkg/common" + cdcconfig "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/etcd" + "github.com/pingcap/ticdc/pkg/security" + "github.com/pingcap/ticdc/pkg/util" + pd "github.com/tikv/pd/client" + pdopt "github.com/tikv/pd/client/opt" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +func runTask(ctx context.Context, cfg *config.Config, dryRun bool) error { + checkpointWatchers, s3Watchers, pdClients, etcdClients, err := initClients(ctx, cfg) + if err != nil { + // Client initialisation is typically a transient (network) failure. + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + // Ensure cleanup happens even if there's an error + defer cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + + if dryRun { + log.Info("Dry-run mode: config validation and connectivity check passed, exiting") + return nil + } + + rec, err := recorder.NewRecorder(cfg.GlobalConfig.DataDir, cfg.Clusters, cfg.GlobalConfig.MaxReportFiles) + if err != nil { + if errors.Is(err, recorder.ErrCheckpointCorruption) { + return &ExitError{Code: ExitCodeCheckpointCorruption, Err: err} + } + // Other recorder init errors (e.g. mkdir, readdir) are transient. + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + timeWindowAdvancer, checkpointDataMap, err := advancer.NewTimeWindowAdvancer(ctx, checkpointWatchers, s3Watchers, pdClients, rec.GetCheckpoint()) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + dataChecker, err := checker.NewDataChecker(ctx, cfg.Clusters, checkpointDataMap, rec.GetCheckpoint()) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + + log.Info("Starting consistency checker task") + for { + // Check if context is cancelled before starting a new iteration + select { + case <-ctx.Done(): + log.Info("Context cancelled, shutting down gracefully") + return ctx.Err() + default: + } + + newTimeWindowData, err := timeWindowAdvancer.AdvanceTimeWindow(ctx) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + + report, err := dataChecker.CheckInNextTimeWindow(newTimeWindowData) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + + if err := rec.RecordTimeWindow(newTimeWindowData, report); err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + } +} + +func initClients(ctx context.Context, cfg *config.Config) ( + map[string]map[string]watcher.Watcher, + map[string]*watcher.S3Watcher, + map[string]pd.Client, + map[string]*etcd.CDCEtcdClientImpl, + error, +) { + checkpointWatchers := make(map[string]map[string]watcher.Watcher) + s3Watchers := make(map[string]*watcher.S3Watcher) + pdClients := make(map[string]pd.Client) + etcdClients := make(map[string]*etcd.CDCEtcdClientImpl) + + for clusterID, clusterConfig := range cfg.Clusters { + pdClient, etcdClient, err := newPDClient(ctx, clusterConfig.PDAddrs, &clusterConfig.SecurityConfig) + if err != nil { + // Clean up already created clients before returning error + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + etcdClients[clusterID] = etcdClient + + clusterCheckpointWatchers := make(map[string]watcher.Watcher) + for peerClusterID, peerClusterChangefeedConfig := range clusterConfig.PeerClusterChangefeedConfig { + checkpointWatcher := watcher.NewCheckpointWatcher(ctx, clusterID, peerClusterID, peerClusterChangefeedConfig.ChangefeedID, etcdClient) + clusterCheckpointWatchers[peerClusterID] = checkpointWatcher + } + checkpointWatchers[clusterID] = clusterCheckpointWatchers + + // Validate s3 changefeed sink config from etcd + if err := validateS3ChangefeedSinkConfig(ctx, etcdClient, clusterID, clusterConfig.S3ChangefeedID, clusterConfig.S3SinkURI); err != nil { + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + + s3Storage, err := util.GetExternalStorageWithDefaultTimeout(ctx, clusterConfig.S3SinkURI) + if err != nil { + // Clean up already created clients before returning error + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + s3Watcher := watcher.NewS3Watcher( + watcher.NewCheckpointWatcher(ctx, clusterID, "s3", clusterConfig.S3ChangefeedID, etcdClient), + s3Storage, + cfg.GlobalConfig.Tables, + ) + s3Watchers[clusterID] = s3Watcher + pdClients[clusterID] = pdClient + } + + return checkpointWatchers, s3Watchers, pdClients, etcdClients, nil +} + +// validateS3ChangefeedSinkConfig fetches the changefeed info from etcd and validates that: +// 1. The changefeed SinkURI bucket/prefix matches the configured s3SinkURI +// 2. The protocol must be canal-json +// 3. The date separator must be "day" +// 4. The file index width must be DefaultFileIndexWidth +func validateS3ChangefeedSinkConfig(ctx context.Context, etcdClient *etcd.CDCEtcdClientImpl, clusterID string, s3ChangefeedID string, s3SinkURI string) error { + displayName := common.NewChangeFeedDisplayName(s3ChangefeedID, "default") + cfInfo, err := etcdClient.GetChangeFeedInfo(ctx, displayName) + if err != nil { + return errors.Annotate(err, fmt.Sprintf("failed to get changefeed info for s3 changefeed %s in cluster %s", s3ChangefeedID, clusterID)) + } + + // 1. Validate that the changefeed's SinkURI bucket/prefix matches the configured s3SinkURI. + // This prevents the checker from reading data that was written by a different changefeed. + if err := validateS3BucketPrefix(cfInfo.SinkURI, s3SinkURI, clusterID, s3ChangefeedID); err != nil { + return err + } + + if cfInfo.Config == nil || cfInfo.Config.Sink == nil { + return fmt.Errorf("cluster %s: s3 changefeed %s has no sink configuration", clusterID, s3ChangefeedID) + } + + sinkConfig := cfInfo.Config.Sink + + // 2. Validate protocol must be canal-json + protocolStr := strings.ToLower(util.GetOrZero(sinkConfig.Protocol)) + if protocolStr == "" { + return fmt.Errorf("cluster %s: s3 changefeed %s has no protocol configured in sink config", clusterID, s3ChangefeedID) + } + protocol, err := cdcconfig.ParseSinkProtocolFromString(protocolStr) + if err != nil { + return errors.Annotate(err, fmt.Sprintf("cluster %s: s3 changefeed %s has invalid protocol", clusterID, s3ChangefeedID)) + } + if protocol != cdcconfig.ProtocolCanalJSON { + return fmt.Errorf("cluster %s: s3 changefeed %s protocol is %q, but only %q is supported", + clusterID, s3ChangefeedID, protocolStr, cdcconfig.ProtocolCanalJSON.String()) + } + + // 3. Validate date separator must be "day" + dateSeparatorStr := util.GetOrZero(sinkConfig.DateSeparator) + if dateSeparatorStr == "" { + dateSeparatorStr = cdcconfig.DateSeparatorNone.String() + } + var dateSep cdcconfig.DateSeparator + if err := dateSep.FromString(dateSeparatorStr); err != nil { + return errors.Annotate(err, fmt.Sprintf("cluster %s: s3 changefeed %s has invalid date-separator %q", clusterID, s3ChangefeedID, dateSeparatorStr)) + } + if dateSep != cdcconfig.DateSeparatorDay { + return fmt.Errorf("cluster %s: s3 changefeed %s date-separator is %q, but only %q is supported", + clusterID, s3ChangefeedID, dateSep.String(), cdcconfig.DateSeparatorDay.String()) + } + + // 4. Validate file index width must be DefaultFileIndexWidth + fileIndexWidth := util.GetOrZero(sinkConfig.FileIndexWidth) + if fileIndexWidth != cdcconfig.DefaultFileIndexWidth { + return fmt.Errorf("cluster %s: s3 changefeed %s file-index-width is %d, but only %d is supported", + clusterID, s3ChangefeedID, fileIndexWidth, cdcconfig.DefaultFileIndexWidth) + } + + log.Info("Validated s3 changefeed sink config from etcd", + zap.String("clusterID", clusterID), + zap.String("s3ChangefeedID", s3ChangefeedID), + zap.String("protocol", protocolStr), + zap.String("dateSeparator", dateSep.String()), + zap.Int("fileIndexWidth", fileIndexWidth), + ) + + return nil +} + +// validateS3BucketPrefix checks that the changefeed's SinkURI and the +// configured s3-sink-uri point to the same S3 bucket and prefix. +// This is a critical sanity check — a mismatch means the checker would +// read data from a different location than where the changefeed writes. +func validateS3BucketPrefix(changefeedSinkURI, configS3SinkURI, clusterID, s3ChangefeedID string) error { + cfURL, err := url.Parse(changefeedSinkURI) + if err != nil { + return fmt.Errorf("cluster %s: s3 changefeed %s has invalid sink URI %q: %v", + clusterID, s3ChangefeedID, changefeedSinkURI, err) + } + cfgURL, err := url.Parse(configS3SinkURI) + if err != nil { + return fmt.Errorf("cluster %s: configured s3-sink-uri %q is invalid: %v", + clusterID, configS3SinkURI, err) + } + + // Compare scheme (s3, gcs, …), bucket (Host) and prefix (Path). + // Path is normalized by trimming trailing slashes so that + // "s3://bucket/prefix" and "s3://bucket/prefix/" are considered equal. + cfScheme := strings.ToLower(cfURL.Scheme) + cfgScheme := strings.ToLower(cfgURL.Scheme) + cfBucket := cfURL.Host + cfgBucket := cfgURL.Host + cfPrefix := strings.TrimRight(cfURL.Path, "/") + cfgPrefix := strings.TrimRight(cfgURL.Path, "/") + + if cfScheme != cfgScheme || cfBucket != cfgBucket || cfPrefix != cfgPrefix { + return fmt.Errorf("cluster %s: s3 changefeed %s sink URI bucket/prefix mismatch: "+ + "changefeed has %s://%s%s but config has %s://%s%s", + clusterID, s3ChangefeedID, + cfScheme, cfBucket, cfURL.Path, + cfgScheme, cfgBucket, cfgURL.Path) + } + return nil +} + +func newPDClient(ctx context.Context, pdAddrs []string, securityConfig *security.Credential) (pd.Client, *etcd.CDCEtcdClientImpl, error) { + pdClient, err := pd.NewClientWithContext( + ctx, "consistency-checker", pdAddrs, securityConfig.PDSecurityOption(), + pdopt.WithCustomTimeoutOption(10*time.Second), + ) + if err != nil { + return nil, nil, errors.Trace(err) + } + + etcdCli, err := etcd.CreateRawEtcdClient(securityConfig, grpc.EmptyDialOption{}, pdAddrs...) + if err != nil { + // Clean up PD client if etcd client creation fails + if pdClient != nil { + pdClient.Close() + } + return nil, nil, errors.Trace(err) + } + + cdcEtcdClient, err := etcd.NewCDCEtcdClient(ctx, etcdCli, "default") + if err != nil { + // Clean up resources if CDC etcd client creation fails + etcdCli.Close() + pdClient.Close() + return nil, nil, errors.Trace(err) + } + + return pdClient, cdcEtcdClient, nil +} + +// cleanupClients closes all PD and etcd clients gracefully +func cleanupClients( + pdClients map[string]pd.Client, + etcdClients map[string]*etcd.CDCEtcdClientImpl, + checkpointWatchers map[string]map[string]watcher.Watcher, + s3Watchers map[string]*watcher.S3Watcher, +) { + log.Info("Cleaning up clients", + zap.Int("pdClients", len(pdClients)), + zap.Int("etcdClients", len(etcdClients)), + zap.Int("checkpointWatchers", len(checkpointWatchers)), + zap.Int("s3Watchers", len(s3Watchers)), + ) + + // Close PD clients + for clusterID, pdClient := range pdClients { + if pdClient != nil { + pdClient.Close() + log.Debug("PD client closed", zap.String("clusterID", clusterID)) + } + } + + // Close etcd clients + for clusterID, etcdClient := range etcdClients { + if etcdClient != nil { + if err := etcdClient.Close(); err != nil { + log.Warn("Failed to close etcd client", + zap.String("clusterID", clusterID), + zap.Error(err)) + } else { + log.Debug("Etcd client closed", zap.String("clusterID", clusterID)) + } + } + } + + // Close checkpoint watchers + for _, clusterWatchers := range checkpointWatchers { + for _, watcher := range clusterWatchers { + watcher.Close() + } + } + + // Close s3 watchers + for _, s3Watcher := range s3Watchers { + s3Watcher.Close() + } + + log.Info("Client cleanup completed") +} diff --git a/cmd/multi-cluster-consistency-checker/types/types.go b/cmd/multi-cluster-consistency-checker/types/types.go new file mode 100644 index 0000000000..cdb7e1760a --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/types/types.go @@ -0,0 +1,98 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "fmt" + "sort" + "strings" + + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" +) + +// PkType is a distinct type for encoded primary key strings, making it clear +// at the type level that the value has been serialized / encoded. +type PkType string + +type CdcVersion struct { + CommitTs uint64 + OriginTs uint64 +} + +func (v *CdcVersion) GetCompareTs() uint64 { + if v.OriginTs > 0 { + return v.OriginTs + } + return v.CommitTs +} + +type SchemaTableKey struct { + Schema string + Table string +} + +type VersionKey struct { + Version uint64 + // Version Path is a hint for the next version path to scan + VersionPath string + // Data Path is a hint for the next data path to scan + DataPath string +} + +// TimeWindow is the time window of the cluster, including the left boundary, right boundary and checkpoint ts +// Assert 1: LeftBoundary < CheckpointTs < RightBoundary +// Assert 2: The other cluster's checkpoint timestamp of next time window should be larger than the PDTimestampAfterTimeWindow saved in this cluster's time window +// Assert 3: CheckpointTs of this cluster should be larger than other clusters' RightBoundary of previous time window +// Assert 4: RightBoundary of this cluster should be larger than other clusters' CheckpointTs of this time window +type TimeWindow struct { + LeftBoundary uint64 `json:"left_boundary"` + RightBoundary uint64 `json:"right_boundary"` + // CheckpointTs is the checkpoint timestamp for each local-to-replicated changefeed, + // mapping from replicated cluster ID to the checkpoint timestamp + CheckpointTs map[string]uint64 `json:"checkpoint_ts"` + // PDTimestampAfterTimeWindow is the max PD timestamp after the time window for each replicated cluster, + // mapping from local cluster ID to the max PD timestamp + PDTimestampAfterTimeWindow map[string]uint64 `json:"pd_timestamp_after_time_window"` + // NextMinLeftBoundary is the minimum left boundary of the next time window for the cluster + NextMinLeftBoundary uint64 `json:"next_min_left_boundary"` +} + +func (t *TimeWindow) String() string { + var builder strings.Builder + fmt.Fprintf(&builder, "time window boundary: (%d, %d]\n", t.LeftBoundary, t.RightBoundary) + + // Sort cluster IDs for deterministic output + clusterIDs := make([]string, 0, len(t.CheckpointTs)) + for id := range t.CheckpointTs { + clusterIDs = append(clusterIDs, id) + } + sort.Strings(clusterIDs) + + for _, replicatedClusterID := range clusterIDs { + fmt.Fprintf(&builder, "checkpoint ts [replicated cluster: %s]: %d\n", replicatedClusterID, t.CheckpointTs[replicatedClusterID]) + } + return builder.String() +} + +type TimeWindowData struct { + TimeWindow + Data map[cloudstorage.DmlPathKey]IncrementalData + MaxVersion map[SchemaTableKey]VersionKey +} + +type IncrementalData struct { + DataContentSlices map[cloudstorage.FileIndexKey][][]byte + ColumnFieldTypes map[string]*ptypes.FieldType +} diff --git a/cmd/multi-cluster-consistency-checker/types/types_test.go b/cmd/multi-cluster-consistency-checker/types/types_test.go new file mode 100644 index 0000000000..7fe90a9739 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/types/types_test.go @@ -0,0 +1,69 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types_test + +import ( + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/stretchr/testify/require" +) + +func TestCdcVersion_GetCompareTs(t *testing.T) { + tests := []struct { + name string + version types.CdcVersion + expected uint64 + }{ + { + name: "OriginTs is set", + version: types.CdcVersion{ + CommitTs: 100, + OriginTs: 200, + }, + expected: 200, + }, + { + name: "OriginTs is smaller than CommitTs", + version: types.CdcVersion{ + CommitTs: 200, + OriginTs: 100, + }, + expected: 100, + }, + { + name: "OriginTs is zero", + version: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + expected: 100, + }, + { + name: "Both are zero", + version: types.CdcVersion{ + CommitTs: 0, + OriginTs: 0, + }, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.version.GetCompareTs() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go new file mode 100644 index 0000000000..1fb96771cd --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go @@ -0,0 +1,320 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "sync" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/etcd" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// errChangefeedKeyDeleted is a sentinel error indicating that the changefeed +// status key has been deleted from etcd. This is a non-recoverable error +// that should not be retried. +var errChangefeedKeyDeleted = errors.New("changefeed status key is deleted") + +const ( + // retryBackoffBase is the initial backoff duration for retries + retryBackoffBase = 500 * time.Millisecond + // retryBackoffMax is the maximum backoff duration for retries + retryBackoffMax = 30 * time.Second + // retryBackoffMultiplier is the multiplier for exponential backoff + retryBackoffMultiplier = 2.0 +) + +type Watcher interface { + AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) + Close() +} + +type waitCheckpointTask struct { + respCh chan uint64 + minCheckpointTs uint64 +} + +type CheckpointWatcher struct { + localClusterID string + replicatedClusterID string + changefeedID common.ChangeFeedID + etcdClient etcd.CDCEtcdClient + + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + latestCheckpoint uint64 + pendingTasks []*waitCheckpointTask + watchErr error + closed bool +} + +// failPendingTasksLocked wakes all pending tasks after a terminal watcher failure. +// Must be called with mu locked. +func (cw *CheckpointWatcher) failPendingTasksLocked() { + for _, task := range cw.pendingTasks { + close(task.respCh) + } + cw.pendingTasks = nil +} + +func NewCheckpointWatcher( + ctx context.Context, + localClusterID, replicatedClusterID, changefeedID string, + etcdClient etcd.CDCEtcdClient, +) *CheckpointWatcher { + cctx, cancel := context.WithCancel(ctx) + watcher := &CheckpointWatcher{ + localClusterID: localClusterID, + replicatedClusterID: replicatedClusterID, + changefeedID: common.NewChangeFeedIDWithName(changefeedID, "default"), + etcdClient: etcdClient, + + ctx: cctx, + cancel: cancel, + } + go watcher.run() + return watcher +} + +// AdvanceCheckpointTs waits for the checkpoint to exceed minCheckpointTs +func (cw *CheckpointWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + cw.mu.Lock() + + // Check if watcher has encountered an error + if cw.watchErr != nil { + err := cw.watchErr + cw.mu.Unlock() + return 0, err + } + + // Check if watcher is closed + if cw.closed { + cw.mu.Unlock() + return 0, errors.Errorf("checkpoint watcher is closed") + } + + // Check if current checkpoint already exceeds minCheckpointTs + if cw.latestCheckpoint > minCheckpointTs { + checkpoint := cw.latestCheckpoint + cw.mu.Unlock() + return checkpoint, nil + } + + // Create a task and wait for the background goroutine to notify + task := &waitCheckpointTask{ + respCh: make(chan uint64, 1), + minCheckpointTs: minCheckpointTs, + } + cw.pendingTasks = append(cw.pendingTasks, task) + cw.mu.Unlock() + + // Wait for response or context cancellation + select { + case <-ctx.Done(): + // Remove the task from pending list + cw.mu.Lock() + for i, t := range cw.pendingTasks { + if t == task { + cw.pendingTasks = append(cw.pendingTasks[:i], cw.pendingTasks[i+1:]...) + break + } + } + cw.mu.Unlock() + return 0, errors.Annotate(ctx.Err(), "context canceled while waiting for checkpoint") + case <-cw.ctx.Done(): + return 0, errors.Annotate(cw.ctx.Err(), "watcher context canceled") + case checkpoint, ok := <-task.respCh: + if !ok { + cw.mu.Lock() + err := cw.watchErr + closed := cw.closed + cw.mu.Unlock() + if err != nil { + return 0, err + } + if closed { + return 0, errors.Errorf("checkpoint watcher is closed") + } + return 0, errors.Errorf("checkpoint watcher failed") + } + return checkpoint, nil + } +} + +// Close stops the watcher +func (cw *CheckpointWatcher) Close() { + cw.cancel() + cw.mu.Lock() + cw.closed = true + // Notify all pending tasks that watcher is closing + for _, task := range cw.pendingTasks { + close(task.respCh) + } + cw.pendingTasks = nil + cw.mu.Unlock() +} + +func (cw *CheckpointWatcher) run() { + backoff := retryBackoffBase + for { + select { + case <-cw.ctx.Done(): + cw.mu.Lock() + cw.watchErr = errors.Annotate(cw.ctx.Err(), "context canceled") + cw.mu.Unlock() + return + default: + } + + err := cw.watchOnce() + if err == nil { + // Normal exit (context canceled) + return + } + + // Check if this is a non-recoverable error + if errors.Is(err, errChangefeedKeyDeleted) { + cw.mu.Lock() + cw.watchErr = err + cw.failPendingTasksLocked() + cw.mu.Unlock() + return + } + + // Log and retry with backoff + log.Warn("checkpoint watcher encountered error, will retry", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Duration("backoff", backoff), + zap.Error(err)) + + select { + case <-cw.ctx.Done(): + cw.mu.Lock() + cw.watchErr = errors.Annotate(cw.ctx.Err(), "context canceled") + cw.mu.Unlock() + return + case <-time.After(backoff): + } + + // Increase backoff for next retry (exponential backoff with cap) + backoff = time.Duration(float64(backoff) * retryBackoffMultiplier) + backoff = min(backoff, retryBackoffMax) + } +} + +// watchOnce performs one watch cycle. Returns nil if context is canceled, +// returns error if watch fails and should be retried. +func (cw *CheckpointWatcher) watchOnce() error { + // First, get the current checkpoint status from etcd + status, modRev, err := cw.etcdClient.GetChangeFeedStatus(cw.ctx, cw.changefeedID) + if err != nil { + // Check if context is canceled + if cw.ctx.Err() != nil { + return nil + } + return errors.Annotate(err, "failed to get changefeed status") + } + + // Update latest checkpoint + cw.mu.Lock() + cw.latestCheckpoint = status.CheckpointTs + cw.notifyPendingTasksLocked() + cw.mu.Unlock() + + statusKey := etcd.GetEtcdKeyJob(cw.etcdClient.GetClusterID(), cw.changefeedID.DisplayName) + + log.Debug("Starting to watch checkpoint", + zap.String("changefeedID", cw.changefeedID.String()), + zap.String("statusKey", statusKey), + zap.String("localClusterID", cw.localClusterID), + zap.String("replicatedClusterID", cw.replicatedClusterID), + zap.Uint64("checkpoint", status.CheckpointTs), + zap.Int64("startRev", modRev+1)) + + watchCh := cw.etcdClient.GetEtcdClient().Watch( + cw.ctx, + statusKey, + "checkpoint-watcher", + clientv3.WithRev(modRev+1), + ) + + for { + select { + case <-cw.ctx.Done(): + return nil + case watchResp, ok := <-watchCh: + if !ok { + return errors.Errorf("[changefeedID: %s] watch channel closed", cw.changefeedID.String()) + } + + if err := watchResp.Err(); err != nil { + return errors.Annotatef(err, "[changefeedID: %s] watch error", cw.changefeedID.String()) + } + + for _, event := range watchResp.Events { + if event.Type == clientv3.EventTypeDelete { + return errors.Annotatef(errChangefeedKeyDeleted, "[changefeedID: %s]", cw.changefeedID.String()) + } + + // Parse the updated status + newStatus := &config.ChangeFeedStatus{} + if err := newStatus.Unmarshal(event.Kv.Value); err != nil { + log.Warn("failed to unmarshal changefeed status, skipping", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Error(err)) + continue + } + + checkpointTs := newStatus.CheckpointTs + log.Debug("Checkpoint updated", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Uint64("checkpoint", checkpointTs)) + + // Update latest checkpoint and notify pending tasks + cw.mu.Lock() + if checkpointTs > cw.latestCheckpoint { + cw.latestCheckpoint = checkpointTs + cw.notifyPendingTasksLocked() + } + cw.mu.Unlock() + } + } + } +} + +// notifyPendingTasksLocked notifies pending tasks whose minCheckpointTs has been exceeded +// Must be called with mu locked +func (cw *CheckpointWatcher) notifyPendingTasksLocked() { + remaining := cw.pendingTasks[:0] + for _, task := range cw.pendingTasks { + if cw.latestCheckpoint > task.minCheckpointTs { + // Non-blocking send since channel has buffer of 1 + select { + case task.respCh <- cw.latestCheckpoint: + default: + } + } else { + remaining = append(remaining, task) + } + } + cw.pendingTasks = remaining +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go new file mode 100644 index 0000000000..4862ce2990 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go @@ -0,0 +1,601 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pingcap/errors" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/etcd" + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" +) + +func TestCheckpointWatcher_AdvanceCheckpointTs_AlreadyExceeds(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + // Create a watch channel that won't send anything during this test + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Request checkpoint that's already exceeded + minCheckpointTs := uint64(500) + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), minCheckpointTs) + require.NoError(t, err) + require.Equal(t, initialCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_AdvanceCheckpointTs_WaitForUpdate(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + updatedCheckpoint := uint64(2000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start waiting for checkpoint in a goroutine + var checkpoint uint64 + var advanceErr error + done := make(chan struct{}) + go func() { + checkpoint, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(1500)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Simulate checkpoint update via watch channel + newStatus := &config.ChangeFeedStatus{CheckpointTs: updatedCheckpoint} + statusStr, err := newStatus.Marshal() + require.NoError(t, err) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypePut, + Kv: &mvccpb.KeyValue{ + Value: []byte(statusStr), + }, + }, + }, + } + + select { + case <-done: + require.NoError(t, advanceErr) + require.Equal(t, updatedCheckpoint, checkpoint) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for checkpoint advance") + } +} + +func TestCheckpointWatcher_AdvanceCheckpointTs_ContextCanceled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Create a context that will be canceled + advanceCtx, advanceCancel := context.WithCancel(t.Context()) + + var advanceErr error + done := make(chan struct{}) + go func() { + _, advanceErr = watcher.AdvanceCheckpointTs(advanceCtx, uint64(2000)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Cancel the context + advanceCancel() + + select { + case <-done: + require.Error(t, advanceErr) + require.Contains(t, advanceErr.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for context cancellation") + } +} + +func TestCheckpointWatcher_Close(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start waiting for checkpoint in a goroutine + var advanceErr error + done := make(chan struct{}) + go func() { + _, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Close the watcher + watcher.Close() + + select { + case <-done: + require.Error(t, advanceErr) + // Error can be "closed" or "canceled" depending on timing + errMsg := advanceErr.Error() + require.True(t, + strings.Contains(errMsg, "closed") || strings.Contains(errMsg, "canceled"), + "expected error to contain 'closed' or 'canceled', got: %s", errMsg) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for watcher close") + } + + // Verify that subsequent calls return error + _, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(100)) + require.Error(t, err) + // Error can be "closed" or "canceled" depending on timing + errMsg := err.Error() + require.True(t, + strings.Contains(errMsg, "closed") || strings.Contains(errMsg, "canceled"), + "expected error to contain 'closed' or 'canceled', got: %s", errMsg) +} + +func TestCheckpointWatcher_MultiplePendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 10) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start multiple goroutines waiting for different checkpoints + var wg sync.WaitGroup + results := make([]struct { + checkpoint uint64 + err error + }, 3) + + for i := range 3 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + minTs := uint64(1100 + idx*500) // 1100, 1600, 2100 + results[idx].checkpoint, results[idx].err = watcher.AdvanceCheckpointTs(context.Background(), minTs) + }(i) + } + + // Give some time for tasks to be registered + time.Sleep(50 * time.Millisecond) + + // Send checkpoint updates + checkpoints := []uint64{1500, 2000, 2500} + for _, cp := range checkpoints { + newStatus := &config.ChangeFeedStatus{CheckpointTs: cp} + statusStr, err := newStatus.Marshal() + require.NoError(t, err) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypePut, + Kv: &mvccpb.KeyValue{ + Value: []byte(statusStr), + }, + }, + }, + } + // Give some time between updates + time.Sleep(20 * time.Millisecond) + } + + // Wait for all goroutines to complete + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Verify results + for i := range 3 { + require.NoError(t, results[i].err, "task %d should not have error", i) + minTs := uint64(1100 + i*500) + require.Greater(t, results[i].checkpoint, minTs, "task %d checkpoint should exceed minTs", i) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for all tasks to complete") + } +} + +func TestCheckpointWatcher_NotifyPendingTasksLocked(t *testing.T) { + cw := &CheckpointWatcher{ + latestCheckpoint: 2000, + pendingTasks: []*waitCheckpointTask{ + {respCh: make(chan uint64, 1), minCheckpointTs: 1000}, + {respCh: make(chan uint64, 1), minCheckpointTs: 1500}, + {respCh: make(chan uint64, 1), minCheckpointTs: 2500}, + {respCh: make(chan uint64, 1), minCheckpointTs: 3000}, + }, + } + + cw.notifyPendingTasksLocked() + + // Tasks with minCheckpointTs < 2000 should be notified and removed + require.Len(t, cw.pendingTasks, 2) + require.Equal(t, uint64(2500), cw.pendingTasks[0].minCheckpointTs) + require.Equal(t, uint64(3000), cw.pendingTasks[1].minCheckpointTs) +} + +func TestCheckpointWatcher_InitialCheckpointNotifiesPendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(5000) + + // Setup mock expectations - initial checkpoint is high enough + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize and get the initial checkpoint + time.Sleep(100 * time.Millisecond) + + // Request checkpoint that's already exceeded by initial checkpoint + checkpoint, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(1000)) + require.NoError(t, err) + require.Equal(t, initialCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_WatchErrorRetry(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + retryCheckpoint := uint64(2000) + + // First call succeeds, second call (retry) also succeeds with updated checkpoint + firstCall := mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: retryCheckpoint}, + int64(101), + nil, + ).After(firstCall) + + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + // First watch channel will be closed (simulating error), second watch channel will work + watchCh1 := make(chan clientv3.WatchResponse) + watchCh2 := make(chan clientv3.WatchResponse, 1) + firstWatch := mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh2).After(firstWatch) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Close the first watch channel to trigger retry + close(watchCh1) + + // Wait for retry to happen (backoff + processing time) + time.Sleep(700 * time.Millisecond) + + // After retry, checkpoint should be updated to retryCheckpoint + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), uint64(1500)) + require.NoError(t, err) + require.Equal(t, retryCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_GetStatusRetry(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + successCheckpoint := uint64(2000) + + // First call fails, second call succeeds + firstCall := mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + nil, + int64(0), + errors.Errorf("connection refused"), + ) + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: successCheckpoint}, + int64(100), + nil, + ).After(firstCall) + + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for retry to happen (backoff + processing time) + time.Sleep(700 * time.Millisecond) + + // After retry, checkpoint should be available + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), uint64(1000)) + require.NoError(t, err) + require.Equal(t, successCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_KeyDeleted(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Send delete event + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypeDelete, + }, + }, + } + + // Give time for the error to be processed + time.Sleep(50 * time.Millisecond) + + // Now trying to advance should return an error + _, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + require.Error(t, err) + require.Contains(t, err.Error(), "deleted") +} + +func TestCheckpointWatcher_KeyDeleted_UnblocksPendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + time.Sleep(50 * time.Millisecond) + + var ( + checkpoint uint64 + advanceErr error + ) + done := make(chan struct{}) + go func() { + checkpoint, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + close(done) + }() + + time.Sleep(50 * time.Millisecond) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypeDelete, + }, + }, + } + + select { + case <-done: + require.Zero(t, checkpoint) + require.Error(t, advanceErr) + require.Contains(t, advanceErr.Error(), "deleted") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for pending task to be unblocked") + } +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go new file mode 100644 index 0000000000..b6e7d92b62 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go @@ -0,0 +1,70 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/consumer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" +) + +type S3Watcher struct { + checkpointWatcher Watcher + consumer *consumer.S3Consumer +} + +func NewS3Watcher( + checkpointWatcher Watcher, + s3Storage storage.ExternalStorage, + tables map[string][]string, +) *S3Watcher { + consumer := consumer.NewS3Consumer(s3Storage, tables) + return &S3Watcher{ + checkpointWatcher: checkpointWatcher, + consumer: consumer, + } +} + +func (sw *S3Watcher) Close() { + sw.checkpointWatcher.Close() +} + +func (sw *S3Watcher) InitializeFromCheckpoint(ctx context.Context, clusterID string, checkpoint *recorder.Checkpoint) (map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + return sw.consumer.InitializeFromCheckpoint(ctx, clusterID, checkpoint) +} + +func (sw *S3Watcher) AdvanceS3CheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + checkpointTs, err := sw.checkpointWatcher.AdvanceCheckpointTs(ctx, minCheckpointTs) + if err != nil { + return 0, errors.Annotate(err, "advance s3 checkpoint timestamp failed") + } + + return checkpointTs, nil +} + +func (sw *S3Watcher) ConsumeNewFiles( + ctx context.Context, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, map[types.SchemaTableKey]types.VersionKey, error) { + // TODO: get the index updated from the s3 + newData, maxVersionMap, err := sw.consumer.ConsumeNewFiles(ctx) + if err != nil { + return nil, nil, errors.Annotate(err, "consume new files failed") + } + return newData, maxVersionMap, nil +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go new file mode 100644 index 0000000000..b3ecfa16d6 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go @@ -0,0 +1,202 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/consumer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/stretchr/testify/require" +) + +// mockWatcher is a mock implementation of the Watcher interface for testing. +type mockWatcher struct { + advanceCheckpointTsFn func(ctx context.Context, minCheckpointTs uint64) (uint64, error) + closeFn func() + closed bool +} + +func (m *mockWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + if m.advanceCheckpointTsFn != nil { + return m.advanceCheckpointTsFn(ctx, minCheckpointTs) + } + return 0, nil +} + +func (m *mockWatcher) Close() { + m.closed = true + if m.closeFn != nil { + m.closeFn() + } +} + +func TestS3Watcher_Close(t *testing.T) { + t.Parallel() + + t.Run("close delegates to checkpoint watcher", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + sw.Close() + require.True(t, mock.closed) + }) + + t.Run("close calls custom close function", func(t *testing.T) { + t.Parallel() + closeCalled := false + mock := &mockWatcher{ + closeFn: func() { + closeCalled = true + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + sw.Close() + require.True(t, closeCalled) + }) +} + +func TestS3Watcher_AdvanceS3CheckpointTs(t *testing.T) { + t.Parallel() + + t.Run("advance checkpoint ts success", func(t *testing.T) { + t.Parallel() + expectedCheckpoint := uint64(5000) + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + require.Equal(t, uint64(3000), minCheckpointTs) + return expectedCheckpoint, nil + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.NoError(t, err) + require.Equal(t, expectedCheckpoint, checkpoint) + }) + + t.Run("advance checkpoint ts error", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + return 0, errors.Errorf("connection lost") + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.Error(t, err) + require.Equal(t, uint64(0), checkpoint) + require.Contains(t, err.Error(), "advance s3 checkpoint timestamp failed") + require.Contains(t, err.Error(), "connection lost") + }) + + t.Run("advance checkpoint ts with context canceled", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + return 0, context.Canceled + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.Error(t, err) + require.Equal(t, uint64(0), checkpoint) + require.Contains(t, err.Error(), "advance s3 checkpoint timestamp failed") + }) +} + +func TestS3Watcher_InitializeFromCheckpoint(t *testing.T) { + t.Parallel() + + t.Run("nil checkpoint returns nil", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + result, err := sw.InitializeFromCheckpoint(context.Background(), "cluster1", nil) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("empty checkpoint returns nil", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint := recorder.NewCheckpoint() + result, err := sw.InitializeFromCheckpoint(context.Background(), "cluster1", checkpoint) + require.NoError(t, err) + require.Nil(t, result) + }) +} + +func TestS3Watcher_ConsumeNewFiles(t *testing.T) { + t.Parallel() + + t.Run("consume new files with empty tables", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), map[string][]string{}), + } + + newData, maxVersionMap, err := sw.ConsumeNewFiles(context.Background()) + require.NoError(t, err) + require.Empty(t, newData) + require.Empty(t, maxVersionMap) + }) + + t.Run("consume new files with nil tables", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + newData, maxVersionMap, err := sw.ConsumeNewFiles(context.Background()) + require.NoError(t, err) + require.Empty(t, newData) + require.Empty(t, maxVersionMap) + }) +} diff --git a/pkg/sink/cloudstorage/path_key.go b/pkg/sink/cloudstorage/path_key.go index d948af9873..3de0512400 100644 --- a/pkg/sink/cloudstorage/path_key.go +++ b/pkg/sink/cloudstorage/path_key.go @@ -162,3 +162,72 @@ func (d *DmlPathKey) ParseIndexFilePath(dateSeparator, path string) (string, err return matches[6], nil } + +// ParseDMLFilePath parses the dml file path and returns the max file index. +// DML file path pattern is as follows: +// {schema}/{table}/{table-version-separator}/{partition-separator}/{date-separator}/, where +// partition-separator and date-separator could be empty. +// DML file name pattern is as follows: CDC_{dispatcherID}_{num}.extension or CDC{num}.extension +func (d *DmlPathKey) ParseDMLFilePath(dateSeparator, path string) (*FileIndex, error) { + var partitionNum int64 + + str := `(\w+)\/(\w+)\/(\d+)\/(\d+)?\/*` + switch dateSeparator { + case config.DateSeparatorNone.String(): + str += `(\d{4})*` + case config.DateSeparatorYear.String(): + str += `(\d{4})\/` + case config.DateSeparatorMonth.String(): + str += `(\d{4}-\d{2})\/` + case config.DateSeparatorDay.String(): + str += `(\d{4}-\d{2}-\d{2})\/` + } + matchesLen := 8 + matchesFileIdx := 7 + // CDC[_{dispatcherID}_]{num}.extension + str += `CDC(?:_(\w+)_)?(\d+).\w+` + pathRE, err := regexp.Compile(str) + if err != nil { + return nil, err + } + + matches := pathRE.FindStringSubmatch(path) + if len(matches) != matchesLen { + return nil, fmt.Errorf("cannot match dml path pattern for %s", path) + } + + version, err := strconv.ParseUint(matches[3], 10, 64) + if err != nil { + return nil, err + } + + if len(matches[4]) > 0 { + partitionNum, err = strconv.ParseInt(matches[4], 10, 64) + if err != nil { + return nil, err + } + } + + fileIdx, err := strconv.ParseUint(strings.TrimLeft(matches[matchesFileIdx], "0"), 10, 64) + if err != nil { + return nil, err + } + + *d = DmlPathKey{ + SchemaPathKey: SchemaPathKey{ + Schema: matches[1], + Table: matches[2], + TableVersion: version, + }, + PartitionNum: partitionNum, + Date: matches[5], + } + + return &FileIndex{ + FileIndexKey: FileIndexKey{ + DispatcherID: matches[6], + EnableTableAcrossNodes: matches[6] != "", + }, + Idx: fileIdx, + }, nil +} diff --git a/pkg/sink/cloudstorage/path_key_test.go b/pkg/sink/cloudstorage/path_key_test.go index f768f7cf74..f59145df7b 100644 --- a/pkg/sink/cloudstorage/path_key_test.go +++ b/pkg/sink/cloudstorage/path_key_test.go @@ -59,7 +59,7 @@ func TestSchemaPathKey(t *testing.T) { } } -func TestDmlPathKey(t *testing.T) { +func TestIndexFileKey(t *testing.T) { t.Parallel() dispatcherID := common.NewDispatcherID() @@ -107,3 +107,42 @@ func TestDmlPathKey(t *testing.T) { require.Equal(t, tc.path, fileName) } } + +func TestDmlPathKey(t *testing.T) { + t.Parallel() + + dispatcherID := common.NewDispatcherID() + testCases := []struct { + index int + fileIndexWidth int + extension string + path string + dmlkey DmlPathKey + }{ + { + index: 10, + fileIndexWidth: 20, + extension: ".csv", + path: fmt.Sprintf("schema1/table1/123456/2023-05-09/CDC_%s_00000000000000000010.csv", dispatcherID.String()), + dmlkey: DmlPathKey{ + SchemaPathKey: SchemaPathKey{ + Schema: "schema1", + Table: "table1", + TableVersion: 123456, + }, + PartitionNum: 0, + Date: "2023-05-09", + }, + }, + } + + for _, tc := range testCases { + var dmlkey DmlPathKey + fileIndex, err := dmlkey.ParseDMLFilePath("day", tc.path) + require.NoError(t, err) + require.Equal(t, tc.dmlkey, dmlkey) + + fileName := dmlkey.GenerateDMLFilePath(fileIndex, tc.extension, tc.fileIndexWidth) + require.Equal(t, tc.path, fileName) + } +}