Skip to content

Commit d999e24

Browse files
committed
server,rpc: validate node IDs in RPC heartbeats
Prior to this patch, it was possible for a RPC client to dial a node ID and get a connection to another node instead. This is because the mapping of node ID -> address may be stale, and a different node could take the address of the intended node from "under" the dialer. (See the previous commit for a scenario.) This happened to be "safe" in many cases where it matters because: - RPC requests for distSQL are OK with being served on a different node than intended (with potential performance drop); - RPC requests to the KV layer are OK with being served on a different node than intended (they would route underneath); - RPC requests to the storage layer are rejected by the remote node because the store ID in the request would not match. However this safety is largely accidental, and we should not work with the assumption that any RPC request is safe to be mis-routed. (In fact, we have not audited all the RPC endpoints and cannot establish this safety exists throughout.) This patch works to prevent these mis-routings by adding a check of the intended node ID during RPC heartbeats (including the initial heartbeat), when the intended node ID is known. A new API `GRPCDialNode()` is introduced to establish such connections. Release note (bug fix): CockroachDB now performs fewer attempts to communicate with the wrong node, when a node is restarted with another node's address.
1 parent 295b6ae commit d999e24

File tree

18 files changed

+326
-78
lines changed

18 files changed

+326
-78
lines changed

pkg/gossip/client_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func startGossipAtAddr(
6666
registry *metric.Registry,
6767
) *Gossip {
6868
rpcContext := newInsecureRPCContext(stopper)
69+
rpcContext.NodeID.Set(context.TODO(), nodeID)
6970
server := rpc.NewServer(rpcContext)
7071
g := NewTest(nodeID, rpcContext, server, stopper, registry)
7172
ln, err := netutil.ListenAndServeGRPC(stopper, server, addr)

pkg/gossip/gossip_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ func TestGossipInfoStore(t *testing.T) {
6161
}
6262

6363
// TestGossipMoveNode verifies that if a node is moved to a new address, it
64-
// gets properly updated in gossip (including that any other node that was
65-
// previously at that address gets removed from the cluster).
64+
// gets properly updated in gossip.
6665
func TestGossipMoveNode(t *testing.T) {
6766
defer leaktest.AfterTest(t)()
6867
stopper := stop.NewStopper()

pkg/rpc/context.go

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ func NewServerWithInterceptor(
253253
clock: ctx.LocalClock,
254254
remoteClockMonitor: ctx.RemoteClocks,
255255
clusterID: &ctx.ClusterID,
256+
nodeID: &ctx.NodeID,
256257
version: ctx.version,
257258
})
258259
return s
@@ -272,14 +273,20 @@ type Connection struct {
272273
initialHeartbeatDone chan struct{} // closed after first heartbeat
273274
stopper *stop.Stopper
274275

276+
// remoteNodeID implies checking the remote node ID. 0 when unknown,
277+
// non-zero to check with remote node. This is constant throughout
278+
// the lifetime of a Connection object.
279+
remoteNodeID roachpb.NodeID
280+
275281
initOnce sync.Once
276282
validatedOnce sync.Once
277283
}
278284

279-
func newConnection(stopper *stop.Stopper) *Connection {
285+
func newConnectionToNodeID(stopper *stop.Stopper, remoteNodeID roachpb.NodeID) *Connection {
280286
c := &Connection{
281287
initialHeartbeatDone: make(chan struct{}),
282288
stopper: stopper,
289+
remoteNodeID: remoteNodeID,
283290
}
284291
c.heartbeatResult.Store(heartbeatResult{err: ErrNotHeartbeated})
285292
return c
@@ -346,13 +353,22 @@ type Context struct {
346353
stats StatsHandler
347354

348355
ClusterID base.ClusterIDContainer
356+
NodeID base.NodeIDContainer
349357
version *cluster.ExposedClusterVersion
350358

351359
// For unittesting.
352360
BreakerFactory func() *circuit.Breaker
353361
testingDialOpts []grpc.DialOption
354362
}
355363

364+
// connKey is used as key in the Context.conns map. Different remote
365+
// node IDs get different *Connection objects, to ensure that we don't
366+
// mis-route RPC requests.
367+
type connKey struct {
368+
targetAddr string
369+
nodeID roachpb.NodeID
370+
}
371+
356372
// NewContext creates an rpc Context with the supplied values.
357373
func NewContext(
358374
ambient log.AmbientContext,
@@ -396,7 +412,7 @@ func NewContext(
396412
conn.dialErr = &roachpb.NodeUnavailableError{}
397413
}
398414
})
399-
ctx.removeConn(k.(string), conn)
415+
ctx.removeConn(k.(connKey), conn)
400416
return true
401417
})
402418
})
@@ -518,15 +534,15 @@ func (ctx *Context) SetLocalInternalServer(internalServer roachpb.InternalServer
518534
ctx.localInternalClient = internalClientAdapter{internalServer}
519535
}
520536

