From f334d70ef76263c5fdc01065616c144895c0dcbe Mon Sep 17 00:00:00 2001 From: Jianjun Liao <36503113+Leavrth@users.noreply.github.com> Date: Tue, 24 Feb 2026 12:43:29 +0800 Subject: [PATCH 01/10] Support multi-clusters active-active data consistent checker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces a consistency validation tool for multi-cluster active-active TiCDC deployments. It connects to each cluster’s PD(etcd) and S3 sink, verifies critical changefeed settings (for example canal-json, date-separator=day, and expected file index width), then continuously watches checkpoints and sink files to advance a unified time window across clusters. Within each window, the checker compares local writes and replicated records by primary key and timestamp to detect anomalies such as missing data, redundant data, out-of-order updates, and LWW violations. The recorder persists checkpoints and reports atomically, so the process is resumable after failures and suitable for long-running, auditable consistency monitoring. --- Makefile | 3 + .../advancer/time_window_advancer.go | 302 ++++++ .../advancer/time_window_advancer_test.go | 226 +++++ .../checker/checker.go | 657 +++++++++++++ .../checker/checker_test.go | 780 +++++++++++++++ .../config/config.example.toml | 46 + .../config/config.go | 152 +++ .../config/config_test.go | 393 ++++++++ .../consumer/consumer.go | 769 +++++++++++++++ .../consumer/consumer_test.go | 706 ++++++++++++++ .../decoder/decoder.go | 454 +++++++++ .../decoder/decoder_test.go | 309 ++++++ .../decoder/value_to_datum_test.go | 898 ++++++++++++++++++ .../integration/integration_test.go | 732 ++++++++++++++ .../integration/mock_cluster.go | 205 ++++ cmd/multi-cluster-consistency-checker/main.go | 160 ++++ .../main_test.go | 192 ++++ .../recorder/recorder.go | 266 ++++++ .../recorder/recorder_test.go | 561 +++++++++++ .../recorder/types.go | 410 ++++++++ .../recorder/types_test.go | 639 +++++++++++++ cmd/multi-cluster-consistency-checker/task.go | 330 +++++++ .../types/types.go | 98 ++ .../types/types_test.go | 69 ++ .../watcher/checkpoint_watcher.go | 300 ++++++ .../watcher/checkpoint_watcher_test.go | 543 +++++++++++ .../watcher/s3_watcher.go | 70 ++ .../watcher/s3_watcher_test.go | 202 ++++ 28 files changed, 10472 insertions(+) create mode 100644 cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go create mode 100644 cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go create mode 100644 cmd/multi-cluster-consistency-checker/checker/checker.go create mode 100644 cmd/multi-cluster-consistency-checker/checker/checker_test.go create mode 100644 cmd/multi-cluster-consistency-checker/config/config.example.toml create mode 100644 cmd/multi-cluster-consistency-checker/config/config.go create mode 100644 cmd/multi-cluster-consistency-checker/config/config_test.go create mode 100644 cmd/multi-cluster-consistency-checker/consumer/consumer.go create mode 100644 cmd/multi-cluster-consistency-checker/consumer/consumer_test.go create mode 100644 cmd/multi-cluster-consistency-checker/decoder/decoder.go create mode 100644 cmd/multi-cluster-consistency-checker/decoder/decoder_test.go create mode 100644 cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go create mode 100644 cmd/multi-cluster-consistency-checker/integration/integration_test.go create mode 100644 cmd/multi-cluster-consistency-checker/integration/mock_cluster.go create mode 100644 cmd/multi-cluster-consistency-checker/main.go create mode 100644 cmd/multi-cluster-consistency-checker/main_test.go create mode 100644 cmd/multi-cluster-consistency-checker/recorder/recorder.go create mode 100644 cmd/multi-cluster-consistency-checker/recorder/recorder_test.go create mode 100644 cmd/multi-cluster-consistency-checker/recorder/types.go create mode 100644 cmd/multi-cluster-consistency-checker/recorder/types_test.go create mode 100644 cmd/multi-cluster-consistency-checker/task.go create mode 100644 cmd/multi-cluster-consistency-checker/types/types.go create mode 100644 cmd/multi-cluster-consistency-checker/types/types_test.go create mode 100644 cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go create mode 100644 cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go create mode 100644 cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go create mode 100644 cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go diff --git a/Makefile b/Makefile index 2dd4c8a8a1..1fbc1bd73b 100644 --- a/Makefile +++ b/Makefile @@ -174,6 +174,9 @@ filter_helper: config-converter: $(GOBUILD) -ldflags '$(LDFLAGS)' -o bin/cdc_config_converter ./cmd/config-converter/main.go +multi-cluster-consistency-checker: + $(GOBUILD) -ldflags '$(LDFLAGS)' -o bin/multi-cluster-consistency-checker ./cmd/multi-cluster-consistency-checker + fmt: tools/bin/gofumports tools/bin/shfmt tools/bin/gci @echo "run gci (format imports)" tools/bin/gci write $(FILES) 2>&1 | $(FAIL_ON_STDOUT) diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go new file mode 100644 index 0000000000..afa93bec49 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go @@ -0,0 +1,302 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package advancer + +import ( + "context" + "maps" + "sync" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +type TimeWindowAdvancer struct { + // round is the current round of the time window + round uint64 + + // timeWindowTriplet is the triplet of adjacent time windows, mapping from cluster ID to the triplet + timeWindowTriplet map[string][3]types.TimeWindow + + // checkpointWatcher is the Active-Active checkpoint watcher for each cluster, + // mapping from local cluster ID to replicated cluster ID to the checkpoint watcher + checkpointWatcher map[string]map[string]watcher.Watcher + + // s3checkpointWatcher is the S3 checkpoint watcher for each cluster, mapping from cluster ID to the s3 checkpoint watcher + s3Watcher map[string]*watcher.S3Watcher + + // pdClients is the pd clients for each cluster, mapping from cluster ID to the pd client + pdClients map[string]pd.Client +} + +func NewTimeWindowAdvancer( + ctx context.Context, + checkpointWatchers map[string]map[string]watcher.Watcher, + s3Watchers map[string]*watcher.S3Watcher, + pdClients map[string]pd.Client, + checkpoint *recorder.Checkpoint, +) (*TimeWindowAdvancer, map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + timeWindowTriplet := make(map[string][3]types.TimeWindow) + for clusterID := range pdClients { + timeWindowTriplet[clusterID] = [3]types.TimeWindow{} + } + advancer := &TimeWindowAdvancer{ + round: 0, + timeWindowTriplet: timeWindowTriplet, + checkpointWatcher: checkpointWatchers, + s3Watcher: s3Watchers, + pdClients: pdClients, + } + newDataMap, err := advancer.initializeFromCheckpoint(ctx, checkpoint) + if err != nil { + return nil, nil, errors.Trace(err) + } + return advancer, newDataMap, nil +} + +func (t *TimeWindowAdvancer) initializeFromCheckpoint( + ctx context.Context, + checkpoint *recorder.Checkpoint, +) (map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + if checkpoint == nil { + return nil, nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil, nil + } + t.round = checkpoint.CheckpointItems[2].Round + 1 + for clusterID := range t.timeWindowTriplet { + newTimeWindows := [3]types.TimeWindow{} + newTimeWindows[2] = checkpoint.CheckpointItems[2].ClusterInfo[clusterID].TimeWindow + if checkpoint.CheckpointItems[1] != nil { + newTimeWindows[1] = checkpoint.CheckpointItems[1].ClusterInfo[clusterID].TimeWindow + } + if checkpoint.CheckpointItems[0] != nil { + newTimeWindows[0] = checkpoint.CheckpointItems[0].ClusterInfo[clusterID].TimeWindow + } + t.timeWindowTriplet[clusterID] = newTimeWindows + } + + var mu sync.Mutex + newDataMap := make(map[string]map[cloudstorage.DmlPathKey]types.IncrementalData) + eg, egCtx := errgroup.WithContext(ctx) + for clusterID, s3Watcher := range t.s3Watcher { + eg.Go(func() error { + newData, err := s3Watcher.InitializeFromCheckpoint(egCtx, clusterID, checkpoint) + if err != nil { + return errors.Trace(err) + } + mu.Lock() + newDataMap[clusterID] = newData + mu.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return newDataMap, nil +} + +// AdvanceTimeWindow advances the time window for each cluster. Here is the steps: +// 1. Advance the checkpoint ts for each local-to-replicated changefeed. +// +// For any local-to-replicated changefeed, the checkpoint ts should be advanced to +// the maximum of pd timestamp after previous time window of the replicated cluster +// advanced and the right boundary of previous time window of every clusters. +// +// 2. Advance the right boundary for each cluster. +// +// For any cluster, the right boundary should be advanced to the maximum of pd timestamp of +// the cluster after the checkpoint ts of its local cluster advanced and the previous +// timewindow's checkpoint ts of changefeed where the cluster is the local or the replicated. +// +// 3. Update the time window for each cluster. +// +// For any cluster, the time window should be updated to the new time window. +func (t *TimeWindowAdvancer) AdvanceTimeWindow( + pctx context.Context, +) (map[string]types.TimeWindowData, error) { + log.Debug("advance time window", zap.Uint64("round", t.round)) + // mapping from local cluster ID to replicated cluster ID to the min checkpoint timestamp + minCheckpointTsMap := make(map[string]map[string]uint64) + maxTimeWindowRightBoundary := uint64(0) + for replicatedClusterID, triplet := range t.timeWindowTriplet { + for localClusterID, pdTimestampAfterTimeWindow := range triplet[2].PDTimestampAfterTimeWindow { + if _, ok := minCheckpointTsMap[localClusterID]; !ok { + minCheckpointTsMap[localClusterID] = make(map[string]uint64) + } + minCheckpointTsMap[localClusterID][replicatedClusterID] = max(minCheckpointTsMap[localClusterID][replicatedClusterID], pdTimestampAfterTimeWindow) + } + maxTimeWindowRightBoundary = max(maxTimeWindowRightBoundary, triplet[2].RightBoundary) + } + + var lock sync.Mutex + newTimeWindow := make(map[string]types.TimeWindow) + maxPDTimestampAfterCheckpointTs := make(map[string]uint64) + // for cluster ID, the max checkpoint timestamp is maximum of checkpoint from cluster to other clusters and checkpoint from other clusters to cluster + maxCheckpointTs := make(map[string]uint64) + // Advance the checkpoint ts for each cluster + eg, ctx := errgroup.WithContext(pctx) + for localClusterID, replicatedCheckpointWatcherMap := range t.checkpointWatcher { + for replicatedClusterID, checkpointWatcher := range replicatedCheckpointWatcherMap { + minCheckpointTs := max(minCheckpointTsMap[localClusterID][replicatedClusterID], maxTimeWindowRightBoundary) + eg.Go(func() error { + checkpointTs, err := checkpointWatcher.AdvanceCheckpointTs(ctx, minCheckpointTs) + if err != nil { + return errors.Trace(err) + } + // TODO: optimize this by getting pd ts in the end of all checkpoint ts advance + pdtsos, err := t.getPDTsFromOtherClusters(ctx, localClusterID) + if err != nil { + return errors.Trace(err) + } + lock.Lock() + timeWindow := newTimeWindow[localClusterID] + if timeWindow.CheckpointTs == nil { + timeWindow.CheckpointTs = make(map[string]uint64) + } + timeWindow.CheckpointTs[replicatedClusterID] = checkpointTs + newTimeWindow[localClusterID] = timeWindow + for otherClusterID, pdtso := range pdtsos { + maxPDTimestampAfterCheckpointTs[otherClusterID] = max(maxPDTimestampAfterCheckpointTs[otherClusterID], pdtso) + } + maxCheckpointTs[localClusterID] = max(maxCheckpointTs[localClusterID], checkpointTs) + maxCheckpointTs[replicatedClusterID] = max(maxCheckpointTs[replicatedClusterID], checkpointTs) + lock.Unlock() + return nil + }) + } + } + if err := eg.Wait(); err != nil { + return nil, errors.Annotate(err, "advance checkpoint timestamp failed") + } + + // Update the time window for each cluster + newDataMap := make(map[string]map[cloudstorage.DmlPathKey]types.IncrementalData) + maxVersionMap := make(map[string]map[types.SchemaTableKey]types.VersionKey) + eg, ctx = errgroup.WithContext(pctx) + for clusterID, triplet := range t.timeWindowTriplet { + minTimeWindowRightBoundary := max(maxCheckpointTs[clusterID], maxPDTimestampAfterCheckpointTs[clusterID], triplet[2].NextMinLeftBoundary) + s3Watcher := t.s3Watcher[clusterID] + eg.Go(func() error { + s3CheckpointTs, err := s3Watcher.AdvanceS3CheckpointTs(ctx, minTimeWindowRightBoundary) + if err != nil { + return errors.Trace(err) + } + newData, maxClusterVersionMap, err := s3Watcher.ConsumeNewFiles(ctx) + if err != nil { + return errors.Trace(err) + } + pdtso, err := t.getPDTsFromCluster(ctx, clusterID) + if err != nil { + return errors.Trace(err) + } + pdtsos, err := t.getPDTsFromOtherClusters(ctx, clusterID) + if err != nil { + return errors.Trace(err) + } + lock.Lock() + newDataMap[clusterID] = newData + maxVersionMap[clusterID] = maxClusterVersionMap + timeWindow := newTimeWindow[clusterID] + timeWindow.LeftBoundary = triplet[2].RightBoundary + timeWindow.RightBoundary = s3CheckpointTs + timeWindow.PDTimestampAfterTimeWindow = make(map[string]uint64) + timeWindow.NextMinLeftBoundary = pdtso + maps.Copy(timeWindow.PDTimestampAfterTimeWindow, pdtsos) + newTimeWindow[clusterID] = timeWindow + lock.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Annotate(err, "advance time window failed") + } + t.updateTimeWindow(newTimeWindow) + t.round++ + return newTimeWindowData(newTimeWindow, newDataMap, maxVersionMap), nil +} + +func (t *TimeWindowAdvancer) updateTimeWindow(newTimeWindow map[string]types.TimeWindow) { + for clusterID, timeWindow := range newTimeWindow { + triplet := t.timeWindowTriplet[clusterID] + triplet[0] = triplet[1] + triplet[1] = triplet[2] + triplet[2] = timeWindow + t.timeWindowTriplet[clusterID] = triplet + log.Debug("update time window", zap.String("clusterID", clusterID), zap.Any("timeWindow", timeWindow)) + } +} + +func (t *TimeWindowAdvancer) getPDTsFromCluster(ctx context.Context, clusterID string) (uint64, error) { + pdClient := t.pdClients[clusterID] + phyTs, logicTs, err := pdClient.GetTS(ctx) + if err != nil { + return 0, errors.Trace(err) + } + ts := oracle.ComposeTS(phyTs, logicTs) + return ts, nil +} + +func (t *TimeWindowAdvancer) getPDTsFromOtherClusters(pctx context.Context, clusterID string) (map[string]uint64, error) { + var lock sync.Mutex + pdtsos := make(map[string]uint64) + eg, ctx := errgroup.WithContext(pctx) + for otherClusterID := range t.pdClients { + if otherClusterID == clusterID { + continue + } + pdClient := t.pdClients[otherClusterID] + eg.Go(func() error { + phyTs, logicTs, err := pdClient.GetTS(ctx) + if err != nil { + return errors.Trace(err) + } + ts := oracle.ComposeTS(phyTs, logicTs) + lock.Lock() + pdtsos[otherClusterID] = ts + lock.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return pdtsos, nil +} + +func newTimeWindowData( + newTimeWindow map[string]types.TimeWindow, + newDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, + maxVersionMap map[string]map[types.SchemaTableKey]types.VersionKey, +) map[string]types.TimeWindowData { + timeWindowDatas := make(map[string]types.TimeWindowData) + for clusterID, timeWindow := range newTimeWindow { + timeWindowDatas[clusterID] = types.TimeWindowData{ + TimeWindow: timeWindow, + Data: newDataMap[clusterID], + MaxVersion: maxVersionMap[clusterID], + } + } + return timeWindowDatas +} diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go new file mode 100644 index 0000000000..d327f85a7d --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go @@ -0,0 +1,226 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package advancer + +import ( + "context" + "maps" + "sync" + "sync/atomic" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/stretchr/testify/require" + pd "github.com/tikv/pd/client" +) + +// mockPDClient mocks pd.Client for testing. +// Each call to GetTS returns a monotonically increasing TSO (physical part increases by 1000ms per call). +type mockPDClient struct { + pd.Client + seq int64 // accessed atomically +} + +func (m *mockPDClient) GetTS(ctx context.Context) (int64, int64, error) { + n := atomic.AddInt64(&m.seq, 1) + // Physical timestamp starts at 11000ms and increases by 1000ms per call. + // oracle.ComposeTS(physical, 0) = physical << 18, so each step is ~262 million. + return 10000 + n*1000, 0, nil +} + +func (m *mockPDClient) Close() {} + +// mockAdvancerWatcher mocks watcher.Watcher for testing. +// Returns minCheckpointTs + delta, ensuring the result is always > minCheckpointTs and monotonically increasing. +type mockAdvancerWatcher struct { + mu sync.Mutex + delta uint64 + history []uint64 +} + +func (m *mockAdvancerWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := minCheckpointTs + m.delta + m.history = append(m.history, result) + return result, nil +} + +func (m *mockAdvancerWatcher) Close() {} + +func (m *mockAdvancerWatcher) getHistory() []uint64 { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]uint64, len(m.history)) + copy(out, m.history) + return out +} + +func TestNewTimeWindowAdvancer(t *testing.T) { + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "cluster1": {}, + "cluster2": {}, + } + s3Watchers := map[string]*watcher.S3Watcher{ + "cluster1": nil, + "cluster2": nil, + } + pdClients := map[string]pd.Client{ + "cluster1": nil, + "cluster2": nil, + } + + advancer, _, err := NewTimeWindowAdvancer(context.Background(), checkpointWatchers, s3Watchers, pdClients, nil) + require.NoError(t, err) + require.NotNil(t, advancer) + require.Equal(t, uint64(0), advancer.round) + require.Len(t, advancer.timeWindowTriplet, 2) + require.Contains(t, advancer.timeWindowTriplet, "cluster1") + require.Contains(t, advancer.timeWindowTriplet, "cluster2") +} + +// TestTimeWindowAdvancer_AdvanceMultipleRounds simulates 4 rounds of AdvanceTimeWindow +// with 2 clusters (c1, c2) performing bidirectional replication. +// +// The test verifies: +// - Time windows advance correctly (LeftBoundary == previous RightBoundary) +// - RightBoundary > LeftBoundary for each time window +// - Checkpoint timestamps are monotonically increasing across rounds +// - PD TSOs are always greater than checkpoint timestamps +// - NextMinLeftBoundary (PD TSO) > RightBoundary (S3 checkpoint) +// - Mock watcher checkpoint histories are strictly increasing +func TestTimeWindowAdvancer_AdvanceMultipleRounds(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create mock PD clients for each cluster (monotonically increasing TSO) + pdC1 := &mockPDClient{} + pdC2 := &mockPDClient{} + pdClients := map[string]pd.Client{ + "c1": pdC1, + "c2": pdC2, + } + + // Create mock checkpoint watchers for bidirectional replication (c1->c2, c2->c1) + // Each returns minCheckpointTs + 100 + cpWatcherC1C2 := &mockAdvancerWatcher{delta: 100} + cpWatcherC2C1 := &mockAdvancerWatcher{delta: 100} + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "c1": {"c2": cpWatcherC1C2}, + "c2": {"c1": cpWatcherC2C1}, + } + + // Create S3 watchers with mock checkpoint watchers (returns minCheckpointTs + 50) + // and empty in-memory storage (no actual S3 data) + s3WatcherMockC1 := &mockAdvancerWatcher{delta: 50} + s3WatcherMockC2 := &mockAdvancerWatcher{delta: 50} + s3Watchers := map[string]*watcher.S3Watcher{ + "c1": watcher.NewS3Watcher(s3WatcherMockC1, storage.NewMemStorage(), nil), + "c2": watcher.NewS3Watcher(s3WatcherMockC2, storage.NewMemStorage(), nil), + } + + advancer, _, err := NewTimeWindowAdvancer(ctx, checkpointWatchers, s3Watchers, pdClients, nil) + require.NoError(t, err) + require.Equal(t, uint64(0), advancer.round) + + // Track previous round values for cross-round assertions + prevRightBoundaries := map[string]uint64{"c1": 0, "c2": 0} + prevCheckpointTs := map[string]map[string]uint64{ + "c1": {"c2": 0}, + "c2": {"c1": 0}, + } + prevRightBoundary := uint64(0) // max across all clusters + + for round := range 4 { + result, err := advancer.AdvanceTimeWindow(ctx) + require.NoError(t, err, "round %d", round) + require.Len(t, result, 2, "round %d: should have data for both clusters", round) + + for clusterID, twData := range result { + tw := twData.TimeWindow + + // 1. LeftBoundary == previous RightBoundary + require.Equal(t, prevRightBoundaries[clusterID], tw.LeftBoundary, + "round %d, cluster %s: LeftBoundary should equal previous RightBoundary", round, clusterID) + + // 2. RightBoundary > LeftBoundary (time window is non-empty) + require.Greater(t, tw.RightBoundary, tw.LeftBoundary, + "round %d, cluster %s: RightBoundary should be > LeftBoundary", round, clusterID) + + // 3. CheckpointTs should be populated and strictly increasing across rounds + require.NotEmpty(t, tw.CheckpointTs, + "round %d, cluster %s: CheckpointTs should be populated", round, clusterID) + for replicatedCluster, cpTs := range tw.CheckpointTs { + require.Greater(t, cpTs, prevCheckpointTs[clusterID][replicatedCluster], + "round %d, %s->%s: checkpoint should be strictly increasing", round, clusterID, replicatedCluster) + } + + // 4. PDTimestampAfterTimeWindow should be populated + require.NotEmpty(t, tw.PDTimestampAfterTimeWindow, + "round %d, cluster %s: PDTimestampAfterTimeWindow should be populated", round, clusterID) + + // 5. NextMinLeftBoundary > RightBoundary + // (PD TSO is obtained after S3 checkpoint, and PD TSO >> S3 checkpoint) + require.Greater(t, tw.NextMinLeftBoundary, tw.RightBoundary, + "round %d, cluster %s: NextMinLeftBoundary (PD TSO) should be > RightBoundary (S3 checkpoint)", round, clusterID) + + // 6. PD TSO values in PDTimestampAfterTimeWindow > all CheckpointTs values + // (PD TSOs are obtained after checkpoint advance) + for otherCluster, pdTs := range tw.PDTimestampAfterTimeWindow { + for replicatedCluster, cpTs := range tw.CheckpointTs { + require.Greater(t, pdTs, cpTs, + "round %d, cluster %s: PD TSO (from %s) should be > checkpoint (%s->%s)", + round, clusterID, otherCluster, clusterID, replicatedCluster) + } + } + + // 7. RightBoundary > previous round's max RightBoundary (time window advances) + require.Greater(t, tw.RightBoundary, prevRightBoundary, + "round %d, cluster %s: RightBoundary should be > previous max RightBoundary", round, clusterID) + } + + // Save current values for next round + maxRB := uint64(0) + for clusterID, twData := range result { + prevRightBoundaries[clusterID] = twData.TimeWindow.RightBoundary + if twData.TimeWindow.RightBoundary > maxRB { + maxRB = twData.TimeWindow.RightBoundary + } + maps.Copy(prevCheckpointTs[clusterID], twData.TimeWindow.CheckpointTs) + } + prevRightBoundary = maxRB + } + + // After 4 rounds, round counter should be 4 + require.Equal(t, uint64(4), advancer.round) + + // Verify all mock watcher checkpoint histories are strictly monotonically increasing + allWatchers := []*mockAdvancerWatcher{cpWatcherC1C2, cpWatcherC2C1, s3WatcherMockC1, s3WatcherMockC2} + watcherNames := []string{"cp c1->c2", "cp c2->c1", "s3 c1", "s3 c2"} + for idx, w := range allWatchers { + history := w.getHistory() + require.GreaterOrEqual(t, len(history), 4, + "%s: should have at least 4 checkpoint values (one per round)", watcherNames[idx]) + for i := 1; i < len(history); i++ { + require.Greater(t, history[i], history[i-1], + "%s: checkpoint values should be strictly increasing (index %d: %d -> %d)", + watcherNames[idx], i, history[i-1], history[i]) + } + } + + // Verify PD clients were called (monotonically increasing due to atomic counter) + require.Greater(t, atomic.LoadInt64(&pdC1.seq), int64(0), "pd-c1 should have been called") + require.Greater(t, atomic.LoadInt64(&pdC2.seq), int64(0), "pd-c2 should have been called") +} diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go new file mode 100644 index 0000000000..1365b9142b --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -0,0 +1,657 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "sort" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "go.uber.org/zap" +) + +type versionCacheEntry struct { + previous int + cdcVersion types.CdcVersion +} + +type clusterViolationChecker struct { + clusterID string + twoPreviousTimeWindowKeyVersionCache map[string]map[types.PkType]versionCacheEntry +} + +func newClusterViolationChecker(clusterID string) *clusterViolationChecker { + return &clusterViolationChecker{ + clusterID: clusterID, + twoPreviousTimeWindowKeyVersionCache: make(map[string]map[types.PkType]versionCacheEntry), + } +} + +func (c *clusterViolationChecker) NewRecordFromCheckpoint(schemaKey string, record *decoder.Record, previous int) { + tableSchemaKeyVersionCache, exists := c.twoPreviousTimeWindowKeyVersionCache[schemaKey] + if !exists { + tableSchemaKeyVersionCache = make(map[types.PkType]versionCacheEntry) + c.twoPreviousTimeWindowKeyVersionCache[schemaKey] = tableSchemaKeyVersionCache + } + entry, exists := tableSchemaKeyVersionCache[record.Pk] + if !exists { + tableSchemaKeyVersionCache[record.Pk] = versionCacheEntry{ + previous: previous, + cdcVersion: record.CdcVersion, + } + return + } + entryCompareTs := entry.cdcVersion.GetCompareTs() + recordCompareTs := record.GetCompareTs() + if entryCompareTs < recordCompareTs { + tableSchemaKeyVersionCache[record.Pk] = versionCacheEntry{ + previous: previous, + cdcVersion: record.CdcVersion, + } + } +} + +func (c *clusterViolationChecker) Check(schemaKey string, r *decoder.Record, report *recorder.ClusterReport) { + tableSchemaKeyVersionCache, exists := c.twoPreviousTimeWindowKeyVersionCache[schemaKey] + if !exists { + tableSchemaKeyVersionCache = make(map[types.PkType]versionCacheEntry) + c.twoPreviousTimeWindowKeyVersionCache[schemaKey] = tableSchemaKeyVersionCache + } + entry, exists := tableSchemaKeyVersionCache[r.Pk] + if !exists { + tableSchemaKeyVersionCache[r.Pk] = versionCacheEntry{ + previous: 0, + cdcVersion: r.CdcVersion, + } + return + } + if entry.cdcVersion.CommitTs >= r.CommitTs { + // duplicated old version, just skip it + return + } + entryCompareTs := entry.cdcVersion.GetCompareTs() + recordCompareTs := r.GetCompareTs() + if entryCompareTs >= recordCompareTs { + // violation detected + log.Error("LWW violation detected", + zap.String("clusterID", c.clusterID), + zap.Any("entry", entry), + zap.String("pk", r.PkStr)) + report.AddLWWViolationItem(schemaKey, r.PkMap, r.PkStr, entry.cdcVersion.OriginTs, entry.cdcVersion.CommitTs, r.OriginTs, r.CommitTs) + return + } + tableSchemaKeyVersionCache[r.Pk] = versionCacheEntry{ + previous: 0, + cdcVersion: r.CdcVersion, + } +} + +func (c *clusterViolationChecker) UpdateCache() { + newTwoPreviousTimeWindowKeyVersionCache := make(map[string]map[types.PkType]versionCacheEntry) + for schemaKey, tableSchemaKeyVersionCache := range c.twoPreviousTimeWindowKeyVersionCache { + newTableSchemaKeyVersionCache := make(map[types.PkType]versionCacheEntry) + for primaryKey, entry := range tableSchemaKeyVersionCache { + if entry.previous >= 2 { + continue + } + newTableSchemaKeyVersionCache[primaryKey] = versionCacheEntry{ + previous: entry.previous + 1, + cdcVersion: entry.cdcVersion, + } + } + if len(newTableSchemaKeyVersionCache) > 0 { + newTwoPreviousTimeWindowKeyVersionCache[schemaKey] = newTableSchemaKeyVersionCache + } + } + c.twoPreviousTimeWindowKeyVersionCache = newTwoPreviousTimeWindowKeyVersionCache +} + +type tableDataCache struct { + // localDataCache is a map of primary key to a map of commit ts to a record + localDataCache map[types.PkType]map[uint64]*decoder.Record + + // replicatedDataCache is a map of primary key to a map of origin ts to a record + replicatedDataCache map[types.PkType]map[uint64]*decoder.Record +} + +func newTableDataCache() *tableDataCache { + return &tableDataCache{ + localDataCache: make(map[types.PkType]map[uint64]*decoder.Record), + replicatedDataCache: make(map[types.PkType]map[uint64]*decoder.Record), + } +} + +func (tdc *tableDataCache) newLocalRecord(record *decoder.Record) { + recordsMap, exists := tdc.localDataCache[record.Pk] + if !exists { + recordsMap = make(map[uint64]*decoder.Record) + tdc.localDataCache[record.Pk] = recordsMap + } + recordsMap[record.CommitTs] = record +} + +func (tdc *tableDataCache) newReplicatedRecord(record *decoder.Record) { + recordsMap, exists := tdc.replicatedDataCache[record.Pk] + if !exists { + recordsMap = make(map[uint64]*decoder.Record) + tdc.replicatedDataCache[record.Pk] = recordsMap + } + recordsMap[record.OriginTs] = record +} + +type timeWindowDataCache struct { + tableDataCaches map[string]*tableDataCache + + leftBoundary uint64 + rightBoundary uint64 + checkpointTs map[string]uint64 +} + +func newTimeWindowDataCache(leftBoundary, rightBoundary uint64, checkpointTs map[string]uint64) timeWindowDataCache { + return timeWindowDataCache{ + tableDataCaches: make(map[string]*tableDataCache), + leftBoundary: leftBoundary, + rightBoundary: rightBoundary, + checkpointTs: checkpointTs, + } +} + +func (twdc *timeWindowDataCache) NewRecord(schemaKey string, record *decoder.Record) { + if record.CommitTs <= twdc.leftBoundary { + // record is before the left boundary, just skip it + return + } + tableDataCache, exists := twdc.tableDataCaches[schemaKey] + if !exists { + tableDataCache = newTableDataCache() + twdc.tableDataCaches[schemaKey] = tableDataCache + } + if record.OriginTs == 0 { + tableDataCache.newLocalRecord(record) + } else { + tableDataCache.newReplicatedRecord(record) + } +} + +type clusterDataChecker struct { + clusterID string + + thisRoundTimeWindow types.TimeWindow + + timeWindowDataCaches [3]timeWindowDataCache + + rightBoundary uint64 + + overDataCaches map[string][]*decoder.Record + + clusterViolationChecker *clusterViolationChecker + + report *recorder.ClusterReport + + lwwSkippedRecordsCount int + checkedRecordsCount int + newTimeWindowRecordsCount int +} + +func newClusterDataChecker(clusterID string) *clusterDataChecker { + return &clusterDataChecker{ + clusterID: clusterID, + timeWindowDataCaches: [3]timeWindowDataCache{}, + rightBoundary: 0, + overDataCaches: make(map[string][]*decoder.Record), + clusterViolationChecker: newClusterViolationChecker(clusterID), + } +} + +func (cd *clusterDataChecker) InitializeFromCheckpoint( + ctx context.Context, + checkpointDataMap map[cloudstorage.DmlPathKey]types.IncrementalData, + checkpoint *recorder.Checkpoint, +) error { + if checkpoint == nil { + return nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil + } + clusterInfo := checkpoint.CheckpointItems[2].ClusterInfo[cd.clusterID] + cd.rightBoundary = clusterInfo.TimeWindow.RightBoundary + cd.timeWindowDataCaches[2] = newTimeWindowDataCache( + clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) + if checkpoint.CheckpointItems[1] != nil { + clusterInfo = checkpoint.CheckpointItems[1].ClusterInfo[cd.clusterID] + cd.timeWindowDataCaches[1] = newTimeWindowDataCache( + clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) + } + for schemaPathKey, incrementalData := range checkpointDataMap { + schemaKey := schemaPathKey.GetKey() + for _, contents := range incrementalData.DataContentSlices { + for _, content := range contents { + records, err := decoder.Decode(content, incrementalData.ColumnFieldTypes) + if err != nil { + return errors.Trace(err) + } + for _, record := range records { + cd.newRecordFromCheckpoint(schemaKey, record) + } + } + } + } + return nil +} + +func (cd *clusterDataChecker) newRecordFromCheckpoint(schemaKey string, record *decoder.Record) { + if record.CommitTs > cd.rightBoundary { + cd.overDataCaches[schemaKey] = append(cd.overDataCaches[schemaKey], record) + return + } + if cd.timeWindowDataCaches[2].leftBoundary < record.CommitTs { + cd.timeWindowDataCaches[2].NewRecord(schemaKey, record) + cd.clusterViolationChecker.NewRecordFromCheckpoint(schemaKey, record, 1) + + } else if cd.timeWindowDataCaches[1].leftBoundary < record.CommitTs { + cd.timeWindowDataCaches[1].NewRecord(schemaKey, record) + cd.clusterViolationChecker.NewRecordFromCheckpoint(schemaKey, record, 2) + } +} + +func (cd *clusterDataChecker) PrepareNextTimeWindowData(timeWindow types.TimeWindow) error { + if timeWindow.LeftBoundary != cd.rightBoundary { + return errors.Errorf("time window left boundary(%d) mismatch right boundary ts(%d)", timeWindow.LeftBoundary, cd.rightBoundary) + } + cd.timeWindowDataCaches[0] = cd.timeWindowDataCaches[1] + cd.timeWindowDataCaches[1] = cd.timeWindowDataCaches[2] + newTimeWindowDataCache := newTimeWindowDataCache(timeWindow.LeftBoundary, timeWindow.RightBoundary, timeWindow.CheckpointTs) + cd.rightBoundary = timeWindow.RightBoundary + newOverDataCache := make(map[string][]*decoder.Record) + for schemaKey, overRecords := range cd.overDataCaches { + newTableOverDataCache := make([]*decoder.Record, 0, len(overRecords)) + for _, overRecord := range overRecords { + if overRecord.CommitTs > timeWindow.RightBoundary { + newTableOverDataCache = append(newTableOverDataCache, overRecord) + } else { + newTimeWindowDataCache.NewRecord(schemaKey, overRecord) + } + } + newOverDataCache[schemaKey] = newTableOverDataCache + } + cd.timeWindowDataCaches[2] = newTimeWindowDataCache + cd.overDataCaches = newOverDataCache + cd.lwwSkippedRecordsCount = 0 + cd.checkedRecordsCount = 0 + cd.newTimeWindowRecordsCount = 0 + return nil +} + +func (cd *clusterDataChecker) NewRecord(schemaKey string, record *decoder.Record) { + if record.CommitTs > cd.rightBoundary { + cd.overDataCaches[schemaKey] = append(cd.overDataCaches[schemaKey], record) + return + } + cd.timeWindowDataCaches[2].NewRecord(schemaKey, record) +} + +func (cd *clusterDataChecker) findClusterReplicatedDataInTimeWindow(timeWindowIdx int, schemaKey string, pk types.PkType, originTs uint64) (*decoder.Record, bool) { + tableDataCache, exists := cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches[schemaKey] + if !exists { + return nil, false + } + records, exists := tableDataCache.replicatedDataCache[pk] + if !exists { + return nil, false + } + if record, exists := records[originTs]; exists { + return record, false + } + for _, record := range records { + if record.GetCompareTs() >= originTs { + return nil, true + } + } + return nil, false +} + +func (cd *clusterDataChecker) findClusterLocalDataInTimeWindow(timeWindowIdx int, schemaKey string, pk types.PkType, commitTs uint64) bool { + tableDataCache, exists := cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches[schemaKey] + if !exists { + return false + } + records, exists := tableDataCache.localDataCache[pk] + if !exists { + return false + } + _, exists = records[commitTs] + return exists +} + +// diffColumns compares column values between local written and replicated records +// and returns the list of inconsistent columns. +func diffColumns(local, replicated *decoder.Record) []recorder.InconsistentColumn { + var result []recorder.InconsistentColumn + for colName, localVal := range local.ColumnValues { + replicatedVal, ok := replicated.ColumnValues[colName] + if !ok { + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: localVal, + Replicated: nil, + }) + } else if localVal != replicatedVal { // safe: ColumnValues only holds comparable types (see decoder.go) + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: localVal, + Replicated: replicatedVal, + }) + } + } + for colName, replicatedVal := range replicated.ColumnValues { + if _, ok := local.ColumnValues[colName]; !ok { + result = append(result, recorder.InconsistentColumn{ + Column: colName, + Local: nil, + Replicated: replicatedVal, + }) + } + } + sort.Slice(result, func(i, j int) bool { + return result[i].Column < result[j].Column + }) + return result +} + +// datalossDetection iterates through the local-written data cache [1] and [2] and filter out the records +// whose checkpoint ts falls within the (checkpoint[1], checkpoint[2]]. The record must be present +// in the replicated data cache [1] or [2] or another new record is present in the replicated data +// cache [1] or [2]. +func (cd *clusterDataChecker) dataLossDetection(checker *DataChecker) { + // Time window [1]: skip records whose commitTs <= checkpoint (already checked in previous round) + cd.checkLocalRecordsForDataLoss(1, func(commitTs, checkpointTs uint64) bool { + return commitTs <= checkpointTs + }, checker) + // Time window [2]: skip records whose commitTs > checkpoint (will be checked in next round) + cd.checkLocalRecordsForDataLoss(2, func(commitTs, checkpointTs uint64) bool { + return commitTs > checkpointTs + }, checker) +} + +// checkLocalRecordsForDataLoss iterates through the local-written data cache at timeWindowIdx +// and checks each record against the replicated data cache. Records for which shouldSkip returns +// true are skipped. This helper unifies the logic for time windows [1] and [2]. +func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( + timeWindowIdx int, + shouldSkip func(commitTs, checkpointTs uint64) bool, + checker *DataChecker, +) { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches { + for _, localDataCache := range tableDataCache.localDataCache { + for _, record := range localDataCache { + for replicatedClusterID, checkpointTs := range cd.timeWindowDataCaches[timeWindowIdx].checkpointTs { + if shouldSkip(record.CommitTs, checkpointTs) { + continue + } + cd.checkedRecordsCount++ + replicatedRecord, skipped := checker.FindClusterReplicatedData(replicatedClusterID, schemaKey, record.Pk, record.CommitTs) + if skipped { + log.Debug("replicated record skipped by LWW", + zap.String("local cluster ID", cd.clusterID), + zap.String("replicated cluster ID", replicatedClusterID), + zap.String("schemaKey", schemaKey), + zap.String("pk", record.PkStr), + zap.Uint64("commitTs", record.CommitTs)) + cd.lwwSkippedRecordsCount++ + continue + } + if replicatedRecord == nil { + // data loss detected + log.Error("data loss detected", + zap.String("local cluster ID", cd.clusterID), + zap.String("replicated cluster ID", replicatedClusterID), + zap.Any("record", record)) + cd.report.AddDataLossItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, record.CommitTs) + } else if !record.EqualReplicatedRecord(replicatedRecord) { + // data inconsistent detected + log.Error("data inconsistent detected", + zap.String("local cluster ID", cd.clusterID), + zap.String("replicated cluster ID", replicatedClusterID), + zap.Any("record", record)) + cd.report.AddDataInconsistentItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, replicatedRecord.OriginTs, record.CommitTs, replicatedRecord.CommitTs, diffColumns(record, replicatedRecord)) + } + } + } + } + } +} + +// dataRedundantDetection iterates through the replicated data cache [2]. The record must be present +// in the local data cache [1] [2] or [3]. +func (cd *clusterDataChecker) dataRedundantDetection(checker *DataChecker) { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { + for _, replicatedDataCache := range tableDataCache.replicatedDataCache { + for _, record := range replicatedDataCache { + cd.checkedRecordsCount++ + // For replicated records, OriginTs is the local commit ts + if !checker.FindSourceLocalData(cd.clusterID, schemaKey, record.Pk, record.OriginTs) { + // data redundant detected + log.Error("data redundant detected", + zap.String("replicated cluster ID", cd.clusterID), + zap.Any("record", record)) + cd.report.AddDataRedundantItem(schemaKey, record.PkMap, record.PkStr, record.OriginTs, record.CommitTs) + } + } + } + } +} + +// lwwViolationDetection check the orderliness of the records +func (cd *clusterDataChecker) lwwViolationDetection() { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { + for pk, localRecords := range tableDataCache.localDataCache { + replicatedRecords := tableDataCache.replicatedDataCache[pk] + pkRecords := make([]*decoder.Record, 0, len(localRecords)+len(replicatedRecords)) + for _, localRecord := range localRecords { + pkRecords = append(pkRecords, localRecord) + } + for _, replicatedRecord := range replicatedRecords { + pkRecords = append(pkRecords, replicatedRecord) + } + sort.Slice(pkRecords, func(i, j int) bool { + return pkRecords[i].CommitTs < pkRecords[j].CommitTs + }) + for _, record := range pkRecords { + cd.newTimeWindowRecordsCount++ + cd.clusterViolationChecker.Check(schemaKey, record, cd.report) + } + } + + for pk, replicatedRecords := range tableDataCache.replicatedDataCache { + if _, exists := tableDataCache.localDataCache[pk]; exists { + continue + } + pkRecords := make([]*decoder.Record, 0, len(replicatedRecords)) + for _, replicatedRecord := range replicatedRecords { + pkRecords = append(pkRecords, replicatedRecord) + } + sort.Slice(pkRecords, func(i, j int) bool { + return pkRecords[i].CommitTs < pkRecords[j].CommitTs + }) + for _, record := range pkRecords { + cd.newTimeWindowRecordsCount++ + cd.clusterViolationChecker.Check(schemaKey, record, cd.report) + } + } + } + + cd.clusterViolationChecker.UpdateCache() +} + +func (cd *clusterDataChecker) Check(checker *DataChecker, enableDataLoss, enableDataRedundant bool) { + cd.report = recorder.NewClusterReport(cd.clusterID, cd.thisRoundTimeWindow) + if enableDataLoss { + // CHECK 1 - Data Loss / Inconsistency Detection (round 2+) + // Needs [1] and [2] populated. + cd.dataLossDetection(checker) + } + if enableDataRedundant { + // CHECK 2 - Data Redundant Detection (round 3+) + // Needs [0], [1] and [2] all populated with real data; + // at round 2 [0] is still round 0 (empty), which would cause false positives. + cd.dataRedundantDetection(checker) + } + // CHECK 3 - LWW Violation Detection + // Always runs to keep the version cache up-to-date; meaningful results + // start from round 1 once the cache has been seeded. + cd.lwwViolationDetection() +} + +func (cd *clusterDataChecker) GetReport() *recorder.ClusterReport { + return cd.report +} + +type DataChecker struct { + round uint64 + checkableRound uint64 + clusterDataCheckers map[string]*clusterDataChecker +} + +func NewDataChecker(ctx context.Context, clusterConfig map[string]config.ClusterConfig, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) *DataChecker { + clusterDataChecker := make(map[string]*clusterDataChecker) + for clusterID := range clusterConfig { + clusterDataChecker[clusterID] = newClusterDataChecker(clusterID) + } + checker := &DataChecker{ + round: 0, + checkableRound: 0, + clusterDataCheckers: clusterDataChecker, + } + checker.initializeFromCheckpoint(ctx, checkpointDataMap, checkpoint) + return checker +} + +func (c *DataChecker) initializeFromCheckpoint(ctx context.Context, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) { + if checkpoint == nil { + return + } + if checkpoint.CheckpointItems[2] == nil { + return + } + c.round = checkpoint.CheckpointItems[2].Round + 1 + c.checkableRound = checkpoint.CheckpointItems[2].Round + 1 + for _, clusterDataChecker := range c.clusterDataCheckers { + clusterDataChecker.InitializeFromCheckpoint(ctx, checkpointDataMap[clusterDataChecker.clusterID], checkpoint) + } +} + +// FindClusterReplicatedData checks whether the record is present in the replicated data +// cache [1] or [2] or another new record is present in the replicated data cache [1] or [2]. +func (c *DataChecker) FindClusterReplicatedData(clusterID string, schemaKey string, pk types.PkType, originTs uint64) (*decoder.Record, bool) { + clusterDataChecker, exists := c.clusterDataCheckers[clusterID] + if !exists { + return nil, false + } + record, skipped := clusterDataChecker.findClusterReplicatedDataInTimeWindow(1, schemaKey, pk, originTs) + if skipped || record != nil { + return record, skipped + } + return clusterDataChecker.findClusterReplicatedDataInTimeWindow(2, schemaKey, pk, originTs) +} + +func (c *DataChecker) FindSourceLocalData(localClusterID string, schemaKey string, pk types.PkType, commitTs uint64) bool { + for _, clusterDataChecker := range c.clusterDataCheckers { + if clusterDataChecker.clusterID == localClusterID { + continue + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(0, schemaKey, pk, commitTs) { + return true + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(1, schemaKey, pk, commitTs) { + return true + } + if clusterDataChecker.findClusterLocalDataInTimeWindow(2, schemaKey, pk, commitTs) { + return true + } + } + return false +} + +func (c *DataChecker) CheckInNextTimeWindow(newTimeWindowData map[string]types.TimeWindowData) (*recorder.Report, error) { + if err := c.decodeNewTimeWindowData(newTimeWindowData); err != nil { + log.Error("failed to decode new time window data", zap.Error(err)) + return nil, errors.Annotate(err, "failed to decode new time window data") + } + report := recorder.NewReport(c.round) + + // Round 0: seed the LWW cache (round 0 data is empty by convention). + // Round 1+: LWW violation detection produces meaningful results. + // Round 2+: data loss / inconsistency detection (needs [1] and [2]). + // Round 3+: data redundant detection (needs [0], [1] and [2] with real data). + enableDataLoss := c.checkableRound >= 2 + enableDataRedundant := c.checkableRound >= 3 + + for clusterID, clusterDataChecker := range c.clusterDataCheckers { + clusterDataChecker.Check(c, enableDataLoss, enableDataRedundant) + log.Info("checked records count", + zap.String("clusterID", clusterID), + zap.Uint64("round", c.round), + zap.Bool("enableDataLoss", enableDataLoss), + zap.Bool("enableDataRedundant", enableDataRedundant), + zap.Int("checked records count", clusterDataChecker.checkedRecordsCount), + zap.Int("new time window records count", clusterDataChecker.newTimeWindowRecordsCount), + zap.Int("lww skipped records count", clusterDataChecker.lwwSkippedRecordsCount)) + report.AddClusterReport(clusterID, clusterDataChecker.GetReport()) + } + + if c.checkableRound < 3 { + c.checkableRound++ + } + c.round++ + return report, nil +} + +func (c *DataChecker) decodeNewTimeWindowData(newTimeWindowData map[string]types.TimeWindowData) error { + if len(newTimeWindowData) != len(c.clusterDataCheckers) { + return errors.Errorf("number of clusters mismatch, expected %d, got %d", len(c.clusterDataCheckers), len(newTimeWindowData)) + } + for clusterID, timeWindowData := range newTimeWindowData { + clusterDataChecker, exists := c.clusterDataCheckers[clusterID] + if !exists { + return errors.Errorf("cluster %s not found", clusterID) + } + clusterDataChecker.thisRoundTimeWindow = timeWindowData.TimeWindow + if err := clusterDataChecker.PrepareNextTimeWindowData(timeWindowData.TimeWindow); err != nil { + return errors.Trace(err) + } + for schemaPathKey, incrementalData := range timeWindowData.Data { + schemaKey := schemaPathKey.GetKey() + for _, contents := range incrementalData.DataContentSlices { + for _, content := range contents { + records, err := decoder.Decode(content, incrementalData.ColumnFieldTypes) + if err != nil { + return errors.Trace(err) + } + for _, record := range records { + clusterDataChecker.NewRecord(schemaKey, record) + } + } + } + } + } + + return nil +} diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go new file mode 100644 index 0000000000..18453f35c2 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -0,0 +1,780 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/stretchr/testify/require" +) + +func TestNewDataChecker(t *testing.T) { + t.Parallel() + + t.Run("create data checker", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "cluster1": { + PDAddrs: []string{"127.0.0.1:2379"}, + S3SinkURI: "s3://bucket/cluster1/", + S3ChangefeedID: "s3-cf-1", + }, + "cluster2": { + PDAddrs: []string{"127.0.0.1:2479"}, + S3SinkURI: "s3://bucket/cluster2/", + S3ChangefeedID: "s3-cf-2", + }, + } + + checker := NewDataChecker(context.Background(), clusterConfig, nil, nil) + require.NotNil(t, checker) + require.Equal(t, uint64(0), checker.round) + require.Len(t, checker.clusterDataCheckers, 2) + require.Contains(t, checker.clusterDataCheckers, "cluster1") + require.Contains(t, checker.clusterDataCheckers, "cluster2") + }) +} + +func TestNewClusterDataChecker(t *testing.T) { + t.Parallel() + + t.Run("create cluster data checker", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1") + require.NotNil(t, checker) + require.Equal(t, "cluster1", checker.clusterID) + require.Equal(t, uint64(0), checker.rightBoundary) + require.NotNil(t, checker.timeWindowDataCaches) + require.NotNil(t, checker.overDataCaches) + require.NotNil(t, checker.clusterViolationChecker) + }) +} + +func TestNewClusterViolationChecker(t *testing.T) { + t.Parallel() + + t.Run("create cluster violation checker", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + require.NotNil(t, checker) + require.Equal(t, "cluster1", checker.clusterID) + require.NotNil(t, checker.twoPreviousTimeWindowKeyVersionCache) + }) +} + +func TestClusterViolationChecker_Check(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("check new record", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record, report) + require.Empty(t, report.TableFailureItems) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache, schemaKey) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache[schemaKey], record.Pk) + }) + + t.Run("check duplicate old version", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record1 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + record2 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 50, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record1, report) + checker.Check(schemaKey, record2, report) + require.Empty(t, report.TableFailureItems) // Should skip duplicate old version + }) + + t.Run("check lww violation", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record1 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + record2 := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 50, // OriginTs is less than record1's CommitTs, causing violation + }, + } + + checker.Check(schemaKey, record1, report) + checker.Check(schemaKey, record2, report) + require.Len(t, report.TableFailureItems, 1) + require.Contains(t, report.TableFailureItems, schemaKey) + tableItems := report.TableFailureItems[schemaKey] + require.Len(t, tableItems.LWWViolationItems, 1) + require.Equal(t, map[string]any{"id": "1"}, tableItems.LWWViolationItems[0].PK) + require.Equal(t, uint64(0), tableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(100), tableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(50), tableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(150), tableItems.LWWViolationItems[0].CommitTS) + }) +} + +func TestClusterViolationChecker_UpdateCache(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("update cache", func(t *testing.T) { + t.Parallel() + checker := newClusterViolationChecker("cluster1") + report := recorder.NewClusterReport("cluster1", types.TimeWindow{}) + + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + } + + checker.Check(schemaKey, record, report) + require.Contains(t, checker.twoPreviousTimeWindowKeyVersionCache, schemaKey) + entry := checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 0, entry.previous) + + checker.UpdateCache() + entry = checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 1, entry.previous) + + checker.UpdateCache() + entry = checker.twoPreviousTimeWindowKeyVersionCache[schemaKey][record.Pk] + require.Equal(t, 2, entry.previous) + + checker.UpdateCache() + // Entry should be removed after 2 updates + _, exists := checker.twoPreviousTimeWindowKeyVersionCache[schemaKey] + require.False(t, exists) + }) +} + +func TestNewTimeWindowDataCache(t *testing.T) { + t.Parallel() + + t.Run("create time window data cache", func(t *testing.T) { + t.Parallel() + leftBoundary := uint64(100) + rightBoundary := uint64(200) + checkpointTs := map[string]uint64{ + "cluster2": 150, + } + + cache := newTimeWindowDataCache(leftBoundary, rightBoundary, checkpointTs) + require.Equal(t, leftBoundary, cache.leftBoundary) + require.Equal(t, rightBoundary, cache.rightBoundary) + require.Equal(t, checkpointTs, cache.checkpointTs) + require.NotNil(t, cache.tableDataCaches) + }) +} + +func TestTimeWindowDataCache_NewRecord(t *testing.T) { + t.Parallel() + + const schemaKey = "test_schema" + + t.Run("add local record", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 0, + }, + } + + cache.NewRecord(schemaKey, record) + require.Contains(t, cache.tableDataCaches, schemaKey) + require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache, record.Pk) + require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache[record.Pk], record.CommitTs) + }) + + t.Run("add replicated record", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 150, + OriginTs: 100, + }, + } + + cache.NewRecord(schemaKey, record) + require.Contains(t, cache.tableDataCaches, schemaKey) + require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache, record.Pk) + require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache[record.Pk], record.OriginTs) + }) + + t.Run("skip record before left boundary", func(t *testing.T) { + t.Parallel() + cache := newTimeWindowDataCache(100, 200, map[string]uint64{}) + record := &decoder.Record{ + Pk: "pk1", + PkMap: map[string]any{"id": "1"}, + CdcVersion: types.CdcVersion{ + CommitTs: 50, + OriginTs: 0, + }, + } + + cache.NewRecord(schemaKey, record) + require.NotContains(t, cache.tableDataCaches, schemaKey) + }) +} + +func TestClusterDataChecker_PrepareNextTimeWindowData(t *testing.T) { + t.Parallel() + + t.Run("prepare next time window data", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1") + checker.rightBoundary = 100 + + timeWindow := types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"cluster2": 150}, + } + + err := checker.PrepareNextTimeWindowData(timeWindow) + require.NoError(t, err) + require.Equal(t, uint64(200), checker.rightBoundary) + }) + + t.Run("mismatch left boundary", func(t *testing.T) { + t.Parallel() + checker := newClusterDataChecker("cluster1") + checker.rightBoundary = 100 + + timeWindow := types.TimeWindow{ + LeftBoundary: 150, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"cluster2": 150}, + } + + err := checker.PrepareNextTimeWindowData(timeWindow) + require.Error(t, err) + require.Contains(t, err.Error(), "mismatch") + }) +} + +// makeCanalJSON builds a canal-JSON formatted record for testing. +// pkID is the primary key value, commitTs is the TiDB commit timestamp, +// originTs is the origin timestamp (0 for locally-written records, non-zero for replicated records), +// val is a varchar column value. +func makeCanalJSON(pkID int, commitTs uint64, originTs uint64, val string) string { + originTsVal := "null" + if originTs > 0 { + originTsVal = fmt.Sprintf(`"%d"`, originTs) + } + return fmt.Sprintf( + `{"id":0,"database":"test","table":"t1","pkNames":["id"],"isDdl":false,"type":"INSERT",`+ + `"es":0,"ts":0,"sql":"","sqlType":{"id":4,"val":12,"_tidb_origin_ts":-5},`+ + `"mysqlType":{"id":"int","val":"varchar","_tidb_origin_ts":"bigint"},`+ + `"old":null,"data":[{"id":"%d","val":"%s","_tidb_origin_ts":%s}],`+ + `"_tidb":{"commitTs":%d}}`, + pkID, val, originTsVal, commitTs) +} + +// makeContent combines canal-JSON records with CRLF terminator. +func makeContent(records ...string) []byte { + return []byte(strings.Join(records, "\r\n")) +} + +// makeTWData builds a TimeWindowData for testing. +func makeTWData(left, right uint64, checkpointTs map[string]uint64, content []byte) types.TimeWindowData { + data := map[cloudstorage.DmlPathKey]types.IncrementalData{} + if content != nil { + data[cloudstorage.DmlPathKey{}] = types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {}: {content}, + }, + } + } + return types.TimeWindowData{ + TimeWindow: types.TimeWindow{ + LeftBoundary: left, + RightBoundary: right, + CheckpointTs: checkpointTs, + }, + Data: data, + } +} + +// defaultSchemaKey is the schema key produced by DmlPathKey{}.GetKey() +// which is QuoteSchema("", "") = "“.“" +var defaultSchemaKey = (&cloudstorage.DmlPathKey{}).GetKey() + +// TestDataChecker_FourRoundsCheck simulates 4 rounds with increasing data and verifies check results. +// Setup: 2 clusters (c1 locally-written, c2 replicated from c1). +// - Round 0: LWW cache is seeded (data is empty by convention). +// - Round 1+: LWW violation detection is active. +// - Round 2+: data loss / inconsistent detection is active. +// - Round 3+: data redundant detection is also active (needs [0],[1],[2] all populated). +func TestDataChecker_FourRoundsCheck(t *testing.T) { + t.Parallel() + ctx := context.Background() + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + // makeBaseRounds creates shared rounds 0 and 1 data for all subtests. + // c1 produces locally-written data, c2 receives matching replicated data from c1. + makeBaseRounds := func() [2]map[string]types.TimeWindowData { + return [2]map[string]types.TimeWindowData{ + // Round 0: [0, 100] + { + "c1": makeTWData(0, 100, map[string]uint64{"c2": 80}, + makeContent(makeCanalJSON(1, 50, 0, "a"))), + "c2": makeTWData(0, 100, nil, + makeContent(makeCanalJSON(1, 60, 50, "a"))), + }, + // Round 1: [100, 200] + { + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, + makeContent(makeCanalJSON(2, 150, 0, "b"))), + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(2, 160, 150, "b"))), + }, + } + } + + t.Run("all consistent", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "c"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + require.Equal(t, uint64(i), report.Round) + // Every round now produces cluster reports (LWW always runs). + require.Len(t, report.ClusterReports, 2, "round %d should have 2 cluster reports", i) + require.False(t, report.NeedFlush(), "round %d should not need flush (all consistent)", i) + for clusterID, cr := range report.ClusterReports { + require.Empty(t, cr.TableFailureItems, "round %d cluster %s should have no table failure items", i, clusterID) + } + } + }) + + t.Run("data loss detected", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + base := makeBaseRounds() + + // Round 2: c1 has locally-written pk=3 but c2 has NO matching replicated data + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, nil), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + // c1 should detect data loss: pk=3 (commitTs=250) missing in c2's replicated data + c1Report := lastReport.ClusterReports["c1"] + require.NotNil(t, c1Report) + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, uint64(250), tableItems.DataLossItems[0].CommitTS) + // c2 should have no issues + c2Report := lastReport.ClusterReports["c2"] + require.Empty(t, c2Report.TableFailureItems) + }) + + t.Run("data inconsistent detected", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + base := makeBaseRounds() + + // Round 2: c2 has replicated data for pk=3 but with wrong column value + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "WRONG"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + c1Report := lastReport.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Empty(t, tableItems.DataLossItems) + require.Len(t, tableItems.DataInconsistentItems, 1) + require.Equal(t, "c2", tableItems.DataInconsistentItems[0].PeerClusterID) + require.Equal(t, uint64(250), tableItems.DataInconsistentItems[0].OriginTS) + require.Equal(t, uint64(250), tableItems.DataInconsistentItems[0].LocalCommitTS) + require.Equal(t, uint64(260), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) + require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) + require.Equal(t, "val", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Column) + require.Equal(t, "c", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Local) + require.Equal(t, "WRONG", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Replicated) + }) + + t.Run("data redundant detected", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "c"))), + } + // Round 3: c2 has an extra replicated pk=99 (originTs=330) that doesn't match + // any locally-written record in c1 + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent( + makeCanalJSON(4, 360, 350, "d"), + makeCanalJSON(99, 340, 330, "x"), + )), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + // c1 should have no data loss + c1Report := lastReport.ClusterReports["c1"] + require.Empty(t, c1Report.TableFailureItems) + // c2 should detect data redundant: pk=99 has no matching locally-written record in c1 + c2Report := lastReport.ClusterReports["c2"] + require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) + tableItems := c2Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataRedundantItems, 1) + require.Equal(t, uint64(330), tableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(340), tableItems.DataRedundantItems[0].CommitTS) + }) + + t.Run("lww violation detected", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + base := makeBaseRounds() + + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 260, 250, "c"))), + } + // Round 3: c1 has locally-written pk=5 (commitTs=350, compareTs=350) and + // replicated pk=5 from c2 (commitTs=370, originTs=310, compareTs=310). + // Since 350 >= 310 with commitTs 350 < 370, this is an LWW violation. + // c2 also has matching records to avoid data loss/redundant noise. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent( + makeCanalJSON(5, 350, 0, "e"), + makeCanalJSON(5, 370, 310, "e"), + )), + "c2": makeTWData(300, 400, nil, + makeContent( + makeCanalJSON(5, 310, 0, "e"), + makeCanalJSON(5, 360, 350, "e"), + )), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + require.True(t, lastReport.NeedFlush()) + c1Report := lastReport.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + c1TableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c1TableItems.LWWViolationItems, 1) + require.Equal(t, uint64(0), c1TableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(350), c1TableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(310), c1TableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(370), c1TableItems.LWWViolationItems[0].CommitTS) + // c2 should have no LWW violation (its records are ordered correctly: + // locally-written commitTs=310 compareTs=310, replicated commitTs=360 compareTs=350, 310 < 350) + c2Report := lastReport.ClusterReports["c2"] + if c2TableItems, ok := c2Report.TableFailureItems[defaultSchemaKey]; ok { + require.Empty(t, c2TableItems.LWWViolationItems) + } + }) + + // lww violation detected at round 1: LWW is active from round 1, + // so a violation introduced in round 1 data should surface immediately. + t.Run("lww violation detected at round 1", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + + // Round 0: [0, 100] — c1 writes pk=1 (commitTs=50, compareTs=50) + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, + makeContent(makeCanalJSON(1, 50, 0, "a"))), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: [100, 200] — c1 writes pk=1 again: + // locally-written pk=1 (commitTs=150, originTs=0, compareTs=150) + // replicated pk=1 (commitTs=180, originTs=120, compareTs=120) + // The LWW cache already has pk=1 compareTs=50 from round 0. + // Record order: commitTs=150 (compareTs=150 > cached 50 → update cache), + // commitTs=180 (compareTs=120 < cached 150 → VIOLATION) + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, nil, + makeContent( + makeCanalJSON(1, 150, 0, "b"), + makeCanalJSON(1, 180, 120, "b"), + )), + "c2": makeTWData(100, 200, nil, nil), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush(), "round 0 should not need flush") + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + + // LWW violation should be detected at round 1 + require.True(t, report1.NeedFlush(), "round 1 should detect LWW violation") + c1Report := report1.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + c1TableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c1TableItems.LWWViolationItems, 1) + require.Equal(t, uint64(0), c1TableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(150), c1TableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(120), c1TableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(180), c1TableItems.LWWViolationItems[0].CommitTS) + }) + + // data loss detected at round 2: Data loss detection is active from round 2. + // A record in round 1 whose commitTs > checkpointTs will enter [1] at round 2, + // and if the replicated counterpart is missing, data loss is detected at round 2. + t.Run("data loss detected at round 2", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, nil), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: c1 writes pk=1 (commitTs=150), checkpointTs["c2"]=140 + // Since 150 > 140, this record needs replication checking. + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, map[string]uint64{"c2": 140}, + makeContent(makeCanalJSON(1, 150, 0, "a"))), + "c2": makeTWData(100, 200, nil, nil), // c2 has NO replicated data + } + // Round 2: round 1 data is now in [1], data loss detection enabled. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 280}, + makeContent(makeCanalJSON(2, 250, 0, "b"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(2, 260, 250, "b"))), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush()) + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + require.False(t, report1.NeedFlush(), "round 1 should not detect data loss yet") + + report2, err := checker.CheckInNextTimeWindow(round2) + require.NoError(t, err) + require.True(t, report2.NeedFlush(), "round 2 should detect data loss") + c1Report := report2.ClusterReports["c1"] + require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) + tableItems := c1Report.TableFailureItems[defaultSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, uint64(150), tableItems.DataLossItems[0].CommitTS) + }) + + // data redundant detected at round 3 (not round 2): + // dataRedundantDetection checks timeWindowDataCaches[2] (latest round). + // At round 2 [0]=round 0 (empty) so FindSourceLocalData may miss data in + // that window → enableDataRedundant is false to avoid false positives. + // At round 3 [0]=round 1, [1]=round 2, [2]=round 3 are all populated + // with real data, so enableDataRedundant=true and an orphan in [2] is caught. + // + // This test puts the SAME orphan pk=99 in both round 2 and round 3: + // - Round 2: orphan in [2] but enableDataRedundant=false → NOT flagged. + // - Round 3: orphan in [2] and enableDataRedundant=true → flagged. + t.Run("data redundant detected at round 3 not round 2", func(t *testing.T) { + t.Parallel() + checker := NewDataChecker(ctx, clusterCfg, nil, nil) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, nil, nil), + "c2": makeTWData(0, 100, nil, nil), + } + // Round 1: normal consistent data. + round1 := map[string]types.TimeWindowData{ + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, + makeContent(makeCanalJSON(1, 150, 0, "a"))), + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(1, 160, 150, "a"))), + } + // Round 2: c2 has orphan replicated pk=99 (originTs=230) in [2]. + // enableDataRedundant=false at round 2, so it must NOT be flagged. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 280}, + makeContent(makeCanalJSON(2, 250, 0, "b"))), + "c2": makeTWData(200, 300, nil, + makeContent( + makeCanalJSON(2, 260, 250, "b"), + makeCanalJSON(99, 240, 230, "x"), // orphan replicated + )), + } + // Round 3: c2 has another orphan replicated pk=99 (originTs=330) in [2]. + // enableDataRedundant=true at round 3, so it IS caught. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(3, 350, 0, "c"))), + "c2": makeTWData(300, 400, nil, + makeContent( + makeCanalJSON(3, 360, 350, "c"), + makeCanalJSON(99, 340, 330, "y"), // orphan replicated + )), + } + + report0, err := checker.CheckInNextTimeWindow(round0) + require.NoError(t, err) + require.False(t, report0.NeedFlush(), "round 0 should not need flush") + + report1, err := checker.CheckInNextTimeWindow(round1) + require.NoError(t, err) + require.False(t, report1.NeedFlush(), "round 1 should not need flush") + + report2, err := checker.CheckInNextTimeWindow(round2) + require.NoError(t, err) + // Round 2: redundant detection is NOT enabled; the orphan pk=99 should NOT be flagged. + require.False(t, report2.NeedFlush(), "round 2 should not flag data redundant yet") + + report3, err := checker.CheckInNextTimeWindow(round3) + require.NoError(t, err) + // Round 3: redundant detection is enabled; the orphan pk=99 in [2] (round 3) + // is now caught. + require.True(t, report3.NeedFlush(), "round 3 should detect data redundant") + c2Report := report3.ClusterReports["c2"] + require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) + c2TableItems := c2Report.TableFailureItems[defaultSchemaKey] + require.Len(t, c2TableItems.DataRedundantItems, 1) + require.Equal(t, uint64(330), c2TableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(340), c2TableItems.DataRedundantItems[0].CommitTS) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/config/config.example.toml b/cmd/multi-cluster-consistency-checker/config/config.example.toml new file mode 100644 index 0000000000..b5430add8e --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config.example.toml @@ -0,0 +1,46 @@ +# Example configuration file for multi-cluster consistency checker + +# Global configuration +[global] + +# Log level: debug, info, warn, error, fatal, panic +log-level = "info" + +# Data directory configuration, contains report and checkpoint data +data-dir = "/tmp/multi-cluster-consistency-checker-data" + + # Tables configuration + [global.tables] + schema1 = ["table1", "table2"] + schema2 = ["table1", "table2"] + +# Cluster configurations +[clusters] + # First cluster configuration + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket-name/cluster1/" + s3-changefeed-id = "s3-changefeed-id-1" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "active-active-changefeed-id-from-cluster1-to-cluster2" } + + # Second cluster configuration + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket-name/cluster2/" + s3-changefeed-id = "s3-changefeed-id-2" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "active-active-changefeed-id-from-cluster2-to-cluster1" } + + # Third cluster configuration (optional) + # [clusters.cluster3] + # pd-addrs = ["127.0.0.1:2579"] + # cdc-addr = "127.0.0.1:8500" + # s3-sink-uri = "s3://bucket-name/cluster3/" + # s3-changefeed-id = "s3-changefeed-id-3" + # security-config = { ca-path = "ca.crt", cert-path = "cert.crt", key-path = "key.crt" } + # [clusters.cluster3.peer-cluster-changefeed-config] + # cluster1 = { changefeed-id = "active-active-changefeed-id-from-cluster3-to-cluster1" } + # cluster2 = { changefeed-id = "active-active-changefeed-id-from-cluster3-to-cluster2" } diff --git a/cmd/multi-cluster-consistency-checker/config/config.go b/cmd/multi-cluster-consistency-checker/config/config.go new file mode 100644 index 0000000000..86fac59af3 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config.go @@ -0,0 +1,152 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + + "github.com/BurntSushi/toml" + "github.com/pingcap/ticdc/pkg/security" +) + +// Config represents the configuration for multi-cluster consistency checker +type Config struct { + // GlobalConfig contains global settings (reserved for future use) + GlobalConfig GlobalConfig `toml:"global" json:"global"` + + // Clusters contains configurations for multiple clusters + Clusters map[string]ClusterConfig `toml:"clusters" json:"clusters"` +} + +const DefaultMaxReportFiles = 1000 + +// GlobalConfig contains global configuration settings +type GlobalConfig struct { + LogLevel string `toml:"log-level" json:"log-level"` + DataDir string `toml:"data-dir" json:"data-dir"` + MaxReportFiles int `toml:"max-report-files" json:"max-report-files"` + Tables map[string][]string `toml:"tables" json:"tables"` +} + +type PeerClusterChangefeedConfig struct { + // ChangefeedID is the changefeed ID for the changefeed + ChangefeedID string `toml:"changefeed-id" json:"changefeed-id"` +} + +// ClusterConfig represents configuration for a single cluster +type ClusterConfig struct { + // PDAddrs is the addresses of the PD (Placement Driver) servers + PDAddrs []string `toml:"pd-addrs" json:"pd-addrs"` + + // S3SinkURI is the S3 sink URI for this cluster + S3SinkURI string `toml:"s3-sink-uri" json:"s3-sink-uri"` + + // S3ChangefeedID is the changefeed ID for the S3 changefeed + S3ChangefeedID string `toml:"s3-changefeed-id" json:"s3-changefeed-id"` + + // SecurityConfig is the security configuration for the cluster + SecurityConfig security.Credential `toml:"security-config" json:"security-config"` + + // PeerClusterChangefeedConfig is the configuration for the changefeed of the peer cluster + // mapping from peer cluster ID to the changefeed configuration + PeerClusterChangefeedConfig map[string]PeerClusterChangefeedConfig `toml:"peer-cluster-changefeed-config" json:"peer-cluster-changefeed-config"` +} + +// loadConfig loads the configuration from a TOML file +func LoadConfig(path string) (*Config, error) { + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil, fmt.Errorf("config file does not exist: %s", path) + } + + cfg := &Config{ + Clusters: make(map[string]ClusterConfig), + } + + meta, err := toml.DecodeFile(path, cfg) + if err != nil { + return nil, fmt.Errorf("failed to decode config file: %w", err) + } + + // Apply defaults + if cfg.GlobalConfig.MaxReportFiles <= 0 { + cfg.GlobalConfig.MaxReportFiles = DefaultMaxReportFiles + } + + // Validate DataDir + if cfg.GlobalConfig.DataDir == "" { + return nil, fmt.Errorf("global: data-dir is required") + } + + // Validate Tables + if len(cfg.GlobalConfig.Tables) == 0 { + return nil, fmt.Errorf("global: at least one schema must be configured in tables") + } + for schema, tables := range cfg.GlobalConfig.Tables { + if len(tables) == 0 { + return nil, fmt.Errorf("global: tables[%s]: at least one table must be configured", schema) + } + } + + // Validate that at least one cluster is configured + if len(cfg.Clusters) == 0 { + return nil, fmt.Errorf("at least one cluster must be configured") + } + + // Validate cluster configurations + for name, cluster := range cfg.Clusters { + if len(cluster.PDAddrs) == 0 { + return nil, fmt.Errorf("cluster '%s': pd-addrs is required", name) + } + if cluster.S3SinkURI == "" { + return nil, fmt.Errorf("cluster '%s': s3-sink-uri is required", name) + } + if cluster.S3ChangefeedID == "" { + return nil, fmt.Errorf("cluster '%s': s3-changefeed-id is required", name) + } + if len(cluster.PeerClusterChangefeedConfig) != len(cfg.Clusters)-1 { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config is not entirely configured", name) + } + for peerClusterID, peerClusterChangefeedConfig := range cluster.PeerClusterChangefeedConfig { + if peerClusterID == name { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config references itself", name) + } + if _, ok := cfg.Clusters[peerClusterID]; !ok { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config references unknown cluster '%s'", name, peerClusterID) + } + if peerClusterChangefeedConfig.ChangefeedID == "" { + return nil, fmt.Errorf("cluster '%s': peer-cluster-changefeed-config[%s]: changefeed-id is required", name, peerClusterID) + } + } + } + + // Check for unknown configuration keys + if undecoded := meta.Undecoded(); len(undecoded) > 0 { + // Filter out keys under [global] and [clusters] sections + var unknownKeys []string + for _, key := range undecoded { + keyStr := key.String() + // Only warn about keys that are not in the expected sections + if keyStr != "global" && keyStr != "clusters" { + unknownKeys = append(unknownKeys, keyStr) + } + } + if len(unknownKeys) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unknown configuration keys found: %v\n", unknownKeys) + } + } + + return cfg, nil +} diff --git a/cmd/multi-cluster-consistency-checker/config/config_test.go b/cmd/multi-cluster-consistency-checker/config/config_test.go new file mode 100644 index 0000000000..e57f2e6a7b --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/config/config_test.go @@ -0,0 +1,393 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadConfig(t *testing.T) { + t.Parallel() + + t.Run("valid config", func(t *testing.T) { + t.Parallel() + // Create a temporary config file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1", "table2"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.NoError(t, err) + require.NotNil(t, cfg) + require.Equal(t, "info", cfg.GlobalConfig.LogLevel) + require.Equal(t, "/tmp/data", cfg.GlobalConfig.DataDir) + require.Len(t, cfg.Clusters, 2) + require.Contains(t, cfg.Clusters, "cluster1") + require.Contains(t, cfg.Clusters, "cluster2") + require.Equal(t, []string{"127.0.0.1:2379"}, cfg.Clusters["cluster1"].PDAddrs) + require.Equal(t, "s3://bucket/cluster1/", cfg.Clusters["cluster1"].S3SinkURI) + require.Equal(t, "s3-cf-1", cfg.Clusters["cluster1"].S3ChangefeedID) + require.Len(t, cfg.Clusters["cluster1"].PeerClusterChangefeedConfig, 1) + require.Equal(t, "cf-1-to-2", cfg.Clusters["cluster1"].PeerClusterChangefeedConfig["cluster2"].ChangefeedID) + // max-report-files not set, should default to DefaultMaxReportFiles + require.Equal(t, DefaultMaxReportFiles, cfg.GlobalConfig.MaxReportFiles) + }) + + t.Run("custom max-report-files", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" +max-report-files = 50 + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.NoError(t, err) + require.Equal(t, 50, cfg.GlobalConfig.MaxReportFiles) + }) + + t.Run("file not exists", func(t *testing.T) { + t.Parallel() + cfg, err := LoadConfig("/nonexistent/path/config.toml") + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "config file does not exist") + }) + + t.Run("invalid toml", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := `invalid toml content [` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "failed to decode config file") + }) + + t.Run("missing data-dir", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "data-dir is required") + }) + + t.Run("missing tables", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least one schema must be configured in tables") + }) + + t.Run("empty table list in schema", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = [] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least one table must be configured") + }) + + t.Run("no clusters", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least one cluster must be configured") + }) + + t.Run("missing pd-addrs", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "pd-addrs is required") + }) + + t.Run("missing s3-sink-uri", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-changefeed-id = "s3-cf-1" +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "s3-sink-uri is required") + }) + + t.Run("missing s3-changefeed-id", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "s3-changefeed-id is required") + }) + + t.Run("incomplete replicated cluster changefeed config", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "peer-cluster-changefeed-config is not entirely configured") + }) + + t.Run("missing changefeed-id in peer cluster config", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = {} + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "changefeed-id is required") + }) +} diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer.go b/cmd/multi-cluster-consistency-checker/consumer/consumer.go new file mode 100644 index 0000000000..07736f7e67 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer.go @@ -0,0 +1,769 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumer + +import ( + "context" + "encoding/json" + "fmt" + "path" + "strings" + "sync" + + perrors "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +type ( + fileIndexRange map[cloudstorage.FileIndexKey]indexRange + fileIndexKeyMap map[cloudstorage.FileIndexKey]uint64 +) + +type indexRange struct { + start uint64 + end uint64 +} + +func updateTableDMLIdxMap( + tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, + dmlkey cloudstorage.DmlPathKey, + fileIdx *cloudstorage.FileIndex, +) { + m, ok := tableDMLIdxMap[dmlkey] + if !ok { + tableDMLIdxMap[dmlkey] = fileIndexKeyMap{ + fileIdx.FileIndexKey: fileIdx.Idx, + } + } else if fileIdx.Idx > m[fileIdx.FileIndexKey] { + m[fileIdx.FileIndexKey] = fileIdx.Idx + } +} + +type schemaDefinition struct { + path string + columnFieldTypes map[string]*ptypes.FieldType +} + +type schemaKey struct { + schema string + table string +} + +var ErrWalkDirEnd = perrors.Normalize("walk dir end", perrors.RFCCodeText("CDC:ErrWalkDirEnd")) + +type CurrentTableVersion struct { + mu sync.RWMutex + currentTableVersionMap map[schemaKey]types.VersionKey +} + +func NewCurrentTableVersion() *CurrentTableVersion { + return &CurrentTableVersion{ + currentTableVersionMap: make(map[schemaKey]types.VersionKey), + } +} + +// GetCurrentTableVersion returns the current table version for a given schema and table +func (cvt *CurrentTableVersion) GetCurrentTableVersion(schema, table string) types.VersionKey { + cvt.mu.RLock() + defer cvt.mu.RUnlock() + return cvt.currentTableVersionMap[schemaKey{schema: schema, table: table}] +} + +// UpdateCurrentTableVersion updates the current table version for a given schema and table +func (cvt *CurrentTableVersion) UpdateCurrentTableVersion(schema, table string, version types.VersionKey) { + cvt.mu.Lock() + defer cvt.mu.Unlock() + cvt.currentTableVersionMap[schemaKey{schema: schema, table: table}] = version +} + +type SchemaDefinitions struct { + mu sync.RWMutex + schemaDefinitionMap map[cloudstorage.SchemaPathKey]schemaDefinition +} + +func NewSchemaDefinitions() *SchemaDefinitions { + return &SchemaDefinitions{ + schemaDefinitionMap: make(map[cloudstorage.SchemaPathKey]schemaDefinition), + } +} + +// GetColumnFieldTypes returns the pre-parsed column field types for a given schema and table version +func (sp *SchemaDefinitions) GetColumnFieldTypes(schema, table string, version uint64) (map[string]*ptypes.FieldType, error) { + schemaPathKey := cloudstorage.SchemaPathKey{ + Schema: schema, + Table: table, + TableVersion: version, + } + sp.mu.RLock() + schemaDefinition, ok := sp.schemaDefinitionMap[schemaPathKey] + sp.mu.RUnlock() + if !ok { + return nil, errors.Errorf("schema definition not found for schema: %s, table: %s, version: %d", schema, table, version) + } + return schemaDefinition.columnFieldTypes, nil +} + +// SetSchemaDefinition sets the schema definition for a given schema and table version. +// It pre-parses the column field types from the table definition for later use by the decoder. +func (sp *SchemaDefinitions) SetSchemaDefinition(schemaPathKey cloudstorage.SchemaPathKey, filePath string, tableDefinition *cloudstorage.TableDefinition) error { + columnFieldTypes := make(map[string]*ptypes.FieldType) + if tableDefinition != nil { + for i, col := range tableDefinition.Columns { + colInfo, err := col.ToTiColumnInfo(int64(i)) + if err != nil { + return errors.Annotatef(err, "failed to convert column %s to FieldType", col.Name) + } + columnFieldTypes[col.Name] = &colInfo.FieldType + } + } + sp.mu.Lock() + sp.schemaDefinitionMap[schemaPathKey] = schemaDefinition{ + path: filePath, + columnFieldTypes: columnFieldTypes, + } + sp.mu.Unlock() + return nil +} + +// RemoveSchemaDefinitionWithCondition removes the schema definition for a given condition +func (sp *SchemaDefinitions) RemoveSchemaDefinitionWithCondition(condition func(schemaPathKey cloudstorage.SchemaPathKey) bool) { + sp.mu.Lock() + for schemaPathkey := range sp.schemaDefinitionMap { + if condition(schemaPathkey) { + delete(sp.schemaDefinitionMap, schemaPathkey) + } + } + sp.mu.Unlock() +} + +type TableDMLIdx struct { + mu sync.Mutex + tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap +} + +func NewTableDMLIdx() *TableDMLIdx { + return &TableDMLIdx{ + tableDMLIdxMap: make(map[cloudstorage.DmlPathKey]fileIndexKeyMap), + } +} + +func (t *TableDMLIdx) UpdateDMLIdxMapByStartPath(dmlkey cloudstorage.DmlPathKey, fileIdx *cloudstorage.FileIndex) { + t.mu.Lock() + defer t.mu.Unlock() + if originalFileIndexKeyMap, ok := t.tableDMLIdxMap[dmlkey]; !ok { + t.tableDMLIdxMap[dmlkey] = fileIndexKeyMap{ + fileIdx.FileIndexKey: fileIdx.Idx, + } + } else { + if fileIdx.Idx > originalFileIndexKeyMap[fileIdx.FileIndexKey] { + originalFileIndexKeyMap[fileIdx.FileIndexKey] = fileIdx.Idx + } + } +} + +func (t *TableDMLIdx) DiffNewTableDMLIdxMap( + newTableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, +) map[cloudstorage.DmlPathKey]fileIndexRange { + resMap := make(map[cloudstorage.DmlPathKey]fileIndexRange) + t.mu.Lock() + defer t.mu.Unlock() + for newDMLPathKey, newFileIndexKeyMap := range newTableDMLIdxMap { + origFileIndexKeyMap, ok := t.tableDMLIdxMap[newDMLPathKey] + if !ok { + t.tableDMLIdxMap[newDMLPathKey] = newFileIndexKeyMap + resMap[newDMLPathKey] = make(fileIndexRange) + for indexKey, newEndVal := range newFileIndexKeyMap { + resMap[newDMLPathKey][indexKey] = indexRange{ + start: 1, + end: newEndVal, + } + } + continue + } + for indexKey, newEndVal := range newFileIndexKeyMap { + origEndVal := origFileIndexKeyMap[indexKey] + if newEndVal > origEndVal { + origFileIndexKeyMap[indexKey] = newEndVal + if _, ok := resMap[newDMLPathKey]; !ok { + resMap[newDMLPathKey] = make(fileIndexRange) + } + resMap[newDMLPathKey][indexKey] = indexRange{ + start: origEndVal + 1, + end: newEndVal, + } + } + } + } + return resMap +} + +type S3Consumer struct { + s3Storage storage.ExternalStorage + fileExtension string + dateSeparator string + fileIndexWidth int + tables map[string][]string + + // skip the first round data download + skipDownloadData bool + + currentTableVersion *CurrentTableVersion + tableDMLIdx *TableDMLIdx + schemaDefinitions *SchemaDefinitions +} + +func NewS3Consumer( + s3Storage storage.ExternalStorage, + tables map[string][]string, +) *S3Consumer { + return &S3Consumer{ + s3Storage: s3Storage, + fileExtension: ".json", + dateSeparator: config.DateSeparatorDay.String(), + fileIndexWidth: config.DefaultFileIndexWidth, + tables: tables, + + skipDownloadData: true, + + currentTableVersion: NewCurrentTableVersion(), + tableDMLIdx: NewTableDMLIdx(), + schemaDefinitions: NewSchemaDefinitions(), + } +} + +func (c *S3Consumer) InitializeFromCheckpoint( + ctx context.Context, clusterID string, checkpoint *recorder.Checkpoint, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + if checkpoint == nil { + return nil, nil + } + if checkpoint.CheckpointItems[2] == nil { + return nil, nil + } + c.skipDownloadData = false + scanRanges, err := checkpoint.ToScanRange(clusterID) + if err != nil { + return nil, errors.Trace(err) + } + var mu sync.Mutex + // Combine DML data and schema data into result + result := make(map[cloudstorage.DmlPathKey]types.IncrementalData) + eg, egCtx := errgroup.WithContext(ctx) + for schemaTableKey, scanRange := range scanRanges { + eg.Go(func() error { + scanVersions, err := c.downloadSchemaFilesWithScanRange( + egCtx, schemaTableKey.Schema, schemaTableKey.Table, scanRange.StartVersionKey, scanRange.EndVersionKey, scanRange.EndDataPath) + if err != nil { + return errors.Trace(err) + } + err = c.downloadDataFilesWithScanRange( + egCtx, schemaTableKey.Schema, schemaTableKey.Table, scanVersions, scanRange, + func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ) { + mu.Lock() + result[dmlPathKey] = types.IncrementalData{ + DataContentSlices: dmlSlices, + ColumnFieldTypes: columnFieldTypes, + } + mu.Unlock() + }, + ) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +func (c *S3Consumer) downloadSchemaFilesWithScanRange( + ctx context.Context, + schema, table string, + startVersionKey string, + endVersionKey string, + endDataPath string, +) ([]types.VersionKey, error) { + metaSubDir := fmt.Sprintf("%s/%s/meta/", schema, table) + opt := &storage.WalkOption{ + SubDir: metaSubDir, + ObjPrefix: "schema_", + // TODO: StartAfter: startVersionKey, + } + + var startSchemaKey, endSchemaKey cloudstorage.SchemaPathKey + _, err := startSchemaKey.ParseSchemaFilePath(startVersionKey) + if err != nil { + return nil, errors.Trace(err) + } + _, err = endSchemaKey.ParseSchemaFilePath(endVersionKey) + if err != nil { + return nil, errors.Trace(err) + } + + var scanVersions []types.VersionKey + newVersionPaths := make(map[cloudstorage.SchemaPathKey]string) + scanVersions = append(scanVersions, types.VersionKey{ + Version: startSchemaKey.TableVersion, + VersionPath: startVersionKey, + }) + newVersionPaths[startSchemaKey] = startVersionKey + if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endVersionKey < filePath { + return ErrWalkDirEnd + } + if !cloudstorage.IsSchemaFile(filePath) { + return nil + } + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if schemaKey.TableVersion > startSchemaKey.TableVersion { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: schemaKey.TableVersion, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }); err != nil && !errors.Is(err, ErrWalkDirEnd) { + return nil, errors.Trace(err) + } + + if err := c.downloadSchemaFiles(ctx, newVersionPaths); err != nil { + return nil, errors.Trace(err) + } + + c.currentTableVersion.UpdateCurrentTableVersion(schema, table, types.VersionKey{ + Version: endSchemaKey.TableVersion, + VersionPath: endVersionKey, + DataPath: endDataPath, + }) + + return scanVersions, nil +} + +// downloadDataFilesWithScanRange downloads data files for a given scan range. +// consumeFunc is called from multiple goroutines concurrently and must be goroutine-safe. +func (c *S3Consumer) downloadDataFilesWithScanRange( + ctx context.Context, + schema, table string, + scanVersions []types.VersionKey, + scanRange *recorder.ScanRange, + consumeFunc func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ), +) error { + eg, egCtx := errgroup.WithContext(ctx) + for _, version := range scanVersions { + eg.Go(func() error { + newFiles, err := c.getNewFilesForSchemaPathKeyWithEndPath(egCtx, schema, table, version.Version, scanRange.StartDataPath, scanRange.EndDataPath) + if err != nil { + return errors.Trace(err) + } + dmlData, err := c.downloadDMLFiles(egCtx, newFiles) + if err != nil { + return errors.Trace(err) + } + columnFieldTypes, err := c.schemaDefinitions.GetColumnFieldTypes(schema, table, version.Version) + if err != nil { + return errors.Trace(err) + } + for dmlPathKey, dmlSlices := range dmlData { + consumeFunc(dmlPathKey, dmlSlices, columnFieldTypes) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +func (c *S3Consumer) getNewFilesForSchemaPathKeyWithEndPath( + ctx context.Context, + schema, table string, + version uint64, + startDataPath string, + endDataPath string, +) (map[cloudstorage.DmlPathKey]fileIndexRange, error) { + schemaPrefix := path.Join(schema, table, fmt.Sprintf("%d", version)) + opt := &storage.WalkOption{ + SubDir: schemaPrefix, + // TODO: StartAfter: startDataPath, + } + newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endDataPath < filePath { + return ErrWalkDirEnd + } + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if filePath == startDataPath { + c.tableDMLIdx.UpdateDMLIdxMapByStartPath(dmlkey, fileIdx) + } else { + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + } + } + return nil + }); err != nil && !errors.Is(err, ErrWalkDirEnd) { + return nil, errors.Trace(err) + } + return c.tableDMLIdx.DiffNewTableDMLIdxMap(newTableDMLIdxMap), nil +} + +// downloadSchemaFiles downloads schema files concurrently for given schema path keys +func (c *S3Consumer) downloadSchemaFiles( + ctx context.Context, + newVersionPaths map[cloudstorage.SchemaPathKey]string, +) error { + eg, egCtx := errgroup.WithContext(ctx) + + log.Debug("starting concurrent schema file download", zap.Int("totalSchemas", len(newVersionPaths))) + for schemaPathKey, filePath := range newVersionPaths { + eg.Go(func() error { + content, err := c.s3Storage.ReadFile(egCtx, filePath) + if err != nil { + return errors.Annotatef(err, "failed to read schema file: %s", filePath) + } + + tableDefinition := &cloudstorage.TableDefinition{} + if err := json.Unmarshal(content, tableDefinition); err != nil { + return errors.Annotatef(err, "failed to unmarshal schema file: %s", filePath) + } + if err := c.schemaDefinitions.SetSchemaDefinition(schemaPathKey, filePath, tableDefinition); err != nil { + return errors.Trace(err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +func (c *S3Consumer) discoverAndDownloadNewTableVersions( + ctx context.Context, + schema, table string, +) ([]types.VersionKey, error) { + currentVersion := c.currentTableVersion.GetCurrentTableVersion(schema, table) + metaSubDir := fmt.Sprintf("%s/%s/meta/", schema, table) + opt := &storage.WalkOption{ + SubDir: metaSubDir, + ObjPrefix: "schema_", + // TODO: StartAfter: currentVersion.versionPath, + } + + var scanVersions []types.VersionKey + newVersionPaths := make(map[cloudstorage.SchemaPathKey]string) + if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if !cloudstorage.IsSchemaFile(filePath) { + return nil + } + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + version := schemaKey.TableVersion + if version > currentVersion.Version { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: version, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }); err != nil { + return nil, errors.Trace(err) + } + + // download new version schema files concurrently + if err := c.downloadSchemaFiles(ctx, newVersionPaths); err != nil { + return nil, errors.Trace(err) + } + + if currentVersion.Version > 0 { + scanVersions = append(scanVersions, currentVersion) + } + return scanVersions, nil +} + +func (c *S3Consumer) getNewFilesForSchemaPathKey( + ctx context.Context, + schema, table string, + version *types.VersionKey, +) (map[cloudstorage.DmlPathKey]fileIndexRange, error) { + schemaPrefix := path.Join(schema, table, fmt.Sprintf("%d", version.Version)) + opt := &storage.WalkOption{ + SubDir: schemaPrefix, + // TODO: StartAfter: version.dataPath, + } + + newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + maxFilePath := "" + if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + maxFilePath = filePath + } + return nil + }); err != nil { + return nil, errors.Trace(err) + } + + version.DataPath = maxFilePath + return c.tableDMLIdx.DiffNewTableDMLIdxMap(newTableDMLIdxMap), nil +} + +func (c *S3Consumer) downloadDMLFiles( + ctx context.Context, + newFiles map[cloudstorage.DmlPathKey]fileIndexRange, +) (map[cloudstorage.DmlPathKey]map[cloudstorage.FileIndexKey][][]byte, error) { + if len(newFiles) == 0 || c.skipDownloadData { + return nil, nil + } + + result := make(map[cloudstorage.DmlPathKey]map[cloudstorage.FileIndexKey][][]byte) + type downloadTask struct { + dmlPathKey cloudstorage.DmlPathKey + fileIndex cloudstorage.FileIndex + } + + var tasks []downloadTask + for dmlPathKey, fileRange := range newFiles { + for indexKey, indexRange := range fileRange { + log.Debug("prepare to download new dml file in index range", + zap.String("schema", dmlPathKey.Schema), + zap.String("table", dmlPathKey.Table), + zap.Uint64("version", dmlPathKey.TableVersion), + zap.Int64("partitionNum", dmlPathKey.PartitionNum), + zap.String("date", dmlPathKey.Date), + zap.String("dispatcherID", indexKey.DispatcherID), + zap.Bool("enableTableAcrossNodes", indexKey.EnableTableAcrossNodes), + zap.Uint64("startIndex", indexRange.start), + zap.Uint64("endIndex", indexRange.end)) + for i := indexRange.start; i <= indexRange.end; i++ { + tasks = append(tasks, downloadTask{ + dmlPathKey: dmlPathKey, + fileIndex: cloudstorage.FileIndex{ + FileIndexKey: indexKey, + Idx: i, + }, + }) + } + } + } + + log.Debug("starting concurrent DML file download", zap.Int("totalFiles", len(tasks))) + + // Concurrently download files + type fileContent struct { + dmlPathKey cloudstorage.DmlPathKey + indexKey cloudstorage.FileIndexKey + idx uint64 + content []byte + } + + fileContents := make(chan fileContent, len(tasks)) + eg, egCtx := errgroup.WithContext(ctx) + for _, task := range tasks { + eg.Go(func() error { + filePath := task.dmlPathKey.GenerateDMLFilePath( + &task.fileIndex, + c.fileExtension, + c.fileIndexWidth, + ) + + content, err := c.s3Storage.ReadFile(egCtx, filePath) + if err != nil { + return errors.Annotatef(err, "failed to read file: %s", filePath) + } + + // Channel writes are thread-safe, no mutex needed + fileContents <- fileContent{ + dmlPathKey: task.dmlPathKey, + indexKey: task.fileIndex.FileIndexKey, + idx: task.fileIndex.Idx, + content: content, + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + + // Close the channel to signal no more writes + close(fileContents) + + // Process the downloaded file contents + for fc := range fileContents { + if result[fc.dmlPathKey] == nil { + result[fc.dmlPathKey] = make(map[cloudstorage.FileIndexKey][][]byte) + } + result[fc.dmlPathKey][fc.indexKey] = append( + result[fc.dmlPathKey][fc.indexKey], + fc.content, + ) + } + + return result, nil +} + +// downloadNewFilesWithVersions downloads new files for given schema versions. +// consumeFunc is called from multiple goroutines concurrently and must be goroutine-safe. +func (c *S3Consumer) downloadNewFilesWithVersions( + ctx context.Context, + schema, table string, + scanVersions []types.VersionKey, + consumeFunc func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ), +) (*types.VersionKey, error) { + var maxVersion *types.VersionKey + eg, egCtx := errgroup.WithContext(ctx) + for _, version := range scanVersions { + versionp := &version + if maxVersion == nil || maxVersion.Version < version.Version { + maxVersion = versionp + } + eg.Go(func() error { + newFiles, err := c.getNewFilesForSchemaPathKey(egCtx, schema, table, versionp) + if err != nil { + return errors.Trace(err) + } + dmlData, err := c.downloadDMLFiles(egCtx, newFiles) + if err != nil { + return errors.Trace(err) + } + columnFieldTypes, err := c.schemaDefinitions.GetColumnFieldTypes(schema, table, versionp.Version) + if err != nil { + return errors.Trace(err) + } + for dmlPathKey, dmlSlices := range dmlData { + consumeFunc(dmlPathKey, dmlSlices, columnFieldTypes) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + if maxVersion != nil { + c.currentTableVersion.UpdateCurrentTableVersion(schema, table, *maxVersion) + } + return maxVersion, nil +} + +func (c *S3Consumer) ConsumeNewFiles( + ctx context.Context, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, map[types.SchemaTableKey]types.VersionKey, error) { + var mu sync.Mutex + // Combine DML data and schema data into result + result := make(map[cloudstorage.DmlPathKey]types.IncrementalData) + var versionMu sync.Mutex + maxVersionMap := make(map[types.SchemaTableKey]types.VersionKey) + eg, egCtx := errgroup.WithContext(ctx) + for schema, tables := range c.tables { + for _, table := range tables { + eg.Go(func() error { + scanVersions, err := c.discoverAndDownloadNewTableVersions(egCtx, schema, table) + if err != nil { + return errors.Trace(err) + } + maxVersion, err := c.downloadNewFilesWithVersions( + egCtx, schema, table, scanVersions, + func( + dmlPathKey cloudstorage.DmlPathKey, + dmlSlices map[cloudstorage.FileIndexKey][][]byte, + columnFieldTypes map[string]*ptypes.FieldType, + ) { + mu.Lock() + result[dmlPathKey] = types.IncrementalData{ + DataContentSlices: dmlSlices, + ColumnFieldTypes: columnFieldTypes, + } + mu.Unlock() + }, + ) + if err != nil { + return errors.Trace(err) + } + if maxVersion != nil { + versionMu.Lock() + maxVersionMap[types.SchemaTableKey{Schema: schema, Table: table}] = *maxVersion + versionMu.Unlock() + } + return nil + }) + } + } + + if err := eg.Wait(); err != nil { + return nil, nil, errors.Trace(err) + } + c.skipDownloadData = false + return result, maxVersionMap, nil +} diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go new file mode 100644 index 0000000000..b51bdb80d1 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go @@ -0,0 +1,706 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumer + +import ( + "bytes" + "context" + "path" + "slices" + "strings" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "github.com/stretchr/testify/require" +) + +func TestUpdateTableDMLIdxMap(t *testing.T) { + t.Parallel() + + t.Run("insert new entry", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + fileIdx := &cloudstorage.FileIndex{ + FileIndexKey: cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}, + Idx: 5, + } + + updateTableDMLIdxMap(m, dmlKey, fileIdx) + require.Len(t, m, 1) + require.Equal(t, uint64(5), m[dmlKey][fileIdx.FileIndexKey]) + }) + + t.Run("update with higher index", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + fileIdx1 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 3} + fileIdx2 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 7} + + updateTableDMLIdxMap(m, dmlKey, fileIdx1) + updateTableDMLIdxMap(m, dmlKey, fileIdx2) + require.Equal(t, uint64(7), m[dmlKey][indexKey]) + }) + + t.Run("skip lower index", func(t *testing.T) { + t.Parallel() + m := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + fileIdx1 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 10} + fileIdx2 := &cloudstorage.FileIndex{FileIndexKey: indexKey, Idx: 5} + + updateTableDMLIdxMap(m, dmlKey, fileIdx1) + updateTableDMLIdxMap(m, dmlKey, fileIdx2) + require.Equal(t, uint64(10), m[dmlKey][indexKey]) + }) +} + +func TestCurrentTableVersion(t *testing.T) { + t.Parallel() + + t.Run("get returns zero value for missing key", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + v := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, types.VersionKey{}, v) + }) + + t.Run("update and get", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk := types.VersionKey{Version: 100, VersionPath: "db/tbl/meta/schema_100_0000000000.json"} + cvt.UpdateCurrentTableVersion("db", "tbl", vk) + got := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, vk, got) + }) + + t.Run("update overwrites previous value", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk1 := types.VersionKey{Version: 1} + vk2 := types.VersionKey{Version: 2} + cvt.UpdateCurrentTableVersion("db", "tbl", vk1) + cvt.UpdateCurrentTableVersion("db", "tbl", vk2) + got := cvt.GetCurrentTableVersion("db", "tbl") + require.Equal(t, vk2, got) + }) + + t.Run("different tables are independent", func(t *testing.T) { + t.Parallel() + cvt := NewCurrentTableVersion() + vk1 := types.VersionKey{Version: 10} + vk2 := types.VersionKey{Version: 20} + cvt.UpdateCurrentTableVersion("db", "tbl1", vk1) + cvt.UpdateCurrentTableVersion("db", "tbl2", vk2) + require.Equal(t, vk1, cvt.GetCurrentTableVersion("db", "tbl1")) + require.Equal(t, vk2, cvt.GetCurrentTableVersion("db", "tbl2")) + }) +} + +func TestSchemaDefinitions(t *testing.T) { + t.Parallel() + + t.Run("get returns error for missing key", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + _, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.Error(t, err) + require.Contains(t, err.Error(), "schema definition not found") + }) + + t.Run("set and get empty table definition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{} + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.NoError(t, err) + + got, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.NoError(t, err) + require.Equal(t, map[string]*ptypes.FieldType{}, got) + }) + + t.Run("set and get with columns parses field types correctly", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{ + Table: "tbl", + Schema: "db", + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "name", Tp: "VARCHAR", Precision: "255"}, + {Name: "score", Tp: "DECIMAL", Precision: "10", Scale: "2"}, + {Name: "duration", Tp: "TIME", Scale: "3"}, + {Name: "created_at", Tp: "TIMESTAMP", Scale: "6"}, + {Name: "big_id", Tp: "BIGINT UNSIGNED", Precision: "20"}, + }, + TotalColumns: 6, + } + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.NoError(t, err) + + got, err := sp.GetColumnFieldTypes("db", "tbl", 1) + require.NoError(t, err) + require.Len(t, got, 6) + + // INT PK + require.Equal(t, mysql.TypeLong, got["id"].GetType()) + require.True(t, mysql.HasPriKeyFlag(got["id"].GetFlag())) + require.Equal(t, 11, got["id"].GetFlen()) + + // VARCHAR(255) + require.Equal(t, mysql.TypeVarchar, got["name"].GetType()) + require.Equal(t, 255, got["name"].GetFlen()) + + // DECIMAL(10,2) + require.Equal(t, mysql.TypeNewDecimal, got["score"].GetType()) + require.Equal(t, 10, got["score"].GetFlen()) + require.Equal(t, 2, got["score"].GetDecimal()) + + // TIME(3) — decimal stores FSP + require.Equal(t, mysql.TypeDuration, got["duration"].GetType()) + require.Equal(t, 3, got["duration"].GetDecimal()) + + // TIMESTAMP(6) — decimal stores FSP + require.Equal(t, mysql.TypeTimestamp, got["created_at"].GetType()) + require.Equal(t, 6, got["created_at"].GetDecimal()) + + // BIGINT UNSIGNED + require.Equal(t, mysql.TypeLonglong, got["big_id"].GetType()) + require.True(t, mysql.HasUnsignedFlag(got["big_id"].GetFlag())) + require.Equal(t, 20, got["big_id"].GetFlen()) + }) + + t.Run("set returns error for invalid column definition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1} + td := &cloudstorage.TableDefinition{ + Table: "tbl", + Schema: "db", + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", Precision: "not_a_number"}, + }, + TotalColumns: 1, + } + err := sp.SetSchemaDefinition(key, "/path/to/schema.json", td) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to convert column id to FieldType") + + // Verify the definition was NOT stored + _, err = sp.GetColumnFieldTypes("db", "tbl", 1) + require.Error(t, err) + require.Contains(t, err.Error(), "schema definition not found") + }) + + t.Run("remove with condition", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key1 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl1", TableVersion: 1} + key2 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 2} + require.NoError(t, sp.SetSchemaDefinition(key1, "/path1", nil)) + require.NoError(t, sp.SetSchemaDefinition(key2, "/path2", nil)) + + // Remove only entries for tbl1 + sp.RemoveSchemaDefinitionWithCondition(func(k cloudstorage.SchemaPathKey) bool { + return k.Table == "tbl1" + }) + + _, err := sp.GetColumnFieldTypes("db", "tbl1", 1) + require.Error(t, err) + + _, err = sp.GetColumnFieldTypes("db", "tbl2", 2) + require.NoError(t, err) + }) + + t.Run("remove with condition matching all", func(t *testing.T) { + t.Parallel() + sp := NewSchemaDefinitions() + key1 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl1", TableVersion: 1} + key2 := cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 2} + require.NoError(t, sp.SetSchemaDefinition(key1, "/path1", nil)) + require.NoError(t, sp.SetSchemaDefinition(key2, "/path2", nil)) + + sp.RemoveSchemaDefinitionWithCondition(func(k cloudstorage.SchemaPathKey) bool { + return true + }) + + _, err := sp.GetColumnFieldTypes("db", "tbl1", 1) + require.Error(t, err) + _, err = sp.GetColumnFieldTypes("db", "tbl2", 2) + require.Error(t, err) + }) +} + +func TestTableDMLIdx_DiffNewTableDMLIdxMap(t *testing.T) { + t.Parallel() + + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + dmlKey := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl", TableVersion: 1}, + Date: "2026-01-01", + } + + t.Run("new entry starts from 1", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey][indexKey]) + }) + + t.Run("existing entry increments from previous end + 1", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + // First call: set initial state + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + // Second call: new end is 7, should get range [4, 7] + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 7}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 4, end: 7}, result[dmlKey][indexKey]) + }) + + t.Run("same end value returns no diff", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Empty(t, result) + }) + + t.Run("lower end value returns no diff", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + + firstMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 10}, + } + idx.DiffNewTableDMLIdxMap(firstMap) + + secondMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(secondMap) + require.Empty(t, result) + }) + + t.Run("empty new map returns empty result", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + result := idx.DiffNewTableDMLIdxMap(map[cloudstorage.DmlPathKey]fileIndexKeyMap{}) + require.Empty(t, result) + }) + + t.Run("multiple keys", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + dmlKey2 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "db", Table: "tbl2", TableVersion: 1}, + Date: "2026-01-02", + } + + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3}, + dmlKey2: {indexKey: 5}, + } + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 2) + require.Equal(t, indexRange{start: 1, end: 3}, result[dmlKey][indexKey]) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey2][indexKey]) + }) + + t.Run("multiple index keys for same dml path", func(t *testing.T) { + t.Parallel() + idx := NewTableDMLIdx() + indexKey2 := cloudstorage.FileIndexKey{DispatcherID: "dispatcher1", EnableTableAcrossNodes: true} + + newMap := map[cloudstorage.DmlPathKey]fileIndexKeyMap{ + dmlKey: {indexKey: 3, indexKey2: 5}, + } + result := idx.DiffNewTableDMLIdxMap(newMap) + require.Len(t, result, 1) + require.Equal(t, indexRange{start: 1, end: 3}, result[dmlKey][indexKey]) + require.Equal(t, indexRange{start: 1, end: 5}, result[dmlKey][indexKey2]) + }) +} + +type mockFile struct { + name string + content []byte +} + +type mockS3Storage struct { + storage.ExternalStorage + + fileOffset map[string]int + sortedFiles []mockFile +} + +func NewMockS3Storage(sortedFiles []mockFile) *mockS3Storage { + s3Storage := &mockS3Storage{} + s3Storage.UpdateFiles(sortedFiles) + return s3Storage +} + +func (m *mockS3Storage) ReadFile(ctx context.Context, name string) ([]byte, error) { + return m.sortedFiles[m.fileOffset[name]].content, nil +} + +func (m *mockS3Storage) WalkDir(ctx context.Context, opt *storage.WalkOption, fn func(path string, size int64) error) error { + filenamePrefix := path.Join(opt.SubDir, opt.ObjPrefix) + for _, file := range m.sortedFiles { + if strings.HasPrefix(file.name, filenamePrefix) { + if err := fn(file.name, 0); err != nil { + return err + } + } + } + return nil +} + +func (m *mockS3Storage) UpdateFiles(sortedFiles []mockFile) { + fileOffset := make(map[string]int) + for i, file := range sortedFiles { + fileOffset[file.name] = i + } + m.fileOffset = fileOffset + m.sortedFiles = sortedFiles +} + +func TestS3Consumer(t *testing.T) { + t.Parallel() + ctx := context.Background() + round1Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + } + round1TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 1, + VersionPath: "test/t1/meta/schema_1_0000000001.json", + DataPath: "test/t1/1/2026-01-01/CDC00000000000000000001.json", + }, + }, + } + expectedMaxVersionMap1 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 1, VersionPath: "test/t1/meta/schema_1_0000000001.json", DataPath: "test/t1/1/2026-01-01/CDC00000000000000000001.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + round2Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("1_2026-01-01_2.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("1_2026-01-02_1.json")}, + } + round2TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 10, RightBoundary: 20}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 1, + VersionPath: "test/t1/meta/schema_1_0000000001.json", + DataPath: "test/t1/1/2026-01-02/CDC00000000000000000001.json", + }, + }, + } + expectedNewData2 := func(newData map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, newData, 2) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-01_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-01", + }]) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + } + expectedMaxVersionMap2 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 1, VersionPath: "test/t1/meta/schema_1_0000000001.json", DataPath: "test/t1/1/2026-01-02/CDC00000000000000000001.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + round3Files := []mockFile{ + {name: "test/t1/meta/schema_1_0000000001.json", content: []byte("{}")}, + {name: "test/t1/meta/schema_2_0000000001.json", content: []byte("{}")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("1_2026-01-01_1.json")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("1_2026-01-01_2.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("1_2026-01-02_1.json")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000002.json", content: []byte("1_2026-01-02_2.json")}, + {name: "test/t1/2/2026-01-02/CDC00000000000000000001.json", content: []byte("2_2026-01-02_1.json")}, + {name: "test/t1/2/2026-01-03/CDC00000000000000000001.json", content: []byte("2_2026-01-03_1.json")}, + {name: "test/t1/2/2026-01-03/CDC00000000000000000002.json", content: []byte("2_2026-01-03_2.json")}, + } + round3TimeWindowData := types.TimeWindowData{ + TimeWindow: types.TimeWindow{LeftBoundary: 20, RightBoundary: 30}, + Data: map[cloudstorage.DmlPathKey]types.IncrementalData{}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "test", Table: "t1"}: { + Version: 2, + VersionPath: "test/t1/meta/schema_2_0000000001.json", + DataPath: "test/t1/2/2026-01-03/CDC00000000000000000002.json", + }, + }, + } + expectedNewData3 := func(newData map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, newData, 3) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-02_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("2_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + newDataContent := newData[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-03", + }] + require.Len(t, newDataContent.DataContentSlices, 1) + contents := newDataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("2_2026-01-03_1.json"), []byte("2_2026-01-03_2.json")}, contents) + } + expectedMaxVersionMap3 := func(maxVersionMap map[types.SchemaTableKey]types.VersionKey) { + require.Len(t, maxVersionMap, 1) + require.Equal(t, types.VersionKey{ + Version: 2, VersionPath: "test/t1/meta/schema_2_0000000001.json", DataPath: "test/t1/2/2026-01-03/CDC00000000000000000002.json", + }, maxVersionMap[types.SchemaTableKey{Schema: "test", Table: "t1"}]) + } + expectedCheckpoint23 := func(data map[cloudstorage.DmlPathKey]types.IncrementalData) { + require.Len(t, data, 4) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("1_2026-01-01_2.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-01", + }]) + dataContent := data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + PartitionNum: 0, + Date: "2026-01-02", + }] + require.Len(t, dataContent.DataContentSlices, 1) + contents := dataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("1_2026-01-02_1.json"), []byte("1_2026-01-02_2.json")}, contents) + require.Equal(t, types.IncrementalData{ + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {DispatcherID: "", EnableTableAcrossNodes: false}: {[]byte("2_2026-01-02_1.json")}, + }, + ColumnFieldTypes: map[string]*ptypes.FieldType{}, + }, data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-02", + }]) + dataContent = data[cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 2}, + PartitionNum: 0, + Date: "2026-01-03", + }] + require.Len(t, dataContent.DataContentSlices, 1) + contents = dataContent.DataContentSlices[cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false}] + require.Len(t, contents, 2) + slices.SortFunc(contents, func(a, b []byte) int { + return bytes.Compare(a, b) + }) + require.Equal(t, [][]byte{[]byte("2_2026-01-03_1.json"), []byte("2_2026-01-03_2.json")}, contents) + } + + t.Run("checkpoint with nil items returns nil", func(t *testing.T) { + t.Parallel() + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "test", nil) + require.NoError(t, err) + require.Empty(t, data) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + require.Empty(t, newData) + expectedMaxVersionMap1(maxVersionMap) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with empty items returns nil", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "test", checkpoint) + require.NoError(t, err) + require.Empty(t, data) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + require.Empty(t, newData) + expectedMaxVersionMap1(maxVersionMap) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 1 item", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + s3Storage := NewMockS3Storage(round1Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + require.Empty(t, data) + s3Storage.UpdateFiles(round2Files) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData2(newData) + expectedMaxVersionMap2(maxVersionMap) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err = s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 2 items", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "clusterX": round2TimeWindowData, + }) + s3Storage := NewMockS3Storage(round2Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + expectedNewData2(data) + s3Storage.UpdateFiles(round3Files) + newData, maxVersionMap, err := s3Consumer.ConsumeNewFiles(ctx) + require.NoError(t, err) + expectedNewData3(newData) + expectedMaxVersionMap3(maxVersionMap) + }) + t.Run("checkpoint with 3 items", func(t *testing.T) { + t.Parallel() + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "clusterX": round1TimeWindowData, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "clusterX": round2TimeWindowData, + }) + checkpoint.NewTimeWindowData(2, map[string]types.TimeWindowData{ + "clusterX": round3TimeWindowData, + }) + s3Storage := NewMockS3Storage(round3Files) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + data, err := s3Consumer.InitializeFromCheckpoint(ctx, "clusterX", checkpoint) + require.NoError(t, err) + expectedCheckpoint23(data) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder.go b/cmd/multi-cluster-consistency-checker/decoder/decoder.go new file mode 100644 index 0000000000..af0f8d01d0 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder.go @@ -0,0 +1,454 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "fmt" + "slices" + "strconv" + "strings" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/codec/common" + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + tiTypes "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "go.uber.org/zap" + "golang.org/x/text/encoding/charmap" +) + +const tidbWaterMarkType = "TIDB_WATERMARK" + +type canalValueDecoderJSONMessage struct { + PkNames []string `json:"pkNames"` + IsDDL bool `json:"isDdl"` + EventType string `json:"type"` + MySQLType map[string]string `json:"mysqlType"` + Data []map[string]any `json:"data"` +} + +func (c *canalValueDecoderJSONMessage) messageType() common.MessageType { + if c.IsDDL { + return common.MessageTypeDDL + } + + if c.EventType == tidbWaterMarkType { + return common.MessageTypeResolved + } + + return common.MessageTypeRow +} + +type TiDBCommitTsExtension struct { + CommitTs uint64 `json:"commitTs"` +} + +type canalValueDecoderJSONMessageWithTiDBExtension struct { + canalValueDecoderJSONMessage + + TiDBCommitTsExtension *TiDBCommitTsExtension `json:"_tidb"` +} + +func defaultCanalJSONCodecConfig() *common.Config { + codecConfig := common.NewConfig(config.ProtocolCanalJSON) + // Always enable tidb extension for canal-json protocol + // because we need to get the commit ts from the extension field. + codecConfig.EnableTiDBExtension = true + codecConfig.Terminator = config.CRLF + return codecConfig +} + +type Record struct { + types.CdcVersion + Pk types.PkType + PkStr string + PkMap map[string]any + ColumnValues map[string]any +} + +func (r *Record) EqualReplicatedRecord(replicatedRecord *Record) bool { + if replicatedRecord == nil { + return false + } + if r.CommitTs != replicatedRecord.OriginTs { + return false + } + if r.Pk != replicatedRecord.Pk { + return false + } + if len(r.ColumnValues) != len(replicatedRecord.ColumnValues) { + return false + } + for columnName, columnValue := range r.ColumnValues { + replicatedColumnValue, ok := replicatedRecord.ColumnValues[columnName] + if !ok { + return false + } + // NOTE: This comparison is safe because ColumnValues only holds comparable + // types (nil, string, int64, float64, etc.) as produced by the canal-json + // decoder. If a non-comparable type (e.g. []byte or map) were ever stored, + // the != operator would panic at runtime. + if columnValue != replicatedColumnValue { + return false + } + } + return true +} + +type columnValueDecoder struct { + data []byte + config *common.Config + + msg *canalValueDecoderJSONMessageWithTiDBExtension + columnFieldTypes map[string]*ptypes.FieldType +} + +func newColumnValueDecoder(data []byte) (*columnValueDecoder, error) { + config := defaultCanalJSONCodecConfig() + data, err := common.Decompress(config.LargeMessageHandle.LargeMessageHandleCompression, data) + if err != nil { + log.Error("decompress data failed", + zap.String("compression", config.LargeMessageHandle.LargeMessageHandleCompression), + zap.Any("data", data), + zap.Error(err)) + return nil, errors.Annotatef(err, "decompress data failed") + } + return &columnValueDecoder{ + config: config, + data: data, + }, nil +} + +func Decode(data []byte, columnFieldTypes map[string]*ptypes.FieldType) ([]*Record, error) { + decoder, err := newColumnValueDecoder(data) + if err != nil { + return nil, errors.Trace(err) + } + + decoder.columnFieldTypes = columnFieldTypes + + records := make([]*Record, 0) + for { + msgType, hasNext := decoder.tryNext() + if !hasNext { + break + } + if msgType == common.MessageTypeRow { + record, err := decoder.decodeNext() + if err != nil { + return nil, errors.Trace(err) + } + records = append(records, record) + } + } + + return records, nil +} + +func (d *columnValueDecoder) tryNext() (common.MessageType, bool) { + if d.data == nil { + return common.MessageTypeUnknown, false + } + var ( + msg = &canalValueDecoderJSONMessageWithTiDBExtension{} + encodedData []byte + ) + + idx := bytes.IndexAny(d.data, d.config.Terminator) + if idx >= 0 { + encodedData = d.data[:idx] + d.data = d.data[idx+len(d.config.Terminator):] + } else { + encodedData = d.data + d.data = nil + } + + if len(encodedData) == 0 { + return common.MessageTypeUnknown, false + } + + if err := json.Unmarshal(encodedData, msg); err != nil { + log.Error("canal-json decoder unmarshal data failed", + zap.Error(err), zap.ByteString("data", encodedData)) + d.msg = nil + return common.MessageTypeUnknown, true + } + d.msg = msg + return d.msg.messageType(), true +} + +func (d *columnValueDecoder) decodeNext() (*Record, error) { + if d.msg == nil || len(d.msg.Data) == 0 || d.msg.messageType() != common.MessageTypeRow { + log.Error("invalid message", zap.Any("msg", d.msg)) + return nil, errors.New("invalid message") + } + + var pkStrBuilder strings.Builder + pkStrBuilder.WriteString("[") + pkValues := make([]tiTypes.Datum, 0, len(d.msg.PkNames)) + pkMap := make(map[string]any, len(d.msg.PkNames)) + slices.Sort(d.msg.PkNames) + for i, pkName := range d.msg.PkNames { + columnValue, ok := d.msg.Data[0][pkName] + if !ok { + log.Error("column value not found", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("column value of column %s not found", pkName) + } + if i > 0 { + pkStrBuilder.WriteString(", ") + } + fmt.Fprintf(&pkStrBuilder, "%s: %v", pkName, columnValue) + pkMap[pkName] = columnValue + ft := d.getColumnFieldType(pkName) + if ft == nil { + log.Error("field type not found", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("field type of column %s not found", pkName) + } + datum := valueToDatum(columnValue, ft) + if datum.IsNull() { + log.Error("column value is null", zap.String("pkName", pkName), zap.Any("msg", d.msg)) + return nil, errors.Errorf("column value of column %s is null", pkName) + } + pkValues = append(pkValues, *datum) + delete(d.msg.Data[0], pkName) + } + pkStrBuilder.WriteString("]") + pkEncoded, err := codec.EncodeKey(time.UTC, nil, pkValues...) + if err != nil { + return nil, errors.Annotate(err, "failed to encode primary key") + } + pk := hex.EncodeToString(pkEncoded) + originTs := uint64(0) + columnValues := make(map[string]any) + for columnName, columnValue := range d.msg.Data[0] { + if columnName == event.OriginTsColumn { + if columnValue != nil { + originTs, err = strconv.ParseUint(columnValue.(string), 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + } + } else { + columnValues[columnName] = columnValue + } + } + commitTs := d.msg.TiDBCommitTsExtension.CommitTs + d.msg = nil + return &Record{ + Pk: types.PkType(pk), + PkStr: pkStrBuilder.String(), + PkMap: pkMap, + ColumnValues: columnValues, + CdcVersion: types.CdcVersion{ + CommitTs: commitTs, + OriginTs: originTs, + }, + }, nil +} + +// getColumnFieldType returns the FieldType for a column. +// It first looks up from the tableDefinition-based columnFieldTypes map, +// then falls back to parsing the MySQLType string from the canal-json message. +func (d *columnValueDecoder) getColumnFieldType(columnName string) *ptypes.FieldType { + if d.columnFieldTypes != nil { + if ft, ok := d.columnFieldTypes[columnName]; ok { + return ft + } + } + // Fallback: parse from MySQLType in the canal-json message + mysqlType, ok := d.msg.MySQLType[columnName] + if !ok { + return nil + } + return newPKColumnFieldTypeFromMysqlType(mysqlType) +} + +func newPKColumnFieldTypeFromMysqlType(mysqlType string) *ptypes.FieldType { + tp := ptypes.NewFieldType(common.ExtractBasicMySQLType(mysqlType)) + if common.IsBinaryMySQLType(mysqlType) { + tp.AddFlag(mysql.BinaryFlag) + tp.SetCharset("binary") + tp.SetCollate("binary") + } + if strings.HasPrefix(mysqlType, "char") || + strings.HasPrefix(mysqlType, "varchar") || + strings.Contains(mysqlType, "text") || + strings.Contains(mysqlType, "enum") || + strings.Contains(mysqlType, "set") { + tp.SetCharset("utf8mb4") + tp.SetCollate("utf8mb4_bin") + } + + if common.IsUnsignedMySQLType(mysqlType) { + tp.AddFlag(mysql.UnsignedFlag) + } + + flen, decimal := common.ExtractFlenDecimal(mysqlType, tp.GetType()) + tp.SetFlen(flen) + tp.SetDecimal(decimal) + switch tp.GetType() { + case mysql.TypeEnum, mysql.TypeSet: + tp.SetElems(common.ExtractElements(mysqlType)) + case mysql.TypeDuration: + decimal = common.ExtractDecimal(mysqlType) + tp.SetDecimal(decimal) + default: + } + return tp +} + +func valueToDatum(value any, ft *ptypes.FieldType) *tiTypes.Datum { + d := &tiTypes.Datum{} + if value == nil { + d.SetNull() + return d + } + rawValue, ok := value.(string) + if !ok { + log.Panic("canal-json encoded message should have type in `string`") + } + if mysql.HasBinaryFlag(ft.GetFlag()) { + // when encoding the `JavaSQLTypeBLOB`, use `IS08859_1` decoder, now reverse it back. + result, err := charmap.ISO8859_1.NewEncoder().String(rawValue) + if err != nil { + log.Panic("invalid column value, please report a bug", zap.Any("rawValue", rawValue), zap.Error(err)) + } + rawValue = result + } + + switch ft.GetType() { + case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: + if mysql.HasUnsignedFlag(ft.GetFlag()) { + data, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for unsigned integer", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetUint64(data) + return d + } + data, err := strconv.ParseInt(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for integer", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetInt64(data) + return d + case mysql.TypeYear: + data, err := strconv.ParseInt(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for year", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetInt64(data) + return d + case mysql.TypeFloat: + data, err := strconv.ParseFloat(rawValue, 32) + if err != nil { + log.Panic("invalid column value for float", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetFloat32(float32(data)) + return d + case mysql.TypeDouble: + data, err := strconv.ParseFloat(rawValue, 64) + if err != nil { + log.Panic("invalid column value for double", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetFloat64(data) + return d + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, + mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + d.SetString(rawValue, ft.GetCollate()) + return d + case mysql.TypeNewDecimal: + data := new(tiTypes.MyDecimal) + err := data.FromString([]byte(rawValue)) + if err != nil { + log.Panic("invalid column value for decimal", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlDecimal(data) + d.SetLength(ft.GetFlen()) + if ft.GetDecimal() == tiTypes.UnspecifiedLength { + d.SetFrac(int(data.GetDigitsFrac())) + } else { + d.SetFrac(ft.GetDecimal()) + } + return d + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + data, err := tiTypes.ParseTime(tiTypes.DefaultStmtNoWarningContext, rawValue, ft.GetType(), ft.GetDecimal()) + if err != nil { + log.Panic("invalid column value for time", zap.Any("rawValue", rawValue), + zap.Int("flen", ft.GetFlen()), zap.Int("decimal", ft.GetDecimal()), + zap.Error(err)) + } + d.SetMysqlTime(data) + return d + case mysql.TypeDuration: + data, _, err := tiTypes.ParseDuration(tiTypes.DefaultStmtNoWarningContext, rawValue, ft.GetDecimal()) + if err != nil { + log.Panic("invalid column value for duration", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlDuration(data) + return d + case mysql.TypeEnum: + enumValue, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for enum", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlEnum(tiTypes.Enum{ + Name: "", + Value: enumValue, + }, ft.GetCollate()) + return d + case mysql.TypeSet: + setValue, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for set", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlSet(tiTypes.Set{ + Name: "", + Value: setValue, + }, ft.GetCollate()) + return d + case mysql.TypeBit: + data, err := strconv.ParseUint(rawValue, 10, 64) + if err != nil { + log.Panic("invalid column value for bit", zap.Any("rawValue", rawValue), zap.Error(err)) + } + byteSize := (ft.GetFlen() + 7) >> 3 + d.SetMysqlBit(tiTypes.NewBinaryLiteralFromUint(data, byteSize)) + return d + case mysql.TypeJSON: + data, err := tiTypes.ParseBinaryJSONFromString(rawValue) + if err != nil { + log.Panic("invalid column value for json", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetMysqlJSON(data) + return d + case mysql.TypeTiDBVectorFloat32: + data, err := tiTypes.ParseVectorFloat32(rawValue) + if err != nil { + log.Panic("cannot parse vector32 value from string", zap.Any("rawValue", rawValue), zap.Error(err)) + } + d.SetVectorFloat32(data) + return d + } + return d +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go new file mode 100644 index 0000000000..972ab377cd --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go @@ -0,0 +1,309 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder_test + +import ( + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + "github.com/stretchr/testify/require" +) + +// buildColumnFieldTypes converts a TableDefinition into a map of column name → FieldType, +// mimicking what SchemaDefinitions.SetSchemaDefinition does. +func buildColumnFieldTypes(t *testing.T, td *cloudstorage.TableDefinition) map[string]*ptypes.FieldType { + t.Helper() + result := make(map[string]*ptypes.FieldType, len(td.Columns)) + for i, col := range td.Columns { + colInfo, err := col.ToTiColumnInfo(int64(i)) + require.NoError(t, err) + result[col.Name] = &colInfo.FieldType + } + return result +} + +// DataContent uses CRLF (\r\n) as line terminator to match the codec config +const DataContent1 string = "" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770184540709,"ts":1770184542274,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar"},"old":null,"data":[{"id":"20","first_name":"t","last_name":"TT","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464043256649875456}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770184540709,"ts":1770184542274,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"21","first_name":"u","last_name":"UU","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464043256649875456}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770301693150,"ts":1770301693833,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int"},"old":null,"data":[{"id":"5","first_name":"e","last_name":"E","_tidb_origin_ts":"464073966942421014","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464073967049113629}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770301693150,"ts":1770301693833,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint"},"old":null,"data":[{"id":"6","first_name":"f","last_name":"F","_tidb_origin_ts":"464073966942421014","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464073967049113629}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770303499850,"ts":1770303500498,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":"464074440387592202","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464074440664678441}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"UPDATE","es":1770303520951,"ts":1770303522531,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp","id":"int","first_name":"varchar"},"old":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":"464074440387592202","_tidb_softdelete_time":null}],"data":[{"id":"7","first_name":"g","last_name":"G","_tidb_origin_ts":null,"_tidb_softdelete_time":"2026-02-05 22:58:40.992217"}],"_tidb":{"commitTs":464074446196178963}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1770303498793,"ts":1770303499864,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464074440387592202}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"UPDATE","es":1770303522494,"ts":1770303523900,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"data":[{"id":"8","first_name":"h","last_name":"H","_tidb_origin_ts":"464074446196178963","_tidb_softdelete_time":"2026-02-05 22:58:40.992217"}],"_tidb":{"commitTs":464074446600667164}}` + +var ExpectedRecords1 = []decoder.Record{ + {CdcVersion: types.CdcVersion{CommitTs: 464043256649875456, OriginTs: 0}, Pk: "038000000000000014", PkStr: "[id: 20]", ColumnValues: map[string]any{"first_name": "t", "last_name": "TT", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464043256649875456, OriginTs: 0}, Pk: "038000000000000015", PkStr: "[id: 21]", ColumnValues: map[string]any{"first_name": "u", "last_name": "UU", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464073967049113629, OriginTs: 464073966942421014}, Pk: "038000000000000005", PkStr: "[id: 5]", ColumnValues: map[string]any{"first_name": "e", "last_name": "E", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464073967049113629, OriginTs: 464073966942421014}, Pk: "038000000000000006", PkStr: "[id: 6]", ColumnValues: map[string]any{"first_name": "f", "last_name": "F", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074440664678441, OriginTs: 464074440387592202}, Pk: "038000000000000007", PkStr: "[id: 7]", ColumnValues: map[string]any{"first_name": "g", "last_name": "G", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074446196178963, OriginTs: 0}, Pk: "038000000000000007", PkStr: "[id: 7]", ColumnValues: map[string]any{"first_name": "g", "last_name": "G", "_tidb_softdelete_time": "2026-02-05 22:58:40.992217"}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074440387592202, OriginTs: 0}, Pk: "038000000000000008", PkStr: "[id: 8]", ColumnValues: map[string]any{"first_name": "h", "last_name": "H", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464074446600667164, OriginTs: 464074446196178963}, Pk: "038000000000000008", PkStr: "[id: 8]", ColumnValues: map[string]any{"first_name": "h", "last_name": "H", "_tidb_softdelete_time": "2026-02-05 22:58:40.992217"}}, +} + +// tableDefinition1 describes the "message" table: id(INT PK), first_name(VARCHAR), last_name(VARCHAR), +// _tidb_origin_ts(BIGINT), _tidb_softdelete_time(TIMESTAMP) +var tableDefinition1 = &cloudstorage.TableDefinition{ + Table: "message", + Schema: "test_active", + Version: 1, + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "first_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "last_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "_tidb_origin_ts", Tp: "BIGINT", Precision: "20"}, + {Name: "_tidb_softdelete_time", Tp: "TIMESTAMP"}, + }, + TotalColumns: 5, +} + +func TestCanalJSONDecoder1(t *testing.T) { + records, err := decoder.Decode([]byte(DataContent1), buildColumnFieldTypes(t, tableDefinition1)) + require.NoError(t, err) + require.Len(t, records, 8) + for i, actualRecord := range records { + expectedRecord := ExpectedRecords1[i] + require.Equal(t, actualRecord.Pk, expectedRecord.Pk) + require.Equal(t, actualRecord.PkStr, expectedRecord.PkStr) + require.Equal(t, actualRecord.ColumnValues, expectedRecord.ColumnValues) + require.Equal(t, actualRecord.CdcVersion.CommitTs, expectedRecord.CdcVersion.CommitTs) + require.Equal(t, actualRecord.CdcVersion.OriginTs, expectedRecord.CdcVersion.OriginTs) + } +} + +const DataContent2 string = "" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344412751,"ts":1770344413749,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"100","first_name":"a","last_name":"A","_tidb_origin_ts":"464085165262503958","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085165736198159}}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344427851,"ts":1770344429772,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"101","first_name":"b","last_name":"B","_tidb_origin_ts":null,"_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085169694572575}}` + "\r\n" + +var ExpectedRecords2 = []decoder.Record{ + {CdcVersion: types.CdcVersion{CommitTs: 464085165736198159, OriginTs: 464085165262503958}, Pk: "016100000000000000f8038000000000000064", PkStr: "[first_name: a, id: 100]", ColumnValues: map[string]any{"last_name": "A", "_tidb_softdelete_time": nil}}, + {CdcVersion: types.CdcVersion{CommitTs: 464085169694572575, OriginTs: 0}, Pk: "016200000000000000f8038000000000000065", PkStr: "[first_name: b, id: 101]", ColumnValues: map[string]any{"last_name": "B", "_tidb_softdelete_time": nil}}, +} + +// tableDefinition2 describes the "message2" table: id(INT PK), first_name(VARCHAR PK), last_name(VARCHAR), +// _tidb_origin_ts(BIGINT), _tidb_softdelete_time(TIMESTAMP) +var tableDefinition2 = &cloudstorage.TableDefinition{ + Table: "message2", + Schema: "test_active", + Version: 1, + Columns: []cloudstorage.TableCol{ + {Name: "id", Tp: "INT", IsPK: "true", Precision: "11"}, + {Name: "first_name", Tp: "VARCHAR", IsPK: "true", Precision: "255"}, + {Name: "last_name", Tp: "VARCHAR", Precision: "255"}, + {Name: "_tidb_origin_ts", Tp: "BIGINT", Precision: "20"}, + {Name: "_tidb_softdelete_time", Tp: "TIMESTAMP"}, + }, + TotalColumns: 5, +} + +func TestCanalJSONDecoder2(t *testing.T) { + records, err := decoder.Decode([]byte(DataContent2), buildColumnFieldTypes(t, tableDefinition2)) + require.NoError(t, err) + require.Len(t, records, 2) + for i, actualRecord := range records { + expectedRecord := ExpectedRecords2[i] + require.Equal(t, actualRecord.Pk, expectedRecord.Pk) + require.Equal(t, actualRecord.PkStr, expectedRecord.PkStr) + require.Equal(t, actualRecord.ColumnValues, expectedRecord.ColumnValues) + require.Equal(t, actualRecord.CdcVersion.CommitTs, expectedRecord.CdcVersion.CommitTs) + require.Equal(t, actualRecord.CdcVersion.OriginTs, expectedRecord.CdcVersion.OriginTs) + } +} + +// TestCanalJSONDecoderWithInvalidMessage verifies that when a malformed message appears in +// the data stream, it is skipped gracefully and subsequent valid messages are still decoded. +// This covers the fix where d.msg is cleared to nil on unmarshal failure to prevent stale +// message data from leaking into decodeNext. +func TestCanalJSONDecoderWithInvalidMessage(t *testing.T) { + // First line is invalid JSON, second line is a valid message. + dataWithInvalidLine := `{invalid json}` + "\r\n" + + `{"id":0,"database":"test_active","table":"message2","pkNames":["id","first_name"],"isDdl":false,"type":"INSERT","es":1770344412751,"ts":1770344413749,"sql":"","sqlType":{"id":4,"first_name":12,"last_name":12,"_tidb_origin_ts":-5,"_tidb_softdelete_time":93},"mysqlType":{"id":"int","first_name":"varchar","last_name":"varchar","_tidb_origin_ts":"bigint","_tidb_softdelete_time":"timestamp"},"old":null,"data":[{"id":"100","first_name":"a","last_name":"A","_tidb_origin_ts":"464085165262503958","_tidb_softdelete_time":null}],"_tidb":{"commitTs":464085165736198159}}` + "\r\n" + + records, err := decoder.Decode([]byte(dataWithInvalidLine), buildColumnFieldTypes(t, tableDefinition2)) + require.NoError(t, err) + // The invalid line should be skipped, only the valid record should be returned. + require.Len(t, records, 1) + require.Equal(t, ExpectedRecords2[0].Pk, records[0].Pk) + require.Equal(t, ExpectedRecords2[0].PkStr, records[0].PkStr) + require.Equal(t, ExpectedRecords2[0].CdcVersion.CommitTs, records[0].CdcVersion.CommitTs) + require.Equal(t, ExpectedRecords2[0].CdcVersion.OriginTs, records[0].CdcVersion.OriginTs) +} + +// TestCanalJSONDecoderAllInvalidMessages verifies that when all messages are malformed, +// the decoder returns an empty result without errors. +func TestCanalJSONDecoderAllInvalidMessages(t *testing.T) { + allInvalid := `{broken}` + "\r\n" + `{also broken}` + "\r\n" + records, err := decoder.Decode([]byte(allInvalid), nil) + require.NoError(t, err) + require.Empty(t, records) +} + +func TestRecord_EqualReplicatedRecord(t *testing.T) { + tests := []struct { + name string + local *decoder.Record + replicated *decoder.Record + expectedEqual bool + }{ + { + name: "equal records", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": 42, + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": 42, + }, + }, + expectedEqual: true, + }, + { + name: "replicated is nil", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: nil, + expectedEqual: false, + }, + { + name: "different CommitTs and OriginTs", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 200}, + Pk: "pk1", + }, + expectedEqual: false, + }, + { + name: "different primary keys", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk2", + }, + expectedEqual: false, + }, + { + name: "different column count", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + "col2": "value2", + }, + }, + expectedEqual: false, + }, + { + name: "different column names", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col2": "value1", + }, + }, + expectedEqual: false, + }, + { + name: "different column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value1", + }, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{ + "col1": "value2", + }, + }, + expectedEqual: false, + }, + { + name: "empty column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: map[string]any{}, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: map[string]any{}, + }, + expectedEqual: true, + }, + { + name: "nil column values", + local: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 100, OriginTs: 0}, + Pk: "pk1", + ColumnValues: nil, + }, + replicated: &decoder.Record{ + CdcVersion: types.CdcVersion{CommitTs: 101, OriginTs: 100}, + Pk: "pk1", + ColumnValues: nil, + }, + expectedEqual: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.local.EqualReplicatedRecord(tt.replicated) + require.Equal(t, tt.expectedEqual, result) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go b/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go new file mode 100644 index 0000000000..95bc8d43e9 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/decoder/value_to_datum_test.go @@ -0,0 +1,898 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package decoder + +import ( + "math" + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + ptypes "github.com/pingcap/tidb/pkg/parser/types" + tiTypes "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestValueToDatum_NilValue(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + d := valueToDatum(nil, ft) + require.True(t, d.IsNull()) +} + +func TestValueToDatum_NonStringPanics(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + require.Panics(t, func() { + valueToDatum(123, ft) + }) +} + +func TestValueToDatum_SignedIntegers(t *testing.T) { + t.Parallel() + + intTypes := []struct { + name string + tp byte + }{ + {"TypeTiny", mysql.TypeTiny}, + {"TypeShort", mysql.TypeShort}, + {"TypeInt24", mysql.TypeInt24}, + {"TypeLong", mysql.TypeLong}, + {"TypeLonglong", mysql.TypeLonglong}, + } + + for _, it := range intTypes { + t.Run(it.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(it.tp) + + t.Run("positive", func(t *testing.T) { + t.Parallel() + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(42), d.GetInt64()) + }) + + t.Run("zero", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(0), d.GetInt64()) + }) + + t.Run("negative", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-100", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(-100), d.GetInt64()) + }) + + t.Run("max int64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("9223372036854775807", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MaxInt64), d.GetInt64()) + }) + + t.Run("min int64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-9223372036854775808", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MinInt64), d.GetInt64()) + }) + }) + } +} + +func TestValueToDatum_UnsignedIntegers(t *testing.T) { + t.Parallel() + + intTypes := []struct { + name string + tp byte + }{ + {"TypeTiny", mysql.TypeTiny}, + {"TypeShort", mysql.TypeShort}, + {"TypeInt24", mysql.TypeInt24}, + {"TypeLong", mysql.TypeLong}, + {"TypeLonglong", mysql.TypeLonglong}, + } + + for _, it := range intTypes { + t.Run(it.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(it.tp) + ft.AddFlag(mysql.UnsignedFlag) + + t.Run("positive", func(t *testing.T) { + t.Parallel() + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(42), d.GetUint64()) + }) + + t.Run("zero", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(0), d.GetUint64()) + }) + + t.Run("max uint64", func(t *testing.T) { + t.Parallel() + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(math.MaxUint64), d.GetUint64()) + }) + }) + } +} + +func TestValueToDatum_InvalidIntegerPanics(t *testing.T) { + t.Parallel() + + t.Run("signed invalid", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + require.Panics(t, func() { + valueToDatum("not_a_number", ft) + }) + }) + + t.Run("unsigned invalid", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeLong) + ft.AddFlag(mysql.UnsignedFlag) + require.Panics(t, func() { + valueToDatum("not_a_number", ft) + }) + }) +} + +func TestValueToDatum_Year(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeYear) + + t.Run("normal year", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2026", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(2026), d.GetInt64()) + }) + + t.Run("zero year", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(0), d.GetInt64()) + }) + + t.Run("invalid year panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("abc", ft) + }) + }) +} + +func TestValueToDatum_Float(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeFloat) + + t.Run("positive float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3.14", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(3.14), d.GetFloat32(), 0.001) + }) + + t.Run("negative float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-2.5", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(-2.5), d.GetFloat32(), 0.001) + }) + + t.Run("zero float", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.Equal(t, float32(0), d.GetFloat32()) + }) + + t.Run("invalid float panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_float", ft) + }) + }) +} + +func TestValueToDatum_Double(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeDouble) + + t.Run("positive double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3.141592653589793", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, 3.141592653589793, d.GetFloat64(), 1e-15) + }) + + t.Run("negative double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("-1.23456789", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, -1.23456789, d.GetFloat64(), 1e-9) + }) + + t.Run("zero double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.Equal(t, float64(0), d.GetFloat64()) + }) + + t.Run("very large double", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1.7976931348623157e+308", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, math.MaxFloat64, d.GetFloat64(), 1e+293) + }) + + t.Run("invalid double panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_double", ft) + }) + }) +} + +func TestValueToDatum_StringTypes(t *testing.T) { + t.Parallel() + + stringTypes := []struct { + name string + tp byte + }{ + {"TypeVarString", mysql.TypeVarString}, + {"TypeVarchar", mysql.TypeVarchar}, + {"TypeString", mysql.TypeString}, + {"TypeBlob", mysql.TypeBlob}, + {"TypeTinyBlob", mysql.TypeTinyBlob}, + {"TypeMediumBlob", mysql.TypeMediumBlob}, + {"TypeLongBlob", mysql.TypeLongBlob}, + } + + for _, st := range stringTypes { + t.Run(st.name, func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(st.tp) + ft.SetCollate("utf8mb4_bin") + + t.Run("normal string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("hello world", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "hello world", d.GetString()) + }) + + t.Run("empty string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "", d.GetString()) + }) + + t.Run("unicode string", func(t *testing.T) { + t.Parallel() + d := valueToDatum("你好世界🌍", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "你好世界🌍", d.GetString()) + }) + }) + } +} + +func TestValueToDatum_BinaryFlag(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeString) + ft.AddFlag(mysql.BinaryFlag) + ft.SetCharset("binary") + ft.SetCollate("binary") + + t.Run("ascii content", func(t *testing.T) { + t.Parallel() + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "abc", d.GetString()) + }) + + t.Run("empty binary", func(t *testing.T) { + t.Parallel() + d := valueToDatum("", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "", d.GetString()) + }) +} + +func TestValueToDatum_Decimal(t *testing.T) { + t.Parallel() + + t.Run("simple decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(2) + + d := valueToDatum("123.45", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "123.45", d.GetMysqlDecimal().String()) + require.Equal(t, 10, d.Length()) + require.Equal(t, 2, d.Frac()) + }) + + t.Run("negative decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(3) + + d := valueToDatum("-99.999", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "-99.999", d.GetMysqlDecimal().String()) + require.Equal(t, 3, d.Frac()) + }) + + t.Run("zero decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(0) + + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "0", d.GetMysqlDecimal().String()) + }) + + t.Run("large decimal", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(65) + ft.SetDecimal(30) + + d := valueToDatum("12345678901234567890.123456789012345678", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, 65, d.Length()) + require.Equal(t, 30, d.Frac()) + }) + + t.Run("unspecified decimal uses actual frac", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(tiTypes.UnspecifiedLength) + + d := valueToDatum("12.345", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, 3, d.Frac()) // actual digits frac from the value + }) + + t.Run("invalid decimal panics", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeNewDecimal) + ft.SetFlen(10) + ft.SetDecimal(2) + require.Panics(t, func() { + valueToDatum("not_decimal", ft) + }) + }) +} + +func TestValueToDatum_Date(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeDate) + ft.SetDecimal(0) + + t.Run("normal date", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2026-02-11", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11", d.GetMysqlTime().String()) + }) + + t.Run("zero date", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0000-00-00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "0000-00-00", d.GetMysqlTime().String()) + }) +} + +func TestValueToDatum_Datetime(t *testing.T) { + t.Parallel() + + t.Run("datetime without fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDatetime) + ft.SetDecimal(0) + + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00", d.GetMysqlTime().String()) + }) + + t.Run("datetime with fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDatetime) + ft.SetDecimal(6) + + d := valueToDatum("2026-02-11 10:30:00.123456", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00.123456", d.GetMysqlTime().String()) + }) +} + +func TestValueToDatum_Timestamp(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeTimestamp) + ft.SetDecimal(0) + + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11 10:30:00", d.GetMysqlTime().String()) +} + +func TestValueToDatum_Duration(t *testing.T) { + t.Parallel() + + t.Run("positive duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("12:30:45", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "12:30:45", d.GetMysqlDuration().String()) + }) + + t.Run("negative duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("-01:00:00", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "-01:00:00", d.GetMysqlDuration().String()) + }) + + t.Run("duration with fractional seconds", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(3) + + d := valueToDatum("10:20:30.123", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "10:20:30.123", d.GetMysqlDuration().String()) + }) + + t.Run("zero duration", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeDuration) + ft.SetDecimal(0) + + d := valueToDatum("00:00:00", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + require.Equal(t, "00:00:00", d.GetMysqlDuration().String()) + }) +} + +func TestValueToDatum_Enum(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeEnum) + ft.SetCharset("utf8mb4") + ft.SetCollate("utf8mb4_bin") + ft.SetElems([]string{"a", "b", "c"}) + + t.Run("valid enum value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(1), d.GetMysqlEnum().Value) + }) + + t.Run("enum value 2", func(t *testing.T) { + t.Parallel() + d := valueToDatum("2", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(2), d.GetMysqlEnum().Value) + }) + + t.Run("enum value 0", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(0), d.GetMysqlEnum().Value) + }) + + t.Run("invalid enum panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("abc", ft) + }) + }) +} + +func TestValueToDatum_Set(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeSet) + ft.SetCharset("utf8mb4") + ft.SetCollate("utf8mb4_bin") + ft.SetElems([]string{"a", "b", "c"}) + + t.Run("single set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(1), d.GetMysqlSet().Value) + }) + + t.Run("combined set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("3", ft) // a,b + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(3), d.GetMysqlSet().Value) + }) + + t.Run("zero set value", func(t *testing.T) { + t.Parallel() + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(0), d.GetMysqlSet().Value) + }) + + t.Run("invalid set panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("xyz", ft) + }) + }) +} + +func TestValueToDatum_Bit(t *testing.T) { + t.Parallel() + + t.Run("bit(1) value 1", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(1) + + d := valueToDatum("1", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit(8) value 255", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + + d := valueToDatum("255", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit(64) large value", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(64) + + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("bit zero", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + + d := valueToDatum("0", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("invalid bit panics", func(t *testing.T) { + t.Parallel() + ft := ptypes.NewFieldType(mysql.TypeBit) + ft.SetFlen(8) + require.Panics(t, func() { + valueToDatum("not_a_bit", ft) + }) + }) +} + +func TestValueToDatum_JSON(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeJSON) + + t.Run("json object", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`{"key": "value"}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + require.Contains(t, d.GetMysqlJSON().String(), "key") + require.Contains(t, d.GetMysqlJSON().String(), "value") + }) + + t.Run("json array", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`[1, 2, 3]`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json string", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`"hello"`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json number", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`42`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json null", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`null`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("json boolean", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`true`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("nested json", func(t *testing.T) { + t.Parallel() + d := valueToDatum(`{"a": [1, {"b": "c"}], "d": null}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) + + t.Run("invalid json panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum(`{invalid`, ft) + }) + }) +} + +func TestValueToDatum_VectorFloat32(t *testing.T) { + t.Parallel() + + ft := ptypes.NewFieldType(mysql.TypeTiDBVectorFloat32) + + t.Run("simple vector", func(t *testing.T) { + t.Parallel() + d := valueToDatum("[1,2,3]", ft) + require.False(t, d.IsNull()) + }) + + t.Run("single element vector", func(t *testing.T) { + t.Parallel() + d := valueToDatum("[0.5]", ft) + require.False(t, d.IsNull()) + }) + + t.Run("invalid vector panics", func(t *testing.T) { + t.Parallel() + require.Panics(t, func() { + valueToDatum("not_a_vector", ft) + }) + }) +} + +func TestValueToDatum_UnknownType(t *testing.T) { + t.Parallel() + // Use a type that doesn't match any case in the switch (TypeGeometry). + // The default datum returned is a zero-value datum, which is null. + ft := ptypes.NewFieldType(mysql.TypeGeometry) + d := valueToDatum("some_value", ft) + require.True(t, d.IsNull()) +} + +func TestValueToDatum_ViaNewPKColumnFieldType(t *testing.T) { + t.Parallel() + // Test valueToDatum using FieldTypes produced by newPKColumnFieldTypeFromMysqlType, + // which is the real caller in production code. + + t.Run("int", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("int") + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(42), d.GetInt64()) + }) + + t.Run("int unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("int unsigned") + d := valueToDatum("42", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(42), d.GetUint64()) + }) + + t.Run("bigint", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bigint") + d := valueToDatum("9223372036854775807", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(math.MaxInt64), d.GetInt64()) + }) + + t.Run("bigint unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bigint unsigned") + d := valueToDatum("18446744073709551615", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(math.MaxUint64), d.GetUint64()) + }) + + t.Run("varchar", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("varchar") + d := valueToDatum("hello", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "hello", d.GetString()) + }) + + t.Run("char", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("char") + d := valueToDatum("x", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + require.Equal(t, "x", d.GetString()) + }) + + t.Run("decimal(10,2)", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("decimal(10,2)") + d := valueToDatum("123.45", ft) + require.Equal(t, tiTypes.KindMysqlDecimal, d.Kind()) + require.Equal(t, "123.45", d.GetMysqlDecimal().String()) + require.Equal(t, 10, d.Length()) + require.Equal(t, 2, d.Frac()) + }) + + t.Run("float", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("float") + d := valueToDatum("3.14", ft) + require.Equal(t, tiTypes.KindFloat32, d.Kind()) + require.InDelta(t, float32(3.14), d.GetFloat32(), 0.001) + }) + + t.Run("double", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("double") + d := valueToDatum("3.141592653589793", ft) + require.Equal(t, tiTypes.KindFloat64, d.Kind()) + require.InDelta(t, 3.141592653589793, d.GetFloat64(), 1e-15) + }) + + t.Run("binary", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("binary") + require.True(t, mysql.HasBinaryFlag(ft.GetFlag())) + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + }) + + t.Run("varbinary", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("varbinary") + require.True(t, mysql.HasBinaryFlag(ft.GetFlag())) + d := valueToDatum("abc", ft) + require.Equal(t, tiTypes.KindString, d.Kind()) + }) + + t.Run("tinyint", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("tinyint") + d := valueToDatum("127", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(127), d.GetInt64()) + }) + + t.Run("smallint unsigned", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("smallint unsigned") + d := valueToDatum("65535", ft) + require.Equal(t, tiTypes.KindUint64, d.Kind()) + require.Equal(t, uint64(65535), d.GetUint64()) + }) + + t.Run("date", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("date") + d := valueToDatum("2026-02-11", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + require.Equal(t, "2026-02-11", d.GetMysqlTime().String()) + }) + + t.Run("datetime", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("datetime") + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + }) + + t.Run("timestamp", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("timestamp") + d := valueToDatum("2026-02-11 10:30:00", ft) + require.Equal(t, tiTypes.KindMysqlTime, d.Kind()) + }) + + t.Run("time", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("time") + d := valueToDatum("12:30:45", ft) + require.Equal(t, tiTypes.KindMysqlDuration, d.Kind()) + }) + + t.Run("year", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("year") + d := valueToDatum("2026", ft) + require.Equal(t, tiTypes.KindInt64, d.Kind()) + require.Equal(t, int64(2026), d.GetInt64()) + }) + + t.Run("enum('a','b','c')", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("enum('a','b','c')") + d := valueToDatum("2", ft) + require.Equal(t, tiTypes.KindMysqlEnum, d.Kind()) + require.Equal(t, uint64(2), d.GetMysqlEnum().Value) + }) + + t.Run("set('x','y','z')", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("set('x','y','z')") + d := valueToDatum("5", ft) + require.Equal(t, tiTypes.KindMysqlSet, d.Kind()) + require.Equal(t, uint64(5), d.GetMysqlSet().Value) + }) + + t.Run("bit(8)", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("bit(8)") + d := valueToDatum("255", ft) + require.Equal(t, tiTypes.KindMysqlBit, d.Kind()) + }) + + t.Run("json", func(t *testing.T) { + t.Parallel() + ft := newPKColumnFieldTypeFromMysqlType("json") + d := valueToDatum(`{"key":"value"}`, ft) + require.Equal(t, tiTypes.KindMysqlJSON, d.Kind()) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/integration/integration_test.go b/cmd/multi-cluster-consistency-checker/integration/integration_test.go new file mode 100644 index 0000000000..236ebb2568 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/integration_test.go @@ -0,0 +1,732 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/advancer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/stretchr/testify/require" +) + +// schemaKey is the schema key for data stored via S3 path "test/t1/1/...". +// It equals QuoteSchema("test", "t1") = "`test`.`t1`". +var schemaKey = (&cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, +}).GetKey() + +// testEnv holds the initialized test environment. +type testEnv struct { + ctx context.Context + mc *MockMultiCluster + advancer *advancer.TimeWindowAdvancer + checker *checker.DataChecker +} + +// setupEnv creates a test environment with 2 clusters (c1, c2), both +// replicating to each other, with in-memory S3 storage and mock PD/watchers. +func setupEnv(t *testing.T) *testEnv { + t.Helper() + ctx := context.Background() + tables := map[string][]string{"test": {"t1"}} + + mc := NewMockMultiCluster( + []string{"c1", "c2"}, + tables, + 0, // pdBase: start physical time at 0ms + 100, // pdStep: 100ms per PD GetTS call + 100, // cpDelta: checkpoint = minCheckpointTs + 100 + 50, // s3Delta: s3 checkpoint = minCheckpointTs + 50 + ) + + require.NoError(t, mc.InitSchemaFiles(ctx)) + + twa, _, err := advancer.NewTimeWindowAdvancer( + ctx, mc.CPWatchers, mc.S3Watchers, mc.GetPDClients(), nil, + ) + require.NoError(t, err) + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + dc := checker.NewDataChecker(ctx, clusterCfg, nil, nil) + + return &testEnv{ctx: ctx, mc: mc, advancer: twa, checker: dc} +} + +// roundResult holds the output of a single round. +type roundResult struct { + report *recorder.Report + twData map[string]types.TimeWindowData +} + +// executeRound writes data to clusters' S3 storage, advances the time window, +// and runs the checker for one round. +func (e *testEnv) executeRound(t *testing.T, c1Content, c2Content []byte) roundResult { + t.Helper() + if c1Content != nil { + require.NoError(t, e.mc.WriteDMLFile(e.ctx, "c1", c1Content)) + } + if c2Content != nil { + require.NoError(t, e.mc.WriteDMLFile(e.ctx, "c2", c2Content)) + } + + twData, err := e.advancer.AdvanceTimeWindow(e.ctx) + require.NoError(t, err) + + report, err := e.checker.CheckInNextTimeWindow(twData) + require.NoError(t, err) + + return roundResult{report: report, twData: twData} +} + +// maxRightBoundary returns the maximum RightBoundary across all clusters. +func maxRightBoundary(twData map[string]types.TimeWindowData) uint64 { + maxRB := uint64(0) + for _, tw := range twData { + if tw.TimeWindow.RightBoundary > maxRB { + maxRB = tw.TimeWindow.RightBoundary + } + } + return maxRB +} + +// The test architecture simulates a 2-cluster active-active setup: +// +// c1 (locally-written records) ──CDC──> c2 (replicated records) +// c2 (locally-written records) ──CDC──> c1 (replicated records) +// +// Each cluster writes locally-written records (originTs=0) and receives replicated +// records from the other cluster (originTs>0). +// +// The checker needs 3 warm-up rounds before it starts checking (checkableRound >= 3). +// Data written in round 0 is tracked by the S3 consumer but not downloaded +// (skipDownloadData=true for the first round). From round 1 onwards, only +// NEW files (with higher indices) are downloaded. +// +// Data commitTs is set to prevMaxRightBoundary+1 to ensure records fall +// within the current time window (leftBoundary, rightBoundary]. +// +// TestIntegration_AllConsistent verifies that no errors are reported +// when all locally-written records have matching replicated records. +func TestIntegration_AllConsistent(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + // c1: locally-written records write (originTs=0) + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + // c2: replicated records replicated from c1 (originTs = c1's commitTs) + c2 := MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: c1 TW=[%d, %d], c2 TW=[%d, %d], commitTs=%d", + round, + result.twData["c1"].TimeWindow.LeftBoundary, result.twData["c1"].TimeWindow.RightBoundary, + result.twData["c2"].TimeWindow.LeftBoundary, result.twData["c2"].TimeWindow.RightBoundary, + cts) + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: all data should be consistent, no report needed", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_AllConsistent_CrossRoundReplicatedRecords verifies that the checker +// treats data as consistent when a locally-written record's commitTs exceeds the +// round's checkpointTs, and the matching replicated records only appears in the next +// round. +// +// This occurs when locally-written records commitTs happen late in the time window, after +// the checkpoint has already been determined. For TW[2], records with +// commitTs > checkpointTs are deferred (skipped). In the next round they +// become TW[1], where the check condition is commitTs > checkpointTs (checked), +// and the replicated records are searched in TW[1] + TW[2] — finding the match in +// the current round's TW[2]. +func TestIntegration_AllConsistent_CrossRoundReplicatedRecords(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + // Offset to place commitTs between checkpointTs and rightBoundary. + // With pdStep=100 and 2 clusters, each round's time window spans + // approximately ComposeTS(300, 0) = 78643200, and checkpointTs sits + // at roughly ComposeTS(200, 0) from leftBoundary. + // Using ComposeTS(250, 0) = 65536000 lands safely between them. + crossRoundOffset := uint64(250 << 18) // ComposeTS(250, 0) = 65536000 + + var lateLocallyWrittenRecordsCommitTs uint64 + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Round N: c1 local has two records: + // pk=round+1 normal commitTs (checked in this round's TW[2]) + // pk=100 large commitTs > checkpointTs + // (deferred in TW[2], checked via TW[1] next round) + // c2 replicated only matches pk=round+1. + lateLocallyWrittenRecordsCommitTs = prevMaxRB + crossRoundOffset + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, lateLocallyWrittenRecordsCommitTs, 0, "late"), + ) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + case 5: + // Round N+1: c2 now includes the replicated record for pk=100. + // The checker evaluates TW[1] (= round 4), finds pk=100 with + // commitTs > checkpointTs, and searches c2's TW[1] + TW[2]. + // pk=100's matching replicated record is in c2's TW[2] (this round). + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, cts+2, lateLocallyWrittenRecordsCommitTs, "late"), + ) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: c1 TW=[%d, %d], cpTs=%v, commitTs=%d", + round, + result.twData["c1"].TimeWindow.LeftBoundary, + result.twData["c1"].TimeWindow.RightBoundary, + result.twData["c1"].TimeWindow.CheckpointTs, + cts) + + if round == 4 { + // Verify the late commitTs falls between checkpointTs and rightBoundary. + c1TW := result.twData["c1"].TimeWindow + cpTs := c1TW.CheckpointTs["c2"] + require.Greater(t, lateLocallyWrittenRecordsCommitTs, cpTs, + "lateLocallyWrittenRecordsCommitTs must be > checkpointTs for cross-round detection") + require.LessOrEqual(t, lateLocallyWrittenRecordsCommitTs, c1TW.RightBoundary, + "lateLocallyWrittenRecordsCommitTs must be <= rightBoundary to stay in this time window") + t.Logf("Round 4 verification: lateCommitTs=%d, checkpointTs=%d, rightBoundary=%d", + lateLocallyWrittenRecordsCommitTs, cpTs, c1TW.RightBoundary) + } + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: data should be consistent (cross-round matching should work)", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_AllConsistent_LWWSkippedReplicatedRecords verifies that no errors +// are reported when a replicated record is "LWW-skipped" during data-loss +// detection, combined with cross-time-window matching. +// +// pk=100: single-cluster overwrite (c1 writes old+new, c2 only has newer replicated records) +// +// Round N: c1 locally-written records pk=100 × 2 (commitTs=A, B; both > checkpointTs) +// c2 has NO replicated records for pk=100 +// Round N+1: c2 replicated records pk=100 (originTs=B, matches newer locally-written records only) +// → old locally-written records LWW-skipped (c2 replicated records compareTs=B >= A) +// +// pk=200: bidirectional write (c1 and c2 both write the same pk) +// +// Round N: c1 locally-written records pk=200 (commitTs=A, deferred) +// Round N+1: c1 locally-written records pk=200 (commitTs=E, newer), c1 replicated records pk=200 (originTs=D, from c2) +// c2 replicated records pk=200 (commitTs=D, D < E), c2 replicated records pk=200 (originTs=E, from c1) +// +// Key constraint: c1 local commitTs (E) > c2 local commitTs (D). +// This ensures that on c2, the replicated (compareTs=E) > local (compareTs=D), +// so the LWW violation checker sees monotonically increasing compareTs. +// +// c1 data loss for old pk=200 (commitTs=A): +// → c2 replicated has originTs=E, compareTs=E >= A → LWW-skipped ✓ +// c1 data loss for new pk=200 (commitTs=E): +// → c2 replicated has originTs=E → exact match ✓ +// c2 data loss for c2 local pk=200 (commitTs=D): +// → c1 replicated has originTs=D → exact match ✓ +func TestIntegration_AllConsistent_LWWSkippedReplicatedRecords(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + + // Place both commitTs values between checkpointTs and rightBoundary. + // With pdStep=100 and 2 clusters: + // window width ≈ ComposeTS(300, 0), checkpointTs ≈ leftBoundary + ComposeTS(200, 0) + // Using ComposeTS(250, 0) puts us safely in the gap. + crossRoundOffset := uint64(250 << 18) // ComposeTS(250, 0) = 65536000 + + var oldCommitTs, newCommitTs uint64 + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Round N: c1 local writes pk=100 twice + pk=200 once, all > checkpointTs. + // c2 has NO replicated record for pk=100 or pk=200; they arrive next round. + oldCommitTs = prevMaxRB + crossRoundOffset + newCommitTs = oldCommitTs + 5 + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, oldCommitTs, 0, "old_write"), + MakeCanalJSON(100, newCommitTs, 0, "new_write"), + MakeCanalJSON(200, oldCommitTs, 0, "old_write"), + ) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + ) + + case 5: + // Round N+1: replicated data arrives for both pk=100 and pk=200. + // + // pk=200 bidirectional: c1 local at cts+5 (> c2 local at cts+2) + // ensures c2's LWW check sees increasing compareTs. + // c1: replicated(commitTs=cts+4, originTs=cts+2) then local(commitTs=cts+5) + // → compareTs order: cts+2 < cts+5 ✓ + // c2: local(commitTs=cts+2) then replicated(commitTs=cts+6, originTs=cts+5) + // → compareTs order: cts+2 < cts+5 ✓ + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(200, cts+4, cts+2, "pk200_c2"), // c1 replicated pk=200 from c2 + MakeCanalJSON(200, cts+5, 0, "pk200_c1"), // c1 local pk=200 (newer) + ) + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(100, cts+2, newCommitTs, "new_write"), + MakeCanalJSON(200, cts+2, 0, "pk200_c2"), // c2 local pk=200 + MakeCanalJSON(200, cts+6, cts+5, "pk200_c1"), // c2 replicated pk=200 from c1 + ) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round == 4 { + // Verify both commitTs fall between checkpointTs and rightBoundary. + c1TW := result.twData["c1"].TimeWindow + cpTs := c1TW.CheckpointTs["c2"] + require.Greater(t, oldCommitTs, cpTs, + "oldCommitTs must be > checkpointTs for cross-round deferral") + require.LessOrEqual(t, newCommitTs, c1TW.RightBoundary, + "newCommitTs must be <= rightBoundary to stay in this time window") + t.Logf("Round 4 verification: oldCommitTs=%d, newCommitTs=%d, checkpointTs=%d, rightBoundary=%d", + oldCommitTs, newCommitTs, cpTs, c1TW.RightBoundary) + } + + if round >= 3 { + require.Len(t, result.report.ClusterReports, 2, "round %d", round) + require.False(t, result.report.NeedFlush(), + "round %d: cross-round LWW-skipped replicated should not cause errors", round) + for clusterID, cr := range result.report.ClusterReports { + require.Empty(t, cr.TableFailureItems, + "round %d, cluster %s: should have no failures", round, clusterID) + } + } + } +} + +// TestIntegration_DataLoss verifies that the checker detects data loss +// when a locally-written record has no matching replicated record in the other cluster. +func TestIntegration_DataLoss(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + dataLossDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + // c1 always produces local data + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + + var c2 []byte + if round == 4 { + // Round 4: c2 has NO matching replicated record → data loss expected + // (round 4's data is checked in the same round since checkableRound >= 3) + c2 = nil + } else { + // Normal: c2 has matching replicated record + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.DataLossItems) > 0 { + t.Logf("Round %d: detected data loss: %+v", round, items.DataLossItems) + dataLossDetected = true + // Verify the data loss item + for _, item := range items.DataLossItems { + require.Equal(t, "c2", item.PeerClusterID) + } + } + } + } + } + } + + require.True(t, dataLossDetected, "data loss should have been detected") +} + +// TestIntegration_DataInconsistent verifies that the checker detects data +// inconsistency when a replicated record has different column values +// from the locally-written record. +func TestIntegration_DataInconsistent(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + inconsistentDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + + var c2 []byte + if round == 4 { + // Round 4: c2 has replicated record with WRONG column value + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, "WRONG_VALUE")) + } else { + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + for _, item := range items.DataInconsistentItems { + t.Logf("Round %d: detected data inconsistency: %+v", round, item) + inconsistentDetected = true + require.Equal(t, "c2", item.PeerClusterID) + } + } + } + } + } + + require.True(t, inconsistentDetected, "data inconsistency should have been detected") +} + +// TestIntegration_DataRedundant verifies that the checker detects redundant +// replicated data that has no matching locally-written record. +func TestIntegration_DataRedundant(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + redundantDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + c1 := MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 := MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + if round == 4 { + // Round 4: c2 has an EXTRA replicated record (pk=999) with a fake + // originTs that doesn't match any c1 local commitTs. + fakeOriginTs := cts - 5 // Doesn't match any c1 local commitTs + c2 = MakeContent( + MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round)), + MakeCanalJSON(999, cts+2, fakeOriginTs, "extra"), + ) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c2Report := result.report.ClusterReports["c2"] + if c2Report != nil { + if items, ok := c2Report.TableFailureItems[schemaKey]; ok { + if len(items.DataRedundantItems) > 0 { + t.Logf("Round %d: detected data redundant: %+v", round, items.DataRedundantItems) + redundantDetected = true + } + } + } + } + } + + require.True(t, redundantDetected, "data redundancy should have been detected") +} + +// TestIntegration_LWWViolation verifies that the checker detects Last Write Wins +// violations when records for the same primary key have non-monotonic origin timestamps. +func TestIntegration_LWWViolation(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + lwwViolationDetected := false + + for round := 0; round < 6; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + if round == 4 { + // Round 4: inject LWW violation in c1. + // Record A: pk=5, commitTs=cts, originTs=0 → compareTs = cts + // Record B: pk=5, commitTs=cts+2, originTs=cts-10 → compareTs = cts-10 + // Since A's compareTs (cts) >= B's compareTs (cts-10) and A's commitTs < B's commitTs, + // this is a Last Write Wins violation. + c1 = MakeContent( + MakeCanalJSON(5, cts, 0, "original"), + MakeCanalJSON(5, cts+2, cts-10, "replicated"), + ) + // c2: provide matching replicated record to avoid data loss noise + c2 = MakeContent( + MakeCanalJSON(5, cts+1, cts, "original"), + ) + } else { + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.LWWViolationItems) > 0 { + t.Logf("Round %d: detected LWW violation: %+v", round, items.LWWViolationItems) + lwwViolationDetected = true + } + } + } + } + } + + require.True(t, lwwViolationDetected, "LWW violation should have been detected") +} + +// TestIntegration_LWWViolation_AcrossRounds verifies that the checker detects +// LWW violations when conflicting records for the same pk appear in rounds N +// and N+2, with no data for that pk in round N+1. +// +// The clusterViolationChecker keeps cache entries for up to 3 rounds +// (previous: 0 → 1 → 2). Since Check runs before UpdateCache, an entry +// created in round N (previous=0) is still available at previous=2 when +// round N+2 runs. +// +// Timeline: +// +// Round N: c1 local pk=50 (originTs=0, compareTs=A) → cached +// Round N+1: no pk=50 data → cache ages (prev 1→2) +// Round N+2: c1 replicated pk=50 (originTs=B= new.compareTs + // → LWW violation across 2-round gap. + violatingOriginTs := firstRecordCommitTs - 10 + c1 = MakeContent( + MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round)), + MakeCanalJSON(50, cts+2, violatingOriginTs, "second"), + ) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d", round, result.report.NeedFlush(), cts) + + if round >= 3 && result.report.NeedFlush() { + c1Report := result.report.ClusterReports["c1"] + if c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.LWWViolationItems) > 0 { + t.Logf("Round %d: LWW violation across rounds: %+v", + round, items.LWWViolationItems) + lwwViolationDetected = true + // Verify the violation details + item := items.LWWViolationItems[0] + require.Equal(t, uint64(0), item.ExistingOriginTS, + "existing record should be local (originTs=0)") + require.Equal(t, firstRecordCommitTs, item.ExistingCommitTS, + "existing record should be from round N") + } + } + } + } + } + + require.True(t, lwwViolationDetected, + "LWW violation across round N and N+2 should have been detected") +} + +// TestIntegration_MultipleErrorTypes verifies that the checker can detect +// multiple error types simultaneously across different clusters and rounds. +func TestIntegration_MultipleErrorTypes(t *testing.T) { + t.Parallel() + env := setupEnv(t) + defer env.mc.Close() + + prevMaxRB := uint64(0) + dataLossDetected := false + redundantDetected := false + + for round := 0; round < 7; round++ { + cts := prevMaxRB + 1 + + var c1, c2 []byte + + switch round { + case 4: + // Data loss: c1 local pk=5, c2 has NO replicated record + c1 = MakeContent(MakeCanalJSON(5, cts, 0, "lost")) + c2 = nil + case 5: + // Data redundant: c2 has extra replicated pk=888 + c1 = MakeContent(MakeCanalJSON(6, cts, 0, "normal")) + fakeOriginTs := cts - 3 + c2 = MakeContent( + MakeCanalJSON(6, cts+1, cts, "normal"), + MakeCanalJSON(888, cts+2, fakeOriginTs, "ghost"), + ) + default: + c1 = MakeContent(MakeCanalJSON(round+1, cts, 0, fmt.Sprintf("v%d", round))) + c2 = MakeContent(MakeCanalJSON(round+1, cts+1, cts, fmt.Sprintf("v%d", round))) + } + + result := env.executeRound(t, c1, c2) + prevMaxRB = maxRightBoundary(result.twData) + + t.Logf("Round %d: NeedFlush=%v, commitTs=%d, ClusterReports=%d", + round, result.report.NeedFlush(), cts, len(result.report.ClusterReports)) + + if round >= 3 && result.report.NeedFlush() { + // Check c1 for data loss + if c1Report := result.report.ClusterReports["c1"]; c1Report != nil { + if items, ok := c1Report.TableFailureItems[schemaKey]; ok { + if len(items.DataLossItems) > 0 { + dataLossDetected = true + t.Logf("Round %d: data loss detected in c1: %d items", + round, len(items.DataLossItems)) + } + } + } + // Check c2 for data redundant + if c2Report := result.report.ClusterReports["c2"]; c2Report != nil { + if items, ok := c2Report.TableFailureItems[schemaKey]; ok { + if len(items.DataRedundantItems) > 0 { + redundantDetected = true + t.Logf("Round %d: data redundant detected in c2: %d items", + round, len(items.DataRedundantItems)) + } + } + } + } + } + + require.True(t, dataLossDetected, "data loss should have been detected") + require.True(t, redundantDetected, "data redundancy should have been detected") +} diff --git a/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go b/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go new file mode 100644 index 0000000000..96216e09b7 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/mock_cluster.go @@ -0,0 +1,205 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/tidb/br/pkg/storage" + pd "github.com/tikv/pd/client" +) + +// mockPDClient simulates a PD server's TSO service. +// Each GetTS call returns a monotonically increasing physical timestamp. +type mockPDClient struct { + pd.Client + seq atomic.Int64 + base int64 // base physical time in milliseconds + step int64 // increment per call in milliseconds +} + +func (m *mockPDClient) GetTS(_ context.Context) (int64, int64, error) { + n := m.seq.Add(1) + return m.base + n*m.step, 0, nil +} + +func (m *mockPDClient) Close() {} + +// ---------- Mock Checkpoint Watcher ---------- + +// mockWatcher simulates a checkpoint watcher. +// It always returns minCheckpointTs + delta, ensuring the result exceeds the minimum. +type mockWatcher struct { + delta uint64 +} + +func (m *mockWatcher) AdvanceCheckpointTs(_ context.Context, minCheckpointTs uint64) (uint64, error) { + return minCheckpointTs + m.delta, nil +} + +func (m *mockWatcher) Close() {} + +// MockMultiCluster manages the mock infrastructure for simulating multiple +// TiCDC clusters. It provides: +// - Mock PD clients for TSO generation +// - Mock checkpoint watchers for inter-cluster replication checkpoints +// - Mock S3 checkpoint watchers and S3 watchers for cloud storage +// - In-memory S3 storage for each cluster +// - Helpers to write canal-JSON formatted data files +type MockMultiCluster struct { + ClusterIDs []string + Tables map[string][]string // schema -> table names + + S3Storages map[string]storage.ExternalStorage + pdClients map[string]*mockPDClient + CPWatchers map[string]map[string]watcher.Watcher + S3Watchers map[string]*watcher.S3Watcher + + // fileCounters tracks the next DML file index per cluster. + // Files are written with monotonically increasing indices so the + // S3Consumer discovers only new files in each round. + fileCounters map[string]uint64 + + date string // fixed date used in all DML file paths +} + +// NewMockMultiCluster creates a new mock multi-cluster environment. +// +// Parameters: +// - clusterIDs: identifiers for the clusters (e.g. ["c1", "c2"]) +// - tables: schema -> table names mapping (e.g. {"test": ["t1"]}) +// - pdBase: base physical time (ms) for mock PD TSO generation +// - pdStep: physical time increment (ms) per GetTS call +// - cpDelta: checkpoint watcher returns minCheckpointTs + cpDelta +// - s3Delta: S3 checkpoint watcher returns minCheckpointTs + s3Delta +func NewMockMultiCluster( + clusterIDs []string, + tables map[string][]string, + pdBase, pdStep int64, + cpDelta, s3Delta uint64, +) *MockMultiCluster { + mc := &MockMultiCluster{ + ClusterIDs: clusterIDs, + Tables: tables, + S3Storages: make(map[string]storage.ExternalStorage), + pdClients: make(map[string]*mockPDClient), + CPWatchers: make(map[string]map[string]watcher.Watcher), + S3Watchers: make(map[string]*watcher.S3Watcher), + fileCounters: make(map[string]uint64), + date: "2026-02-11", + } + + for _, id := range clusterIDs { + mc.S3Storages[id] = storage.NewMemStorage() + mc.pdClients[id] = &mockPDClient{base: pdBase, step: pdStep} + + // Checkpoint watchers: one per replicated cluster + watchers := make(map[string]watcher.Watcher) + for _, other := range clusterIDs { + if other != id { + watchers[other] = &mockWatcher{delta: cpDelta} + } + } + mc.CPWatchers[id] = watchers + + // S3 watcher: uses in-memory storage + mock checkpoint watcher + s3CpWatcher := &mockWatcher{delta: s3Delta} + mc.S3Watchers[id] = watcher.NewS3Watcher( + s3CpWatcher, + mc.S3Storages[id], + tables, + ) + } + + return mc +} + +// InitSchemaFiles writes initial schema files for all tables in all clusters. +// The schema file content is empty (parser is nil in the current implementation). +func (mc *MockMultiCluster) InitSchemaFiles(ctx context.Context) error { + for _, s3 := range mc.S3Storages { + for schema, tableList := range mc.Tables { + for _, table := range tableList { + path := fmt.Sprintf("%s/%s/meta/schema_1_0000000000.json", schema, table) + if err := s3.WriteFile(ctx, path, []byte("{}")); err != nil { + return err + } + } + } + } + return nil +} + +// WriteDMLFile writes a canal-JSON DML file to a cluster's S3 storage. +// Each call increments the file index for that cluster, ensuring the +// S3Consumer discovers it as a new file. +func (mc *MockMultiCluster) WriteDMLFile(ctx context.Context, clusterID string, content []byte) error { + mc.fileCounters[clusterID]++ + idx := mc.fileCounters[clusterID] + for schema, tableList := range mc.Tables { + for _, table := range tableList { + path := fmt.Sprintf("%s/%s/1/%s/CDC%020d.json", schema, table, mc.date, idx) + if err := mc.S3Storages[clusterID].WriteFile(ctx, path, content); err != nil { + return err + } + } + } + return nil +} + +// GetPDClients returns mock PD clients as the pd.Client interface. +func (mc *MockMultiCluster) GetPDClients() map[string]pd.Client { + clients := make(map[string]pd.Client) + for id, c := range mc.pdClients { + clients[id] = c + } + return clients +} + +// Close closes all S3 watchers. +func (mc *MockMultiCluster) Close() { + for _, sw := range mc.S3Watchers { + sw.Close() + } +} + +// MakeCanalJSON builds a canal-JSON formatted record for testing. +// +// Parameters: +// - pkID: primary key value (int column "id") +// - commitTs: TiDB commit timestamp +// - originTs: origin timestamp (0 for locally-written records, non-zero for replicated records) +// - val: value for the "val" varchar column +func MakeCanalJSON(pkID int, commitTs uint64, originTs uint64, val string) string { + originTsVal := "null" + if originTs > 0 { + originTsVal = fmt.Sprintf(`"%d"`, originTs) + } + return fmt.Sprintf( + `{"id":0,"database":"test","table":"t1","pkNames":["id"],"isDdl":false,"type":"INSERT",`+ + `"es":0,"ts":0,"sql":"","sqlType":{"id":4,"val":12,"_tidb_origin_ts":-5},`+ + `"mysqlType":{"id":"int","val":"varchar","_tidb_origin_ts":"bigint"},`+ + `"old":null,"data":[{"id":"%d","val":"%s","_tidb_origin_ts":%s}],`+ + `"_tidb":{"commitTs":%d}}`, + pkID, val, originTsVal, commitTs) +} + +// MakeContent combines canal-JSON records with CRLF terminator (matching codec config). +func MakeContent(records ...string) []byte { + return []byte(strings.Join(records, "\r\n")) +} diff --git a/cmd/multi-cluster-consistency-checker/main.go b/cmd/multi-cluster-consistency-checker/main.go new file mode 100644 index 0000000000..7f2390f99e --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/main.go @@ -0,0 +1,160 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/logger" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +var ( + cfgPath string + dryRun bool +) + +// Exit codes for multi-cluster-consistency-checker. +// +// 0 – clean shutdown (normal exit or graceful signal handling) +// 1 – transient error, safe to restart (network, I/O, temporary failures) +// 2 – invalid configuration (missing required flags / fields) +// 3 – configuration decode failure (malformed config file) +// 4 – checkpoint corruption, requires manual intervention +// 5 – unrecoverable internal error +const ( + ExitCodeTransient = 1 + ExitCodeInvalidConfig = 2 + ExitCodeDecodeConfigFailed = 3 + ExitCodeCheckpointCorruption = 4 + ExitCodeUnrecoverable = 5 +) + +// ExitError wraps an error with a process exit code so that callers higher in +// the stack can translate domain errors into the correct exit status. +type ExitError struct { + Code int + Err error +} + +func (e *ExitError) Error() string { return e.Err.Error() } +func (e *ExitError) Unwrap() error { return e.Err } + +// exitCodeFromError extracts the exit code from an error. +// If the error is an *ExitError the embedded code is returned; +// otherwise the fallback code is returned. +func exitCodeFromError(err error, fallback int) int { + var ee *ExitError + if errors.As(err, &ee) { + return ee.Code + } + return fallback +} + +const ( + FlagConfig = "config" + FlagDryRun = "dry-run" +) + +func main() { + rootCmd := &cobra.Command{ + Use: "multi-cluster-consistency-checker", + Short: "A tool to check consistency across multiple TiCDC clusters", + Long: "A tool to check consistency across multiple TiCDC clusters by comparing data from different clusters' S3 sink locations", + Run: run, + } + + rootCmd.Flags().StringVarP(&cfgPath, FlagConfig, "c", "", "configuration file path (required)") + rootCmd.MarkFlagRequired(FlagConfig) + rootCmd.Flags().BoolVar(&dryRun, FlagDryRun, false, "validate config and connectivity without running the checker") + + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(ExitCodeUnrecoverable) + } +} + +func run(cmd *cobra.Command, args []string) { + if cfgPath == "" { + fmt.Fprintln(os.Stderr, "error: --config flag is required") + os.Exit(ExitCodeInvalidConfig) + } + + cfg, err := config.LoadConfig(cfgPath) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) + os.Exit(ExitCodeDecodeConfigFailed) + } + + // Initialize logger with configured log level + logLevel := cfg.GlobalConfig.LogLevel + if logLevel == "" { + logLevel = "info" // default log level + } + loggerConfig := &logger.Config{ + Level: logLevel, + } + err = logger.InitLogger(loggerConfig) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err) + os.Exit(ExitCodeUnrecoverable) + } + log.Info("Logger initialized", zap.String("level", logLevel)) + + // Create a context that can be cancelled by signals + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + // Start the task in a goroutine + errChan := make(chan error, 1) + go func() { + err := runTask(ctx, cfg, dryRun) + if err != nil { + log.Error("task error", zap.Error(err)) + } + errChan <- err + }() + + // Wait for either a signal or task completion + select { + case sig := <-sigChan: + fmt.Fprintf(os.Stdout, "\nReceived signal: %v, shutting down gracefully...\n", sig) + cancel() + // Wait for the task to finish + if err := <-errChan; err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "task error during shutdown: %v\n", err) + code := exitCodeFromError(err, ExitCodeTransient) + os.Exit(code) + } + fmt.Fprintf(os.Stdout, "Shutdown complete\n") + case err := <-errChan: + if err != nil { + fmt.Fprintf(os.Stderr, "failed to run task: %v\n", err) + code := exitCodeFromError(err, ExitCodeTransient) + os.Exit(code) + } + } +} diff --git a/cmd/multi-cluster-consistency-checker/main_test.go b/cmd/multi-cluster-consistency-checker/main_test.go new file mode 100644 index 0000000000..ee3fc71f23 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/main_test.go @@ -0,0 +1,192 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "testing" + + "github.com/pingcap/ticdc/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestExitError_Error(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("something went wrong") + ee := &ExitError{Code: ExitCodeTransient, Err: inner} + require.Equal(t, "something went wrong", ee.Error()) +} + +func TestExitError_Unwrap(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("root cause") + ee := &ExitError{Code: ExitCodeCheckpointCorruption, Err: inner} + require.ErrorIs(t, ee, inner) + require.Equal(t, inner, ee.Unwrap()) +} + +func TestExitError_Unwrap_deep(t *testing.T) { + t.Parallel() + root := errors.New("root") + wrapped := fmt.Errorf("layer1: %w", root) + ee := &ExitError{Code: ExitCodeTransient, Err: wrapped} + require.ErrorIs(t, ee, root) +} + +func TestValidateS3BucketPrefix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + changefeedURI string + configURI string + wantErr bool + errContains string + }{ + { + name: "exact match", + changefeedURI: "s3://my-bucket/cluster1/", + configURI: "s3://my-bucket/cluster1/", + wantErr: false, + }, + { + name: "match ignoring trailing slash", + changefeedURI: "s3://my-bucket/cluster1", + configURI: "s3://my-bucket/cluster1/", + wantErr: false, + }, + { + name: "match with query params in changefeed URI", + changefeedURI: "s3://my-bucket/prefix/?protocol=canal-json&date-separator=day", + configURI: "s3://my-bucket/prefix/", + wantErr: false, + }, + { + name: "bucket mismatch", + changefeedURI: "s3://bucket-a/prefix/", + configURI: "s3://bucket-b/prefix/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "prefix mismatch", + changefeedURI: "s3://my-bucket/cluster1/", + configURI: "s3://my-bucket/cluster2/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "scheme mismatch", + changefeedURI: "gcs://my-bucket/prefix/", + configURI: "s3://my-bucket/prefix/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "deeper prefix mismatch", + changefeedURI: "s3://my-bucket/a/b/c/", + configURI: "s3://my-bucket/a/b/d/", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + { + name: "empty config URI", + changefeedURI: "s3://my-bucket/prefix/", + configURI: "", + wantErr: true, + errContains: "bucket/prefix mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateS3BucketPrefix(tt.changefeedURI, tt.configURI, "test-cluster", "cf-1") + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestExitCodeFromError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + fallback int + expected int + }{ + { + name: "ExitError with transient code", + err: &ExitError{Code: ExitCodeTransient, Err: fmt.Errorf("timeout")}, + fallback: ExitCodeUnrecoverable, + expected: ExitCodeTransient, + }, + { + name: "ExitError with checkpoint corruption code", + err: &ExitError{Code: ExitCodeCheckpointCorruption, Err: fmt.Errorf("bad checkpoint")}, + fallback: ExitCodeTransient, + expected: ExitCodeCheckpointCorruption, + }, + { + name: "ExitError with unrecoverable code", + err: &ExitError{Code: ExitCodeUnrecoverable, Err: fmt.Errorf("fatal")}, + fallback: ExitCodeTransient, + expected: ExitCodeUnrecoverable, + }, + { + name: "ExitError with invalid config code", + err: &ExitError{Code: ExitCodeInvalidConfig, Err: fmt.Errorf("missing field")}, + fallback: ExitCodeTransient, + expected: ExitCodeInvalidConfig, + }, + { + name: "ExitError with decode config failed code", + err: &ExitError{Code: ExitCodeDecodeConfigFailed, Err: fmt.Errorf("bad toml")}, + fallback: ExitCodeTransient, + expected: ExitCodeDecodeConfigFailed, + }, + { + name: "plain error returns fallback", + err: fmt.Errorf("some plain error"), + fallback: ExitCodeTransient, + expected: ExitCodeTransient, + }, + { + name: "plain error returns different fallback", + err: fmt.Errorf("another plain error"), + fallback: ExitCodeUnrecoverable, + expected: ExitCodeUnrecoverable, + }, + { + name: "wrapped ExitError is still extracted", + err: fmt.Errorf("outer: %w", &ExitError{Code: ExitCodeCheckpointCorruption, Err: fmt.Errorf("inner")}), + fallback: ExitCodeTransient, + expected: ExitCodeCheckpointCorruption, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + code := exitCodeFromError(tt.err, tt.fallback) + require.Equal(t, tt.expected, code) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder.go b/cmd/multi-cluster-consistency-checker/recorder/recorder.go new file mode 100644 index 0000000000..3affe7b7dd --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder.go @@ -0,0 +1,266 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "go.uber.org/zap" +) + +// ErrCheckpointCorruption is a sentinel error indicating that the persisted +// checkpoint data is corrupted and requires manual intervention to fix. +var ErrCheckpointCorruption = errors.New("checkpoint corruption") + +type Recorder struct { + reportDir string + checkpointDir string + maxReportFiles int + + // reportFiles caches the sorted list of report file names in reportDir. + // Updated in-memory after flush and cleanup to avoid repeated os.ReadDir calls. + reportFiles []string + + checkpoint *Checkpoint +} + +func NewRecorder(dataDir string, clusters map[string]config.ClusterConfig, maxReportFiles int) (*Recorder, error) { + if err := os.MkdirAll(filepath.Join(dataDir, "report"), 0755); err != nil { + return nil, errors.Trace(err) + } + if err := os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755); err != nil { + return nil, errors.Trace(err) + } + if maxReportFiles <= 0 { + maxReportFiles = config.DefaultMaxReportFiles + } + + // Read existing report files once at startup + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + if err != nil { + return nil, errors.Trace(err) + } + reportFiles := make([]string, 0, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + reportFiles = append(reportFiles, entry.Name()) + } + } + sort.Strings(reportFiles) + + r := &Recorder{ + reportDir: filepath.Join(dataDir, "report"), + checkpointDir: filepath.Join(dataDir, "checkpoint"), + maxReportFiles: maxReportFiles, + reportFiles: reportFiles, + + checkpoint: NewCheckpoint(), + } + if err := r.initializeCheckpoint(); err != nil { + return nil, errors.Trace(err) + } + for _, item := range r.checkpoint.CheckpointItems { + if item == nil { + continue + } + if len(item.ClusterInfo) != len(clusters) { + return nil, errors.Annotatef(ErrCheckpointCorruption, "checkpoint item (round %d) cluster info length mismatch, expected %d, got %d", item.Round, len(clusters), len(item.ClusterInfo)) + } + for clusterID := range clusters { + if _, ok := item.ClusterInfo[clusterID]; !ok { + return nil, errors.Annotatef(ErrCheckpointCorruption, "checkpoint item (round %d) cluster info missing for cluster %s", item.Round, clusterID) + } + } + } + + return r, nil +} + +func (r *Recorder) GetCheckpoint() *Checkpoint { + return r.checkpoint +} + +func (r *Recorder) initializeCheckpoint() error { + checkpointFile := filepath.Join(r.checkpointDir, "checkpoint.json") + bakFile := filepath.Join(r.checkpointDir, "checkpoint.json.bak") + + // If checkpoint.json exists, use it directly. + if _, err := os.Stat(checkpointFile); err == nil { + data, err := os.ReadFile(checkpointFile) + if err != nil { + return errors.Trace(err) // transient I/O error + } + if err := json.Unmarshal(data, r.checkpoint); err != nil { + return errors.Annotatef(ErrCheckpointCorruption, "failed to unmarshal checkpoint.json: %v", err) + } + return nil + } + + // checkpoint.json is missing — try recovering from the backup. + // This can happen when the process crashed after rename but before the + // new temp file was renamed into place. + if _, err := os.Stat(bakFile); err == nil { + log.Warn("checkpoint.json not found, recovering from checkpoint.json.bak") + data, err := os.ReadFile(bakFile) + if err != nil { + return errors.Trace(err) // transient I/O error + } + if err := json.Unmarshal(data, r.checkpoint); err != nil { + return errors.Annotatef(ErrCheckpointCorruption, "failed to unmarshal checkpoint.json.bak: %v", err) + } + // Restore the backup as the primary file + if err := os.Rename(bakFile, checkpointFile); err != nil { + return errors.Trace(err) // transient I/O error + } + return nil + } + + // Neither file exists — fresh start. + return nil +} + +func (r *Recorder) RecordTimeWindow(timeWindowData map[string]types.TimeWindowData, report *Report) error { + for clusterID, timeWindow := range timeWindowData { + log.Info("time window advanced", + zap.Uint64("round", report.Round), + zap.String("clusterID", clusterID), + zap.Uint64("window left boundary", timeWindow.LeftBoundary), + zap.Uint64("window right boundary", timeWindow.RightBoundary), + zap.Any("checkpoint ts", timeWindow.CheckpointTs)) + } + if report.NeedFlush() { + if err := r.flushReport(report); err != nil { + return errors.Trace(err) + } + r.cleanupOldReports() + } + if err := r.flushCheckpoint(report.Round, timeWindowData); err != nil { + return errors.Trace(err) + } + return nil +} + +func (r *Recorder) flushReport(report *Report) error { + reportName := fmt.Sprintf("report-%d.report", report.Round) + if err := atomicWriteFile(filepath.Join(r.reportDir, reportName), []byte(report.MarshalReport())); err != nil { + return errors.Trace(err) + } + + jsonName := fmt.Sprintf("report-%d.json", report.Round) + dataBytes, err := json.Marshal(report) + if err != nil { + return errors.Trace(err) + } + if err := atomicWriteFile(filepath.Join(r.reportDir, jsonName), dataBytes); err != nil { + return errors.Trace(err) + } + + // Append new file names to the cache (they are always the latest, so append at end) + r.reportFiles = append(r.reportFiles, reportName, jsonName) + return nil +} + +// atomicWriteFile writes data to a temporary file, fsyncs it to ensure +// durability, and then atomically renames it to the target path. +// This prevents partial / corrupt files on crash. +func atomicWriteFile(targetPath string, data []byte) error { + tempPath := targetPath + ".tmp" + if err := syncWriteFile(tempPath, data); err != nil { + return errors.Trace(err) + } + if err := os.Rename(tempPath, targetPath); err != nil { + return errors.Trace(err) + } + return nil +} + +// syncWriteFile writes data to a file and fsyncs it before returning, +// guaranteeing that the content is durable on disk. +func syncWriteFile(path string, data []byte) error { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return errors.Trace(err) + } + if _, err := f.Write(data); err != nil { + f.Close() + return errors.Trace(err) + } + if err := f.Sync(); err != nil { + f.Close() + return errors.Trace(err) + } + return errors.Trace(f.Close()) +} + +// cleanupOldReports removes the oldest report files from the in-memory cache +// when the total number exceeds maxReportFiles * 2 (each round produces .report + .json). +func (r *Recorder) cleanupOldReports() { + if len(r.reportFiles) <= r.maxReportFiles*2 { + return + } + + toDelete := len(r.reportFiles) - r.maxReportFiles*2 + for i := 0; i < toDelete; i++ { + path := filepath.Join(r.reportDir, r.reportFiles[i]) + if err := os.Remove(path); err != nil { + log.Warn("failed to remove old report file", + zap.String("path", path), + zap.Error(err)) + } else { + log.Info("removed old report file", zap.String("path", path)) + } + } + r.reportFiles = r.reportFiles[toDelete:] +} + +func (r *Recorder) flushCheckpoint(round uint64, timeWindowData map[string]types.TimeWindowData) error { + r.checkpoint.NewTimeWindowData(round, timeWindowData) + + checkpointFile := filepath.Join(r.checkpointDir, "checkpoint.json") + bakFile := filepath.Join(r.checkpointDir, "checkpoint.json.bak") + tempFile := filepath.Join(r.checkpointDir, "checkpoint_temp.json") + + data, err := json.Marshal(r.checkpoint) + if err != nil { + return errors.Trace(err) + } + + // 1. Write the new content to a temp file first and fsync it. + if err := syncWriteFile(tempFile, data); err != nil { + return errors.Trace(err) + } + + // 2. Rename the existing checkpoint to .bak (ignore error if it doesn't exist yet). + if err := os.Rename(checkpointFile, bakFile); err != nil && !os.IsNotExist(err) { + return errors.Trace(err) + } + + // 3. Rename the temp file to be the new checkpoint. + if err := os.Rename(tempFile, checkpointFile); err != nil { + return errors.Trace(err) + } + + // 4. Remove the backup — no longer needed. + _ = os.Remove(bakFile) + + return nil +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go new file mode 100644 index 0000000000..af3a8297b3 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go @@ -0,0 +1,561 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestNewRecorder(t *testing.T) { + t.Parallel() + + t.Run("creates directories", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + require.NotNil(t, r) + + // Verify directories exist + info, err := os.Stat(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.True(t, info.IsDir()) + + info, err = os.Stat(filepath.Join(dataDir, "checkpoint")) + require.NoError(t, err) + require.True(t, info.IsDir()) + }) + + t.Run("checkpoint is initialized empty", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r.GetCheckpoint() + require.NotNil(t, cp) + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.Nil(t, cp.CheckpointItems[2]) + }) + + t.Run("loads existing checkpoint on startup", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // First recorder: write a checkpoint + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 1, VersionPath: "vp1", DataPath: "dp1"}, + }, + }, + } + report := NewReport(0) + err = r1.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // Second recorder: should load the checkpoint + r2, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r2.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[2].Round) + info := cp.CheckpointItems[2].ClusterInfo["c1"] + require.Equal(t, uint64(1), info.TimeWindow.LeftBoundary) + require.Equal(t, uint64(10), info.TimeWindow.RightBoundary) + }) + + t.Run("cluster count mismatch rejects checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a checkpoint with 2 clusters + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Try to load with only 1 cluster — should fail + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster info length mismatch") + }) + + t.Run("cluster ID missing rejects checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a checkpoint with clusters c1 and c2 + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Try to load with c1 and c3 (same count, different ID) — should fail + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c3": {}}, 0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster info missing for cluster c3") + }) + + t.Run("matching clusters loads checkpoint successfully", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + clusters := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + // Write checkpoint across 3 rounds so all 3 slots are filled + r1, err := NewRecorder(dataDir, clusters, 0) + require.NoError(t, err) + for i := range 3 { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: uint64(i * 10), RightBoundary: uint64((i + 1) * 10)}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: uint64(i * 10), RightBoundary: uint64((i + 1) * 10)}}, + } + err = r1.RecordTimeWindow(twData, NewReport(uint64(i))) + require.NoError(t, err) + } + + // Reload with the same clusters — should succeed + r2, err := NewRecorder(dataDir, clusters, 0) + require.NoError(t, err) + cp := r2.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + }) + + t.Run("nil checkpoint items are skipped during validation", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write only 1 round — items[0] and items[1] stay nil + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload — should succeed even with a different cluster count since + // only the non-nil item[2] is validated + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + }) + + t.Run("no checkpoint file skips validation", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Fresh start with any cluster config — should always succeed + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}, "c3": {}}, 0) + require.NoError(t, err) + require.NotNil(t, r) + + cp := r.GetCheckpoint() + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.Nil(t, cp.CheckpointItems[2]) + }) +} + +func TestRecorder_RecordTimeWindow(t *testing.T) { + t.Parallel() + + t.Run("without report flush writes only checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + } + report := NewReport(0) // needFlush = false + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // checkpoint.json should exist + _, err = os.Stat(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + // No report files + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Empty(t, entries) + }) + + t.Run("with report flush writes both checkpoint and report", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + } + report := NewReport(5) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, 200) + report.AddClusterReport("c1", cr) + require.True(t, report.NeedFlush()) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // checkpoint.json should exist + _, err = os.Stat(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + // Report files should exist + _, err = os.Stat(filepath.Join(dataDir, "report", "report-5.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-5.json")) + require.NoError(t, err) + + // Verify report content + reportData, err := os.ReadFile(filepath.Join(dataDir, "report", "report-5.report")) + require.NoError(t, err) + require.Contains(t, string(reportData), "round: 5") + require.Contains(t, string(reportData), `[id: 1]`) + + // Verify json report is valid JSON + jsonData, err := os.ReadFile(filepath.Join(dataDir, "report", "report-5.json")) + require.NoError(t, err) + var parsed Report + err = json.Unmarshal(jsonData, &parsed) + require.NoError(t, err) + require.Equal(t, uint64(5), parsed.Round) + }) + + t.Run("multiple rounds advance checkpoint", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + for i := uint64(0); i < 4; i++ { + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: i + 1}, + }, + }, + } + report := NewReport(i) + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // After 4 rounds, checkpoint should have rounds 1, 2, 3 (oldest evicted) + cp := r.GetCheckpoint() + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(1), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(3), cp.CheckpointItems[2].Round) + }) +} + +func TestRecorder_CheckpointPersistence(t *testing.T) { + t.Parallel() + + t.Run("checkpoint survives restart", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + stk := types.SchemaTableKey{Schema: "db", Table: "tbl"} + + // Simulate 3 rounds + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + for i := uint64(0); i < 3; i++ { + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: i + 1, VersionPath: fmt.Sprintf("vp%d", i), DataPath: fmt.Sprintf("dp%d", i)}, + }, + }, + } + report := NewReport(i) + err = r1.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // Restart: new recorder from same dir + r2, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + cp := r2.GetCheckpoint() + require.Equal(t, uint64(0), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[2].Round) + + // Verify ToScanRange works after restart + scanRange, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, scanRange, 1) + sr := scanRange[stk] + require.Equal(t, "vp0", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp0", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("old report files are cleaned up when exceeding max", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + maxReportFiles := 3 + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, maxReportFiles) + require.NoError(t, err) + + // Write 5 reports (each creates 2 files: .report and .json) + for i := uint64(0); i < 5; i++ { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + } + report := NewReport(i) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, i+1) + report.AddClusterReport("c1", cr) + require.True(t, report.NeedFlush()) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // Should have at most maxReportFiles * 2 = 6 files (rounds 2, 3, 4) + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Equal(t, maxReportFiles*2, len(entries)) + + // Oldest files (round 0 and 1) should be deleted + _, err = os.Stat(filepath.Join(dataDir, "report", "report-0.report")) + require.True(t, os.IsNotExist(err)) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-1.report")) + require.True(t, os.IsNotExist(err)) + + // Newest files should still exist + _, err = os.Stat(filepath.Join(dataDir, "report", "report-2.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-3.report")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(dataDir, "report", "report-4.report")) + require.NoError(t, err) + }) + + t.Run("no cleanup when under max report files", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + maxReportFiles := 10 + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, maxReportFiles) + require.NoError(t, err) + + // Write 3 reports + for i := uint64(0); i < 3; i++ { + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + } + report := NewReport(i) + cr := NewClusterReport("c1", types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}) + cr.AddDataLossItem("d1", "test_table", map[string]any{"id": "1"}, `[id: 1]`, i+1) + report.AddClusterReport("c1", cr) + + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + } + + // All 6 files should exist (3 rounds * 2 files each) + entries, err := os.ReadDir(filepath.Join(dataDir, "report")) + require.NoError(t, err) + require.Equal(t, 6, len(entries)) + }) + + t.Run("checkpoint json is valid", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + r, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + + twData := map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + CheckpointTs: map[string]uint64{"c2": 150}, + }, + }, + } + report := NewReport(0) + err = r.RecordTimeWindow(twData, report) + require.NoError(t, err) + + // Read and parse checkpoint.json + data, err := os.ReadFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json")) + require.NoError(t, err) + + var cp Checkpoint + err = json.Unmarshal(data, &cp) + require.NoError(t, err) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(100), cp.CheckpointItems[2].ClusterInfo["c1"].TimeWindow.LeftBoundary) + require.Equal(t, uint64(200), cp.CheckpointItems[2].ClusterInfo["c1"].TimeWindow.RightBoundary) + }) +} + +func TestErrCheckpointCorruption(t *testing.T) { + t.Parallel() + + t.Run("corrupted checkpoint file returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create report and checkpoint directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + + // Write invalid JSON to checkpoint.json + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), []byte("{bad json"), 0600) + require.NoError(t, err) + + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("cluster count mismatch returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a valid checkpoint with 2 clusters + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload with 1 cluster — should be ErrCheckpointCorruption + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("missing cluster ID returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Write a valid checkpoint with clusters c1 and c2 + r1, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c2": {}}, 0) + require.NoError(t, err) + twData := map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + "c2": {TimeWindow: types.TimeWindow{LeftBoundary: 0, RightBoundary: 10}}, + } + err = r1.RecordTimeWindow(twData, NewReport(0)) + require.NoError(t, err) + + // Reload with c1 and c3 — should be ErrCheckpointCorruption + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}, "c3": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) + + t.Run("fresh start does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.NoError(t, err) + }) + + t.Run("unreadable checkpoint file does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + + // Create checkpoint.json as a directory so ReadFile fails with a non-corruption I/O error + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), 0755)) + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.False(t, errors.Is(err, ErrCheckpointCorruption), + "I/O errors should NOT be classified as ErrCheckpointCorruption, got: %v", err) + }) + + t.Run("unreadable backup checkpoint does not return ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + + // Make checkpoint.json.bak a directory so ReadFile fails with an I/O error + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), 0755)) + + _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.False(t, errors.Is(err, ErrCheckpointCorruption), + "I/O errors should NOT be classified as ErrCheckpointCorruption, got: %v", err) + }) + + t.Run("corrupted backup checkpoint returns ErrCheckpointCorruption", func(t *testing.T) { + t.Parallel() + dataDir := t.TempDir() + + // Create directories + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + + // Write invalid JSON to checkpoint.json.bak (simulate crash recovery with corrupted backup) + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), []byte("not valid json"), 0600) + require.NoError(t, err) + + _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) + require.Error(t, err) + require.True(t, errors.Is(err, ErrCheckpointCorruption)) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/types.go b/cmd/multi-cluster-consistency-checker/recorder/types.go new file mode 100644 index 0000000000..2c316fc965 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/types.go @@ -0,0 +1,410 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "fmt" + "sort" + "strings" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" +) + +type DataLossItem struct { + PeerClusterID string `json:"peer_cluster_id"` + PK map[string]any `json:"pk"` + CommitTS uint64 `json:"commit_ts"` + + PKStr string `json:"-"` +} + +func (item *DataLossItem) String() string { + return fmt.Sprintf("peer cluster: %s, pk: %s, commit ts: %d", item.PeerClusterID, item.PKStr, item.CommitTS) +} + +type InconsistentColumn struct { + Column string `json:"column"` + Local any `json:"local"` + Replicated any `json:"replicated"` +} + +func (c *InconsistentColumn) String() string { + return fmt.Sprintf("column: %s, local: %v, replicated: %v", c.Column, c.Local, c.Replicated) +} + +type DataInconsistentItem struct { + PeerClusterID string `json:"peer_cluster_id"` + PK map[string]any `json:"pk"` + OriginTS uint64 `json:"origin_ts"` + LocalCommitTS uint64 `json:"local_commit_ts"` + ReplicatedCommitTS uint64 `json:"replicated_commit_ts"` + InconsistentColumns []InconsistentColumn `json:"inconsistent_columns,omitempty"` + + PKStr string `json:"-"` +} + +func (item *DataInconsistentItem) String() string { + var sb strings.Builder + fmt.Fprintf(&sb, "peer cluster: %s, pk: %s, origin ts: %d, local commit ts: %d, replicated commit ts: %d", + item.PeerClusterID, item.PKStr, item.OriginTS, item.LocalCommitTS, item.ReplicatedCommitTS) + if len(item.InconsistentColumns) > 0 { + sb.WriteString(", inconsistent columns: [") + for i, col := range item.InconsistentColumns { + if i > 0 { + sb.WriteString("; ") + } + sb.WriteString(col.String()) + } + sb.WriteString("]") + } + return sb.String() +} + +type DataRedundantItem struct { + PK map[string]any `json:"pk"` + OriginTS uint64 `json:"origin_ts"` + CommitTS uint64 `json:"commit_ts"` + + PKStr string `json:"-"` +} + +func (item *DataRedundantItem) String() string { + return fmt.Sprintf("pk: %s, origin ts: %d, commit ts: %d", item.PKStr, item.OriginTS, item.CommitTS) +} + +type LWWViolationItem struct { + PK map[string]any `json:"pk"` + ExistingOriginTS uint64 `json:"existing_origin_ts"` + ExistingCommitTS uint64 `json:"existing_commit_ts"` + OriginTS uint64 `json:"origin_ts"` + CommitTS uint64 `json:"commit_ts"` + + PKStr string `json:"-"` +} + +func (item *LWWViolationItem) String() string { + return fmt.Sprintf( + "pk: %s, existing origin ts: %d, existing commit ts: %d, origin ts: %d, commit ts: %d", + item.PKStr, item.ExistingOriginTS, item.ExistingCommitTS, item.OriginTS, item.CommitTS) +} + +type TableFailureItems struct { + DataLossItems []DataLossItem `json:"data_loss_items"` // data loss items + DataInconsistentItems []DataInconsistentItem `json:"data_inconsistent_items"` // data inconsistent items + DataRedundantItems []DataRedundantItem `json:"data_redundant_items"` // data redundant items + LWWViolationItems []LWWViolationItem `json:"lww_violation_items"` // lww violation items +} + +func NewTableFailureItems() *TableFailureItems { + return &TableFailureItems{ + DataLossItems: make([]DataLossItem, 0), + DataInconsistentItems: make([]DataInconsistentItem, 0), + DataRedundantItems: make([]DataRedundantItem, 0), + LWWViolationItems: make([]LWWViolationItem, 0), + } +} + +type ClusterReport struct { + ClusterID string `json:"cluster_id"` + + TimeWindow types.TimeWindow `json:"time_window"` + + TableFailureItems map[string]*TableFailureItems `json:"table_failure_items"` // table failure items + + needFlush bool `json:"-"` +} + +func NewClusterReport(clusterID string, timeWindow types.TimeWindow) *ClusterReport { + return &ClusterReport{ + ClusterID: clusterID, + TimeWindow: timeWindow, + TableFailureItems: make(map[string]*TableFailureItems), + needFlush: false, + } +} + +func (r *ClusterReport) AddDataLossItem( + peerClusterID, schemaKey string, + pk map[string]any, + pkStr string, + commitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataLossItems = append(tableFailureItems.DataLossItems, DataLossItem{ + PeerClusterID: peerClusterID, + PK: pk, + CommitTS: commitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddDataInconsistentItem( + peerClusterID, schemaKey string, + pk map[string]any, + pkStr string, + originTS, localCommitTS, replicatedCommitTS uint64, + inconsistentColumns []InconsistentColumn, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataInconsistentItems = append(tableFailureItems.DataInconsistentItems, DataInconsistentItem{ + PeerClusterID: peerClusterID, + PK: pk, + OriginTS: originTS, + LocalCommitTS: localCommitTS, + ReplicatedCommitTS: replicatedCommitTS, + InconsistentColumns: inconsistentColumns, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddDataRedundantItem( + schemaKey string, + pk map[string]any, + pkStr string, + originTS, commitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.DataRedundantItems = append(tableFailureItems.DataRedundantItems, DataRedundantItem{ + PK: pk, + OriginTS: originTS, + CommitTS: commitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +func (r *ClusterReport) AddLWWViolationItem( + schemaKey string, + pk map[string]any, + pkStr string, + existingOriginTS, existingCommitTS uint64, + originTS, commitTS uint64, +) { + tableFailureItems, exists := r.TableFailureItems[schemaKey] + if !exists { + tableFailureItems = NewTableFailureItems() + r.TableFailureItems[schemaKey] = tableFailureItems + } + tableFailureItems.LWWViolationItems = append(tableFailureItems.LWWViolationItems, LWWViolationItem{ + PK: pk, + ExistingOriginTS: existingOriginTS, + ExistingCommitTS: existingCommitTS, + OriginTS: originTS, + CommitTS: commitTS, + + PKStr: pkStr, + }) + r.needFlush = true +} + +type Report struct { + Round uint64 `json:"round"` + ClusterReports map[string]*ClusterReport `json:"cluster_reports"` + needFlush bool `json:"-"` +} + +func NewReport(round uint64) *Report { + return &Report{ + Round: round, + ClusterReports: make(map[string]*ClusterReport), + needFlush: false, + } +} + +func (r *Report) AddClusterReport(clusterID string, clusterReport *ClusterReport) { + r.ClusterReports[clusterID] = clusterReport + r.needFlush = r.needFlush || clusterReport.needFlush +} + +func (r *Report) MarshalReport() string { + var reportMsg strings.Builder + fmt.Fprintf(&reportMsg, "round: %d\n", r.Round) + + // Sort cluster IDs for deterministic output + clusterIDs := make([]string, 0, len(r.ClusterReports)) + for clusterID := range r.ClusterReports { + clusterIDs = append(clusterIDs, clusterID) + } + sort.Strings(clusterIDs) + + for _, clusterID := range clusterIDs { + clusterReport := r.ClusterReports[clusterID] + if !clusterReport.needFlush { + continue + } + fmt.Fprintf(&reportMsg, "\n[cluster: %s]\n", clusterID) + fmt.Fprintf(&reportMsg, "time window: %s\n", clusterReport.TimeWindow.String()) + + // Sort schema keys for deterministic output + schemaKeys := make([]string, 0, len(clusterReport.TableFailureItems)) + for schemaKey := range clusterReport.TableFailureItems { + schemaKeys = append(schemaKeys, schemaKey) + } + sort.Strings(schemaKeys) + + for _, schemaKey := range schemaKeys { + tableFailureItems := clusterReport.TableFailureItems[schemaKey] + fmt.Fprintf(&reportMsg, " - [table name: %s]\n", schemaKey) + if len(tableFailureItems.DataLossItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data loss items: %d]\n", len(tableFailureItems.DataLossItems)) + for _, dataLossItem := range tableFailureItems.DataLossItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataLossItem.String()) + } + } + if len(tableFailureItems.DataInconsistentItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data inconsistent items: %d]\n", len(tableFailureItems.DataInconsistentItems)) + for _, dataInconsistentItem := range tableFailureItems.DataInconsistentItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataInconsistentItem.String()) + } + } + if len(tableFailureItems.DataRedundantItems) > 0 { + fmt.Fprintf(&reportMsg, " - [data redundant items: %d]\n", len(tableFailureItems.DataRedundantItems)) + for _, dataRedundantItem := range tableFailureItems.DataRedundantItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", dataRedundantItem.String()) + } + } + if len(tableFailureItems.LWWViolationItems) > 0 { + fmt.Fprintf(&reportMsg, " - [lww violation items: %d]\n", len(tableFailureItems.LWWViolationItems)) + for _, lwwViolationItem := range tableFailureItems.LWWViolationItems { + fmt.Fprintf(&reportMsg, " - [%s]\n", lwwViolationItem.String()) + } + } + } + } + reportMsg.WriteString("\n") + return reportMsg.String() +} + +func (r *Report) NeedFlush() bool { + return r.needFlush +} + +type SchemaTableVersionKey struct { + types.SchemaTableKey + types.VersionKey +} + +func NewSchemaTableVersionKeyFromVersionKeyMap(versionKeyMap map[types.SchemaTableKey]types.VersionKey) []SchemaTableVersionKey { + result := make([]SchemaTableVersionKey, 0, len(versionKeyMap)) + for schemaTableKey, versionKey := range versionKeyMap { + result = append(result, SchemaTableVersionKey{ + SchemaTableKey: schemaTableKey, + VersionKey: versionKey, + }) + } + return result +} + +type CheckpointClusterInfo struct { + TimeWindow types.TimeWindow `json:"time_window"` + MaxVersion []SchemaTableVersionKey `json:"max_version"` +} + +type CheckpointItem struct { + Round uint64 `json:"round"` + ClusterInfo map[string]CheckpointClusterInfo `json:"cluster_info"` +} + +type Checkpoint struct { + CheckpointItems [3]*CheckpointItem `json:"checkpoint_items"` +} + +func NewCheckpoint() *Checkpoint { + return &Checkpoint{ + CheckpointItems: [3]*CheckpointItem{ + nil, + nil, + nil, + }, + } +} + +func (c *Checkpoint) NewTimeWindowData(round uint64, timeWindowData map[string]types.TimeWindowData) { + newCheckpointItem := CheckpointItem{ + Round: round, + ClusterInfo: make(map[string]CheckpointClusterInfo), + } + for clusterID, timeWindow := range timeWindowData { + newCheckpointItem.ClusterInfo[clusterID] = CheckpointClusterInfo{ + TimeWindow: timeWindow.TimeWindow, + MaxVersion: NewSchemaTableVersionKeyFromVersionKeyMap(timeWindow.MaxVersion), + } + } + c.CheckpointItems[0] = c.CheckpointItems[1] + c.CheckpointItems[1] = c.CheckpointItems[2] + c.CheckpointItems[2] = &newCheckpointItem +} + +type ScanRange struct { + StartVersionKey string + EndVersionKey string + StartDataPath string + EndDataPath string +} + +func (c *Checkpoint) ToScanRange(clusterID string) (map[types.SchemaTableKey]*ScanRange, error) { + result := make(map[types.SchemaTableKey]*ScanRange) + if c.CheckpointItems[2] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[2].ClusterInfo[clusterID].MaxVersion { + result[versionKey.SchemaTableKey] = &ScanRange{ + StartVersionKey: versionKey.VersionPath, + EndVersionKey: versionKey.VersionPath, + StartDataPath: versionKey.DataPath, + EndDataPath: versionKey.DataPath, + } + } + if c.CheckpointItems[1] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[1].ClusterInfo[clusterID].MaxVersion { + scanRange, ok := result[versionKey.SchemaTableKey] + if !ok { + return nil, errors.Errorf("schema table key %s.%s not found in result", versionKey.Schema, versionKey.Table) + } + scanRange.StartVersionKey = versionKey.VersionPath + scanRange.StartDataPath = versionKey.DataPath + } + if c.CheckpointItems[0] == nil { + return result, nil + } + for _, versionKey := range c.CheckpointItems[0].ClusterInfo[clusterID].MaxVersion { + scanRange, ok := result[versionKey.SchemaTableKey] + if !ok { + return nil, errors.Errorf("schema table key %s.%s not found in result", versionKey.Schema, versionKey.Table) + } + scanRange.StartVersionKey = versionKey.VersionPath + scanRange.StartDataPath = versionKey.DataPath + } + return result, nil +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/types_test.go b/cmd/multi-cluster-consistency-checker/recorder/types_test.go new file mode 100644 index 0000000000..ec2df79d3c --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/recorder/types_test.go @@ -0,0 +1,639 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package recorder + +import ( + "fmt" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/stretchr/testify/require" +) + +func TestDataLossItem_String(t *testing.T) { + t.Parallel() + item := &DataLossItem{ + PeerClusterID: "cluster-2", + PK: map[string]any{"id": "1"}, + CommitTS: 200, + PKStr: `[id: 1]`, + } + s := item.String() + require.Equal(t, `peer cluster: cluster-2, pk: [id: 1], commit ts: 200`, s) +} + +func TestDataInconsistentItem_String(t *testing.T) { + t.Parallel() + + t.Run("without inconsistent columns", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + OriginTS: 300, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + } + s := item.String() + require.Equal(t, `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410`, s) + }) + + t.Run("with inconsistent columns", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + OriginTS: 300, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + InconsistentColumns: []InconsistentColumn{ + {Column: "col1", Local: "val_a", Replicated: "val_b"}, + {Column: "col2", Local: 100, Replicated: 200}, + }, + } + s := item.String() + require.Equal(t, + `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410, `+ + "inconsistent columns: [column: col1, local: val_a, replicated: val_b; column: col2, local: 100, replicated: 200]", + s) + }) + + t.Run("with missing column in replicated", func(t *testing.T) { + t.Parallel() + item := &DataInconsistentItem{ + PeerClusterID: "cluster-3", + PK: map[string]any{"id": "2"}, + OriginTS: 300, + LocalCommitTS: 400, + ReplicatedCommitTS: 410, + PKStr: `[id: 2]`, + InconsistentColumns: []InconsistentColumn{ + {Column: "col1", Local: "val_a", Replicated: nil}, + }, + } + s := item.String() + require.Equal(t, + `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410, `+ + "inconsistent columns: [column: col1, local: val_a, replicated: ]", + s) + }) +} + +func TestDataRedundantItem_String(t *testing.T) { + t.Parallel() + item := &DataRedundantItem{PK: map[string]any{"id": "x"}, PKStr: `[id: x]`, OriginTS: 10, CommitTS: 20} + s := item.String() + require.Equal(t, `pk: [id: x], origin ts: 10, commit ts: 20`, s) +} + +func TestLWWViolationItem_String(t *testing.T) { + t.Parallel() + item := &LWWViolationItem{ + PK: map[string]any{"id": "y"}, + PKStr: `[id: y]`, + ExistingOriginTS: 1, + ExistingCommitTS: 2, + OriginTS: 3, + CommitTS: 4, + } + s := item.String() + require.Equal(t, `pk: [id: y], existing origin ts: 1, existing commit ts: 2, origin ts: 3, commit ts: 4`, s) +} + +const testSchemaKey = "test_table" + +func TestClusterReport(t *testing.T) { + t.Parallel() + + t.Run("new cluster report is empty and does not need flush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + require.Equal(t, "c1", cr.ClusterID) + require.Empty(t, cr.TableFailureItems) + require.False(t, cr.needFlush) + }) + + t.Run("add data loss item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("peer-cluster-1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 200) + require.Len(t, cr.TableFailureItems, 1) + require.Contains(t, cr.TableFailureItems, testSchemaKey) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, "peer-cluster-1", tableItems.DataLossItems[0].PeerClusterID) + require.Equal(t, map[string]any{"id": "1"}, tableItems.DataLossItems[0].PK) + require.Equal(t, uint64(200), tableItems.DataLossItems[0].CommitTS) + }) + + t.Run("add data inconsistent item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cols := []InconsistentColumn{ + {Column: "val", Local: "a", Replicated: "b"}, + } + cr.AddDataInconsistentItem("peer-cluster-2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 300, 400, 410, cols) + require.Len(t, cr.TableFailureItems, 1) + require.Contains(t, cr.TableFailureItems, testSchemaKey) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataInconsistentItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, "peer-cluster-2", tableItems.DataInconsistentItems[0].PeerClusterID) + require.Equal(t, map[string]any{"id": "2"}, tableItems.DataInconsistentItems[0].PK) + require.Equal(t, uint64(300), tableItems.DataInconsistentItems[0].OriginTS) + require.Equal(t, uint64(400), tableItems.DataInconsistentItems[0].LocalCommitTS) + require.Equal(t, uint64(410), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) + require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) + require.Equal(t, "val", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Column) + require.Equal(t, "a", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Local) + require.Equal(t, "b", tableItems.DataInconsistentItems[0].InconsistentColumns[0].Replicated) + }) + + t.Run("add data redundant item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "2"}, `id: 2`, 300, 400) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataRedundantItems, 1) + require.True(t, cr.needFlush) + }) + + t.Run("add lww violation item sets needFlush", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "3"}, `id: 3`, 1, 2, 3, 4) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.LWWViolationItems, 1) + require.True(t, cr.needFlush) + require.Equal(t, uint64(1), tableItems.LWWViolationItems[0].ExistingOriginTS) + require.Equal(t, uint64(2), tableItems.LWWViolationItems[0].ExistingCommitTS) + require.Equal(t, uint64(3), tableItems.LWWViolationItems[0].OriginTS) + require.Equal(t, uint64(4), tableItems.LWWViolationItems[0].CommitTS) + }) + + t.Run("add multiple items", func(t *testing.T) { + t.Parallel() + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `id: 1`, 2) + cr.AddDataInconsistentItem("d2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 3, 4, 5, nil) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "3"}, `[id: 3]`, 5, 6) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "4"}, `[id: 4]`, 7, 8, 9, 10) + require.Len(t, cr.TableFailureItems, 1) + tableItems := cr.TableFailureItems[testSchemaKey] + require.Len(t, tableItems.DataLossItems, 1) + require.Len(t, tableItems.DataInconsistentItems, 1) + require.Len(t, tableItems.DataRedundantItems, 1) + require.Len(t, tableItems.LWWViolationItems, 1) + }) +} + +func TestReport(t *testing.T) { + t.Parallel() + + t.Run("new report does not need flush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + require.Equal(t, uint64(1), r.Round) + require.Empty(t, r.ClusterReports) + require.False(t, r.NeedFlush()) + }) + + t.Run("add empty cluster report does not set needFlush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", types.TimeWindow{}) + r.AddClusterReport("c1", cr) + require.Len(t, r.ClusterReports, 1) + require.False(t, r.NeedFlush()) + }) + + t.Run("add non-empty cluster report sets needFlush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", types.TimeWindow{}) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2) + r.AddClusterReport("c1", cr) + require.True(t, r.NeedFlush()) + }) + + t.Run("needFlush propagates from any cluster report", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr1 := NewClusterReport("c1", types.TimeWindow{}) + cr2 := NewClusterReport("c2", types.TimeWindow{}) + cr2.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 1, 2) + r.AddClusterReport("c1", cr1) + r.AddClusterReport("c2", cr2) + require.True(t, r.NeedFlush()) + }) +} + +func TestReport_MarshalReport(t *testing.T) { + t.Parallel() + + tw := types.TimeWindow{LeftBoundary: 0, RightBoundary: 0} + twStr := tw.String() + + t.Run("empty report", func(t *testing.T) { + t.Parallel() + r := NewReport(5) + s := r.MarshalReport() + require.Equal(t, "round: 5\n\n", s) + }) + + t.Run("report with data loss items", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + cr := NewClusterReport("c1", tw) + cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 200) + r.AddClusterReport("c1", cr) + s := r.MarshalReport() + require.Equal(t, "round: 1\n\n"+ + "[cluster: c1]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], commit ts: 200]`+"\n\n", + s) + }) + + t.Run("report with data redundant items", func(t *testing.T) { + t.Parallel() + r := NewReport(2) + cr := NewClusterReport("c2", tw) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "r"}, `[id: r]`, 10, 20) + r.AddClusterReport("c2", cr) + s := r.MarshalReport() + require.Equal(t, "round: 2\n\n"+ + "[cluster: c2]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data redundant items: 1]\n"+ + ` - [pk: [id: r], origin ts: 10, commit ts: 20]`+"\n\n", + s) + }) + + t.Run("report with lww violation items", func(t *testing.T) { + t.Parallel() + r := NewReport(3) + cr := NewClusterReport("c3", tw) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "v"}, `[id: v]`, 1, 2, 3, 4) + r.AddClusterReport("c3", cr) + s := r.MarshalReport() + require.Equal(t, "round: 3\n\n"+ + "[cluster: c3]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [lww violation items: 1]\n"+ + ` - [pk: [id: v], existing origin ts: 1, existing commit ts: 2, origin ts: 3, commit ts: 4]`+"\n\n", + s) + }) + + t.Run("skips cluster reports that do not need flush", func(t *testing.T) { + t.Parallel() + r := NewReport(1) + crEmpty := NewClusterReport("empty-cluster", tw) + crFull := NewClusterReport("full-cluster", tw) + crFull.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2) + r.AddClusterReport("empty-cluster", crEmpty) + r.AddClusterReport("full-cluster", crFull) + s := r.MarshalReport() + require.Equal(t, "round: 1\n\n"+ + "[cluster: full-cluster]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], commit ts: 2]`+"\n\n", + s) + }) + + t.Run("report with mixed items", func(t *testing.T) { + t.Parallel() + r := NewReport(10) + cr := NewClusterReport("c1", tw) + cr.AddDataLossItem("d0", testSchemaKey, map[string]any{"id": "0"}, `[id: 0]`, 1) + cr.AddDataInconsistentItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 1, 2, 3, []InconsistentColumn{ + {Column: "val", Local: "x", Replicated: "y"}, + }) + cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 3, 4) + cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "3"}, `[id: 3]`, 5, 6, 7, 8) + r.AddClusterReport("c1", cr) + s := r.MarshalReport() + require.Equal(t, "round: 10\n\n"+ + "[cluster: c1]\n"+ + "time window: "+twStr+"\n"+ + " - [table name: "+testSchemaKey+"]\n"+ + " - [data loss items: 1]\n"+ + ` - [peer cluster: d0, pk: [id: 0], commit ts: 1]`+"\n"+ + " - [data inconsistent items: 1]\n"+ + ` - [peer cluster: d1, pk: [id: 1], origin ts: 1, local commit ts: 2, replicated commit ts: 3, inconsistent columns: [column: val, local: x, replicated: y]]`+"\n"+ + " - [data redundant items: 1]\n"+ + ` - [pk: [id: 2], origin ts: 3, commit ts: 4]`+"\n"+ + " - [lww violation items: 1]\n"+ + ` - [pk: [id: 3], existing origin ts: 5, existing commit ts: 6, origin ts: 7, commit ts: 8]`+"\n\n", + s) + }) +} + +func TestNewSchemaTableVersionKeyFromVersionKeyMap(t *testing.T) { + t.Parallel() + + t.Run("empty map", func(t *testing.T) { + t.Parallel() + result := NewSchemaTableVersionKeyFromVersionKeyMap(nil) + require.Empty(t, result) + }) + + t.Run("single entry", func(t *testing.T) { + t.Parallel() + m := map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 1, VersionPath: "path1"}, + } + result := NewSchemaTableVersionKeyFromVersionKeyMap(m) + require.Len(t, result, 1) + require.Equal(t, "db", result[0].Schema) + require.Equal(t, "tbl", result[0].Table) + require.Equal(t, uint64(1), result[0].Version) + require.Equal(t, "path1", result[0].VersionPath) + }) + + t.Run("multiple entries", func(t *testing.T) { + t.Parallel() + m := map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db1", Table: "t1"}: {Version: 1}, + {Schema: "db2", Table: "t2"}: {Version: 2}, + } + result := NewSchemaTableVersionKeyFromVersionKeyMap(m) + require.Len(t, result, 2) + }) +} + +func TestCheckpoint_NewTimeWindowData(t *testing.T) { + t.Parallel() + + t.Run("first call populates slot 2", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + }) + require.Nil(t, cp.CheckpointItems[0]) + require.Nil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[2].Round) + }) + + t.Run("second call shifts slots", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}}, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: 10, RightBoundary: 20}}, + }) + require.Nil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[2].Round) + }) + + t.Run("third call fills all slots", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 3; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + }) + } + require.NotNil(t, cp.CheckpointItems[0]) + require.NotNil(t, cp.CheckpointItems[1]) + require.NotNil(t, cp.CheckpointItems[2]) + require.Equal(t, uint64(0), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(1), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[2].Round) + }) + + t.Run("fourth call evicts oldest", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 4; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": {TimeWindow: types.TimeWindow{LeftBoundary: i * 10, RightBoundary: (i + 1) * 10}}, + }) + } + require.Equal(t, uint64(1), cp.CheckpointItems[0].Round) + require.Equal(t, uint64(2), cp.CheckpointItems[1].Round) + require.Equal(t, uint64(3), cp.CheckpointItems[2].Round) + }) + + t.Run("stores max version from time window data", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{LeftBoundary: 1, RightBoundary: 10}, + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "db", Table: "tbl"}: {Version: 5, VersionPath: "vp", DataPath: "dp"}, + }, + }, + }) + info := cp.CheckpointItems[2].ClusterInfo["c1"] + require.Len(t, info.MaxVersion, 1) + require.Equal(t, uint64(5), info.MaxVersion[0].Version) + require.Equal(t, "vp", info.MaxVersion[0].VersionPath) + require.Equal(t, "dp", info.MaxVersion[0].DataPath) + }) +} + +func TestCheckpoint_ToScanRange(t *testing.T) { + t.Parallel() + + stk := types.SchemaTableKey{Schema: "db", Table: "tbl"} + + t.Run("empty checkpoint returns empty", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Empty(t, result) + }) + + t.Run("only item[2] set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2", DataPath: "dp2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // With only item[2], Start and End are both from item[2] + require.Equal(t, "vp2", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp2", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("items[1] and items[2] set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1", DataPath: "dp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2", DataPath: "dp2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // End comes from item[2], Start overridden by item[1] + require.Equal(t, "vp1", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp1", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("all three items set", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + for i := uint64(0); i < 3; i++ { + cp.NewTimeWindowData(i, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: { + Version: i + 1, + VersionPath: fmt.Sprintf("vp%d", i), + DataPath: fmt.Sprintf("dp%d", i), + }, + }, + }, + }) + } + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 1) + sr := result[stk] + // End from item[2], Start overridden by item[0] (oldest) + require.Equal(t, "vp0", sr.StartVersionKey) + require.Equal(t, "vp2", sr.EndVersionKey) + require.Equal(t, "dp0", sr.StartDataPath) + require.Equal(t, "dp2", sr.EndDataPath) + }) + + t.Run("missing key in item[1] returns error", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "other", Table: "other"}: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2"}, + }, + }, + }) + _, err := cp.ToScanRange("c1") + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("missing key in item[0] returns error", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + {Schema: "other", Table: "other"}: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + cp.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 2, VersionPath: "vp2"}, + }, + }, + }) + cp.NewTimeWindowData(2, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 3, VersionPath: "vp3"}, + }, + }, + }) + _, err := cp.ToScanRange("c1") + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("unknown cluster returns empty", func(t *testing.T) { + t.Parallel() + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1"}, + }, + }, + }) + result, err := cp.ToScanRange("unknown-cluster") + require.NoError(t, err) + require.Empty(t, result) + }) + + t.Run("multiple tables", func(t *testing.T) { + t.Parallel() + stk2 := types.SchemaTableKey{Schema: "db2", Table: "tbl2"} + cp := NewCheckpoint() + cp.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + MaxVersion: map[types.SchemaTableKey]types.VersionKey{ + stk: {Version: 1, VersionPath: "vp1-t1", DataPath: "dp1-t1"}, + stk2: {Version: 1, VersionPath: "vp1-t2", DataPath: "dp1-t2"}, + }, + }, + }) + result, err := cp.ToScanRange("c1") + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, stk) + require.Contains(t, result, stk2) + }) +} diff --git a/cmd/multi-cluster-consistency-checker/task.go b/cmd/multi-cluster-consistency-checker/task.go new file mode 100644 index 0000000000..1bb8c35d86 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/task.go @@ -0,0 +1,330 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/advancer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" + "github.com/pingcap/ticdc/pkg/common" + cdcconfig "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/etcd" + "github.com/pingcap/ticdc/pkg/security" + "github.com/pingcap/ticdc/pkg/util" + pd "github.com/tikv/pd/client" + pdopt "github.com/tikv/pd/client/opt" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +func runTask(ctx context.Context, cfg *config.Config, dryRun bool) error { + checkpointWatchers, s3Watchers, pdClients, etcdClients, err := initClients(ctx, cfg) + if err != nil { + // Client initialisation is typically a transient (network) failure. + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + // Ensure cleanup happens even if there's an error + defer cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + + if dryRun { + log.Info("Dry-run mode: config validation and connectivity check passed, exiting") + return nil + } + + rec, err := recorder.NewRecorder(cfg.GlobalConfig.DataDir, cfg.Clusters, cfg.GlobalConfig.MaxReportFiles) + if err != nil { + if errors.Is(err, recorder.ErrCheckpointCorruption) { + return &ExitError{Code: ExitCodeCheckpointCorruption, Err: err} + } + // Other recorder init errors (e.g. mkdir, readdir) are transient. + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + timeWindowAdvancer, checkpointDataMap, err := advancer.NewTimeWindowAdvancer(ctx, checkpointWatchers, s3Watchers, pdClients, rec.GetCheckpoint()) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + dataChecker := checker.NewDataChecker(ctx, cfg.Clusters, checkpointDataMap, rec.GetCheckpoint()) + + log.Info("Starting consistency checker task") + for { + // Check if context is cancelled before starting a new iteration + select { + case <-ctx.Done(): + log.Info("Context cancelled, shutting down gracefully") + return ctx.Err() + default: + } + + newTimeWindowData, err := timeWindowAdvancer.AdvanceTimeWindow(ctx) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + + report, err := dataChecker.CheckInNextTimeWindow(newTimeWindowData) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + + if err := rec.RecordTimeWindow(newTimeWindowData, report); err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } + } +} + +func initClients(ctx context.Context, cfg *config.Config) ( + map[string]map[string]watcher.Watcher, + map[string]*watcher.S3Watcher, + map[string]pd.Client, + map[string]*etcd.CDCEtcdClientImpl, + error, +) { + checkpointWatchers := make(map[string]map[string]watcher.Watcher) + s3Watchers := make(map[string]*watcher.S3Watcher) + pdClients := make(map[string]pd.Client) + etcdClients := make(map[string]*etcd.CDCEtcdClientImpl) + + for clusterID, clusterConfig := range cfg.Clusters { + pdClient, etcdClient, err := newPDClient(ctx, clusterConfig.PDAddrs, &clusterConfig.SecurityConfig) + if err != nil { + // Clean up already created clients before returning error + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + etcdClients[clusterID] = etcdClient + + clusterCheckpointWatchers := make(map[string]watcher.Watcher) + for peerClusterID, peerClusterChangefeedConfig := range clusterConfig.PeerClusterChangefeedConfig { + checkpointWatcher := watcher.NewCheckpointWatcher(ctx, clusterID, peerClusterID, peerClusterChangefeedConfig.ChangefeedID, etcdClient) + clusterCheckpointWatchers[peerClusterID] = checkpointWatcher + } + checkpointWatchers[clusterID] = clusterCheckpointWatchers + + // Validate s3 changefeed sink config from etcd + if err := validateS3ChangefeedSinkConfig(ctx, etcdClient, clusterID, clusterConfig.S3ChangefeedID, clusterConfig.S3SinkURI); err != nil { + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + + s3Storage, err := util.GetExternalStorageWithDefaultTimeout(ctx, clusterConfig.S3SinkURI) + if err != nil { + // Clean up already created clients before returning error + cleanupClients(pdClients, etcdClients, checkpointWatchers, s3Watchers) + return nil, nil, nil, nil, errors.Trace(err) + } + s3Watcher := watcher.NewS3Watcher( + watcher.NewCheckpointWatcher(ctx, clusterID, "s3", clusterConfig.S3ChangefeedID, etcdClient), + s3Storage, + cfg.GlobalConfig.Tables, + ) + s3Watchers[clusterID] = s3Watcher + pdClients[clusterID] = pdClient + } + + return checkpointWatchers, s3Watchers, pdClients, etcdClients, nil +} + +// validateS3ChangefeedSinkConfig fetches the changefeed info from etcd and validates that: +// 1. The changefeed SinkURI bucket/prefix matches the configured s3SinkURI +// 2. The protocol must be canal-json +// 3. The date separator must be "day" +// 4. The file index width must be DefaultFileIndexWidth +func validateS3ChangefeedSinkConfig(ctx context.Context, etcdClient *etcd.CDCEtcdClientImpl, clusterID string, s3ChangefeedID string, s3SinkURI string) error { + displayName := common.NewChangeFeedDisplayName(s3ChangefeedID, "default") + cfInfo, err := etcdClient.GetChangeFeedInfo(ctx, displayName) + if err != nil { + return errors.Annotate(err, fmt.Sprintf("failed to get changefeed info for s3 changefeed %s in cluster %s", s3ChangefeedID, clusterID)) + } + + // 1. Validate that the changefeed's SinkURI bucket/prefix matches the configured s3SinkURI. + // This prevents the checker from reading data that was written by a different changefeed. + if err := validateS3BucketPrefix(cfInfo.SinkURI, s3SinkURI, clusterID, s3ChangefeedID); err != nil { + return err + } + + if cfInfo.Config == nil || cfInfo.Config.Sink == nil { + return fmt.Errorf("cluster %s: s3 changefeed %s has no sink configuration", clusterID, s3ChangefeedID) + } + + sinkConfig := cfInfo.Config.Sink + + // 2. Validate protocol must be canal-json + protocolStr := strings.ToLower(util.GetOrZero(sinkConfig.Protocol)) + if protocolStr == "" { + return fmt.Errorf("cluster %s: s3 changefeed %s has no protocol configured in sink config", clusterID, s3ChangefeedID) + } + protocol, err := cdcconfig.ParseSinkProtocolFromString(protocolStr) + if err != nil { + return errors.Annotate(err, fmt.Sprintf("cluster %s: s3 changefeed %s has invalid protocol", clusterID, s3ChangefeedID)) + } + if protocol != cdcconfig.ProtocolCanalJSON { + return fmt.Errorf("cluster %s: s3 changefeed %s protocol is %q, but only %q is supported", + clusterID, s3ChangefeedID, protocolStr, cdcconfig.ProtocolCanalJSON.String()) + } + + // 3. Validate date separator must be "day" + dateSeparatorStr := util.GetOrZero(sinkConfig.DateSeparator) + if dateSeparatorStr == "" { + dateSeparatorStr = cdcconfig.DateSeparatorNone.String() + } + var dateSep cdcconfig.DateSeparator + if err := dateSep.FromString(dateSeparatorStr); err != nil { + return errors.Annotate(err, fmt.Sprintf("cluster %s: s3 changefeed %s has invalid date-separator %q", clusterID, s3ChangefeedID, dateSeparatorStr)) + } + if dateSep != cdcconfig.DateSeparatorDay { + return fmt.Errorf("cluster %s: s3 changefeed %s date-separator is %q, but only %q is supported", + clusterID, s3ChangefeedID, dateSep.String(), cdcconfig.DateSeparatorDay.String()) + } + + // 4. Validate file index width must be DefaultFileIndexWidth + fileIndexWidth := util.GetOrZero(sinkConfig.FileIndexWidth) + if fileIndexWidth != cdcconfig.DefaultFileIndexWidth { + return fmt.Errorf("cluster %s: s3 changefeed %s file-index-width is %d, but only %d is supported", + clusterID, s3ChangefeedID, fileIndexWidth, cdcconfig.DefaultFileIndexWidth) + } + + log.Info("Validated s3 changefeed sink config from etcd", + zap.String("clusterID", clusterID), + zap.String("s3ChangefeedID", s3ChangefeedID), + zap.String("protocol", protocolStr), + zap.String("dateSeparator", dateSep.String()), + zap.Int("fileIndexWidth", fileIndexWidth), + ) + + return nil +} + +// validateS3BucketPrefix checks that the changefeed's SinkURI and the +// configured s3-sink-uri point to the same S3 bucket and prefix. +// This is a critical sanity check — a mismatch means the checker would +// read data from a different location than where the changefeed writes. +func validateS3BucketPrefix(changefeedSinkURI, configS3SinkURI, clusterID, s3ChangefeedID string) error { + cfURL, err := url.Parse(changefeedSinkURI) + if err != nil { + return fmt.Errorf("cluster %s: s3 changefeed %s has invalid sink URI %q: %v", + clusterID, s3ChangefeedID, changefeedSinkURI, err) + } + cfgURL, err := url.Parse(configS3SinkURI) + if err != nil { + return fmt.Errorf("cluster %s: configured s3-sink-uri %q is invalid: %v", + clusterID, configS3SinkURI, err) + } + + // Compare scheme (s3, gcs, …), bucket (Host) and prefix (Path). + // Path is normalized by trimming trailing slashes so that + // "s3://bucket/prefix" and "s3://bucket/prefix/" are considered equal. + cfScheme := strings.ToLower(cfURL.Scheme) + cfgScheme := strings.ToLower(cfgURL.Scheme) + cfBucket := cfURL.Host + cfgBucket := cfgURL.Host + cfPrefix := strings.TrimRight(cfURL.Path, "/") + cfgPrefix := strings.TrimRight(cfgURL.Path, "/") + + if cfScheme != cfgScheme || cfBucket != cfgBucket || cfPrefix != cfgPrefix { + return fmt.Errorf("cluster %s: s3 changefeed %s sink URI bucket/prefix mismatch: "+ + "changefeed has %s://%s%s but config has %s://%s%s", + clusterID, s3ChangefeedID, + cfScheme, cfBucket, cfURL.Path, + cfgScheme, cfgBucket, cfgURL.Path) + } + return nil +} + +func newPDClient(ctx context.Context, pdAddrs []string, securityConfig *security.Credential) (pd.Client, *etcd.CDCEtcdClientImpl, error) { + pdClient, err := pd.NewClientWithContext( + ctx, "consistency-checker", pdAddrs, securityConfig.PDSecurityOption(), + pdopt.WithCustomTimeoutOption(10*time.Second), + ) + if err != nil { + return nil, nil, errors.Trace(err) + } + + etcdCli, err := etcd.CreateRawEtcdClient(securityConfig, grpc.EmptyDialOption{}, pdAddrs...) + if err != nil { + // Clean up PD client if etcd client creation fails + if pdClient != nil { + pdClient.Close() + } + return nil, nil, errors.Trace(err) + } + + cdcEtcdClient, err := etcd.NewCDCEtcdClient(ctx, etcdCli, "default") + if err != nil { + // Clean up resources if CDC etcd client creation fails + etcdCli.Close() + pdClient.Close() + return nil, nil, errors.Trace(err) + } + + return pdClient, cdcEtcdClient, nil +} + +// cleanupClients closes all PD and etcd clients gracefully +func cleanupClients( + pdClients map[string]pd.Client, + etcdClients map[string]*etcd.CDCEtcdClientImpl, + checkpointWatchers map[string]map[string]watcher.Watcher, + s3Watchers map[string]*watcher.S3Watcher, +) { + log.Info("Cleaning up clients", + zap.Int("pdClients", len(pdClients)), + zap.Int("etcdClients", len(etcdClients)), + zap.Int("checkpointWatchers", len(checkpointWatchers)), + zap.Int("s3Watchers", len(s3Watchers)), + ) + + // Close PD clients + for clusterID, pdClient := range pdClients { + if pdClient != nil { + pdClient.Close() + log.Debug("PD client closed", zap.String("clusterID", clusterID)) + } + } + + // Close etcd clients + for clusterID, etcdClient := range etcdClients { + if etcdClient != nil { + if err := etcdClient.Close(); err != nil { + log.Warn("Failed to close etcd client", + zap.String("clusterID", clusterID), + zap.Error(err)) + } else { + log.Debug("Etcd client closed", zap.String("clusterID", clusterID)) + } + } + } + + // Close checkpoint watchers + for _, clusterWatchers := range checkpointWatchers { + for _, watcher := range clusterWatchers { + watcher.Close() + } + } + + // Close s3 watchers + for _, s3Watcher := range s3Watchers { + s3Watcher.Close() + } + + log.Info("Client cleanup completed") +} diff --git a/cmd/multi-cluster-consistency-checker/types/types.go b/cmd/multi-cluster-consistency-checker/types/types.go new file mode 100644 index 0000000000..cdb7e1760a --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/types/types.go @@ -0,0 +1,98 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "fmt" + "sort" + "strings" + + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + ptypes "github.com/pingcap/tidb/pkg/parser/types" +) + +// PkType is a distinct type for encoded primary key strings, making it clear +// at the type level that the value has been serialized / encoded. +type PkType string + +type CdcVersion struct { + CommitTs uint64 + OriginTs uint64 +} + +func (v *CdcVersion) GetCompareTs() uint64 { + if v.OriginTs > 0 { + return v.OriginTs + } + return v.CommitTs +} + +type SchemaTableKey struct { + Schema string + Table string +} + +type VersionKey struct { + Version uint64 + // Version Path is a hint for the next version path to scan + VersionPath string + // Data Path is a hint for the next data path to scan + DataPath string +} + +// TimeWindow is the time window of the cluster, including the left boundary, right boundary and checkpoint ts +// Assert 1: LeftBoundary < CheckpointTs < RightBoundary +// Assert 2: The other cluster's checkpoint timestamp of next time window should be larger than the PDTimestampAfterTimeWindow saved in this cluster's time window +// Assert 3: CheckpointTs of this cluster should be larger than other clusters' RightBoundary of previous time window +// Assert 4: RightBoundary of this cluster should be larger than other clusters' CheckpointTs of this time window +type TimeWindow struct { + LeftBoundary uint64 `json:"left_boundary"` + RightBoundary uint64 `json:"right_boundary"` + // CheckpointTs is the checkpoint timestamp for each local-to-replicated changefeed, + // mapping from replicated cluster ID to the checkpoint timestamp + CheckpointTs map[string]uint64 `json:"checkpoint_ts"` + // PDTimestampAfterTimeWindow is the max PD timestamp after the time window for each replicated cluster, + // mapping from local cluster ID to the max PD timestamp + PDTimestampAfterTimeWindow map[string]uint64 `json:"pd_timestamp_after_time_window"` + // NextMinLeftBoundary is the minimum left boundary of the next time window for the cluster + NextMinLeftBoundary uint64 `json:"next_min_left_boundary"` +} + +func (t *TimeWindow) String() string { + var builder strings.Builder + fmt.Fprintf(&builder, "time window boundary: (%d, %d]\n", t.LeftBoundary, t.RightBoundary) + + // Sort cluster IDs for deterministic output + clusterIDs := make([]string, 0, len(t.CheckpointTs)) + for id := range t.CheckpointTs { + clusterIDs = append(clusterIDs, id) + } + sort.Strings(clusterIDs) + + for _, replicatedClusterID := range clusterIDs { + fmt.Fprintf(&builder, "checkpoint ts [replicated cluster: %s]: %d\n", replicatedClusterID, t.CheckpointTs[replicatedClusterID]) + } + return builder.String() +} + +type TimeWindowData struct { + TimeWindow + Data map[cloudstorage.DmlPathKey]IncrementalData + MaxVersion map[SchemaTableKey]VersionKey +} + +type IncrementalData struct { + DataContentSlices map[cloudstorage.FileIndexKey][][]byte + ColumnFieldTypes map[string]*ptypes.FieldType +} diff --git a/cmd/multi-cluster-consistency-checker/types/types_test.go b/cmd/multi-cluster-consistency-checker/types/types_test.go new file mode 100644 index 0000000000..7fe90a9739 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/types/types_test.go @@ -0,0 +1,69 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types_test + +import ( + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/stretchr/testify/require" +) + +func TestCdcVersion_GetCompareTs(t *testing.T) { + tests := []struct { + name string + version types.CdcVersion + expected uint64 + }{ + { + name: "OriginTs is set", + version: types.CdcVersion{ + CommitTs: 100, + OriginTs: 200, + }, + expected: 200, + }, + { + name: "OriginTs is smaller than CommitTs", + version: types.CdcVersion{ + CommitTs: 200, + OriginTs: 100, + }, + expected: 100, + }, + { + name: "OriginTs is zero", + version: types.CdcVersion{ + CommitTs: 100, + OriginTs: 0, + }, + expected: 100, + }, + { + name: "Both are zero", + version: types.CdcVersion{ + CommitTs: 0, + OriginTs: 0, + }, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.version.GetCompareTs() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go new file mode 100644 index 0000000000..6aa05c17f6 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go @@ -0,0 +1,300 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "sync" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/etcd" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// errChangefeedKeyDeleted is a sentinel error indicating that the changefeed +// status key has been deleted from etcd. This is a non-recoverable error +// that should not be retried. +var errChangefeedKeyDeleted = errors.New("changefeed status key is deleted") + +const ( + // retryBackoffBase is the initial backoff duration for retries + retryBackoffBase = 500 * time.Millisecond + // retryBackoffMax is the maximum backoff duration for retries + retryBackoffMax = 30 * time.Second + // retryBackoffMultiplier is the multiplier for exponential backoff + retryBackoffMultiplier = 2.0 +) + +type Watcher interface { + AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) + Close() +} + +type waitCheckpointTask struct { + respCh chan uint64 + minCheckpointTs uint64 +} + +type CheckpointWatcher struct { + localClusterID string + replicatedClusterID string + changefeedID common.ChangeFeedID + etcdClient etcd.CDCEtcdClient + + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + latestCheckpoint uint64 + pendingTasks []*waitCheckpointTask + watchErr error + closed bool +} + +func NewCheckpointWatcher( + ctx context.Context, + localClusterID, replicatedClusterID, changefeedID string, + etcdClient etcd.CDCEtcdClient, +) *CheckpointWatcher { + cctx, cancel := context.WithCancel(ctx) + watcher := &CheckpointWatcher{ + localClusterID: localClusterID, + replicatedClusterID: replicatedClusterID, + changefeedID: common.NewChangeFeedIDWithName(changefeedID, "default"), + etcdClient: etcdClient, + + ctx: cctx, + cancel: cancel, + } + go watcher.run() + return watcher +} + +// AdvanceCheckpointTs waits for the checkpoint to exceed minCheckpointTs +func (cw *CheckpointWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + cw.mu.Lock() + + // Check if watcher has encountered an error + if cw.watchErr != nil { + err := cw.watchErr + cw.mu.Unlock() + return 0, err + } + + // Check if watcher is closed + if cw.closed { + cw.mu.Unlock() + return 0, errors.Errorf("checkpoint watcher is closed") + } + + // Check if current checkpoint already exceeds minCheckpointTs + if cw.latestCheckpoint > minCheckpointTs { + checkpoint := cw.latestCheckpoint + cw.mu.Unlock() + return checkpoint, nil + } + + // Create a task and wait for the background goroutine to notify + task := &waitCheckpointTask{ + respCh: make(chan uint64, 1), + minCheckpointTs: minCheckpointTs, + } + cw.pendingTasks = append(cw.pendingTasks, task) + cw.mu.Unlock() + + // Wait for response or context cancellation + select { + case <-ctx.Done(): + // Remove the task from pending list + cw.mu.Lock() + for i, t := range cw.pendingTasks { + if t == task { + cw.pendingTasks = append(cw.pendingTasks[:i], cw.pendingTasks[i+1:]...) + break + } + } + cw.mu.Unlock() + return 0, errors.Annotate(ctx.Err(), "context canceled while waiting for checkpoint") + case <-cw.ctx.Done(): + return 0, errors.Annotate(cw.ctx.Err(), "watcher context canceled") + case checkpoint, ok := <-task.respCh: + if !ok { + return 0, errors.Errorf("checkpoint watcher is closed") + } + return checkpoint, nil + } +} + +// Close stops the watcher +func (cw *CheckpointWatcher) Close() { + cw.cancel() + cw.mu.Lock() + cw.closed = true + // Notify all pending tasks that watcher is closing + for _, task := range cw.pendingTasks { + close(task.respCh) + } + cw.pendingTasks = nil + cw.mu.Unlock() +} + +func (cw *CheckpointWatcher) run() { + backoff := retryBackoffBase + for { + select { + case <-cw.ctx.Done(): + cw.mu.Lock() + cw.watchErr = errors.Annotate(cw.ctx.Err(), "context canceled") + cw.mu.Unlock() + return + default: + } + + err := cw.watchOnce() + if err == nil { + // Normal exit (context canceled) + return + } + + // Check if this is a non-recoverable error + if errors.Is(err, errChangefeedKeyDeleted) { + cw.mu.Lock() + cw.watchErr = err + cw.mu.Unlock() + return + } + + // Log and retry with backoff + log.Warn("checkpoint watcher encountered error, will retry", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Duration("backoff", backoff), + zap.Error(err)) + + select { + case <-cw.ctx.Done(): + cw.mu.Lock() + cw.watchErr = errors.Annotate(cw.ctx.Err(), "context canceled") + cw.mu.Unlock() + return + case <-time.After(backoff): + } + + // Increase backoff for next retry (exponential backoff with cap) + backoff = time.Duration(float64(backoff) * retryBackoffMultiplier) + backoff = min(backoff, retryBackoffMax) + } +} + +// watchOnce performs one watch cycle. Returns nil if context is canceled, +// returns error if watch fails and should be retried. +func (cw *CheckpointWatcher) watchOnce() error { + // First, get the current checkpoint status from etcd + status, modRev, err := cw.etcdClient.GetChangeFeedStatus(cw.ctx, cw.changefeedID) + if err != nil { + // Check if context is canceled + if cw.ctx.Err() != nil { + return nil + } + return errors.Annotate(err, "failed to get changefeed status") + } + + // Update latest checkpoint + cw.mu.Lock() + cw.latestCheckpoint = status.CheckpointTs + cw.notifyPendingTasksLocked() + cw.mu.Unlock() + + statusKey := etcd.GetEtcdKeyJob(cw.etcdClient.GetClusterID(), cw.changefeedID.DisplayName) + + log.Debug("Starting to watch checkpoint", + zap.String("changefeed ID", cw.changefeedID.String()), + zap.String("statusKey", statusKey), + zap.String("local cluster ID", cw.localClusterID), + zap.String("replicated cluster ID", cw.replicatedClusterID), + zap.Uint64("checkpoint", status.CheckpointTs), + zap.Int64("startRev", modRev+1)) + + watchCh := cw.etcdClient.GetEtcdClient().Watch( + cw.ctx, + statusKey, + "checkpoint-watcher", + clientv3.WithRev(modRev+1), + ) + + for { + select { + case <-cw.ctx.Done(): + return nil + case watchResp, ok := <-watchCh: + if !ok { + return errors.Errorf("[changefeedID: %s] watch channel closed", cw.changefeedID.String()) + } + + if err := watchResp.Err(); err != nil { + return errors.Annotatef(err, "[changefeedID: %s] watch error", cw.changefeedID.String()) + } + + for _, event := range watchResp.Events { + if event.Type == clientv3.EventTypeDelete { + return errors.Annotatef(errChangefeedKeyDeleted, "[changefeedID: %s]", cw.changefeedID.String()) + } + + // Parse the updated status + newStatus := &config.ChangeFeedStatus{} + if err := newStatus.Unmarshal(event.Kv.Value); err != nil { + log.Warn("failed to unmarshal changefeed status, skipping", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Error(err)) + continue + } + + checkpointTs := newStatus.CheckpointTs + log.Debug("Checkpoint updated", + zap.String("changefeedID", cw.changefeedID.String()), + zap.Uint64("checkpoint", checkpointTs)) + + // Update latest checkpoint and notify pending tasks + cw.mu.Lock() + if checkpointTs > cw.latestCheckpoint { + cw.latestCheckpoint = checkpointTs + cw.notifyPendingTasksLocked() + } + cw.mu.Unlock() + } + } + } +} + +// notifyPendingTasksLocked notifies pending tasks whose minCheckpointTs has been exceeded +// Must be called with mu locked +func (cw *CheckpointWatcher) notifyPendingTasksLocked() { + remaining := cw.pendingTasks[:0] + for _, task := range cw.pendingTasks { + if cw.latestCheckpoint > task.minCheckpointTs { + // Non-blocking send since channel has buffer of 1 + select { + case task.respCh <- cw.latestCheckpoint: + default: + } + } else { + remaining = append(remaining, task) + } + } + cw.pendingTasks = remaining +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go new file mode 100644 index 0000000000..c2f0055ad0 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go @@ -0,0 +1,543 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pingcap/errors" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/etcd" + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" +) + +func TestCheckpointWatcher_AdvanceCheckpointTs_AlreadyExceeds(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + // Create a watch channel that won't send anything during this test + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Request checkpoint that's already exceeded + minCheckpointTs := uint64(500) + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), minCheckpointTs) + require.NoError(t, err) + require.Equal(t, initialCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_AdvanceCheckpointTs_WaitForUpdate(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + updatedCheckpoint := uint64(2000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start waiting for checkpoint in a goroutine + var checkpoint uint64 + var advanceErr error + done := make(chan struct{}) + go func() { + checkpoint, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(1500)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Simulate checkpoint update via watch channel + newStatus := &config.ChangeFeedStatus{CheckpointTs: updatedCheckpoint} + statusStr, err := newStatus.Marshal() + require.NoError(t, err) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypePut, + Kv: &mvccpb.KeyValue{ + Value: []byte(statusStr), + }, + }, + }, + } + + select { + case <-done: + require.NoError(t, advanceErr) + require.Equal(t, updatedCheckpoint, checkpoint) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for checkpoint advance") + } +} + +func TestCheckpointWatcher_AdvanceCheckpointTs_ContextCanceled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Create a context that will be canceled + advanceCtx, advanceCancel := context.WithCancel(t.Context()) + + var advanceErr error + done := make(chan struct{}) + go func() { + _, advanceErr = watcher.AdvanceCheckpointTs(advanceCtx, uint64(2000)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Cancel the context + advanceCancel() + + select { + case <-done: + require.Error(t, advanceErr) + require.Contains(t, advanceErr.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for context cancellation") + } +} + +func TestCheckpointWatcher_Close(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start waiting for checkpoint in a goroutine + var advanceErr error + done := make(chan struct{}) + go func() { + _, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + close(done) + }() + + // Give some time for the task to be registered + time.Sleep(50 * time.Millisecond) + + // Close the watcher + watcher.Close() + + select { + case <-done: + require.Error(t, advanceErr) + // Error can be "closed" or "canceled" depending on timing + errMsg := advanceErr.Error() + require.True(t, + strings.Contains(errMsg, "closed") || strings.Contains(errMsg, "canceled"), + "expected error to contain 'closed' or 'canceled', got: %s", errMsg) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for watcher close") + } + + // Verify that subsequent calls return error + _, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(100)) + require.Error(t, err) + // Error can be "closed" or "canceled" depending on timing + errMsg := err.Error() + require.True(t, + strings.Contains(errMsg, "closed") || strings.Contains(errMsg, "canceled"), + "expected error to contain 'closed' or 'canceled', got: %s", errMsg) +} + +func TestCheckpointWatcher_MultiplePendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 10) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Start multiple goroutines waiting for different checkpoints + var wg sync.WaitGroup + results := make([]struct { + checkpoint uint64 + err error + }, 3) + + for i := range 3 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + minTs := uint64(1100 + idx*500) // 1100, 1600, 2100 + results[idx].checkpoint, results[idx].err = watcher.AdvanceCheckpointTs(context.Background(), minTs) + }(i) + } + + // Give some time for tasks to be registered + time.Sleep(50 * time.Millisecond) + + // Send checkpoint updates + checkpoints := []uint64{1500, 2000, 2500} + for _, cp := range checkpoints { + newStatus := &config.ChangeFeedStatus{CheckpointTs: cp} + statusStr, err := newStatus.Marshal() + require.NoError(t, err) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypePut, + Kv: &mvccpb.KeyValue{ + Value: []byte(statusStr), + }, + }, + }, + } + // Give some time between updates + time.Sleep(20 * time.Millisecond) + } + + // Wait for all goroutines to complete + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Verify results + for i := range 3 { + require.NoError(t, results[i].err, "task %d should not have error", i) + minTs := uint64(1100 + i*500) + require.Greater(t, results[i].checkpoint, minTs, "task %d checkpoint should exceed minTs", i) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for all tasks to complete") + } +} + +func TestCheckpointWatcher_NotifyPendingTasksLocked(t *testing.T) { + cw := &CheckpointWatcher{ + latestCheckpoint: 2000, + pendingTasks: []*waitCheckpointTask{ + {respCh: make(chan uint64, 1), minCheckpointTs: 1000}, + {respCh: make(chan uint64, 1), minCheckpointTs: 1500}, + {respCh: make(chan uint64, 1), minCheckpointTs: 2500}, + {respCh: make(chan uint64, 1), minCheckpointTs: 3000}, + }, + } + + cw.notifyPendingTasksLocked() + + // Tasks with minCheckpointTs < 2000 should be notified and removed + require.Len(t, cw.pendingTasks, 2) + require.Equal(t, uint64(2500), cw.pendingTasks[0].minCheckpointTs) + require.Equal(t, uint64(3000), cw.pendingTasks[1].minCheckpointTs) +} + +func TestCheckpointWatcher_InitialCheckpointNotifiesPendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(5000) + + // Setup mock expectations - initial checkpoint is high enough + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize and get the initial checkpoint + time.Sleep(100 * time.Millisecond) + + // Request checkpoint that's already exceeded by initial checkpoint + checkpoint, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(1000)) + require.NoError(t, err) + require.Equal(t, initialCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_WatchErrorRetry(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + retryCheckpoint := uint64(2000) + + // First call succeeds, second call (retry) also succeeds with updated checkpoint + firstCall := mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: retryCheckpoint}, + int64(101), + nil, + ).After(firstCall) + + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + // First watch channel will be closed (simulating error), second watch channel will work + watchCh1 := make(chan clientv3.WatchResponse) + watchCh2 := make(chan clientv3.WatchResponse, 1) + firstWatch := mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh2).After(firstWatch) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Close the first watch channel to trigger retry + close(watchCh1) + + // Wait for retry to happen (backoff + processing time) + time.Sleep(700 * time.Millisecond) + + // After retry, checkpoint should be updated to retryCheckpoint + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), uint64(1500)) + require.NoError(t, err) + require.Equal(t, retryCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_GetStatusRetry(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + successCheckpoint := uint64(2000) + + // First call fails, second call succeeds + firstCall := mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + nil, + int64(0), + errors.Errorf("connection refused"), + ) + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: successCheckpoint}, + int64(100), + nil, + ).After(firstCall) + + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for retry to happen (backoff + processing time) + time.Sleep(700 * time.Millisecond) + + // After retry, checkpoint should be available + checkpoint, err := watcher.AdvanceCheckpointTs(t.Context(), uint64(1000)) + require.NoError(t, err) + require.Equal(t, successCheckpoint, checkpoint) +} + +func TestCheckpointWatcher_KeyDeleted(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + // Setup mock expectations + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + // Wait for watcher to initialize + time.Sleep(50 * time.Millisecond) + + // Send delete event + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypeDelete, + }, + }, + } + + // Give time for the error to be processed + time.Sleep(50 * time.Millisecond) + + // Now trying to advance should return an error + _, err := watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + require.Error(t, err) + require.Contains(t, err.Error(), "deleted") +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go new file mode 100644 index 0000000000..b6e7d92b62 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher.go @@ -0,0 +1,70 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/consumer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/cloudstorage" + "github.com/pingcap/tidb/br/pkg/storage" +) + +type S3Watcher struct { + checkpointWatcher Watcher + consumer *consumer.S3Consumer +} + +func NewS3Watcher( + checkpointWatcher Watcher, + s3Storage storage.ExternalStorage, + tables map[string][]string, +) *S3Watcher { + consumer := consumer.NewS3Consumer(s3Storage, tables) + return &S3Watcher{ + checkpointWatcher: checkpointWatcher, + consumer: consumer, + } +} + +func (sw *S3Watcher) Close() { + sw.checkpointWatcher.Close() +} + +func (sw *S3Watcher) InitializeFromCheckpoint(ctx context.Context, clusterID string, checkpoint *recorder.Checkpoint) (map[cloudstorage.DmlPathKey]types.IncrementalData, error) { + return sw.consumer.InitializeFromCheckpoint(ctx, clusterID, checkpoint) +} + +func (sw *S3Watcher) AdvanceS3CheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + checkpointTs, err := sw.checkpointWatcher.AdvanceCheckpointTs(ctx, minCheckpointTs) + if err != nil { + return 0, errors.Annotate(err, "advance s3 checkpoint timestamp failed") + } + + return checkpointTs, nil +} + +func (sw *S3Watcher) ConsumeNewFiles( + ctx context.Context, +) (map[cloudstorage.DmlPathKey]types.IncrementalData, map[types.SchemaTableKey]types.VersionKey, error) { + // TODO: get the index updated from the s3 + newData, maxVersionMap, err := sw.consumer.ConsumeNewFiles(ctx) + if err != nil { + return nil, nil, errors.Annotate(err, "consume new files failed") + } + return newData, maxVersionMap, nil +} diff --git a/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go new file mode 100644 index 0000000000..b3ecfa16d6 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/watcher/s3_watcher_test.go @@ -0,0 +1,202 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package watcher + +import ( + "context" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/consumer" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/stretchr/testify/require" +) + +// mockWatcher is a mock implementation of the Watcher interface for testing. +type mockWatcher struct { + advanceCheckpointTsFn func(ctx context.Context, minCheckpointTs uint64) (uint64, error) + closeFn func() + closed bool +} + +func (m *mockWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + if m.advanceCheckpointTsFn != nil { + return m.advanceCheckpointTsFn(ctx, minCheckpointTs) + } + return 0, nil +} + +func (m *mockWatcher) Close() { + m.closed = true + if m.closeFn != nil { + m.closeFn() + } +} + +func TestS3Watcher_Close(t *testing.T) { + t.Parallel() + + t.Run("close delegates to checkpoint watcher", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + sw.Close() + require.True(t, mock.closed) + }) + + t.Run("close calls custom close function", func(t *testing.T) { + t.Parallel() + closeCalled := false + mock := &mockWatcher{ + closeFn: func() { + closeCalled = true + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + sw.Close() + require.True(t, closeCalled) + }) +} + +func TestS3Watcher_AdvanceS3CheckpointTs(t *testing.T) { + t.Parallel() + + t.Run("advance checkpoint ts success", func(t *testing.T) { + t.Parallel() + expectedCheckpoint := uint64(5000) + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + require.Equal(t, uint64(3000), minCheckpointTs) + return expectedCheckpoint, nil + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.NoError(t, err) + require.Equal(t, expectedCheckpoint, checkpoint) + }) + + t.Run("advance checkpoint ts error", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + return 0, errors.Errorf("connection lost") + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.Error(t, err) + require.Equal(t, uint64(0), checkpoint) + require.Contains(t, err.Error(), "advance s3 checkpoint timestamp failed") + require.Contains(t, err.Error(), "connection lost") + }) + + t.Run("advance checkpoint ts with context canceled", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{ + advanceCheckpointTsFn: func(ctx context.Context, minCheckpointTs uint64) (uint64, error) { + return 0, context.Canceled + }, + } + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint, err := sw.AdvanceS3CheckpointTs(context.Background(), uint64(3000)) + require.Error(t, err) + require.Equal(t, uint64(0), checkpoint) + require.Contains(t, err.Error(), "advance s3 checkpoint timestamp failed") + }) +} + +func TestS3Watcher_InitializeFromCheckpoint(t *testing.T) { + t.Parallel() + + t.Run("nil checkpoint returns nil", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + result, err := sw.InitializeFromCheckpoint(context.Background(), "cluster1", nil) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("empty checkpoint returns nil", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + checkpoint := recorder.NewCheckpoint() + result, err := sw.InitializeFromCheckpoint(context.Background(), "cluster1", checkpoint) + require.NoError(t, err) + require.Nil(t, result) + }) +} + +func TestS3Watcher_ConsumeNewFiles(t *testing.T) { + t.Parallel() + + t.Run("consume new files with empty tables", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), map[string][]string{}), + } + + newData, maxVersionMap, err := sw.ConsumeNewFiles(context.Background()) + require.NoError(t, err) + require.Empty(t, newData) + require.Empty(t, maxVersionMap) + }) + + t.Run("consume new files with nil tables", func(t *testing.T) { + t.Parallel() + mock := &mockWatcher{} + sw := &S3Watcher{ + checkpointWatcher: mock, + consumer: consumer.NewS3Consumer(storage.NewMemStorage(), nil), + } + + newData, maxVersionMap, err := sw.ConsumeNewFiles(context.Background()) + require.NoError(t, err) + require.Empty(t, newData) + require.Empty(t, maxVersionMap) + }) +} From 079cfde6e16d57ce09403a3cb68ca7921dea94d5 Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Tue, 24 Feb 2026 14:47:44 +0800 Subject: [PATCH 02/10] make clean Signed-off-by: Jianjun Liao --- .../checker/checker.go | 48 +++++---- .../checker/checker_test.go | 58 +++++++++-- .../checker/failpoint.go | 98 +++++++++++++++++++ .../config/config_test.go | 24 ++--- .../integration/integration_test.go | 3 +- .../recorder/recorder.go | 12 +-- .../recorder/recorder_test.go | 24 ++--- cmd/multi-cluster-consistency-checker/task.go | 5 +- .../watcher/checkpoint_watcher.go | 28 +++++- .../watcher/checkpoint_watcher_test.go | 58 +++++++++++ pkg/sink/cloudstorage/path_key.go | 69 +++++++++++++ pkg/sink/cloudstorage/path_key_test.go | 41 +++++++- 12 files changed, 405 insertions(+), 63 deletions(-) create mode 100644 cmd/multi-cluster-consistency-checker/checker/failpoint.go diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go index 1365b9142b..3f35372798 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -17,6 +17,7 @@ import ( "context" "sort" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/config" "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/decoder" @@ -409,9 +410,17 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( cd.checkedRecordsCount++ replicatedRecord, skipped := checker.FindClusterReplicatedData(replicatedClusterID, schemaKey, record.Pk, record.CommitTs) if skipped { + failpoint.Inject("multiClusterConsistencyCheckerLWWViolation", func() { + Write("multiClusterConsistencyCheckerLWWViolation", []RowRecord{ + { + CommitTs: record.CommitTs, + PrimaryKeys: record.PkMap, + }, + }) + }) log.Debug("replicated record skipped by LWW", - zap.String("local cluster ID", cd.clusterID), - zap.String("replicated cluster ID", replicatedClusterID), + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), zap.String("schemaKey", schemaKey), zap.String("pk", record.PkStr), zap.Uint64("commitTs", record.CommitTs)) @@ -421,15 +430,15 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( if replicatedRecord == nil { // data loss detected log.Error("data loss detected", - zap.String("local cluster ID", cd.clusterID), - zap.String("replicated cluster ID", replicatedClusterID), + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), zap.Any("record", record)) cd.report.AddDataLossItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, record.CommitTs) } else if !record.EqualReplicatedRecord(replicatedRecord) { // data inconsistent detected log.Error("data inconsistent detected", - zap.String("local cluster ID", cd.clusterID), - zap.String("replicated cluster ID", replicatedClusterID), + zap.String("localClusterID", cd.clusterID), + zap.String("replicatedClusterID", replicatedClusterID), zap.Any("record", record)) cd.report.AddDataInconsistentItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, replicatedRecord.OriginTs, record.CommitTs, replicatedRecord.CommitTs, diffColumns(record, replicatedRecord)) } @@ -450,7 +459,7 @@ func (cd *clusterDataChecker) dataRedundantDetection(checker *DataChecker) { if !checker.FindSourceLocalData(cd.clusterID, schemaKey, record.Pk, record.OriginTs) { // data redundant detected log.Error("data redundant detected", - zap.String("replicated cluster ID", cd.clusterID), + zap.String("replicatedClusterID", cd.clusterID), zap.Any("record", record)) cd.report.AddDataRedundantItem(schemaKey, record.PkMap, record.PkStr, record.OriginTs, record.CommitTs) } @@ -530,7 +539,7 @@ type DataChecker struct { clusterDataCheckers map[string]*clusterDataChecker } -func NewDataChecker(ctx context.Context, clusterConfig map[string]config.ClusterConfig, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) *DataChecker { +func NewDataChecker(ctx context.Context, clusterConfig map[string]config.ClusterConfig, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) (*DataChecker, error) { clusterDataChecker := make(map[string]*clusterDataChecker) for clusterID := range clusterConfig { clusterDataChecker[clusterID] = newClusterDataChecker(clusterID) @@ -540,22 +549,27 @@ func NewDataChecker(ctx context.Context, clusterConfig map[string]config.Cluster checkableRound: 0, clusterDataCheckers: clusterDataChecker, } - checker.initializeFromCheckpoint(ctx, checkpointDataMap, checkpoint) - return checker + if err := checker.initializeFromCheckpoint(ctx, checkpointDataMap, checkpoint); err != nil { + return nil, errors.Trace(err) + } + return checker, nil } -func (c *DataChecker) initializeFromCheckpoint(ctx context.Context, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) { +func (c *DataChecker) initializeFromCheckpoint(ctx context.Context, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) error { if checkpoint == nil { - return + return nil } if checkpoint.CheckpointItems[2] == nil { - return + return nil } c.round = checkpoint.CheckpointItems[2].Round + 1 c.checkableRound = checkpoint.CheckpointItems[2].Round + 1 for _, clusterDataChecker := range c.clusterDataCheckers { - clusterDataChecker.InitializeFromCheckpoint(ctx, checkpointDataMap[clusterDataChecker.clusterID], checkpoint) + if err := clusterDataChecker.InitializeFromCheckpoint(ctx, checkpointDataMap[clusterDataChecker.clusterID], checkpoint); err != nil { + return errors.Trace(err) + } } + return nil } // FindClusterReplicatedData checks whether the record is present in the replicated data @@ -611,9 +625,9 @@ func (c *DataChecker) CheckInNextTimeWindow(newTimeWindowData map[string]types.T zap.Uint64("round", c.round), zap.Bool("enableDataLoss", enableDataLoss), zap.Bool("enableDataRedundant", enableDataRedundant), - zap.Int("checked records count", clusterDataChecker.checkedRecordsCount), - zap.Int("new time window records count", clusterDataChecker.newTimeWindowRecordsCount), - zap.Int("lww skipped records count", clusterDataChecker.lwwSkippedRecordsCount)) + zap.Int("checkedRecordsCount", clusterDataChecker.checkedRecordsCount), + zap.Int("newTimeWindowRecordsCount", clusterDataChecker.newTimeWindowRecordsCount), + zap.Int("lwwSkippedRecordsCount", clusterDataChecker.lwwSkippedRecordsCount)) report.AddClusterReport(clusterID, clusterDataChecker.GetReport()) } diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go index 18453f35c2..589fd52129 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker_test.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -45,7 +45,8 @@ func TestNewDataChecker(t *testing.T) { }, } - checker := NewDataChecker(context.Background(), clusterConfig, nil, nil) + checker, initErr := NewDataChecker(context.Background(), clusterConfig, nil, nil) + require.NoError(t, initErr) require.NotNil(t, checker) require.Equal(t, uint64(0), checker.round) require.Len(t, checker.clusterDataCheckers, 2) @@ -54,6 +55,37 @@ func TestNewDataChecker(t *testing.T) { }) } +func TestNewDataCheckerInitializeFromCheckpointError(t *testing.T) { + t.Parallel() + + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + invalidContent := []byte(`{"pkNames":["id"],"isDdl":false,"type":"INSERT","mysqlType":{"id":"int"},"data":[{"val":"x"}],"_tidb":{"commitTs":1}}`) + checkpointDataMap := map[string]map[cloudstorage.DmlPathKey]types.IncrementalData{ + "c1": { + {}: { + DataContentSlices: map[cloudstorage.FileIndexKey][][]byte{ + {}: {invalidContent}, + }, + }, + }, + } + + _, err := NewDataChecker(context.Background(), clusterConfig, checkpointDataMap, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "column value of column id not found") +} + func TestNewClusterDataChecker(t *testing.T) { t.Parallel() @@ -403,7 +435,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { t.Run("all consistent", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) base := makeBaseRounds() round2 := map[string]types.TimeWindowData{ @@ -435,7 +468,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { t.Run("data loss detected", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) base := makeBaseRounds() // Round 2: c1 has locally-written pk=3 but c2 has NO matching replicated data @@ -475,7 +509,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { t.Run("data inconsistent detected", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) base := makeBaseRounds() // Round 2: c2 has replicated data for pk=3 but with wrong column value @@ -518,7 +553,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { t.Run("data redundant detected", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) base := makeBaseRounds() round2 := map[string]types.TimeWindowData{ @@ -562,7 +598,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { t.Run("lww violation detected", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) base := makeBaseRounds() round2 := map[string]types.TimeWindowData{ @@ -617,7 +654,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { // so a violation introduced in round 1 data should surface immediately. t.Run("lww violation detected at round 1", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) // Round 0: [0, 100] — c1 writes pk=1 (commitTs=50, compareTs=50) round0 := map[string]types.TimeWindowData{ @@ -664,7 +702,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { // and if the replicated counterpart is missing, data loss is detected at round 2. t.Run("data loss detected at round 2", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) round0 := map[string]types.TimeWindowData{ "c1": makeTWData(0, 100, nil, nil), @@ -716,7 +755,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { // - Round 3: orphan in [2] and enableDataRedundant=true → flagged. t.Run("data redundant detected at round 3 not round 2", func(t *testing.T) { t.Parallel() - checker := NewDataChecker(ctx, clusterCfg, nil, nil) + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) round0 := map[string]types.TimeWindowData{ "c1": makeTWData(0, 100, nil, nil), diff --git a/cmd/multi-cluster-consistency-checker/checker/failpoint.go b/cmd/multi-cluster-consistency-checker/checker/failpoint.go new file mode 100644 index 0000000000..8dac48cee9 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/checker/failpoint.go @@ -0,0 +1,98 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "encoding/json" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// envKey is the environment variable that controls the output file path. +const envKey = "TICDC_MULTI_CLUSTER_CONSISTENCY_CHECKER_FAILPOINT_RECORD_FILE" + +// RowRecord captures the essential identity of a single affected row. +type RowRecord struct { + CommitTs uint64 `json:"commitTs"` + PrimaryKeys map[string]any `json:"primaryKeys"` +} + +// Record is one line written to the JSONL file. +type Record struct { + Time string `json:"time"` + Failpoint string `json:"failpoint"` + Rows []RowRecord `json:"rows"` +} + +var ( + initOnce sync.Once + mu sync.Mutex + file *os.File + disabled atomic.Bool +) + +func ensureFile() { + initOnce.Do(func() { + path := os.Getenv(envKey) + if path == "" { + disabled.Store(true) + return + } + var err error + file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + log.Warn("failed to open failpoint record file, recording disabled", + zap.String("path", path), zap.Error(err)) + disabled.Store(true) + return + } + log.Info("failpoint record file opened", zap.String("path", path)) + }) +} + +// Write persists one failpoint event to the JSONL file. +// It is safe for concurrent use. +// If the env var is not set the call is a no-op (zero allocation). +func Write(failpoint string, rows []RowRecord) { + if disabled.Load() { + return + } + ensureFile() + if file == nil { + return + } + + rec := Record{ + Time: time.Now().UTC().Format(time.RFC3339Nano), + Failpoint: failpoint, + Rows: rows, + } + data, err := json.Marshal(rec) + if err != nil { + log.Warn("failed to marshal failpoint record", zap.Error(err)) + return + } + data = append(data, '\n') + + mu.Lock() + defer mu.Unlock() + if _, err := file.Write(data); err != nil { + log.Warn("failed to write failpoint record", zap.Error(err)) + } +} diff --git a/cmd/multi-cluster-consistency-checker/config/config_test.go b/cmd/multi-cluster-consistency-checker/config/config_test.go index e57f2e6a7b..a0027efb5f 100644 --- a/cmd/multi-cluster-consistency-checker/config/config_test.go +++ b/cmd/multi-cluster-consistency-checker/config/config_test.go @@ -51,7 +51,7 @@ data-dir = "/tmp/data" [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -98,7 +98,7 @@ max-report-files = 50 [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -119,7 +119,7 @@ max-report-files = 50 tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "config.toml") configContent := `invalid toml content [` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -151,7 +151,7 @@ log-level = "info" [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -184,7 +184,7 @@ data-dir = "/tmp/data" [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -219,7 +219,7 @@ data-dir = "/tmp/data" [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -239,7 +239,7 @@ data-dir = "/tmp/data" [global.tables] schema1 = ["table1"] ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -264,7 +264,7 @@ data-dir = "/tmp/data" s3-sink-uri = "s3://bucket/cluster1/" s3-changefeed-id = "s3-cf-1" ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -289,7 +289,7 @@ data-dir = "/tmp/data" pd-addrs = ["127.0.0.1:2379"] s3-changefeed-id = "s3-cf-1" ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -314,7 +314,7 @@ data-dir = "/tmp/data" pd-addrs = ["127.0.0.1:2379"] s3-sink-uri = "s3://bucket/cluster1/" ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -347,7 +347,7 @@ data-dir = "/tmp/data" s3-sink-uri = "s3://bucket/cluster2/" s3-changefeed-id = "s3-cf-2" ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) @@ -382,7 +382,7 @@ data-dir = "/tmp/data" [clusters.cluster2.peer-cluster-changefeed-config] cluster1 = { changefeed-id = "cf-2-to-1" } ` - err := os.WriteFile(configPath, []byte(configContent), 0644) + err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) cfg, err := LoadConfig(configPath) diff --git a/cmd/multi-cluster-consistency-checker/integration/integration_test.go b/cmd/multi-cluster-consistency-checker/integration/integration_test.go index 236ebb2568..9a6dbc78be 100644 --- a/cmd/multi-cluster-consistency-checker/integration/integration_test.go +++ b/cmd/multi-cluster-consistency-checker/integration/integration_test.go @@ -65,7 +65,8 @@ func setupEnv(t *testing.T) *testEnv { require.NoError(t, err) clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} - dc := checker.NewDataChecker(ctx, clusterCfg, nil, nil) + dc, err := checker.NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, err) return &testEnv{ctx: ctx, mc: mc, advancer: twa, checker: dc} } diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder.go b/cmd/multi-cluster-consistency-checker/recorder/recorder.go index 3affe7b7dd..e46d07a88d 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/recorder.go +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder.go @@ -44,10 +44,10 @@ type Recorder struct { } func NewRecorder(dataDir string, clusters map[string]config.ClusterConfig, maxReportFiles int) (*Recorder, error) { - if err := os.MkdirAll(filepath.Join(dataDir, "report"), 0755); err != nil { + if err := os.MkdirAll(filepath.Join(dataDir, "report"), 0o755); err != nil { return nil, errors.Trace(err) } - if err := os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755); err != nil { + if err := os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755); err != nil { return nil, errors.Trace(err) } if maxReportFiles <= 0 { @@ -143,9 +143,9 @@ func (r *Recorder) RecordTimeWindow(timeWindowData map[string]types.TimeWindowDa log.Info("time window advanced", zap.Uint64("round", report.Round), zap.String("clusterID", clusterID), - zap.Uint64("window left boundary", timeWindow.LeftBoundary), - zap.Uint64("window right boundary", timeWindow.RightBoundary), - zap.Any("checkpoint ts", timeWindow.CheckpointTs)) + zap.Uint64("windowLeftBoundary", timeWindow.LeftBoundary), + zap.Uint64("windowRightBoundary", timeWindow.RightBoundary), + zap.Any("checkpointTs", timeWindow.CheckpointTs)) } if report.NeedFlush() { if err := r.flushReport(report); err != nil { @@ -196,7 +196,7 @@ func atomicWriteFile(targetPath string, data []byte) error { // syncWriteFile writes data to a file and fsyncs it before returning, // guaranteeing that the content is durable on disk. func syncWriteFile(path string, data []byte) error { - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return errors.Trace(err) } diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go index af3a8297b3..75f1209b9a 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder_test.go @@ -448,11 +448,11 @@ func TestErrCheckpointCorruption(t *testing.T) { dataDir := t.TempDir() // Create report and checkpoint directories - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) // Write invalid JSON to checkpoint.json - err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), []byte("{bad json"), 0600) + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), []byte("{bad json"), 0o600) require.NoError(t, err) _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) @@ -513,11 +513,11 @@ func TestErrCheckpointCorruption(t *testing.T) { dataDir := t.TempDir() // Create directories - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) // Create checkpoint.json as a directory so ReadFile fails with a non-corruption I/O error - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json"), 0o755)) _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) require.Error(t, err) @@ -530,11 +530,11 @@ func TestErrCheckpointCorruption(t *testing.T) { dataDir := t.TempDir() // Create directories - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) // Make checkpoint.json.bak a directory so ReadFile fails with an I/O error - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), 0o755)) _, err := NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) require.Error(t, err) @@ -547,11 +547,11 @@ func TestErrCheckpointCorruption(t *testing.T) { dataDir := t.TempDir() // Create directories - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0755)) - require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "report"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(dataDir, "checkpoint"), 0o755)) // Write invalid JSON to checkpoint.json.bak (simulate crash recovery with corrupted backup) - err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), []byte("not valid json"), 0600) + err := os.WriteFile(filepath.Join(dataDir, "checkpoint", "checkpoint.json.bak"), []byte("not valid json"), 0o600) require.NoError(t, err) _, err = NewRecorder(dataDir, map[string]config.ClusterConfig{"c1": {}}, 0) diff --git a/cmd/multi-cluster-consistency-checker/task.go b/cmd/multi-cluster-consistency-checker/task.go index 1bb8c35d86..fa93ffcd55 100644 --- a/cmd/multi-cluster-consistency-checker/task.go +++ b/cmd/multi-cluster-consistency-checker/task.go @@ -64,7 +64,10 @@ func runTask(ctx context.Context, cfg *config.Config, dryRun bool) error { if err != nil { return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} } - dataChecker := checker.NewDataChecker(ctx, cfg.Clusters, checkpointDataMap, rec.GetCheckpoint()) + dataChecker, err := checker.NewDataChecker(ctx, cfg.Clusters, checkpointDataMap, rec.GetCheckpoint()) + if err != nil { + return &ExitError{Code: ExitCodeTransient, Err: errors.Trace(err)} + } log.Info("Starting consistency checker task") for { diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go index 6aa05c17f6..1fb96771cd 100644 --- a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher.go @@ -67,6 +67,15 @@ type CheckpointWatcher struct { closed bool } +// failPendingTasksLocked wakes all pending tasks after a terminal watcher failure. +// Must be called with mu locked. +func (cw *CheckpointWatcher) failPendingTasksLocked() { + for _, task := range cw.pendingTasks { + close(task.respCh) + } + cw.pendingTasks = nil +} + func NewCheckpointWatcher( ctx context.Context, localClusterID, replicatedClusterID, changefeedID string, @@ -135,7 +144,17 @@ func (cw *CheckpointWatcher) AdvanceCheckpointTs(ctx context.Context, minCheckpo return 0, errors.Annotate(cw.ctx.Err(), "watcher context canceled") case checkpoint, ok := <-task.respCh: if !ok { - return 0, errors.Errorf("checkpoint watcher is closed") + cw.mu.Lock() + err := cw.watchErr + closed := cw.closed + cw.mu.Unlock() + if err != nil { + return 0, err + } + if closed { + return 0, errors.Errorf("checkpoint watcher is closed") + } + return 0, errors.Errorf("checkpoint watcher failed") } return checkpoint, nil } @@ -176,6 +195,7 @@ func (cw *CheckpointWatcher) run() { if errors.Is(err, errChangefeedKeyDeleted) { cw.mu.Lock() cw.watchErr = err + cw.failPendingTasksLocked() cw.mu.Unlock() return } @@ -223,10 +243,10 @@ func (cw *CheckpointWatcher) watchOnce() error { statusKey := etcd.GetEtcdKeyJob(cw.etcdClient.GetClusterID(), cw.changefeedID.DisplayName) log.Debug("Starting to watch checkpoint", - zap.String("changefeed ID", cw.changefeedID.String()), + zap.String("changefeedID", cw.changefeedID.String()), zap.String("statusKey", statusKey), - zap.String("local cluster ID", cw.localClusterID), - zap.String("replicated cluster ID", cw.replicatedClusterID), + zap.String("localClusterID", cw.localClusterID), + zap.String("replicatedClusterID", cw.replicatedClusterID), zap.Uint64("checkpoint", status.CheckpointTs), zap.Int64("startRev", modRev+1)) diff --git a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go index c2f0055ad0..4862ce2990 100644 --- a/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go +++ b/cmd/multi-cluster-consistency-checker/watcher/checkpoint_watcher_test.go @@ -541,3 +541,61 @@ func TestCheckpointWatcher_KeyDeleted(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "deleted") } + +func TestCheckpointWatcher_KeyDeleted_UnblocksPendingTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEtcdClient := etcd.NewMockCDCEtcdClient(ctrl) + mockClient := etcd.NewMockClient(ctrl) + + initialCheckpoint := uint64(1000) + + mockEtcdClient.EXPECT().GetChangeFeedStatus(gomock.Any(), gomock.Any()).Return( + &config.ChangeFeedStatus{CheckpointTs: initialCheckpoint}, + int64(100), + nil, + ) + mockEtcdClient.EXPECT().GetClusterID().Return("test-cluster").AnyTimes() + mockEtcdClient.EXPECT().GetEtcdClient().Return(mockClient).AnyTimes() + + watchCh := make(chan clientv3.WatchResponse, 1) + mockClient.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(watchCh) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + watcher := NewCheckpointWatcher(ctx, "local-1", "replicated-1", "test-cf", mockEtcdClient) + defer watcher.Close() + + time.Sleep(50 * time.Millisecond) + + var ( + checkpoint uint64 + advanceErr error + ) + done := make(chan struct{}) + go func() { + checkpoint, advanceErr = watcher.AdvanceCheckpointTs(context.Background(), uint64(2000)) + close(done) + }() + + time.Sleep(50 * time.Millisecond) + + watchCh <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: clientv3.EventTypeDelete, + }, + }, + } + + select { + case <-done: + require.Zero(t, checkpoint) + require.Error(t, advanceErr) + require.Contains(t, advanceErr.Error(), "deleted") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for pending task to be unblocked") + } +} diff --git a/pkg/sink/cloudstorage/path_key.go b/pkg/sink/cloudstorage/path_key.go index d948af9873..3de0512400 100644 --- a/pkg/sink/cloudstorage/path_key.go +++ b/pkg/sink/cloudstorage/path_key.go @@ -162,3 +162,72 @@ func (d *DmlPathKey) ParseIndexFilePath(dateSeparator, path string) (string, err return matches[6], nil } + +// ParseDMLFilePath parses the dml file path and returns the max file index. +// DML file path pattern is as follows: +// {schema}/{table}/{table-version-separator}/{partition-separator}/{date-separator}/, where +// partition-separator and date-separator could be empty. +// DML file name pattern is as follows: CDC_{dispatcherID}_{num}.extension or CDC{num}.extension +func (d *DmlPathKey) ParseDMLFilePath(dateSeparator, path string) (*FileIndex, error) { + var partitionNum int64 + + str := `(\w+)\/(\w+)\/(\d+)\/(\d+)?\/*` + switch dateSeparator { + case config.DateSeparatorNone.String(): + str += `(\d{4})*` + case config.DateSeparatorYear.String(): + str += `(\d{4})\/` + case config.DateSeparatorMonth.String(): + str += `(\d{4}-\d{2})\/` + case config.DateSeparatorDay.String(): + str += `(\d{4}-\d{2}-\d{2})\/` + } + matchesLen := 8 + matchesFileIdx := 7 + // CDC[_{dispatcherID}_]{num}.extension + str += `CDC(?:_(\w+)_)?(\d+).\w+` + pathRE, err := regexp.Compile(str) + if err != nil { + return nil, err + } + + matches := pathRE.FindStringSubmatch(path) + if len(matches) != matchesLen { + return nil, fmt.Errorf("cannot match dml path pattern for %s", path) + } + + version, err := strconv.ParseUint(matches[3], 10, 64) + if err != nil { + return nil, err + } + + if len(matches[4]) > 0 { + partitionNum, err = strconv.ParseInt(matches[4], 10, 64) + if err != nil { + return nil, err + } + } + + fileIdx, err := strconv.ParseUint(strings.TrimLeft(matches[matchesFileIdx], "0"), 10, 64) + if err != nil { + return nil, err + } + + *d = DmlPathKey{ + SchemaPathKey: SchemaPathKey{ + Schema: matches[1], + Table: matches[2], + TableVersion: version, + }, + PartitionNum: partitionNum, + Date: matches[5], + } + + return &FileIndex{ + FileIndexKey: FileIndexKey{ + DispatcherID: matches[6], + EnableTableAcrossNodes: matches[6] != "", + }, + Idx: fileIdx, + }, nil +} diff --git a/pkg/sink/cloudstorage/path_key_test.go b/pkg/sink/cloudstorage/path_key_test.go index f768f7cf74..f59145df7b 100644 --- a/pkg/sink/cloudstorage/path_key_test.go +++ b/pkg/sink/cloudstorage/path_key_test.go @@ -59,7 +59,7 @@ func TestSchemaPathKey(t *testing.T) { } } -func TestDmlPathKey(t *testing.T) { +func TestIndexFileKey(t *testing.T) { t.Parallel() dispatcherID := common.NewDispatcherID() @@ -107,3 +107,42 @@ func TestDmlPathKey(t *testing.T) { require.Equal(t, tc.path, fileName) } } + +func TestDmlPathKey(t *testing.T) { + t.Parallel() + + dispatcherID := common.NewDispatcherID() + testCases := []struct { + index int + fileIndexWidth int + extension string + path string + dmlkey DmlPathKey + }{ + { + index: 10, + fileIndexWidth: 20, + extension: ".csv", + path: fmt.Sprintf("schema1/table1/123456/2023-05-09/CDC_%s_00000000000000000010.csv", dispatcherID.String()), + dmlkey: DmlPathKey{ + SchemaPathKey: SchemaPathKey{ + Schema: "schema1", + Table: "table1", + TableVersion: 123456, + }, + PartitionNum: 0, + Date: "2023-05-09", + }, + }, + } + + for _, tc := range testCases { + var dmlkey DmlPathKey + fileIndex, err := dmlkey.ParseDMLFilePath("day", tc.path) + require.NoError(t, err) + require.Equal(t, tc.dmlkey, dmlkey) + + fileName := dmlkey.GenerateDMLFilePath(fileIndex, tc.extension, tc.fileIndexWidth) + require.Equal(t, tc.path, fileName) + } +} From f7392af35930e9f28c855e1162141283b9ae9e1a Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Tue, 24 Feb 2026 14:55:08 +0800 Subject: [PATCH 03/10] make clean Signed-off-by: Jianjun Liao --- cmd/multi-cluster-consistency-checker/checker/failpoint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/multi-cluster-consistency-checker/checker/failpoint.go b/cmd/multi-cluster-consistency-checker/checker/failpoint.go index 8dac48cee9..7104645d33 100644 --- a/cmd/multi-cluster-consistency-checker/checker/failpoint.go +++ b/cmd/multi-cluster-consistency-checker/checker/failpoint.go @@ -55,7 +55,7 @@ func ensureFile() { return } var err error - file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { log.Warn("failed to open failpoint record file, recording disabled", zap.String("path", path), zap.Error(err)) From a7bb1359bd4714109b3de33ba5352e0466aaf151 Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Tue, 24 Feb 2026 15:16:23 +0800 Subject: [PATCH 04/10] make clean Signed-off-by: Jianjun Liao --- cmd/multi-cluster-consistency-checker/consumer/consumer.go | 1 + cmd/multi-cluster-consistency-checker/decoder/decoder.go | 2 +- cmd/multi-cluster-consistency-checker/recorder/recorder.go | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer.go b/cmd/multi-cluster-consistency-checker/consumer/consumer.go index 07736f7e67..2f695f0867 100644 --- a/cmd/multi-cluster-consistency-checker/consumer/consumer.go +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer.go @@ -626,6 +626,7 @@ func (c *S3Consumer) downloadDMLFiles( fileContents := make(chan fileContent, len(tasks)) eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(128) for _, task := range tasks { eg.Go(func() error { filePath := task.dmlPathKey.GenerateDMLFilePath( diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder.go b/cmd/multi-cluster-consistency-checker/decoder/decoder.go index af0f8d01d0..d03c86379f 100644 --- a/cmd/multi-cluster-consistency-checker/decoder/decoder.go +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder.go @@ -188,7 +188,7 @@ func (d *columnValueDecoder) tryNext() (common.MessageType, bool) { } if err := json.Unmarshal(encodedData, msg); err != nil { - log.Error("canal-json decoder unmarshal data failed", + log.Error("canal json decoder unmarshal data failed", zap.Error(err), zap.ByteString("data", encodedData)) d.msg = nil return common.MessageTypeUnknown, true diff --git a/cmd/multi-cluster-consistency-checker/recorder/recorder.go b/cmd/multi-cluster-consistency-checker/recorder/recorder.go index e46d07a88d..06fb7c69bd 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/recorder.go +++ b/cmd/multi-cluster-consistency-checker/recorder/recorder.go @@ -113,6 +113,8 @@ func (r *Recorder) initializeCheckpoint() error { return errors.Annotatef(ErrCheckpointCorruption, "failed to unmarshal checkpoint.json: %v", err) } return nil + } else if !os.IsNotExist(err) { + return errors.Annotatef(ErrCheckpointCorruption, "failed to stat checkpoint.json: %v", err) } // checkpoint.json is missing — try recovering from the backup. @@ -132,6 +134,8 @@ func (r *Recorder) initializeCheckpoint() error { return errors.Trace(err) // transient I/O error } return nil + } else if !os.IsNotExist(err) { + return errors.Annotatef(ErrCheckpointCorruption, "failed to stat checkpoint.json.bak: %v", err) } // Neither file exists — fresh start. From 1b51f1d63c416360dd9c6acb5c2bbceab788a0a9 Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Wed, 25 Feb 2026 17:26:31 +0800 Subject: [PATCH 05/10] update Signed-off-by: Jianjun Liao --- .../checker/checker.go | 28 +++++++---- .../checker/checker_test.go | 46 +++++++++++++++++-- .../checker/failpoint.go | 1 + .../recorder/types.go | 32 ++++++------- .../recorder/types_test.go | 37 +++++++-------- 5 files changed, 94 insertions(+), 50 deletions(-) diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go index 3f35372798..b36e415aea 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -314,16 +314,24 @@ func (cd *clusterDataChecker) findClusterReplicatedDataInTimeWindow(timeWindowId if !exists { return nil, false } - records, exists := tableDataCache.replicatedDataCache[pk] - if !exists { - return nil, false - } - if record, exists := records[originTs]; exists { - return record, false + if records, exists := tableDataCache.replicatedDataCache[pk]; exists { + if record, exists := records[originTs]; exists { + return record, false + } + for _, record := range records { + if record.GetCompareTs() >= originTs { + return nil, true + } + } } - for _, record := range records { - if record.GetCompareTs() >= originTs { - return nil, true + // If no replicated record is found, a newer/equal local record with the + // same PK in the peer cluster also indicates this origin write can be + // considered overwritten by LWW semantics instead of hard data loss. + if records, exists := tableDataCache.localDataCache[pk]; exists { + for _, record := range records { + if record.GetCompareTs() >= originTs { + return nil, true + } } } return nil, false @@ -440,7 +448,7 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( zap.String("localClusterID", cd.clusterID), zap.String("replicatedClusterID", replicatedClusterID), zap.Any("record", record)) - cd.report.AddDataInconsistentItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, replicatedRecord.OriginTs, record.CommitTs, replicatedRecord.CommitTs, diffColumns(record, replicatedRecord)) + cd.report.AddDataInconsistentItem(replicatedClusterID, schemaKey, record.PkMap, record.PkStr, record.CommitTs, replicatedRecord.CommitTs, diffColumns(record, replicatedRecord)) } } } diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go index 589fd52129..f1cc67aa49 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker_test.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -501,12 +501,49 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { tableItems := c1Report.TableFailureItems[defaultSchemaKey] require.Len(t, tableItems.DataLossItems, 1) require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) - require.Equal(t, uint64(250), tableItems.DataLossItems[0].CommitTS) + require.Equal(t, uint64(250), tableItems.DataLossItems[0].LocalCommitTS) // c2 should have no issues c2Report := lastReport.ClusterReports["c2"] require.Empty(t, c2Report.TableFailureItems) }) + t.Run("no data loss when peer local newer version exists for same pk", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + base := makeBaseRounds() + + // Round 2: c1 and c2 both write the same PK locally. + // c2 does not contain a replicated row with originTs=250, but has a newer + // local row (commitTs=270) for the same PK. This should be treated as LWW + // overwrite/skip instead of data loss. + round2 := map[string]types.TimeWindowData{ + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, + makeContent(makeCanalJSON(3, 250, 0, "c1-local"))), + "c2": makeTWData(200, 300, nil, + makeContent(makeCanalJSON(3, 270, 0, "c2-local"))), + } + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(4, 350, 0, "d"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(4, 360, 350, "d"))), + } + + rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} + var lastReport *recorder.Report + for i, roundData := range rounds { + report, err := checker.CheckInNextTimeWindow(roundData) + require.NoError(t, err, "round %d", i) + lastReport = report + } + + c1Report := lastReport.ClusterReports["c1"] + if tableItems, ok := c1Report.TableFailureItems[defaultSchemaKey]; ok { + require.Empty(t, tableItems.DataLossItems) + } + }) + t.Run("data inconsistent detected", func(t *testing.T) { t.Parallel() checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) @@ -542,7 +579,6 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { require.Empty(t, tableItems.DataLossItems) require.Len(t, tableItems.DataInconsistentItems, 1) require.Equal(t, "c2", tableItems.DataInconsistentItems[0].PeerClusterID) - require.Equal(t, uint64(250), tableItems.DataInconsistentItems[0].OriginTS) require.Equal(t, uint64(250), tableItems.DataInconsistentItems[0].LocalCommitTS) require.Equal(t, uint64(260), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) @@ -593,7 +629,7 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { tableItems := c2Report.TableFailureItems[defaultSchemaKey] require.Len(t, tableItems.DataRedundantItems, 1) require.Equal(t, uint64(330), tableItems.DataRedundantItems[0].OriginTS) - require.Equal(t, uint64(340), tableItems.DataRedundantItems[0].CommitTS) + require.Equal(t, uint64(340), tableItems.DataRedundantItems[0].ReplicatedCommitTS) }) t.Run("lww violation detected", func(t *testing.T) { @@ -740,7 +776,7 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { tableItems := c1Report.TableFailureItems[defaultSchemaKey] require.Len(t, tableItems.DataLossItems, 1) require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) - require.Equal(t, uint64(150), tableItems.DataLossItems[0].CommitTS) + require.Equal(t, uint64(150), tableItems.DataLossItems[0].LocalCommitTS) }) // data redundant detected at round 3 (not round 2): @@ -815,6 +851,6 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { c2TableItems := c2Report.TableFailureItems[defaultSchemaKey] require.Len(t, c2TableItems.DataRedundantItems, 1) require.Equal(t, uint64(330), c2TableItems.DataRedundantItems[0].OriginTS) - require.Equal(t, uint64(340), c2TableItems.DataRedundantItems[0].CommitTS) + require.Equal(t, uint64(340), c2TableItems.DataRedundantItems[0].ReplicatedCommitTS) }) } diff --git a/cmd/multi-cluster-consistency-checker/checker/failpoint.go b/cmd/multi-cluster-consistency-checker/checker/failpoint.go index 7104645d33..0339f9c389 100644 --- a/cmd/multi-cluster-consistency-checker/checker/failpoint.go +++ b/cmd/multi-cluster-consistency-checker/checker/failpoint.go @@ -30,6 +30,7 @@ const envKey = "TICDC_MULTI_CLUSTER_CONSISTENCY_CHECKER_FAILPOINT_RECORD_FILE" // RowRecord captures the essential identity of a single affected row. type RowRecord struct { CommitTs uint64 `json:"commitTs"` + OriginTs uint64 `json:"originTs"` PrimaryKeys map[string]any `json:"primaryKeys"` } diff --git a/cmd/multi-cluster-consistency-checker/recorder/types.go b/cmd/multi-cluster-consistency-checker/recorder/types.go index 2c316fc965..a3c8cbd234 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/types.go +++ b/cmd/multi-cluster-consistency-checker/recorder/types.go @@ -25,13 +25,13 @@ import ( type DataLossItem struct { PeerClusterID string `json:"peer_cluster_id"` PK map[string]any `json:"pk"` - CommitTS uint64 `json:"commit_ts"` + LocalCommitTS uint64 `json:"local_commit_ts"` PKStr string `json:"-"` } func (item *DataLossItem) String() string { - return fmt.Sprintf("peer cluster: %s, pk: %s, commit ts: %d", item.PeerClusterID, item.PKStr, item.CommitTS) + return fmt.Sprintf("peer cluster: %s, pk: %s, local commit ts: %d", item.PeerClusterID, item.PKStr, item.LocalCommitTS) } type InconsistentColumn struct { @@ -47,7 +47,6 @@ func (c *InconsistentColumn) String() string { type DataInconsistentItem struct { PeerClusterID string `json:"peer_cluster_id"` PK map[string]any `json:"pk"` - OriginTS uint64 `json:"origin_ts"` LocalCommitTS uint64 `json:"local_commit_ts"` ReplicatedCommitTS uint64 `json:"replicated_commit_ts"` InconsistentColumns []InconsistentColumn `json:"inconsistent_columns,omitempty"` @@ -57,8 +56,8 @@ type DataInconsistentItem struct { func (item *DataInconsistentItem) String() string { var sb strings.Builder - fmt.Fprintf(&sb, "peer cluster: %s, pk: %s, origin ts: %d, local commit ts: %d, replicated commit ts: %d", - item.PeerClusterID, item.PKStr, item.OriginTS, item.LocalCommitTS, item.ReplicatedCommitTS) + fmt.Fprintf(&sb, "peer cluster: %s, pk: %s, local commit ts: %d, replicated commit ts: %d", + item.PeerClusterID, item.PKStr, item.LocalCommitTS, item.ReplicatedCommitTS) if len(item.InconsistentColumns) > 0 { sb.WriteString(", inconsistent columns: [") for i, col := range item.InconsistentColumns { @@ -73,15 +72,15 @@ func (item *DataInconsistentItem) String() string { } type DataRedundantItem struct { - PK map[string]any `json:"pk"` - OriginTS uint64 `json:"origin_ts"` - CommitTS uint64 `json:"commit_ts"` + PK map[string]any `json:"pk"` + OriginTS uint64 `json:"origin_ts"` + ReplicatedCommitTS uint64 `json:"replicated_commit_ts"` PKStr string `json:"-"` } func (item *DataRedundantItem) String() string { - return fmt.Sprintf("pk: %s, origin ts: %d, commit ts: %d", item.PKStr, item.OriginTS, item.CommitTS) + return fmt.Sprintf("pk: %s, origin ts: %d, replicated commit ts: %d", item.PKStr, item.OriginTS, item.ReplicatedCommitTS) } type LWWViolationItem struct { @@ -139,7 +138,7 @@ func (r *ClusterReport) AddDataLossItem( peerClusterID, schemaKey string, pk map[string]any, pkStr string, - commitTS uint64, + localCommitTS uint64, ) { tableFailureItems, exists := r.TableFailureItems[schemaKey] if !exists { @@ -149,7 +148,7 @@ func (r *ClusterReport) AddDataLossItem( tableFailureItems.DataLossItems = append(tableFailureItems.DataLossItems, DataLossItem{ PeerClusterID: peerClusterID, PK: pk, - CommitTS: commitTS, + LocalCommitTS: localCommitTS, PKStr: pkStr, }) @@ -160,7 +159,7 @@ func (r *ClusterReport) AddDataInconsistentItem( peerClusterID, schemaKey string, pk map[string]any, pkStr string, - originTS, localCommitTS, replicatedCommitTS uint64, + localCommitTS, replicatedCommitTS uint64, inconsistentColumns []InconsistentColumn, ) { tableFailureItems, exists := r.TableFailureItems[schemaKey] @@ -171,7 +170,6 @@ func (r *ClusterReport) AddDataInconsistentItem( tableFailureItems.DataInconsistentItems = append(tableFailureItems.DataInconsistentItems, DataInconsistentItem{ PeerClusterID: peerClusterID, PK: pk, - OriginTS: originTS, LocalCommitTS: localCommitTS, ReplicatedCommitTS: replicatedCommitTS, InconsistentColumns: inconsistentColumns, @@ -185,7 +183,7 @@ func (r *ClusterReport) AddDataRedundantItem( schemaKey string, pk map[string]any, pkStr string, - originTS, commitTS uint64, + originTS, replicatedCommitTS uint64, ) { tableFailureItems, exists := r.TableFailureItems[schemaKey] if !exists { @@ -193,9 +191,9 @@ func (r *ClusterReport) AddDataRedundantItem( r.TableFailureItems[schemaKey] = tableFailureItems } tableFailureItems.DataRedundantItems = append(tableFailureItems.DataRedundantItems, DataRedundantItem{ - PK: pk, - OriginTS: originTS, - CommitTS: commitTS, + PK: pk, + OriginTS: originTS, + ReplicatedCommitTS: replicatedCommitTS, PKStr: pkStr, }) diff --git a/cmd/multi-cluster-consistency-checker/recorder/types_test.go b/cmd/multi-cluster-consistency-checker/recorder/types_test.go index ec2df79d3c..53887ef29d 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/types_test.go +++ b/cmd/multi-cluster-consistency-checker/recorder/types_test.go @@ -26,11 +26,11 @@ func TestDataLossItem_String(t *testing.T) { item := &DataLossItem{ PeerClusterID: "cluster-2", PK: map[string]any{"id": "1"}, - CommitTS: 200, + LocalCommitTS: 200, PKStr: `[id: 1]`, } s := item.String() - require.Equal(t, `peer cluster: cluster-2, pk: [id: 1], commit ts: 200`, s) + require.Equal(t, `peer cluster: cluster-2, pk: [id: 1], local commit ts: 200`, s) } func TestDataInconsistentItem_String(t *testing.T) { @@ -41,13 +41,12 @@ func TestDataInconsistentItem_String(t *testing.T) { item := &DataInconsistentItem{ PeerClusterID: "cluster-3", PK: map[string]any{"id": "2"}, - OriginTS: 300, LocalCommitTS: 400, ReplicatedCommitTS: 410, PKStr: `[id: 2]`, } s := item.String() - require.Equal(t, `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410`, s) + require.Equal(t, `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410`, s) }) t.Run("with inconsistent columns", func(t *testing.T) { @@ -55,7 +54,6 @@ func TestDataInconsistentItem_String(t *testing.T) { item := &DataInconsistentItem{ PeerClusterID: "cluster-3", PK: map[string]any{"id": "2"}, - OriginTS: 300, LocalCommitTS: 400, ReplicatedCommitTS: 410, PKStr: `[id: 2]`, @@ -66,7 +64,7 @@ func TestDataInconsistentItem_String(t *testing.T) { } s := item.String() require.Equal(t, - `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410, `+ + `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410, `+ "inconsistent columns: [column: col1, local: val_a, replicated: val_b; column: col2, local: 100, replicated: 200]", s) }) @@ -76,7 +74,6 @@ func TestDataInconsistentItem_String(t *testing.T) { item := &DataInconsistentItem{ PeerClusterID: "cluster-3", PK: map[string]any{"id": "2"}, - OriginTS: 300, LocalCommitTS: 400, ReplicatedCommitTS: 410, PKStr: `[id: 2]`, @@ -86,7 +83,7 @@ func TestDataInconsistentItem_String(t *testing.T) { } s := item.String() require.Equal(t, - `peer cluster: cluster-3, pk: [id: 2], origin ts: 300, local commit ts: 400, replicated commit ts: 410, `+ + `peer cluster: cluster-3, pk: [id: 2], local commit ts: 400, replicated commit ts: 410, `+ "inconsistent columns: [column: col1, local: val_a, replicated: ]", s) }) @@ -94,9 +91,14 @@ func TestDataInconsistentItem_String(t *testing.T) { func TestDataRedundantItem_String(t *testing.T) { t.Parallel() - item := &DataRedundantItem{PK: map[string]any{"id": "x"}, PKStr: `[id: x]`, OriginTS: 10, CommitTS: 20} + item := &DataRedundantItem{ + PK: map[string]any{"id": "x"}, + PKStr: `[id: x]`, + OriginTS: 10, + ReplicatedCommitTS: 20, + } s := item.String() - require.Equal(t, `pk: [id: x], origin ts: 10, commit ts: 20`, s) + require.Equal(t, `pk: [id: x], origin ts: 10, replicated commit ts: 20`, s) } func TestLWWViolationItem_String(t *testing.T) { @@ -137,7 +139,7 @@ func TestClusterReport(t *testing.T) { require.True(t, cr.needFlush) require.Equal(t, "peer-cluster-1", tableItems.DataLossItems[0].PeerClusterID) require.Equal(t, map[string]any{"id": "1"}, tableItems.DataLossItems[0].PK) - require.Equal(t, uint64(200), tableItems.DataLossItems[0].CommitTS) + require.Equal(t, uint64(200), tableItems.DataLossItems[0].LocalCommitTS) }) t.Run("add data inconsistent item sets needFlush", func(t *testing.T) { @@ -146,7 +148,7 @@ func TestClusterReport(t *testing.T) { cols := []InconsistentColumn{ {Column: "val", Local: "a", Replicated: "b"}, } - cr.AddDataInconsistentItem("peer-cluster-2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 300, 400, 410, cols) + cr.AddDataInconsistentItem("peer-cluster-2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 400, 410, cols) require.Len(t, cr.TableFailureItems, 1) require.Contains(t, cr.TableFailureItems, testSchemaKey) tableItems := cr.TableFailureItems[testSchemaKey] @@ -154,7 +156,6 @@ func TestClusterReport(t *testing.T) { require.True(t, cr.needFlush) require.Equal(t, "peer-cluster-2", tableItems.DataInconsistentItems[0].PeerClusterID) require.Equal(t, map[string]any{"id": "2"}, tableItems.DataInconsistentItems[0].PK) - require.Equal(t, uint64(300), tableItems.DataInconsistentItems[0].OriginTS) require.Equal(t, uint64(400), tableItems.DataInconsistentItems[0].LocalCommitTS) require.Equal(t, uint64(410), tableItems.DataInconsistentItems[0].ReplicatedCommitTS) require.Len(t, tableItems.DataInconsistentItems[0].InconsistentColumns, 1) @@ -191,7 +192,7 @@ func TestClusterReport(t *testing.T) { t.Parallel() cr := NewClusterReport("c1", types.TimeWindow{}) cr.AddDataLossItem("d1", testSchemaKey, map[string]any{"id": "1"}, `id: 1`, 2) - cr.AddDataInconsistentItem("d2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 3, 4, 5, nil) + cr.AddDataInconsistentItem("d2", testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 4, 5, nil) cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "3"}, `[id: 3]`, 5, 6) cr.AddLWWViolationItem(testSchemaKey, map[string]any{"id": "4"}, `[id: 4]`, 7, 8, 9, 10) require.Len(t, cr.TableFailureItems, 1) @@ -285,7 +286,7 @@ func TestReport_MarshalReport(t *testing.T) { "time window: "+twStr+"\n"+ " - [table name: "+testSchemaKey+"]\n"+ " - [data redundant items: 1]\n"+ - ` - [pk: [id: r], origin ts: 10, commit ts: 20]`+"\n\n", + ` - [pk: [id: r], origin ts: 10, replicated commit ts: 20]`+"\n\n", s) }) @@ -328,7 +329,7 @@ func TestReport_MarshalReport(t *testing.T) { r := NewReport(10) cr := NewClusterReport("c1", tw) cr.AddDataLossItem("d0", testSchemaKey, map[string]any{"id": "0"}, `[id: 0]`, 1) - cr.AddDataInconsistentItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 1, 2, 3, []InconsistentColumn{ + cr.AddDataInconsistentItem("d1", testSchemaKey, map[string]any{"id": "1"}, `[id: 1]`, 2, 3, []InconsistentColumn{ {Column: "val", Local: "x", Replicated: "y"}, }) cr.AddDataRedundantItem(testSchemaKey, map[string]any{"id": "2"}, `[id: 2]`, 3, 4) @@ -342,9 +343,9 @@ func TestReport_MarshalReport(t *testing.T) { " - [data loss items: 1]\n"+ ` - [peer cluster: d0, pk: [id: 0], commit ts: 1]`+"\n"+ " - [data inconsistent items: 1]\n"+ - ` - [peer cluster: d1, pk: [id: 1], origin ts: 1, local commit ts: 2, replicated commit ts: 3, inconsistent columns: [column: val, local: x, replicated: y]]`+"\n"+ + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 2, replicated commit ts: 3, inconsistent columns: [column: val, local: x, replicated: y]]`+"\n"+ " - [data redundant items: 1]\n"+ - ` - [pk: [id: 2], origin ts: 3, commit ts: 4]`+"\n"+ + ` - [pk: [id: 2], origin ts: 3, replicated commit ts: 4]`+"\n"+ " - [lww violation items: 1]\n"+ ` - [pk: [id: 3], existing origin ts: 5, existing commit ts: 6, origin ts: 7, commit ts: 8]`+"\n\n", s) From 7e86f696327db58e2568a3130cb81ffa4970accd Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Wed, 25 Feb 2026 20:01:08 +0800 Subject: [PATCH 06/10] add failpoint validation Signed-off-by: Jianjun Liao --- .../integration/validation_test.go | 315 ++++++++++++++++++ .../recorder/types_test.go | 6 +- 2 files changed, 318 insertions(+), 3 deletions(-) create mode 100644 cmd/multi-cluster-consistency-checker/integration/validation_test.go diff --git a/cmd/multi-cluster-consistency-checker/integration/validation_test.go b/cmd/multi-cluster-consistency-checker/integration/validation_test.go new file mode 100644 index 0000000000..97a6ef94d8 --- /dev/null +++ b/cmd/multi-cluster-consistency-checker/integration/validation_test.go @@ -0,0 +1,315 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "testing" + + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/checker" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/stretchr/testify/require" +) + +const ( + envKeyCDCOutput = "ACTIVE_ACTIVE_FAILPOINT_CDC_OUTPUT" + + envKeyCheckerOutput = "ACTIVE_ACTIVE_FAILPOINT_CHECKER_OUTPUT" + + envKeyReportDir = "ACTIVE_ACTIVE_FAILPOINT_REPORT_DIR" +) + +func readRecordJSONL(path string) ([]checker.Record, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var records []checker.Record + scanner := bufio.NewScanner(file) + // JSONL lines may contain large row sets; increase scanner limit to avoid + // "bufio.Scanner: token too long". + scanner.Buffer(make([]byte, 0, 1024*1024), 64*1024*1024) + lineNo := 0 + for scanner.Scan() { + lineNo++ + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var record checker.Record + err := json.Unmarshal([]byte(line), &record) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal jsonl record in %s line %d: %w", path, lineNo, err) + } + records = append(records, record) + } + + err = scanner.Err() + if err != nil { + return nil, fmt.Errorf("failed to scan jsonl file %s: %w", path, err) + } + return records, nil +} + +func getCdcRecords(t *testing.T, cdcOutputPath string) []checker.Record { + records, err := readRecordJSONL(cdcOutputPath) + require.NoError(t, err) + return records +} + +func getCheckerRecords(t *testing.T, checkerOutputPath string) []checker.Record { + records, err := readRecordJSONL(checkerOutputPath) + require.NoError(t, err) + return records +} + +func readReports(t *testing.T, reportDir string) []recorder.Report { + entries, err := os.ReadDir(reportDir) + require.NoError(t, err) + + var reportFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + if strings.HasSuffix(entry.Name(), ".json") { + reportFiles = append(reportFiles, filepath.Join(reportDir, entry.Name())) + } + } + sort.Strings(reportFiles) + + var reports []recorder.Report + for _, path := range reportFiles { + content, err := os.ReadFile(path) + require.NoError(t, err) + + var report recorder.Report + err = json.Unmarshal(content, &report) + require.NoError(t, err) + reports = append(reports, report) + } + return reports +} + +type dataLossKey struct { + PK string + LocalCommitTS uint64 +} + +type dataInconsistentKey struct { + PK string + LocalCommitTS uint64 +} + +type dataRedundantKey struct { + PK string + OriginTS uint64 +} + +type lwwViolationKey struct { + PK string + OriginTS uint64 +} + +func pkMapToKey(pk map[string]any) string { + keys := make([]string, 0, len(pk)) + for key := range pk { + keys = append(keys, key) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, fmt.Sprintf("%s=%s", key, normalizePKValue(pk[key]))) + } + return strings.Join(parts, ",") +} + +func normalizePKValue(v any) string { + switch value := v.(type) { + case string: + if u, err := strconv.ParseUint(value, 10, 64); err == nil { + return strconv.FormatUint(u, 10) + } + if i, err := strconv.ParseInt(value, 10, 64); err == nil { + return strconv.FormatInt(i, 10) + } + return value + case json.Number: + return value.String() + case float64: + return strconv.FormatFloat(value, 'f', -1, 64) + case float32: + return strconv.FormatFloat(float64(value), 'f', -1, 32) + case int: + return strconv.FormatInt(int64(value), 10) + case int8: + return strconv.FormatInt(int64(value), 10) + case int16: + return strconv.FormatInt(int64(value), 10) + case int32: + return strconv.FormatInt(int64(value), 10) + case int64: + return strconv.FormatInt(value, 10) + case uint: + return strconv.FormatUint(uint64(value), 10) + case uint8: + return strconv.FormatUint(uint64(value), 10) + case uint16: + return strconv.FormatUint(uint64(value), 10) + case uint32: + return strconv.FormatUint(uint64(value), 10) + case uint64: + return strconv.FormatUint(value, 10) + default: + return fmt.Sprintf("%v", value) + } +} + +func validate(t *testing.T, + cdcRecords, checkerRecords []checker.Record, + reports []recorder.Report, +) { + dataLossItems := make(map[dataLossKey]struct{}) + dataInconsistentItems := make(map[dataInconsistentKey]struct{}) + dataRedundantItems := make(map[dataRedundantKey]struct{}) + lwwViolationItems := make(map[lwwViolationKey]struct{}) + for _, report := range reports { + for _, clusterReport := range report.ClusterReports { + for _, tableFailureItems := range clusterReport.TableFailureItems { + for _, dataLossItem := range tableFailureItems.DataLossItems { + key := dataLossKey{ + PK: pkMapToKey(dataLossItem.PK), + LocalCommitTS: dataLossItem.LocalCommitTS, + } + dataLossItems[key] = struct{}{} + } + for _, dataInconsistentItem := range tableFailureItems.DataInconsistentItems { + key := dataInconsistentKey{ + PK: pkMapToKey(dataInconsistentItem.PK), + LocalCommitTS: dataInconsistentItem.LocalCommitTS, + } + dataInconsistentItems[key] = struct{}{} + } + for _, dataRedundantItem := range tableFailureItems.DataRedundantItems { + key := dataRedundantKey{ + PK: pkMapToKey(dataRedundantItem.PK), + OriginTS: dataRedundantItem.OriginTS, + } + dataRedundantItems[key] = struct{}{} + } + for _, lwwViolationItem := range tableFailureItems.LWWViolationItems { + key := lwwViolationKey{ + PK: pkMapToKey(lwwViolationItem.PK), + OriginTS: lwwViolationItem.OriginTS, + } + lwwViolationItems[key] = struct{}{} + } + } + } + } + skippedRecords := make(map[dataLossKey]struct{}) + for _, record := range checkerRecords { + for _, row := range record.Rows { + skippedRecords[dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs}] = struct{}{} + } + } + for _, record := range cdcRecords { + for _, row := range record.Rows { + switch record.Failpoint { + case "cloudStorageSinkDropMessage": + if row.OriginTs > 0 { + key := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.OriginTs} + if _, ok := skippedRecords[key]; !ok { + _, ok := dataLossItems[key] + require.True(t, ok) + delete(dataLossItems, key) + } + } else { + // the replicated record maybe skipped by LWW + key := dataRedundantKey{PK: pkMapToKey(row.PrimaryKeys), OriginTS: row.CommitTs} + _, ok := dataRedundantItems[key] + if !ok { + t.Log("replicated record maybe skipped by LWW", key) + } + delete(dataRedundantItems, key) + } + case "cloudStorageSinkMutateValue": + keyLoss := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs} + if _, skipped := skippedRecords[keyLoss]; !skipped { + key := dataInconsistentKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.CommitTs} + _, ok := dataInconsistentItems[key] + require.True(t, ok) + delete(dataInconsistentItems, key) + } + case "cloudStorageSinkMutateValueTidbOriginTs": + keyLoss := dataLossKey{PK: pkMapToKey(row.PrimaryKeys), LocalCommitTS: row.OriginTs} + if _, ok := skippedRecords[keyLoss]; !ok { + _, ok := dataLossItems[keyLoss] + require.True(t, ok) + delete(dataLossItems, keyLoss) + } + + keyRedundant := dataRedundantKey{PK: pkMapToKey(row.PrimaryKeys), OriginTS: row.OriginTs + 1} + _, ok := dataRedundantItems[keyRedundant] + require.True(t, ok) + delete(dataRedundantItems, keyRedundant) + } + } + } + require.Empty(t, dataLossItems) + require.Empty(t, dataInconsistentItems) + require.Empty(t, dataRedundantItems) + require.Empty(t, lwwViolationItems) + require.Fail(t, "success") +} + +func TestValidation(t *testing.T) { + var ( + cdcOutputPath string + checkerOutputPath string + reportDir string + ) + + cdcOutputPath = os.Getenv(envKeyCDCOutput) + if cdcOutputPath == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_CDC_OUTPUT is not set") + return + } + checkerOutputPath = os.Getenv(envKeyCheckerOutput) + if checkerOutputPath == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_CHECKER_OUTPUT is not set") + return + } + reportDir = os.Getenv(envKeyReportDir) + if reportDir == "" { + t.Log("skipped because ACTIVE_ACTIVE_FAILPOINT_REPORT_DIR is not set") + return + } + + cdcRecords := getCdcRecords(t, cdcOutputPath) + checkerRecords := getCheckerRecords(t, checkerOutputPath) + reports := readReports(t, reportDir) + + validate(t, cdcRecords, checkerRecords, reports) +} diff --git a/cmd/multi-cluster-consistency-checker/recorder/types_test.go b/cmd/multi-cluster-consistency-checker/recorder/types_test.go index 53887ef29d..02c23908d6 100644 --- a/cmd/multi-cluster-consistency-checker/recorder/types_test.go +++ b/cmd/multi-cluster-consistency-checker/recorder/types_test.go @@ -270,7 +270,7 @@ func TestReport_MarshalReport(t *testing.T) { "time window: "+twStr+"\n"+ " - [table name: "+testSchemaKey+"]\n"+ " - [data loss items: 1]\n"+ - ` - [peer cluster: d1, pk: [id: 1], commit ts: 200]`+"\n\n", + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 200]`+"\n\n", s) }) @@ -320,7 +320,7 @@ func TestReport_MarshalReport(t *testing.T) { "time window: "+twStr+"\n"+ " - [table name: "+testSchemaKey+"]\n"+ " - [data loss items: 1]\n"+ - ` - [peer cluster: d1, pk: [id: 1], commit ts: 2]`+"\n\n", + ` - [peer cluster: d1, pk: [id: 1], local commit ts: 2]`+"\n\n", s) }) @@ -341,7 +341,7 @@ func TestReport_MarshalReport(t *testing.T) { "time window: "+twStr+"\n"+ " - [table name: "+testSchemaKey+"]\n"+ " - [data loss items: 1]\n"+ - ` - [peer cluster: d0, pk: [id: 0], commit ts: 1]`+"\n"+ + ` - [peer cluster: d0, pk: [id: 0], local commit ts: 1]`+"\n"+ " - [data inconsistent items: 1]\n"+ ` - [peer cluster: d1, pk: [id: 1], local commit ts: 2, replicated commit ts: 3, inconsistent columns: [column: val, local: x, replicated: y]]`+"\n"+ " - [data redundant items: 1]\n"+ From 3ec31a77d0f04c55da0727647761d89d990443d9 Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Sat, 28 Feb 2026 10:47:43 +0800 Subject: [PATCH 07/10] update Signed-off-by: Jianjun Liao --- .../advancer/time_window_advancer.go | 18 ++- .../advancer/time_window_advancer_test.go | 32 ++++ .../checker/checker.go | 115 +++++++++----- .../checker/checker_test.go | 148 +++++++++++++++--- .../config/config.go | 7 +- .../config/config_test.go | 55 ++++++- .../decoder/decoder.go | 19 ++- .../decoder/decoder_test.go | 17 ++ 8 files changed, 347 insertions(+), 64 deletions(-) diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go index afa93bec49..f243403761 100644 --- a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer.go @@ -86,12 +86,24 @@ func (t *TimeWindowAdvancer) initializeFromCheckpoint( t.round = checkpoint.CheckpointItems[2].Round + 1 for clusterID := range t.timeWindowTriplet { newTimeWindows := [3]types.TimeWindow{} - newTimeWindows[2] = checkpoint.CheckpointItems[2].ClusterInfo[clusterID].TimeWindow + clusterInfo, exists := checkpoint.CheckpointItems[2].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[2]", clusterID) + } + newTimeWindows[2] = clusterInfo.TimeWindow if checkpoint.CheckpointItems[1] != nil { - newTimeWindows[1] = checkpoint.CheckpointItems[1].ClusterInfo[clusterID].TimeWindow + clusterInfo, exists = checkpoint.CheckpointItems[1].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[1]", clusterID) + } + newTimeWindows[1] = clusterInfo.TimeWindow } if checkpoint.CheckpointItems[0] != nil { - newTimeWindows[0] = checkpoint.CheckpointItems[0].ClusterInfo[clusterID].TimeWindow + clusterInfo, exists = checkpoint.CheckpointItems[0].ClusterInfo[clusterID] + if !exists { + return nil, errors.Errorf("cluster %s not found in checkpoint item[0]", clusterID) + } + newTimeWindows[0] = clusterInfo.TimeWindow } t.timeWindowTriplet[clusterID] = newTimeWindows } diff --git a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go index d327f85a7d..c0d5a68429 100644 --- a/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go +++ b/cmd/multi-cluster-consistency-checker/advancer/time_window_advancer_test.go @@ -20,6 +20,8 @@ import ( "sync/atomic" "testing" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" + "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/watcher" "github.com/pingcap/tidb/br/pkg/storage" "github.com/stretchr/testify/require" @@ -91,6 +93,36 @@ func TestNewTimeWindowAdvancer(t *testing.T) { require.Contains(t, advancer.timeWindowTriplet, "cluster2") } +func TestNewTimeWindowAdvancerInitializeFromCheckpointMissingClusterInfo(t *testing.T) { + t.Parallel() + checkpointWatchers := map[string]map[string]watcher.Watcher{ + "cluster1": {}, + "cluster2": {}, + } + s3Watchers := map[string]*watcher.S3Watcher{ + "cluster1": watcher.NewS3Watcher(&mockAdvancerWatcher{delta: 1}, storage.NewMemStorage(), nil), + "cluster2": watcher.NewS3Watcher(&mockAdvancerWatcher{delta: 1}, storage.NewMemStorage(), nil), + } + pdClients := map[string]pd.Client{ + "cluster1": &mockPDClient{}, + "cluster2": &mockPDClient{}, + } + + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "cluster1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + + _, _, err := NewTimeWindowAdvancer(context.Background(), checkpointWatchers, s3Watchers, pdClients, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster cluster2 not found in checkpoint item[2]") +} + // TestTimeWindowAdvancer_AdvanceMultipleRounds simulates 4 rounds of AdvanceTimeWindow // with 2 clusters (c1, c2) performing bidirectional replication. // diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go index b36e415aea..7e77dd1258 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -124,37 +124,55 @@ func (c *clusterViolationChecker) UpdateCache() { c.twoPreviousTimeWindowKeyVersionCache = newTwoPreviousTimeWindowKeyVersionCache } +type RecordsMapWithMaxCompareTs struct { + records map[uint64]*decoder.Record + maxCompareTs uint64 +} + type tableDataCache struct { // localDataCache is a map of primary key to a map of commit ts to a record - localDataCache map[types.PkType]map[uint64]*decoder.Record + localDataCache map[types.PkType]*RecordsMapWithMaxCompareTs // replicatedDataCache is a map of primary key to a map of origin ts to a record - replicatedDataCache map[types.PkType]map[uint64]*decoder.Record + replicatedDataCache map[types.PkType]*RecordsMapWithMaxCompareTs } func newTableDataCache() *tableDataCache { return &tableDataCache{ - localDataCache: make(map[types.PkType]map[uint64]*decoder.Record), - replicatedDataCache: make(map[types.PkType]map[uint64]*decoder.Record), + localDataCache: make(map[types.PkType]*RecordsMapWithMaxCompareTs), + replicatedDataCache: make(map[types.PkType]*RecordsMapWithMaxCompareTs), } } func (tdc *tableDataCache) newLocalRecord(record *decoder.Record) { recordsMap, exists := tdc.localDataCache[record.Pk] if !exists { - recordsMap = make(map[uint64]*decoder.Record) + recordsMap = &RecordsMapWithMaxCompareTs{ + records: make(map[uint64]*decoder.Record), + maxCompareTs: 0, + } tdc.localDataCache[record.Pk] = recordsMap } - recordsMap[record.CommitTs] = record + recordsMap.records[record.CommitTs] = record + if record.CommitTs > recordsMap.maxCompareTs { + recordsMap.maxCompareTs = record.CommitTs + } } func (tdc *tableDataCache) newReplicatedRecord(record *decoder.Record) { recordsMap, exists := tdc.replicatedDataCache[record.Pk] if !exists { - recordsMap = make(map[uint64]*decoder.Record) + recordsMap = &RecordsMapWithMaxCompareTs{ + records: make(map[uint64]*decoder.Record), + maxCompareTs: 0, + } tdc.replicatedDataCache[record.Pk] = recordsMap } - recordsMap[record.OriginTs] = record + recordsMap.records[record.OriginTs] = record + compareTs := record.GetCompareTs() + if compareTs > recordsMap.maxCompareTs { + recordsMap.maxCompareTs = compareTs + } } type timeWindowDataCache struct { @@ -193,6 +211,8 @@ func (twdc *timeWindowDataCache) NewRecord(schemaKey string, record *decoder.Rec type clusterDataChecker struct { clusterID string + // true if more than 2 clusters are involved in the check + multiCluster bool thisRoundTimeWindow types.TimeWindow @@ -211,9 +231,10 @@ type clusterDataChecker struct { newTimeWindowRecordsCount int } -func newClusterDataChecker(clusterID string) *clusterDataChecker { +func newClusterDataChecker(clusterID string, multiCluster bool) *clusterDataChecker { return &clusterDataChecker{ clusterID: clusterID, + multiCluster: multiCluster, timeWindowDataCaches: [3]timeWindowDataCache{}, rightBoundary: 0, overDataCaches: make(map[string][]*decoder.Record), @@ -232,12 +253,18 @@ func (cd *clusterDataChecker) InitializeFromCheckpoint( if checkpoint.CheckpointItems[2] == nil { return nil } - clusterInfo := checkpoint.CheckpointItems[2].ClusterInfo[cd.clusterID] + clusterInfo, exists := checkpoint.CheckpointItems[2].ClusterInfo[cd.clusterID] + if !exists { + return errors.Errorf("cluster %s not found in checkpoint item[2]", cd.clusterID) + } cd.rightBoundary = clusterInfo.TimeWindow.RightBoundary cd.timeWindowDataCaches[2] = newTimeWindowDataCache( clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) if checkpoint.CheckpointItems[1] != nil { - clusterInfo = checkpoint.CheckpointItems[1].ClusterInfo[cd.clusterID] + clusterInfo, exists = checkpoint.CheckpointItems[1].ClusterInfo[cd.clusterID] + if !exists { + return errors.Errorf("cluster %s not found in checkpoint item[1]", cd.clusterID) + } cd.timeWindowDataCaches[1] = newTimeWindowDataCache( clusterInfo.TimeWindow.LeftBoundary, clusterInfo.TimeWindow.RightBoundary, clusterInfo.TimeWindow.CheckpointTs) } @@ -314,24 +341,20 @@ func (cd *clusterDataChecker) findClusterReplicatedDataInTimeWindow(timeWindowId if !exists { return nil, false } - if records, exists := tableDataCache.replicatedDataCache[pk]; exists { - if record, exists := records[originTs]; exists { + if recordsMap, exists := tableDataCache.replicatedDataCache[pk]; exists { + if record, exists := recordsMap.records[originTs]; exists { return record, false } - for _, record := range records { - if record.GetCompareTs() >= originTs { - return nil, true - } + if recordsMap.maxCompareTs > originTs { + return nil, true } } // If no replicated record is found, a newer/equal local record with the // same PK in the peer cluster also indicates this origin write can be // considered overwritten by LWW semantics instead of hard data loss. - if records, exists := tableDataCache.localDataCache[pk]; exists { - for _, record := range records { - if record.GetCompareTs() >= originTs { - return nil, true - } + if recordsMap, exists := tableDataCache.localDataCache[pk]; exists { + if recordsMap.maxCompareTs > originTs { + return nil, true } } return nil, false @@ -342,11 +365,11 @@ func (cd *clusterDataChecker) findClusterLocalDataInTimeWindow(timeWindowIdx int if !exists { return false } - records, exists := tableDataCache.localDataCache[pk] + recordsMap, exists := tableDataCache.localDataCache[pk] if !exists { return false } - _, exists = records[commitTs] + _, exists = recordsMap.records[commitTs] return exists } @@ -410,13 +433,13 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( ) { for schemaKey, tableDataCache := range cd.timeWindowDataCaches[timeWindowIdx].tableDataCaches { for _, localDataCache := range tableDataCache.localDataCache { - for _, record := range localDataCache { + for _, record := range localDataCache.records { for replicatedClusterID, checkpointTs := range cd.timeWindowDataCaches[timeWindowIdx].checkpointTs { if shouldSkip(record.CommitTs, checkpointTs) { continue } cd.checkedRecordsCount++ - replicatedRecord, skipped := checker.FindClusterReplicatedData(replicatedClusterID, schemaKey, record.Pk, record.CommitTs) + replicatedRecord, skipped := checker.FindClusterReplicatedData(replicatedClusterID, schemaKey, record.Pk, record.CommitTs, cd.multiCluster) if skipped { failpoint.Inject("multiClusterConsistencyCheckerLWWViolation", func() { Write("multiClusterConsistencyCheckerLWWViolation", []RowRecord{ @@ -461,7 +484,7 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( func (cd *clusterDataChecker) dataRedundantDetection(checker *DataChecker) { for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { for _, replicatedDataCache := range tableDataCache.replicatedDataCache { - for _, record := range replicatedDataCache { + for _, record := range replicatedDataCache.records { cd.checkedRecordsCount++ // For replicated records, OriginTs is the local commit ts if !checker.FindSourceLocalData(cd.clusterID, schemaKey, record.Pk, record.OriginTs) { @@ -481,12 +504,18 @@ func (cd *clusterDataChecker) lwwViolationDetection() { for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { for pk, localRecords := range tableDataCache.localDataCache { replicatedRecords := tableDataCache.replicatedDataCache[pk] - pkRecords := make([]*decoder.Record, 0, len(localRecords)+len(replicatedRecords)) - for _, localRecord := range localRecords { + replicatedCount := 0 + if replicatedRecords != nil { + replicatedCount = len(replicatedRecords.records) + } + pkRecords := make([]*decoder.Record, 0, len(localRecords.records)+replicatedCount) + for _, localRecord := range localRecords.records { pkRecords = append(pkRecords, localRecord) } - for _, replicatedRecord := range replicatedRecords { - pkRecords = append(pkRecords, replicatedRecord) + if replicatedRecords != nil { + for _, replicatedRecord := range replicatedRecords.records { + pkRecords = append(pkRecords, replicatedRecord) + } } sort.Slice(pkRecords, func(i, j int) bool { return pkRecords[i].CommitTs < pkRecords[j].CommitTs @@ -501,8 +530,8 @@ func (cd *clusterDataChecker) lwwViolationDetection() { if _, exists := tableDataCache.localDataCache[pk]; exists { continue } - pkRecords := make([]*decoder.Record, 0, len(replicatedRecords)) - for _, replicatedRecord := range replicatedRecords { + pkRecords := make([]*decoder.Record, 0, len(replicatedRecords.records)) + for _, replicatedRecord := range replicatedRecords.records { pkRecords = append(pkRecords, replicatedRecord) } sort.Slice(pkRecords, func(i, j int) bool { @@ -550,7 +579,7 @@ type DataChecker struct { func NewDataChecker(ctx context.Context, clusterConfig map[string]config.ClusterConfig, checkpointDataMap map[string]map[cloudstorage.DmlPathKey]types.IncrementalData, checkpoint *recorder.Checkpoint) (*DataChecker, error) { clusterDataChecker := make(map[string]*clusterDataChecker) for clusterID := range clusterConfig { - clusterDataChecker[clusterID] = newClusterDataChecker(clusterID) + clusterDataChecker[clusterID] = newClusterDataChecker(clusterID, len(clusterConfig) > 2) } checker := &DataChecker{ round: 0, @@ -582,11 +611,17 @@ func (c *DataChecker) initializeFromCheckpoint(ctx context.Context, checkpointDa // FindClusterReplicatedData checks whether the record is present in the replicated data // cache [1] or [2] or another new record is present in the replicated data cache [1] or [2]. -func (c *DataChecker) FindClusterReplicatedData(clusterID string, schemaKey string, pk types.PkType, originTs uint64) (*decoder.Record, bool) { +func (c *DataChecker) FindClusterReplicatedData(clusterID string, schemaKey string, pk types.PkType, originTs uint64, multiCluster bool) (*decoder.Record, bool) { clusterDataChecker, exists := c.clusterDataCheckers[clusterID] if !exists { return nil, false } + if multiCluster { + record, skipped := clusterDataChecker.findClusterReplicatedDataInTimeWindow(0, schemaKey, pk, originTs) + if skipped || record != nil { + return record, skipped + } + } record, skipped := clusterDataChecker.findClusterReplicatedDataInTimeWindow(1, schemaKey, pk, originTs) if skipped || record != nil { return record, skipped @@ -623,7 +658,7 @@ func (c *DataChecker) CheckInNextTimeWindow(newTimeWindowData map[string]types.T // Round 1+: LWW violation detection produces meaningful results. // Round 2+: data loss / inconsistency detection (needs [1] and [2]). // Round 3+: data redundant detection (needs [0], [1] and [2] with real data). - enableDataLoss := c.checkableRound >= 2 + enableDataLoss := c.checkableRound >= 3 || (len(c.clusterDataCheckers) > 2 && c.checkableRound >= 2) enableDataRedundant := c.checkableRound >= 3 for clusterID, clusterDataChecker := range c.clusterDataCheckers { @@ -655,6 +690,14 @@ func (c *DataChecker) decodeNewTimeWindowData(newTimeWindowData map[string]types if !exists { return errors.Errorf("cluster %s not found", clusterID) } + for replicatedClusterID := range timeWindowData.TimeWindow.CheckpointTs { + if replicatedClusterID == clusterID { + return errors.Errorf("cluster %s has invalid checkpoint ts target to itself", clusterID) + } + if _, ok := c.clusterDataCheckers[replicatedClusterID]; !ok { + return errors.Errorf("cluster %s has checkpoint ts target %s not found", clusterID, replicatedClusterID) + } + } clusterDataChecker.thisRoundTimeWindow = timeWindowData.TimeWindow if err := clusterDataChecker.PrepareNextTimeWindowData(timeWindowData.TimeWindow); err != nil { return errors.Trace(err) diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go index f1cc67aa49..df52d191db 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker_test.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -86,12 +86,72 @@ func TestNewDataCheckerInitializeFromCheckpointError(t *testing.T) { require.Contains(t, err.Error(), "column value of column id not found") } +func TestNewDataCheckerInitializeFromCheckpointMissingClusterInfo(t *testing.T) { + t.Parallel() + + t.Run("cluster missing in checkpoint item[2]", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + "c2": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + + _, err := NewDataChecker(context.Background(), clusterConfig, nil, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c2 not found in checkpoint item[2]") + }) + + t.Run("cluster missing in checkpoint item[1]", func(t *testing.T) { + t.Parallel() + clusterConfig := map[string]config.ClusterConfig{ + "c1": {}, + "c2": {}, + } + checkpoint := recorder.NewCheckpoint() + checkpoint.NewTimeWindowData(0, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 0, + RightBoundary: 100, + }, + }, + }) + checkpoint.NewTimeWindowData(1, map[string]types.TimeWindowData{ + "c1": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + }, + }, + "c2": { + TimeWindow: types.TimeWindow{ + LeftBoundary: 100, + RightBoundary: 200, + }, + }, + }) + + _, err := NewDataChecker(context.Background(), clusterConfig, nil, checkpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c2 not found in checkpoint item[1]") + }) +} + func TestNewClusterDataChecker(t *testing.T) { t.Parallel() t.Run("create cluster data checker", func(t *testing.T) { t.Parallel() - checker := newClusterDataChecker("cluster1") + checker := newClusterDataChecker("cluster1", false) require.NotNil(t, checker) require.Equal(t, "cluster1", checker.clusterID) require.Equal(t, uint64(0), checker.rightBoundary) @@ -279,7 +339,7 @@ func TestTimeWindowDataCache_NewRecord(t *testing.T) { cache.NewRecord(schemaKey, record) require.Contains(t, cache.tableDataCaches, schemaKey) require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache, record.Pk) - require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache[record.Pk], record.CommitTs) + require.Contains(t, cache.tableDataCaches[schemaKey].localDataCache[record.Pk].records, record.CommitTs) }) t.Run("add replicated record", func(t *testing.T) { @@ -297,7 +357,7 @@ func TestTimeWindowDataCache_NewRecord(t *testing.T) { cache.NewRecord(schemaKey, record) require.Contains(t, cache.tableDataCaches, schemaKey) require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache, record.Pk) - require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache[record.Pk], record.OriginTs) + require.Contains(t, cache.tableDataCaches[schemaKey].replicatedDataCache[record.Pk].records, record.OriginTs) }) t.Run("skip record before left boundary", func(t *testing.T) { @@ -322,7 +382,7 @@ func TestClusterDataChecker_PrepareNextTimeWindowData(t *testing.T) { t.Run("prepare next time window data", func(t *testing.T) { t.Parallel() - checker := newClusterDataChecker("cluster1") + checker := newClusterDataChecker("cluster1", false) checker.rightBoundary = 100 timeWindow := types.TimeWindow{ @@ -338,7 +398,7 @@ func TestClusterDataChecker_PrepareNextTimeWindowData(t *testing.T) { t.Run("mismatch left boundary", func(t *testing.T) { t.Parallel() - checker := newClusterDataChecker("cluster1") + checker := newClusterDataChecker("cluster1", false) checker.rightBoundary = 100 timeWindow := types.TimeWindow{ @@ -733,10 +793,11 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { require.Equal(t, uint64(180), c1TableItems.LWWViolationItems[0].CommitTS) }) - // data loss detected at round 2: Data loss detection is active from round 2. - // A record in round 1 whose commitTs > checkpointTs will enter [1] at round 2, - // and if the replicated counterpart is missing, data loss is detected at round 2. - t.Run("data loss detected at round 2", func(t *testing.T) { + // data loss detected at round 3 for 2-cluster mode: + // enableDataLoss is active from round 3 when only 2 clusters are involved. + // A record in round 2 whose commitTs > checkpointTs enters [1] at round 3, + // and if the replicated counterpart is missing, data loss is detected at round 3. + t.Run("data loss detected at round 3", func(t *testing.T) { t.Parallel() checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) require.NoError(t, initErr) @@ -745,19 +806,27 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { "c1": makeTWData(0, 100, nil, nil), "c2": makeTWData(0, 100, nil, nil), } - // Round 1: c1 writes pk=1 (commitTs=150), checkpointTs["c2"]=140 - // Since 150 > 140, this record needs replication checking. + // Round 1: consistent data. round1 := map[string]types.TimeWindowData{ - "c1": makeTWData(100, 200, map[string]uint64{"c2": 140}, + "c1": makeTWData(100, 200, map[string]uint64{"c2": 180}, makeContent(makeCanalJSON(1, 150, 0, "a"))), - "c2": makeTWData(100, 200, nil, nil), // c2 has NO replicated data + "c2": makeTWData(100, 200, nil, + makeContent(makeCanalJSON(1, 160, 150, "a"))), } - // Round 2: round 1 data is now in [1], data loss detection enabled. + // Round 2: c1 writes pk=2 (commitTs=250), checkpointTs["c2"]=240. + // Since 250 > 240, this record requires replication checking. + // c2 has no matching replicated data in this round. round2 := map[string]types.TimeWindowData{ - "c1": makeTWData(200, 300, map[string]uint64{"c2": 280}, + "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, makeContent(makeCanalJSON(2, 250, 0, "b"))), - "c2": makeTWData(200, 300, nil, - makeContent(makeCanalJSON(2, 260, 250, "b"))), + "c2": makeTWData(200, 300, nil, nil), + } + // Round 3: data loss detection is enabled in 2-cluster mode. + round3 := map[string]types.TimeWindowData{ + "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, + makeContent(makeCanalJSON(3, 350, 0, "c"))), + "c2": makeTWData(300, 400, nil, + makeContent(makeCanalJSON(3, 360, 350, "c"))), } report0, err := checker.CheckInNextTimeWindow(round0) @@ -770,13 +839,17 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { report2, err := checker.CheckInNextTimeWindow(round2) require.NoError(t, err) - require.True(t, report2.NeedFlush(), "round 2 should detect data loss") - c1Report := report2.ClusterReports["c1"] + require.False(t, report2.NeedFlush(), "round 2 should not detect data loss yet") + + report3, err := checker.CheckInNextTimeWindow(round3) + require.NoError(t, err) + require.True(t, report3.NeedFlush(), "round 3 should detect data loss") + c1Report := report3.ClusterReports["c1"] require.Contains(t, c1Report.TableFailureItems, defaultSchemaKey) tableItems := c1Report.TableFailureItems[defaultSchemaKey] require.Len(t, tableItems.DataLossItems, 1) require.Equal(t, "c2", tableItems.DataLossItems[0].PeerClusterID) - require.Equal(t, uint64(150), tableItems.DataLossItems[0].LocalCommitTS) + require.Equal(t, uint64(250), tableItems.DataLossItems[0].LocalCommitTS) }) // data redundant detected at round 3 (not round 2): @@ -854,3 +927,38 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { require.Equal(t, uint64(340), c2TableItems.DataRedundantItems[0].ReplicatedCommitTS) }) } + +func TestDataChecker_CheckInNextTimeWindowInvalidCheckpointTarget(t *testing.T) { + t.Parallel() + ctx := context.Background() + + clusterCfg := map[string]config.ClusterConfig{"c1": {}, "c2": {}} + + t.Run("unknown target cluster", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, map[string]uint64{"c3": 80}, nil), + "c2": makeTWData(0, 100, nil, nil), + } + _, err := checker.CheckInNextTimeWindow(round0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c1 has checkpoint ts target c3 not found") + }) + + t.Run("self target cluster", func(t *testing.T) { + t.Parallel() + checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) + require.NoError(t, initErr) + + round0 := map[string]types.TimeWindowData{ + "c1": makeTWData(0, 100, map[string]uint64{"c1": 80}, nil), + "c2": makeTWData(0, 100, nil, nil), + } + _, err := checker.CheckInNextTimeWindow(round0) + require.Error(t, err) + require.Contains(t, err.Error(), "cluster c1 has invalid checkpoint ts target to itself") + }) +} diff --git a/cmd/multi-cluster-consistency-checker/config/config.go b/cmd/multi-cluster-consistency-checker/config/config.go index 86fac59af3..7875c4a6af 100644 --- a/cmd/multi-cluster-consistency-checker/config/config.go +++ b/cmd/multi-cluster-consistency-checker/config/config.go @@ -100,9 +100,10 @@ func LoadConfig(path string) (*Config, error) { } } - // Validate that at least one cluster is configured - if len(cfg.Clusters) == 0 { - return nil, fmt.Errorf("at least one cluster must be configured") + // Validate that at least two clusters are configured. + // Single-cluster mode cannot provide cross-cluster consistency checks. + if len(cfg.Clusters) < 2 { + return nil, fmt.Errorf("at least two clusters must be configured") } // Validate cluster configurations diff --git a/cmd/multi-cluster-consistency-checker/config/config_test.go b/cmd/multi-cluster-consistency-checker/config/config_test.go index a0027efb5f..8a1e03c9b1 100644 --- a/cmd/multi-cluster-consistency-checker/config/config_test.go +++ b/cmd/multi-cluster-consistency-checker/config/config_test.go @@ -245,7 +245,33 @@ data-dir = "/tmp/data" cfg, err := LoadConfig(configPath) require.Error(t, err) require.Nil(t, cfg) - require.Contains(t, err.Error(), "at least one cluster must be configured") + require.Contains(t, err.Error(), "at least two clusters must be configured") + }) + + t.Run("single cluster is invalid", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.toml") + configContent := ` +[global] +log-level = "info" +data-dir = "/tmp/data" + [global.tables] + schema1 = ["table1"] + +[clusters] + [clusters.cluster1] + pd-addrs = ["127.0.0.1:2379"] + s3-sink-uri = "s3://bucket/cluster1/" + s3-changefeed-id = "s3-cf-1" +` + err := os.WriteFile(configPath, []byte(configContent), 0o644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.Error(t, err) + require.Nil(t, cfg) + require.Contains(t, err.Error(), "at least two clusters must be configured") }) t.Run("missing pd-addrs", func(t *testing.T) { @@ -263,6 +289,15 @@ data-dir = "/tmp/data" [clusters.cluster1] s3-sink-uri = "s3://bucket/cluster1/" s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } ` err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) @@ -288,6 +323,15 @@ data-dir = "/tmp/data" [clusters.cluster1] pd-addrs = ["127.0.0.1:2379"] s3-changefeed-id = "s3-cf-1" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } ` err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) @@ -313,6 +357,15 @@ data-dir = "/tmp/data" [clusters.cluster1] pd-addrs = ["127.0.0.1:2379"] s3-sink-uri = "s3://bucket/cluster1/" + [clusters.cluster1.peer-cluster-changefeed-config] + cluster2 = { changefeed-id = "cf-1-to-2" } + + [clusters.cluster2] + pd-addrs = ["127.0.0.1:2479"] + s3-sink-uri = "s3://bucket/cluster2/" + s3-changefeed-id = "s3-cf-2" + [clusters.cluster2.peer-cluster-changefeed-config] + cluster1 = { changefeed-id = "cf-2-to-1" } ` err := os.WriteFile(configPath, []byte(configContent), 0o644) require.NoError(t, err) diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder.go b/cmd/multi-cluster-consistency-checker/decoder/decoder.go index d03c86379f..a35ebeeebe 100644 --- a/cmd/multi-cluster-consistency-checker/decoder/decoder.go +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder.go @@ -224,7 +224,14 @@ func (d *columnValueDecoder) decodeNext() (*Record, error) { log.Error("field type not found", zap.String("pkName", pkName), zap.Any("msg", d.msg)) return nil, errors.Errorf("field type of column %s not found", pkName) } - datum := valueToDatum(columnValue, ft) + datum, err := safeValueToDatum(columnValue, ft) + if err != nil { + log.Error("failed to convert primary key column value", + zap.String("pkName", pkName), + zap.Any("columnValue", columnValue), + zap.Error(err)) + return nil, errors.Annotatef(err, "failed to convert primary key column %s", pkName) + } if datum.IsNull() { log.Error("column value is null", zap.String("pkName", pkName), zap.Any("msg", d.msg)) return nil, errors.Errorf("column value of column %s is null", pkName) @@ -266,6 +273,16 @@ func (d *columnValueDecoder) decodeNext() (*Record, error) { }, nil } +func safeValueToDatum(value any, ft *ptypes.FieldType) (datum *tiTypes.Datum, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.Errorf("value to datum conversion panic: %v", r) + datum = nil + } + }() + return valueToDatum(value, ft), nil +} + // getColumnFieldType returns the FieldType for a column. // It first looks up from the tableDefinition-based columnFieldTypes map, // then falls back to parsing the MySQLType string from the canal-json message. diff --git a/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go index 972ab377cd..ba04e4ef3d 100644 --- a/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go +++ b/cmd/multi-cluster-consistency-checker/decoder/decoder_test.go @@ -155,6 +155,23 @@ func TestCanalJSONDecoderAllInvalidMessages(t *testing.T) { require.Empty(t, records) } +func TestCanalJSONDecoderInvalidPrimaryKeyValueNoPanic(t *testing.T) { + // id is intentionally encoded as a JSON number (not string), which would + // trigger a panic inside valueToDatum without recover protection. + invalidPKType := `{"id":0,"database":"test_active","table":"message","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1,"ts":1,"sql":"","sqlType":{"id":4},"mysqlType":{"id":"int"},"old":null,"data":[{"id":20}],"_tidb":{"commitTs":100}}` + "\r\n" + + var ( + records []*decoder.Record + err error + ) + require.NotPanics(t, func() { + records, err = decoder.Decode([]byte(invalidPKType), buildColumnFieldTypes(t, tableDefinition1)) + }) + require.Error(t, err) + require.Empty(t, records) + require.Contains(t, err.Error(), "failed to convert primary key column id") +} + func TestRecord_EqualReplicatedRecord(t *testing.T) { tests := []struct { name string From 1ca5a759fd4c1a7f633ada274458d5dcb63fd747 Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Sat, 28 Feb 2026 16:28:27 +0800 Subject: [PATCH 08/10] fix data redundant detection Signed-off-by: Jianjun Liao --- .../checker/checker.go | 6 +-- .../checker/checker_test.go | 41 +++++++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/cmd/multi-cluster-consistency-checker/checker/checker.go b/cmd/multi-cluster-consistency-checker/checker/checker.go index 7e77dd1258..77972240ac 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker.go @@ -479,10 +479,10 @@ func (cd *clusterDataChecker) checkLocalRecordsForDataLoss( } } -// dataRedundantDetection iterates through the replicated data cache [2]. The record must be present -// in the local data cache [1] [2] or [3]. +// dataRedundantDetection iterates through the replicated data cache [1]. The record must be present +// in the source local data cache across recent windows. func (cd *clusterDataChecker) dataRedundantDetection(checker *DataChecker) { - for schemaKey, tableDataCache := range cd.timeWindowDataCaches[2].tableDataCaches { + for schemaKey, tableDataCache := range cd.timeWindowDataCaches[1].tableDataCaches { for _, replicatedDataCache := range tableDataCache.replicatedDataCache { for _, record := range replicatedDataCache.records { cd.checkedRecordsCount++ diff --git a/cmd/multi-cluster-consistency-checker/checker/checker_test.go b/cmd/multi-cluster-consistency-checker/checker/checker_test.go index df52d191db..0d940e5ee0 100644 --- a/cmd/multi-cluster-consistency-checker/checker/checker_test.go +++ b/cmd/multi-cluster-consistency-checker/checker/checker_test.go @@ -657,18 +657,18 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { "c1": makeTWData(200, 300, map[string]uint64{"c2": 240}, makeContent(makeCanalJSON(3, 250, 0, "c"))), "c2": makeTWData(200, 300, nil, - makeContent(makeCanalJSON(3, 260, 250, "c"))), + makeContent( + makeCanalJSON(3, 260, 250, "c"), + makeCanalJSON(99, 240, 230, "x"), + )), } - // Round 3: c2 has an extra replicated pk=99 (originTs=330) that doesn't match - // any locally-written record in c1 + // Round 3: keep normal data so data redundant detection at round 3 + // checks round2 ([1]) and catches the orphan from round2. round3 := map[string]types.TimeWindowData{ "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, makeContent(makeCanalJSON(4, 350, 0, "d"))), "c2": makeTWData(300, 400, nil, - makeContent( - makeCanalJSON(4, 360, 350, "d"), - makeCanalJSON(99, 340, 330, "x"), - )), + makeContent(makeCanalJSON(4, 360, 350, "d"))), } rounds := [4]map[string]types.TimeWindowData{base[0], base[1], round2, round3} @@ -688,8 +688,8 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) tableItems := c2Report.TableFailureItems[defaultSchemaKey] require.Len(t, tableItems.DataRedundantItems, 1) - require.Equal(t, uint64(330), tableItems.DataRedundantItems[0].OriginTS) - require.Equal(t, uint64(340), tableItems.DataRedundantItems[0].ReplicatedCommitTS) + require.Equal(t, uint64(230), tableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(240), tableItems.DataRedundantItems[0].ReplicatedCommitTS) }) t.Run("lww violation detected", func(t *testing.T) { @@ -853,15 +853,15 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { }) // data redundant detected at round 3 (not round 2): - // dataRedundantDetection checks timeWindowDataCaches[2] (latest round). + // dataRedundantDetection checks timeWindowDataCaches[1]. // At round 2 [0]=round 0 (empty) so FindSourceLocalData may miss data in // that window → enableDataRedundant is false to avoid false positives. // At round 3 [0]=round 1, [1]=round 2, [2]=round 3 are all populated - // with real data, so enableDataRedundant=true and an orphan in [2] is caught. + // with real data, so enableDataRedundant=true and an orphan in [1] is caught. // - // This test puts the SAME orphan pk=99 in both round 2 and round 3: + // This test puts an orphan pk=99 in round 2 only: // - Round 2: orphan in [2] but enableDataRedundant=false → NOT flagged. - // - Round 3: orphan in [2] and enableDataRedundant=true → flagged. + // - Round 3: orphan moved to [1] and enableDataRedundant=true → flagged. t.Run("data redundant detected at round 3 not round 2", func(t *testing.T) { t.Parallel() checker, initErr := NewDataChecker(ctx, clusterCfg, nil, nil) @@ -889,16 +889,13 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { makeCanalJSON(99, 240, 230, "x"), // orphan replicated )), } - // Round 3: c2 has another orphan replicated pk=99 (originTs=330) in [2]. - // enableDataRedundant=true at round 3, so it IS caught. + // Round 3: no new orphan in [2]; enableDataRedundant=true at round 3 should + // catch the orphan that is now in [1] (from round 2). round3 := map[string]types.TimeWindowData{ "c1": makeTWData(300, 400, map[string]uint64{"c2": 380}, makeContent(makeCanalJSON(3, 350, 0, "c"))), "c2": makeTWData(300, 400, nil, - makeContent( - makeCanalJSON(3, 360, 350, "c"), - makeCanalJSON(99, 340, 330, "y"), // orphan replicated - )), + makeContent(makeCanalJSON(3, 360, 350, "c"))), } report0, err := checker.CheckInNextTimeWindow(round0) @@ -916,15 +913,15 @@ func TestDataChecker_FourRoundsCheck(t *testing.T) { report3, err := checker.CheckInNextTimeWindow(round3) require.NoError(t, err) - // Round 3: redundant detection is enabled; the orphan pk=99 in [2] (round 3) + // Round 3: redundant detection is enabled; the orphan pk=99 in [1] (round 2) // is now caught. require.True(t, report3.NeedFlush(), "round 3 should detect data redundant") c2Report := report3.ClusterReports["c2"] require.Contains(t, c2Report.TableFailureItems, defaultSchemaKey) c2TableItems := c2Report.TableFailureItems[defaultSchemaKey] require.Len(t, c2TableItems.DataRedundantItems, 1) - require.Equal(t, uint64(330), c2TableItems.DataRedundantItems[0].OriginTS) - require.Equal(t, uint64(340), c2TableItems.DataRedundantItems[0].ReplicatedCommitTS) + require.Equal(t, uint64(230), c2TableItems.DataRedundantItems[0].OriginTS) + require.Equal(t, uint64(240), c2TableItems.DataRedundantItems[0].ReplicatedCommitTS) }) } From 5a952dd3104b60ad6403aec6942eb88976ceb3fa Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Sun, 1 Mar 2026 00:14:11 +0800 Subject: [PATCH 09/10] limit download concurrency Signed-off-by: Jianjun Liao --- .../consumer/consumer.go | 34 ++++++- .../consumer/consumer_test.go | 98 +++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer.go b/cmd/multi-cluster-consistency-checker/consumer/consumer.go index 2f695f0867..650b08b440 100644 --- a/cmd/multi-cluster-consistency-checker/consumer/consumer.go +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer.go @@ -44,6 +44,8 @@ type indexRange struct { end uint64 } +const defaultGlobalDownloadConcurrencyLimit = 128 + func updateTableDMLIdxMap( tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, dmlkey cloudstorage.DmlPathKey, @@ -223,6 +225,9 @@ type S3Consumer struct { dateSeparator string fileIndexWidth int tables map[string][]string + // downloadLimiter limits the total number of concurrent file downloads + // across all downloadDMLFiles calls on this consumer. + downloadLimiter chan struct{} // skip the first round data download skipDownloadData bool @@ -242,6 +247,10 @@ func NewS3Consumer( dateSeparator: config.DateSeparatorDay.String(), fileIndexWidth: config.DefaultFileIndexWidth, tables: tables, + downloadLimiter: make( + chan struct{}, + defaultGlobalDownloadConcurrencyLimit, + ), skipDownloadData: true, @@ -251,6 +260,25 @@ func NewS3Consumer( } } +func (c *S3Consumer) acquireDownloadSlot(ctx context.Context) error { + if c.downloadLimiter == nil { + return nil + } + select { + case c.downloadLimiter <- struct{}{}: + return nil + case <-ctx.Done(): + return errors.Trace(ctx.Err()) + } +} + +func (c *S3Consumer) releaseDownloadSlot() { + if c.downloadLimiter == nil { + return + } + <-c.downloadLimiter +} + func (c *S3Consumer) InitializeFromCheckpoint( ctx context.Context, clusterID string, checkpoint *recorder.Checkpoint, ) (map[cloudstorage.DmlPathKey]types.IncrementalData, error) { @@ -626,9 +654,13 @@ func (c *S3Consumer) downloadDMLFiles( fileContents := make(chan fileContent, len(tasks)) eg, egCtx := errgroup.WithContext(ctx) - eg.SetLimit(128) for _, task := range tasks { eg.Go(func() error { + if err := c.acquireDownloadSlot(egCtx); err != nil { + return errors.Trace(err) + } + defer c.releaseDownloadSlot() + filePath := task.dmlPathKey.GenerateDMLFilePath( &task.fileIndex, c.fileExtension, diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go index b51bdb80d1..deccdba352 100644 --- a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go @@ -19,7 +19,10 @@ import ( "path" "slices" "strings" + "sync" + "sync/atomic" "testing" + "time" "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/recorder" "github.com/pingcap/ticdc/cmd/multi-cluster-consistency-checker/types" @@ -386,16 +389,50 @@ type mockS3Storage struct { sortedFiles []mockFile } +type trackingMockS3Storage struct { + *mockS3Storage + currentConcurrent int64 + maxConcurrent int64 + delay time.Duration +} + func NewMockS3Storage(sortedFiles []mockFile) *mockS3Storage { s3Storage := &mockS3Storage{} s3Storage.UpdateFiles(sortedFiles) return s3Storage } +func NewTrackingMockS3Storage(sortedFiles []mockFile, delay time.Duration) *trackingMockS3Storage { + return &trackingMockS3Storage{ + mockS3Storage: NewMockS3Storage(sortedFiles), + delay: delay, + } +} + func (m *mockS3Storage) ReadFile(ctx context.Context, name string) ([]byte, error) { return m.sortedFiles[m.fileOffset[name]].content, nil } +func (m *trackingMockS3Storage) ReadFile(ctx context.Context, name string) ([]byte, error) { + current := atomic.AddInt64(&m.currentConcurrent, 1) + for { + max := atomic.LoadInt64(&m.maxConcurrent) + if current <= max || atomic.CompareAndSwapInt64(&m.maxConcurrent, max, current) { + break + } + } + defer atomic.AddInt64(&m.currentConcurrent, -1) + + timer := time.NewTimer(m.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + return m.mockS3Storage.ReadFile(ctx, name) +} + func (m *mockS3Storage) WalkDir(ctx context.Context, opt *storage.WalkOption, fn func(path string, size int64) error) error { filenamePrefix := path.Join(opt.SubDir, opt.ObjPrefix) for _, file := range m.sortedFiles { @@ -417,6 +454,67 @@ func (m *mockS3Storage) UpdateFiles(sortedFiles []mockFile) { m.sortedFiles = sortedFiles } +func (m *trackingMockS3Storage) MaxConcurrent() int64 { + return atomic.LoadInt64(&m.maxConcurrent) +} + +func TestDownloadDMLFilesGlobalConcurrencyLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + indexKey := cloudstorage.FileIndexKey{DispatcherID: "", EnableTableAcrossNodes: false} + dmlPathKey1 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + Date: "2026-01-01", + } + dmlPathKey2 := cloudstorage.DmlPathKey{ + SchemaPathKey: cloudstorage.SchemaPathKey{Schema: "test", Table: "t1", TableVersion: 1}, + Date: "2026-01-02", + } + + files := []mockFile{ + {name: "test/t1/1/2026-01-01/CDC00000000000000000001.json", content: []byte("f1")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000002.json", content: []byte("f2")}, + {name: "test/t1/1/2026-01-01/CDC00000000000000000003.json", content: []byte("f3")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000001.json", content: []byte("f4")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000002.json", content: []byte("f5")}, + {name: "test/t1/1/2026-01-02/CDC00000000000000000003.json", content: []byte("f6")}, + } + s3Storage := NewTrackingMockS3Storage(files, 40*time.Millisecond) + s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) + s3Consumer.skipDownloadData = false + s3Consumer.downloadLimiter = make(chan struct{}, 2) + + newFiles1 := map[cloudstorage.DmlPathKey]fileIndexRange{ + dmlPathKey1: {indexKey: {start: 1, end: 3}}, + } + newFiles2 := map[cloudstorage.DmlPathKey]fileIndexRange{ + dmlPathKey2: {indexKey: {start: 1, end: 3}}, + } + + var wg sync.WaitGroup + errCh := make(chan error, 2) + wg.Add(2) + go func() { + defer wg.Done() + _, err := s3Consumer.downloadDMLFiles(ctx, newFiles1) + errCh <- err + }() + go func() { + defer wg.Done() + _, err := s3Consumer.downloadDMLFiles(ctx, newFiles2) + errCh <- err + }() + wg.Wait() + close(errCh) + for err := range errCh { + require.NoError(t, err) + } + + require.Greater(t, s3Storage.MaxConcurrent(), int64(1)) + require.LessOrEqual(t, s3Storage.MaxConcurrent(), int64(2)) +} + func TestS3Consumer(t *testing.T) { t.Parallel() ctx := context.Background() From cd5e26132be27768b2383ff91181ad9554a94aab Mon Sep 17 00:00:00 2001 From: Jianjun Liao Date: Sun, 1 Mar 2026 00:47:53 +0800 Subject: [PATCH 10/10] limit download concurrency Signed-off-by: Jianjun Liao --- .../consumer/consumer.go | 280 +++++++++++------- .../consumer/consumer_test.go | 2 +- 2 files changed, 177 insertions(+), 105 deletions(-) diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer.go b/cmd/multi-cluster-consistency-checker/consumer/consumer.go index 650b08b440..458db8c225 100644 --- a/cmd/multi-cluster-consistency-checker/consumer/consumer.go +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer.go @@ -44,7 +44,9 @@ type indexRange struct { end uint64 } -const defaultGlobalDownloadConcurrencyLimit = 128 +const defaultGlobalReadConcurrencyLimit = 128 +const defaultGlobalWalkConcurrencyLimit = 64 +const defaultTableWorkerConcurrencyLimit = 256 func updateTableDMLIdxMap( tableDMLIdxMap map[cloudstorage.DmlPathKey]fileIndexKeyMap, @@ -225,9 +227,12 @@ type S3Consumer struct { dateSeparator string fileIndexWidth int tables map[string][]string - // downloadLimiter limits the total number of concurrent file downloads - // across all downloadDMLFiles calls on this consumer. - downloadLimiter chan struct{} + // readLimiter limits the total number of concurrent ReadFile calls. + readLimiter chan struct{} + // walkLimiter limits the total number of concurrent WalkDir calls. + walkLimiter chan struct{} + // tableWorkerConcurrencyLimit limits table-level goroutines in top-level flows. + tableWorkerConcurrencyLimit int // skip the first round data download skipDownloadData bool @@ -247,10 +252,15 @@ func NewS3Consumer( dateSeparator: config.DateSeparatorDay.String(), fileIndexWidth: config.DefaultFileIndexWidth, tables: tables, - downloadLimiter: make( + readLimiter: make( chan struct{}, - defaultGlobalDownloadConcurrencyLimit, + defaultGlobalReadConcurrencyLimit, ), + walkLimiter: make( + chan struct{}, + defaultGlobalWalkConcurrencyLimit, + ), + tableWorkerConcurrencyLimit: defaultTableWorkerConcurrencyLimit, skipDownloadData: true, @@ -260,23 +270,42 @@ func NewS3Consumer( } } -func (c *S3Consumer) acquireDownloadSlot(ctx context.Context) error { - if c.downloadLimiter == nil { +func (c *S3Consumer) acquireReadSlot(ctx context.Context) error { + if c.readLimiter == nil { return nil } select { - case c.downloadLimiter <- struct{}{}: + case c.readLimiter <- struct{}{}: return nil case <-ctx.Done(): return errors.Trace(ctx.Err()) } } -func (c *S3Consumer) releaseDownloadSlot() { - if c.downloadLimiter == nil { +func (c *S3Consumer) releaseReadSlot() { + if c.readLimiter == nil { return } - <-c.downloadLimiter + <-c.readLimiter +} + +func (c *S3Consumer) acquireWalkSlot(ctx context.Context) error { + if c.walkLimiter == nil { + return nil + } + select { + case c.walkLimiter <- struct{}{}: + return nil + case <-ctx.Done(): + return errors.Trace(ctx.Err()) + } +} + +func (c *S3Consumer) releaseWalkSlot() { + if c.walkLimiter == nil { + return + } + <-c.walkLimiter } func (c *S3Consumer) InitializeFromCheckpoint( @@ -297,6 +326,7 @@ func (c *S3Consumer) InitializeFromCheckpoint( // Combine DML data and schema data into result result := make(map[cloudstorage.DmlPathKey]types.IncrementalData) eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(c.tableWorkerConcurrencyLimit) for schemaTableKey, scanRange := range scanRanges { eg.Go(func() error { scanVersions, err := c.downloadSchemaFilesWithScanRange( @@ -362,32 +392,38 @@ func (c *S3Consumer) downloadSchemaFilesWithScanRange( VersionPath: startVersionKey, }) newVersionPaths[startSchemaKey] = startVersionKey - if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { - if endVersionKey < filePath { - return ErrWalkDirEnd + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) } - if !cloudstorage.IsSchemaFile(filePath) { - return nil - } - var schemaKey cloudstorage.SchemaPathKey - _, err := schemaKey.ParseSchemaFilePath(filePath) - if err != nil { - log.Error("failed to parse schema file path, skipping", - zap.String("path", filePath), - zap.Error(err)) - return nil - } - if schemaKey.TableVersion > startSchemaKey.TableVersion { - if _, exists := newVersionPaths[schemaKey]; !exists { - scanVersions = append(scanVersions, types.VersionKey{ - Version: schemaKey.TableVersion, - VersionPath: filePath, - }) + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endVersionKey < filePath { + return ErrWalkDirEnd } - newVersionPaths[schemaKey] = filePath - } - return nil - }); err != nil && !errors.Is(err, ErrWalkDirEnd) { + if !cloudstorage.IsSchemaFile(filePath) { + return nil + } + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if schemaKey.TableVersion > startSchemaKey.TableVersion { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: schemaKey.TableVersion, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }) + }(); err != nil && !errors.Is(err, ErrWalkDirEnd) { return nil, errors.Trace(err) } @@ -457,28 +493,34 @@ func (c *S3Consumer) getNewFilesForSchemaPathKeyWithEndPath( // TODO: StartAfter: startDataPath, } newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) - if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { - if endDataPath < filePath { - return ErrWalkDirEnd + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) } - // Try to parse DML file path if it matches the expected extension - if strings.HasSuffix(filePath, c.fileExtension) { - var dmlkey cloudstorage.DmlPathKey - fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) - if err != nil { - log.Error("failed to parse dml file path, skipping", - zap.String("path", filePath), - zap.Error(err)) - return nil + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if endDataPath < filePath { + return ErrWalkDirEnd } - if filePath == startDataPath { - c.tableDMLIdx.UpdateDMLIdxMapByStartPath(dmlkey, fileIdx) - } else { - updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + if filePath == startDataPath { + c.tableDMLIdx.UpdateDMLIdxMapByStartPath(dmlkey, fileIdx) + } else { + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + } } - } - return nil - }); err != nil && !errors.Is(err, ErrWalkDirEnd) { + return nil + }) + }(); err != nil && !errors.Is(err, ErrWalkDirEnd) { return nil, errors.Trace(err) } return c.tableDMLIdx.DiffNewTableDMLIdxMap(newTableDMLIdxMap), nil @@ -494,16 +536,7 @@ func (c *S3Consumer) downloadSchemaFiles( log.Debug("starting concurrent schema file download", zap.Int("totalSchemas", len(newVersionPaths))) for schemaPathKey, filePath := range newVersionPaths { eg.Go(func() error { - content, err := c.s3Storage.ReadFile(egCtx, filePath) - if err != nil { - return errors.Annotatef(err, "failed to read schema file: %s", filePath) - } - - tableDefinition := &cloudstorage.TableDefinition{} - if err := json.Unmarshal(content, tableDefinition); err != nil { - return errors.Annotatef(err, "failed to unmarshal schema file: %s", filePath) - } - if err := c.schemaDefinitions.SetSchemaDefinition(schemaPathKey, filePath, tableDefinition); err != nil { + if err := c.readAndParseSchemaFile(egCtx, schemaPathKey, filePath); err != nil { return errors.Trace(err) } return nil @@ -515,6 +548,32 @@ func (c *S3Consumer) downloadSchemaFiles( return nil } +// readAndParseSchemaFile reads a schema file with global read concurrency control. +func (c *S3Consumer) readAndParseSchemaFile( + ctx context.Context, + schemaPathKey cloudstorage.SchemaPathKey, + filePath string, +) error { + if err := c.acquireReadSlot(ctx); err != nil { + return errors.Trace(err) + } + defer c.releaseReadSlot() + + content, err := c.s3Storage.ReadFile(ctx, filePath) + if err != nil { + return errors.Annotatef(err, "failed to read schema file: %s", filePath) + } + + tableDefinition := &cloudstorage.TableDefinition{} + if err := json.Unmarshal(content, tableDefinition); err != nil { + return errors.Annotatef(err, "failed to unmarshal schema file: %s", filePath) + } + if err := c.schemaDefinitions.SetSchemaDefinition(schemaPathKey, filePath, tableDefinition); err != nil { + return errors.Trace(err) + } + return nil +} + func (c *S3Consumer) discoverAndDownloadNewTableVersions( ctx context.Context, schema, table string, @@ -529,30 +588,36 @@ func (c *S3Consumer) discoverAndDownloadNewTableVersions( var scanVersions []types.VersionKey newVersionPaths := make(map[cloudstorage.SchemaPathKey]string) - if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { - if !cloudstorage.IsSchemaFile(filePath) { - return nil - } - var schemaKey cloudstorage.SchemaPathKey - _, err := schemaKey.ParseSchemaFilePath(filePath) - if err != nil { - log.Error("failed to parse schema file path, skipping", - zap.String("path", filePath), - zap.Error(err)) - return nil + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) } - version := schemaKey.TableVersion - if version > currentVersion.Version { - if _, exists := newVersionPaths[schemaKey]; !exists { - scanVersions = append(scanVersions, types.VersionKey{ - Version: version, - VersionPath: filePath, - }) + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + if !cloudstorage.IsSchemaFile(filePath) { + return nil } - newVersionPaths[schemaKey] = filePath - } - return nil - }); err != nil { + var schemaKey cloudstorage.SchemaPathKey + _, err := schemaKey.ParseSchemaFilePath(filePath) + if err != nil { + log.Error("failed to parse schema file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + version := schemaKey.TableVersion + if version > currentVersion.Version { + if _, exists := newVersionPaths[schemaKey]; !exists { + scanVersions = append(scanVersions, types.VersionKey{ + Version: version, + VersionPath: filePath, + }) + } + newVersionPaths[schemaKey] = filePath + } + return nil + }) + }(); err != nil { return nil, errors.Trace(err) } @@ -580,22 +645,28 @@ func (c *S3Consumer) getNewFilesForSchemaPathKey( newTableDMLIdxMap := make(map[cloudstorage.DmlPathKey]fileIndexKeyMap) maxFilePath := "" - if err := c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { - // Try to parse DML file path if it matches the expected extension - if strings.HasSuffix(filePath, c.fileExtension) { - var dmlkey cloudstorage.DmlPathKey - fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) - if err != nil { - log.Error("failed to parse dml file path, skipping", - zap.String("path", filePath), - zap.Error(err)) - return nil - } - updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) - maxFilePath = filePath + if err := func() error { + if err := c.acquireWalkSlot(ctx); err != nil { + return errors.Trace(err) } - return nil - }); err != nil { + defer c.releaseWalkSlot() + return c.s3Storage.WalkDir(ctx, opt, func(filePath string, size int64) error { + // Try to parse DML file path if it matches the expected extension + if strings.HasSuffix(filePath, c.fileExtension) { + var dmlkey cloudstorage.DmlPathKey + fileIdx, err := dmlkey.ParseDMLFilePath(c.dateSeparator, filePath) + if err != nil { + log.Error("failed to parse dml file path, skipping", + zap.String("path", filePath), + zap.Error(err)) + return nil + } + updateTableDMLIdxMap(newTableDMLIdxMap, dmlkey, fileIdx) + maxFilePath = filePath + } + return nil + }) + }(); err != nil { return nil, errors.Trace(err) } @@ -656,10 +727,10 @@ func (c *S3Consumer) downloadDMLFiles( eg, egCtx := errgroup.WithContext(ctx) for _, task := range tasks { eg.Go(func() error { - if err := c.acquireDownloadSlot(egCtx); err != nil { + if err := c.acquireReadSlot(egCtx); err != nil { return errors.Trace(err) } - defer c.releaseDownloadSlot() + defer c.releaseReadSlot() filePath := task.dmlPathKey.GenerateDMLFilePath( &task.fileIndex, @@ -759,6 +830,7 @@ func (c *S3Consumer) ConsumeNewFiles( var versionMu sync.Mutex maxVersionMap := make(map[types.SchemaTableKey]types.VersionKey) eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(c.tableWorkerConcurrencyLimit) for schema, tables := range c.tables { for _, table := range tables { eg.Go(func() error { diff --git a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go index deccdba352..28328c7597 100644 --- a/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go +++ b/cmd/multi-cluster-consistency-checker/consumer/consumer_test.go @@ -483,7 +483,7 @@ func TestDownloadDMLFilesGlobalConcurrencyLimit(t *testing.T) { s3Storage := NewTrackingMockS3Storage(files, 40*time.Millisecond) s3Consumer := NewS3Consumer(s3Storage, map[string][]string{"test": {"t1"}}) s3Consumer.skipDownloadData = false - s3Consumer.downloadLimiter = make(chan struct{}, 2) + s3Consumer.readLimiter = make(chan struct{}, 2) newFiles1 := map[cloudstorage.DmlPathKey]fileIndexRange{ dmlPathKey1: {indexKey: {start: 1, end: 3}},