Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions api/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ package v1

import (
"bytes"
"context"
"encoding/json"
stdErrors "errors"
"io"
"net/http"
"strconv"
"sync/atomic"

"github.com/gin-gonic/gin"
"github.com/pingcap/log"
Expand Down Expand Up @@ -192,25 +193,41 @@ func (o *OpenAPIV1) rebalanceTables(c *gin.Context) {
// drainCapture drains all tables from a capture.
// Usage:
// curl -X PUT http://127.0.0.1:8300/api/v1/captures/drain
// TODO: Implement this API in the future, currently it is a no-op.
//
// It is kept for API v1 compatibility. In the new architecture, `capture_id` is
// treated as a node ID, and the returned `current_table_count` represents the
// remaining drain work rather than a literal table count.
func (o *OpenAPIV1) drainCapture(c *gin.Context) {
var req drainCaptureRequest
if err := c.ShouldBindJSON(&req); err != nil {
_ = c.Error(errors.ErrAPIInvalidParam.Wrap(err))
return
}
drainCaptureCounter.Add(1)
if drainCaptureCounter.Load()%10 == 0 {
log.Info("api v1 drainCapture", zap.Any("captureID", req.CaptureID), zap.Int64("currentTableCount", drainCaptureCounter.Load()))
c.JSON(http.StatusAccepted, &drainCaptureResp{
CurrentTableCount: 10,
})
} else {
log.Info("api v1 drainCapture done", zap.Any("captureID", req.CaptureID), zap.Int64("currentTableCount", drainCaptureCounter.Load()))
c.JSON(http.StatusAccepted, &drainCaptureResp{
CurrentTableCount: 0,
})

coordinator, err := o.server.GetCoordinator()
if err != nil {
_ = c.Error(err)
return
}
drainable, ok := coordinator.(interface {
DrainNode(ctx context.Context, nodeID string) (int, error)
})
if !ok {
_ = c.Error(stdErrors.New("coordinator does not support node drain"))
return
}

remaining, err := drainable.DrainNode(c.Request.Context(), req.CaptureID)
if err != nil {
_ = c.Error(err)
return
}
log.Info("api v1 drainCapture",
zap.String("captureID", req.CaptureID),
zap.Int("remaining", remaining))
c.JSON(http.StatusAccepted, &drainCaptureResp{
CurrentTableCount: remaining,
})
}
Comment on lines 200 to 231

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The drainCapture API endpoint, now fully implemented and functional, lacks authenticateMiddleware in its route registration (RegisterOpenAPIV1Routes). This is a significant security vulnerability, as it allows any unauthenticated user to trigger a node drain operation, potentially leading to a Denial of Service (DoS) by draining all nodes. It is crucial to protect the captureGroup or the specific drain route with authenticateMiddleware. Furthermore, for consistency with other error handling and to provide more structured error information to clients, it's better to use a typed error from the pkg/errors package instead of a raw string error from the standard library.

        _ = c.Error(errors.ErrInternalServerError.WithMessage("coordinator does not support node drain"))


func getV2ChangefeedConfig(changefeedConfig changefeedConfig) *v2.ChangefeedConfig {
Expand Down Expand Up @@ -258,8 +275,6 @@ type drainCaptureRequest struct {
CaptureID string `json:"capture_id"`
}

var drainCaptureCounter atomic.Int64

// drainCaptureResp is response for manual `DrainCapture`
type drainCaptureResp struct {
CurrentTableCount int `json:"current_table_count"`
Expand Down
3 changes: 2 additions & 1 deletion api/v2/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import (
// @Router /api/v2/health [get]
func (h *OpenAPIV2) ServerHealth(c *gin.Context) {
liveness := h.server.Liveness()
if liveness != api.LivenessCaptureAlive {
// Draining is a pre-offline state and should not be treated as unhealthy.
if liveness != api.LivenessCaptureAlive && liveness != api.LivenessCaptureDraining {
err := errors.ErrClusterIsUnhealthy.FastGenByArgs()
_ = c.Error(err)
return
Expand Down
44 changes: 42 additions & 2 deletions coordinator/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (

"github.com/pingcap/log"
"github.com/pingcap/ticdc/coordinator/changefeed"
"github.com/pingcap/ticdc/coordinator/drain"
"github.com/pingcap/ticdc/coordinator/nodeliveness"
"github.com/pingcap/ticdc/coordinator/operator"
coscheduler "github.com/pingcap/ticdc/coordinator/scheduler"
"github.com/pingcap/ticdc/heartbeatpb"
Expand Down Expand Up @@ -86,6 +88,9 @@ type Controller struct {

changefeedChangeCh chan []*changefeedChange
apiLock sync.RWMutex

livenessView *nodeliveness.View
drainController *drain.Controller
}

type changefeedChange struct {
Expand Down Expand Up @@ -116,7 +121,9 @@ func NewController(
balanceInterval time.Duration,
pdClient pd.Client,
) *Controller {
mc := appcontext.GetService[messaging.MessageCenter](appcontext.MessageCenter)
changefeedDB := changefeed.NewChangefeedDB(version)
livenessView := nodeliveness.NewView(30 * time.Second)

oc := operator.NewOperatorController(selfNode, changefeedDB, backend, batchSize)
c := &Controller{
Expand All @@ -129,27 +136,38 @@ func NewController(
batchSize,
oc,
changefeedDB,
livenessView,
),
scheduler.BalanceScheduler: coscheduler.NewBalanceScheduler(
selfNode.ID.String(),
batchSize,
oc,
changefeedDB,
balanceInterval,
livenessView,
),
scheduler.DrainScheduler: coscheduler.NewDrainScheduler(
selfNode.ID.String(),
batchSize,
oc,
changefeedDB,
livenessView,
),
}),
eventCh: eventCh,
operatorController: oc,
messageCenter: appcontext.GetService[messaging.MessageCenter](appcontext.MessageCenter),
messageCenter: mc,
changefeedDB: changefeedDB,
nodeManager: appcontext.GetService[*watcher.NodeManager](watcher.NodeManagerName),
taskScheduler: threadpool.NewThreadPoolDefault(),
backend: backend,
changefeedChangeCh: changefeedChangeCh,
pdClient: pdClient,
pdClock: appcontext.GetService[pdutil.Clock](appcontext.DefaultPDClock),
livenessView: livenessView,
}
c.nodeChanged.changed = false
c.drainController = drain.NewController(mc, livenessView, changefeedDB, oc)

c.bootstrapper = bootstrap.NewBootstrapper[heartbeatpb.CoordinatorBootstrapResponse](
bootstrapperID,
Expand Down Expand Up @@ -273,6 +291,16 @@ func (c *Controller) onMessage(ctx context.Context, msg *messaging.TargetMessage
req := msg.Message[0].(*heartbeatpb.MaintainerHeartbeat)
c.handleMaintainerStatus(msg.From, req.Statuses)
}
case messaging.TypeNodeHeartbeatRequest:
if c.livenessView != nil {
req := msg.Message[0].(*heartbeatpb.NodeHeartbeat)
c.livenessView.HandleNodeHeartbeat(msg.From, req, time.Now())
}
case messaging.TypeSetNodeLivenessResponse:
if c.livenessView != nil {
resp := msg.Message[0].(*heartbeatpb.SetNodeLivenessResponse)
c.livenessView.HandleSetNodeLivenessResponse(msg.From, resp, time.Now())
}
case messaging.TypeLogCoordinatorResolvedTsResponse:
c.onLogCoordinatorReportResolvedTs(msg)
default:
Expand Down Expand Up @@ -327,6 +355,17 @@ func (c *Controller) RequestResolvedTsFromLogCoordinator(ctx context.Context, ch
}
}

// DrainNode requests draining a node and returns the remaining drain work.
//
// It is exposed for API v1 compatibility and should be safe to call repeatedly.
func (c *Controller) DrainNode(nodeID node.ID) int {
if c.drainController == nil {
return 1
}
c.drainController.RequestDrain(nodeID)
return c.drainController.Remaining(nodeID)
}

func (c *Controller) onNodeChanged(ctx context.Context) {
addedNodes, removedNodes, requests, responses := c.bootstrapper.HandleNodesChange(c.nodeManager.GetAliveNodes())
log.Info("controller detects node changed",
Expand Down Expand Up @@ -584,7 +623,8 @@ func (c *Controller) finishBootstrap(ctx context.Context, runningChangefeeds map
defer c.taskHandlerMutex.Unlock()
c.taskHandlers = append(c.taskHandlers, c.scheduler.Start(c.taskScheduler)...)
operatorControllerHandle := c.taskScheduler.Submit(c.operatorController, time.Now())
c.taskHandlers = append(c.taskHandlers, operatorControllerHandle)
drainControllerHandle := c.taskScheduler.Submit(c.drainController, time.Now())
c.taskHandlers = append(c.taskHandlers, operatorControllerHandle, drainControllerHandle)
c.initialized.Store(true)
log.Info("coordinator bootstrapped", zap.Any("nodeID", c.selfNode.ID))
}
Expand Down
8 changes: 8 additions & 0 deletions coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ func (c *coordinator) RequestResolvedTsFromLogCoordinator(ctx context.Context, c
c.controller.RequestResolvedTsFromLogCoordinator(ctx, changefeedDisplayName)
}

// DrainNode requests draining a node and returns the remaining drain work.
//
// This is used by API v1 compatibility handler. It is intentionally not part of
// pkg/server.Coordinator interface to avoid broad interface changes.
func (c *coordinator) DrainNode(_ context.Context, nodeID string) (int, error) {
return c.controller.DrainNode(node.ID(nodeID)), nil
}

func (c *coordinator) sendMessages(msgs []*messaging.TargetMessage) {
for _, msg := range msgs {
err := c.mc.SendCommand(msg)
Expand Down
68 changes: 37 additions & 31 deletions coordinator/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/pprof"
"strconv"
"sync"
Expand Down Expand Up @@ -290,16 +291,28 @@ func (m *mockMaintainerManager) sendHeartbeat() {
}
}

func newTestNodeWithListener(t *testing.T) (*node.Info, net.Listener) {
t.Helper()

// Use a random loopback port to avoid collisions when running tests from
// different packages in parallel.
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { _ = lis.Close() })

n := node.NewInfo(lis.Addr().String(), "")
return n, lis
}

func TestCoordinatorScheduling(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
go func() {
t.Fatal(http.ListenAndServe(":38300", mux))
}()
pprofServer := httptest.NewServer(mux)
defer pprofServer.Close()

ctx := context.Background()
info := node.NewInfo("127.0.0.1:8700", "")
Expand Down Expand Up @@ -373,22 +386,21 @@ func TestCoordinatorScheduling(t *testing.T) {
}

func TestScaleNode(t *testing.T) {
ctx := context.Background()
info := node.NewInfo("127.0.0.1:28300", "")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

info, lis1 := newTestNodeWithListener(t)
etcdClient := newMockEtcdClient(string(info.ID))
nodeManager := watcher.NewNodeManager(nil, etcdClient)
appcontext.SetService(watcher.NodeManagerName, nodeManager)
nodeManager.GetAliveNodes()[info.ID] = info
cfg := config.NewDefaultMessageCenterConfig(info.AdvertiseAddr)
mc1 := messaging.NewMessageCenter(ctx, info.ID, cfg, nil)
mc1.Run(ctx)
defer func() {
mc1.Close()
log.Info("close message center 1")
}()

appcontext.SetService(appcontext.MessageCenter, mc1)
startMaintainerNode(ctx, info, mc1, nodeManager)
n1 := startMaintainerNode(ctx, info, lis1, mc1, nodeManager)
defer n1.stop()

serviceID := "default"

Expand Down Expand Up @@ -426,23 +438,17 @@ func TestScaleNode(t *testing.T) {
}, waitTime, time.Millisecond*5)

// add two nodes
info2 := node.NewInfo("127.0.0.1:28400", "")
info2, lis2 := newTestNodeWithListener(t)
mc2 := messaging.NewMessageCenter(ctx, info2.ID, config.NewDefaultMessageCenterConfig(info2.AdvertiseAddr), nil)
mc2.Run(ctx)
defer func() {
mc2.Close()
log.Info("close message center 2")
}()
startMaintainerNode(ctx, info2, mc2, nodeManager)
info3 := node.NewInfo("127.0.0.1:28500", "")
n2 := startMaintainerNode(ctx, info2, lis2, mc2, nodeManager)
defer n2.stop()

info3, lis3 := newTestNodeWithListener(t)
mc3 := messaging.NewMessageCenter(ctx, info3.ID, config.NewDefaultMessageCenterConfig(info3.AdvertiseAddr), nil)
mc3.Run(ctx)
defer func() {
mc3.Close()
log.Info("close message center 3")
}()

startMaintainerNode(ctx, info3, mc3, nodeManager)
n3 := startMaintainerNode(ctx, info3, lis3, mc3, nodeManager)
defer n3.stop()

log.Info("Start maintainer node",
zap.Stringer("id", info3.ID),
Expand Down Expand Up @@ -488,7 +494,8 @@ func TestScaleNode(t *testing.T) {
func TestBootstrapWithUnStoppedChangefeed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
info := node.NewInfo("127.0.0.1:28301", "")

info, lis := newTestNodeWithListener(t)
etcdClient := newMockEtcdClient(string(info.ID))
nodeManager := watcher.NewNodeManager(nil, etcdClient)
appcontext.SetService(watcher.NodeManagerName, nodeManager)
Expand All @@ -498,10 +505,10 @@ func TestBootstrapWithUnStoppedChangefeed(t *testing.T) {

mc1 := messaging.NewMessageCenter(ctx, info.ID, config.NewDefaultMessageCenterConfig(info.AdvertiseAddr), nil)
mc1.Run(ctx)
defer mc1.Close()

appcontext.SetService(appcontext.MessageCenter, mc1)
mNode := startMaintainerNode(ctx, info, mc1, nodeManager)
mNode := startMaintainerNode(ctx, info, lis, mc1, nodeManager)
defer mNode.stop()

removingCf1 := &changefeed.ChangefeedMetaWrapper{
Info: &config.ChangeFeedInfo{
Expand Down Expand Up @@ -718,7 +725,9 @@ func (d *maintainNode) stop() {
}

func startMaintainerNode(ctx context.Context,
node *node.Info, mc messaging.MessageCenter,
node *node.Info,
lis net.Listener,
mc messaging.MessageCenter,
nodeManager *watcher.NodeManager,
) *maintainNode {
nodeManager.RegisterNodeChangeHandler(node.ID, mc.OnNodeChanges)
Expand All @@ -729,15 +738,12 @@ func startMaintainerNode(ctx context.Context,
grpcServer := grpc.NewServer(opts...)
mcs := messaging.NewMessageCenterServer(mc)
proto.RegisterMessageServiceServer(grpcServer, mcs)
lis, err := net.Listen("tcp", node.AdvertiseAddr)
if err != nil {
panic(err)
}
go func() {
_ = grpcServer.Serve(lis)
}()
_ = maintainerM.Run(ctx)
grpcServer.Stop()
_ = lis.Close()
}()
return &maintainNode{
cancel: cancel,
Expand Down
Loading