521-
func (ctx *Context) removeConn(key string, conn *Connection) {
537+
func (ctx *Context) removeConn(key connKey, conn *Connection) {
522538
ctx.conns.Delete(key)
523539
if log.V(1) {
524-
log.Infof(ctx.masterCtx, "closing %s", key)
540+
log.Infof(ctx.masterCtx, "closing %+v", key)
525541
}
526542
if grpcConn := conn.grpcConn; grpcConn != nil {
527543
if err := grpcConn.Close(); err != nil && !grpcutil.IsClosedConnection(err) {
528544
if log.V(1) {
529-
log.Errorf(ctx.masterCtx, "failed to close client connection: %s", err)
545+
log.Errorf(ctx.masterCtx, "failed to close client connection: %v", err)
530546
}
531547
}
532548
}
@@ -650,10 +666,35 @@ func (ctx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, <-chan struct{
650666
}
651667

652668
// GRPCDial calls grpc.Dial with options appropriate for the context.
669+
//
670+
// It does not require validation of the node ID between client and server:
671+
// if a connection existed already with some node ID requirement, that
672+
// requirement will remain; if no connection existed yet,
673+
// a new one is created without a node ID requirement.
653674
func (ctx *Context) GRPCDial(target string) *Connection {
654-
value, ok := ctx.conns.Load(target)
675+
return ctx.GRPCDialNode(target, 0)
676+
}
677+
678+
// GRPCDialNode calls grpc.Dial with options appropriate for the context.
679+
//
680+
// The remoteNodeID, if non-zero, becomes a constraint on the expected
681+
// node ID of the remote node; this is checked during heartbeats.
682+
func (ctx *Context) GRPCDialNode(target string, remoteNodeID roachpb.NodeID) *Connection {
683+
thisConnKey := connKey{target, remoteNodeID}
684+
value, ok := ctx.conns.Load(thisConnKey)
655685
if !ok {
656-
value, _ = ctx.conns.LoadOrStore(target, newConnection(ctx.Stopper))
686+
value, _ = ctx.conns.LoadOrStore(thisConnKey, newConnectionToNodeID(ctx.Stopper, remoteNodeID))
687+
if remoteNodeID != 0 {
688+
// If the first connection established at a target address is
689+
// for a specific node ID, then we want to reuse that connection
690+
// also for other dials (eg for gossip) which don't require a
691+
// specific node ID. (We do this as an optimization to reduce
692+
// the number of TCP connections alive between nodes. This is
693+
// not strictly required for correctness.) This LoadOrStore will
694+
// ensure we're registering the connection we just created for
695+
// future use by these other dials.
696+
_, _ = ctx.conns.LoadOrStore(connKey{target, 0}, value)
697+
}
657698
}
658699

659700
conn := value.(*Connection)
@@ -668,11 +709,11 @@ func (ctx *Context) GRPCDial(target string) *Connection {
668709
if err != nil && !grpcutil.IsClosedConnection(err) {
669710
log.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err)
670711
}
671-
ctx.removeConn(target, conn)
712+
ctx.removeConn(thisConnKey, conn)
672713
})
673714
}); err != nil {
674715
conn.dialErr = err
675-
ctx.removeConn(target, conn)
716+
ctx.removeConn(thisConnKey, conn)
676717
}
677718
}
678719
})
@@ -703,6 +744,9 @@ var ErrNotHeartbeated = errors.New("not yet heartbeated")
703744
// error will be returned. This method should therefore be used to
704745
// prioritize among a list of candidate nodes, but not to filter out
705746
// "unhealthy" nodes.
747+
//
748+
// This is used in tests only; in clusters use (*Dialer).ConnHealth()
749+
// instead which validates the node ID.
706750
func (ctx *Context) ConnHealth(target string) error {
707751
if ctx.GetLocalInternalClientForAddr(target) != nil {
708752
// The local server is always considered healthy.
@@ -716,14 +760,8 @@ func (ctx *Context) runHeartbeat(
716760
conn *Connection, target string, redialChan <-chan struct{},
717761
) error {
718762
maxOffset := ctx.LocalClock.MaxOffset()
719-
clusterID := ctx.ClusterID.Get()
763+
maxOffsetNanos := maxOffset.Nanoseconds()
720764

721-
request := PingRequest{
722-
Addr: ctx.Addr,
723-
MaxOffsetNanos: maxOffset.Nanoseconds(),
724-
ClusterID: &clusterID,
725-
ServerVersion: ctx.version.ServerVersion,
726-
}
727765
heartbeatClient := NewHeartbeatClient(conn.grpcConn)
728766

729767
var heartbeatTimer timeutil.Timer
@@ -748,9 +786,18 @@ func (ctx *Context) runHeartbeat(
748786
goCtx, cancel = context.WithTimeout(goCtx, hbTimeout)
749787
}
750788
sendTime := ctx.LocalClock.PhysicalTime()
789+
// We re-mint the PingRequest to pick up any asynchronous update to clusterID.
790+
clusterID := ctx.ClusterID.Get()
791+
request := &PingRequest{
792+
Addr: ctx.Addr,
793+
MaxOffsetNanos: maxOffsetNanos,
794+
ClusterID: &clusterID,
795+
NodeID: conn.remoteNodeID,
796+
ServerVersion: ctx.version.ServerVersion,
797+
}
751798
// NB: We want the request to fail-fast (the default), otherwise we won't
752799
// be notified of transport failures.
753-
response, err := heartbeatClient.Ping(goCtx, &request)
800+
response, err := heartbeatClient.Ping(goCtx, request)
754801
if cancel != nil {
755802
cancel()
756803
}

pkg/rpc/context_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func TestHeartbeatCB(t *testing.T) {
9393
clock: clock,
9494
remoteClockMonitor: serverCtx.RemoteClocks,
9595
clusterID: &serverCtx.ClusterID,
96+
nodeID: &serverCtx.NodeID,
9697
version: serverCtx.version,
9798
})
9899

@@ -341,6 +342,7 @@ func TestHeartbeatHealthTransport(t *testing.T) {
341342
clock: clock,
342343
remoteClockMonitor: serverCtx.RemoteClocks,
343344
clusterID: &serverCtx.ClusterID,
345+
nodeID: &serverCtx.NodeID,
344346
version: serverCtx.version,
345347
})
346348

@@ -515,6 +517,7 @@ func TestOffsetMeasurement(t *testing.T) {
515517
clock: serverClock,
516518
remoteClockMonitor: serverCtx.RemoteClocks,
517519
clusterID: &serverCtx.ClusterID,
520+
nodeID: &serverCtx.NodeID,
518521
version: serverCtx.version,
519522
})
520523

@@ -682,6 +685,7 @@ func TestRemoteOffsetUnhealthy(t *testing.T) {
682685
clock: clock,
683686
remoteClockMonitor: nodeCtxs[i].ctx.RemoteClocks,
684687
clusterID: &nodeCtxs[i].ctx.ClusterID,
688+
nodeID: &nodeCtxs[i].ctx.NodeID,
685689
version: nodeCtxs[i].ctx.version,
686690
})
687691
ln, err := netutil.ListenAndServeGRPC(nodeCtxs[i].ctx.Stopper, s, util.TestAddr)
@@ -829,6 +833,7 @@ func TestGRPCKeepaliveFailureFailsInflightRPCs(t *testing.T) {
829833
clock: clock,
830834
remoteClockMonitor: serverCtx.RemoteClocks,
831835
clusterID: &serverCtx.ClusterID,
836+
nodeID: &serverCtx.NodeID,
832837
version: serverCtx.version,
833838
},
834839
interval: msgInterval,
@@ -1011,6 +1016,50 @@ func TestClusterIDMismatch(t *testing.T) {
10111016
wg.Wait()
10121017
}
10131018

1019+
func TestNodeIDMismatch(t *testing.T) {
1020+
defer leaktest.AfterTest(t)()
1021+
1022+
stopper := stop.NewStopper()
1023+
defer stopper.Stop(context.TODO())
1024+
1025+
clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
1026+
serverCtx := newTestContext(clock, stopper)
1027+
uuid1 := uuid.MakeV4()
1028+
serverCtx.ClusterID.Set(context.TODO(), uuid1)
1029+
serverCtx.NodeID.Set(context.TODO(), 1)
1030+
s := newTestServer(t, serverCtx)
1031+
RegisterHeartbeatServer(s, &HeartbeatService{
1032+
clock: clock,
1033+
remoteClockMonitor: serverCtx.RemoteClocks,
1034+
clusterID: &serverCtx.ClusterID,
1035+
nodeID: &serverCtx.NodeID,
1036+
version: serverCtx.version,
1037+
})
1038+
1039+
ln, err := netutil.ListenAndServeGRPC(serverCtx.Stopper, s, util.TestAddr)
1040+
if err != nil {
1041+
t.Fatal(err)
1042+
}
1043+
remoteAddr := ln.Addr().String()
1044+
1045+
clientCtx := newTestContext(clock, stopper)
1046+
clientCtx.ClusterID.Set(context.TODO(), uuid1)
1047+
1048+
var wg sync.WaitGroup
1049+
for i := 0; i < 10; i++ {
1050+
wg.Add(1)
1051+
go func() {
1052+
_, err := clientCtx.GRPCDialNode(remoteAddr, 2).Connect(context.Background())
1053+
expected := "initial connection heartbeat failed.*doesn't match server node ID"
1054+
if !testutils.IsError(err, expected) {
1055+
t.Errorf("expected %s error, got %v", expected, err)
1056+
}
1057+
wg.Done()
1058+
}()
1059+
}
1060+
wg.Wait()
1061+
}
1062+
10141063
func setVersion(c *Context, v roachpb.Version) error {
10151064
settings := cluster.MakeClusterSettings(v, v)
10161065
cv := cluster.ClusterVersion{Version: v}
@@ -1048,6 +1097,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
10481097
clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
10491098
serverCtx := newTestContext(clock, stopper)
10501099
serverCtx.ClusterID.Set(context.TODO(), uuid.MakeV4())
1100+
serverCtx.NodeID.Set(context.TODO(), 1)
10511101
if err := setVersion(serverCtx, td.serverVersion); err != nil {
10521102
t.Fatal(err)
10531103
}
@@ -1056,6 +1106,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
10561106
clock: clock,
10571107
remoteClockMonitor: serverCtx.RemoteClocks,
10581108
clusterID: &serverCtx.ClusterID,
1109+
nodeID: &serverCtx.NodeID,
10591110
version: serverCtx.version,
10601111
})
10611112

pkg/rpc/heartbeat.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/cockroachdb/cockroach/pkg/roachpb"
2424
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
2525
"github.com/cockroachdb/cockroach/pkg/util/hlc"
26+
"github.com/cockroachdb/cockroach/pkg/util/log"
2627
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
2728
"github.com/cockroachdb/cockroach/pkg/util/uuid"
2829
"github.com/pkg/errors"
@@ -48,6 +49,7 @@ type HeartbeatService struct {
4849
// shared by rpc clients, to keep track of remote clock measurements.
4950
remoteClockMonitor *RemoteClockMonitor
5051
clusterID *base.ClusterIDContainer
52+
nodeID *base.NodeIDContainer
5153
version *cluster.ExposedClusterVersion
5254
}
5355

@@ -74,18 +76,52 @@ func checkVersion(
7476
return nil
7577
}
7678

79+
// TestingAllowNamedRPCToAnonymousServer, when called (in tests),
80+
// disables errors in case a heartbeat requests a specific node ID but
81+
// the remote node doesn't have a node ID yet. This testing knob is
82+
// currently used by the multiTestContext which does not suitably
83+
// populate separate node IDs for each heartbeat service.
84+
// The returned callback should be called to cancel the effect.
85+
func TestingAllowNamedRPCToAnonymousServer() func() {
86+
old := testingAllowNamedRPCToAnonymousServer
87+
testingAllowNamedRPCToAnonymousServer = true
88+
return func() { testingAllowNamedRPCToAnonymousServer = old }
89+
}
90+
91+
var testingAllowNamedRPCToAnonymousServer bool
92+
7793
// Ping echos the contents of the request to the response, and returns the
7894
// server's current clock value, allowing the requester to measure its clock.
7995
// The requester should also estimate its offset from this server along
8096
// with the requester's address.
8197
func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingResponse, error) {
98+
if log.V(2) {
99+
log.Infof(ctx, "received heartbeat: %+v vs local cluster %+v node %+v", args, hs.clusterID, hs.nodeID)
100+
}
82101
// Check that cluster IDs match.
83102
clusterID := hs.clusterID.Get()
84103
if args.ClusterID != nil && *args.ClusterID != uuid.Nil && clusterID != uuid.Nil &&
85104
*args.ClusterID != clusterID {
86105
return nil, errors.Errorf(
87106
"client cluster ID %q doesn't match server cluster ID %q", args.ClusterID, clusterID)
88107
}
108+
// Check that node IDs match.
109+
var nodeID roachpb.NodeID
110+
if hs.nodeID != nil {
111+
nodeID = hs.nodeID.Get()
112+
}
113+
if args.NodeID != 0 && (!testingAllowNamedRPCToAnonymousServer || nodeID != 0) && args.NodeID != nodeID {
114+
// If nodeID != 0, the situation is clear (we are checking that
115+
// the other side is talking to the right node).
116+
//
117+
// If nodeID == 0 this means that this node (serving the
118+
// heartbeat) doesn't have a node ID yet. Then we can't serve
119+
// connections for other nodes that want a specific node ID,
120+
// however we can still serve connections that don't need a node
121+
// ID, e.g. during initial gossip.
122+
return nil, errors.Errorf(
123+
"client requested node ID %d doesn't match server node ID %d", args.NodeID, nodeID)
124+
}
89125

90126
// Check version compatibility.
91127
if err := checkVersion(hs.version, args.ServerVersion); err != nil {

0 commit comments

Comments
 (0)