Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pkg/cli/debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ func TestRemoveDeadReplicas(t *testing.T) {
tc := testcluster.StartTestCluster(t, 3, clusterArgs)
defer tc.Stopper().Stop(ctx)

grpcConn, err := tc.Server(0).RPCContext().GRPCDial(tc.Server(0).ServingAddr()).Connect(ctx)
grpcConn, err := tc.Server(0).RPCContext().GRPCDialNode(
tc.Server(0).ServingAddr(),
tc.Server(0).NodeID(),
).Connect(ctx)
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/cli/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,9 @@ func getClientGRPCConn(ctx context.Context) (*grpc.ClientConn, *hlc.Clock, func(
stopper.Stop(ctx)
return nil, nil, nil, err
}
conn, err := rpcContext.GRPCDial(addr).Connect(ctx)
// We use GRPCUnvalidatedDial() here because it does not matter
// to which node we're talking to.
conn, err := rpcContext.GRPCUnvalidatedDial(addr).Connect(ctx)
if err != nil {
stopper.Stop(ctx)
return nil, nil, nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/gossip/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *client) startLocked(
// asynchronous from the caller's perspective, so the only effect of
// `WithBlock` here is blocking shutdown - at the time of this writing,
// that ends ups up making `kv` tests take twice as long.
conn, err := rpcCtx.GRPCDial(c.addr.String()).Connect(ctx)
conn, err := rpcCtx.GRPCUnvalidatedDial(c.addr.String()).Connect(ctx)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/gossip/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func startGossipAtAddr(
registry *metric.Registry,
) *Gossip {
rpcContext := newInsecureRPCContext(stopper)
rpcContext.NodeID.Set(context.TODO(), nodeID)
server := rpc.NewServer(rpcContext)
g := NewTest(nodeID, rpcContext, server, stopper, registry)
ln, err := netutil.ListenAndServeGRPC(stopper, server, addr)
Expand Down
5 changes: 2 additions & 3 deletions pkg/gossip/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ func TestGossipInfoStore(t *testing.T) {
}

// TestGossipMoveNode verifies that if a node is moved to a new address, it
// gets properly updated in gossip (including that any other node that was
// previously at that address gets removed from the cluster).
// gets properly updated in gossip.
func TestGossipMoveNode(t *testing.T) {
defer leaktest.AfterTest(t)()
stopper := stop.NewStopper()
Expand Down Expand Up @@ -462,7 +461,7 @@ func TestGossipNoForwardSelf(t *testing.T) {
c := newClient(log.AmbientContext{Tracer: tracing.NewTracer()}, local.GetNodeAddr(), makeMetrics())

testutils.SucceedsSoon(t, func() error {
conn, err := peer.rpcContext.GRPCDial(c.addr.String()).Connect(ctx)
conn, err := peer.rpcContext.GRPCUnvalidatedDial(c.addr.String()).Connect(ctx)
if err != nil {
return err
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/kv/send_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ func TestSendToOneClient(t *testing.T) {
stopper,
&cluster.MakeTestingClusterSettings().Version,
)

// This test uses the testing function sendBatch() which does not
// support setting the node ID on GRPCDialNode(). Disable Node ID
// checks to avoid log.Fatal.
rpcContext.TestingAllowNamedRPCToAnonymousServer = true

s := rpc.NewServer(rpcContext)
roachpb.RegisterInternalServer(s, Node(0))
ln, err := netutil.ListenAndServeGRPC(rpcContext.Stopper, s, util.TestAddr)
Expand Down Expand Up @@ -136,6 +142,10 @@ func TestComplexScenarios(t *testing.T) {
stopper,
&cluster.MakeTestingClusterSettings().Version,
)

// We're going to serve multiple node IDs with that one
// context. Disable node ID checks.
nodeContext.TestingAllowNamedRPCToAnonymousServer = true
nodeDialer := nodedialer.New(nodeContext, nil)

// TODO(bdarnell): the retryable flag is no longer used for RPC errors.
Expand Down
121 changes: 92 additions & 29 deletions pkg/rpc/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,12 @@ func NewServerWithInterceptor(

s := grpc.NewServer(opts...)
RegisterHeartbeatServer(s, &HeartbeatService{
clock: ctx.LocalClock,
remoteClockMonitor: ctx.RemoteClocks,
clusterID: &ctx.ClusterID,
version: ctx.version,
clock: ctx.LocalClock,
remoteClockMonitor: ctx.RemoteClocks,
clusterID: &ctx.ClusterID,
nodeID: &ctx.NodeID,
version: ctx.version,
testingAllowNamedRPCToAnonymousServer: ctx.TestingAllowNamedRPCToAnonymousServer,
})
return s
}
Expand All @@ -298,14 +300,20 @@ type Connection struct {
initialHeartbeatDone chan struct{} // closed after first heartbeat
stopper *stop.Stopper

// remoteNodeID implies checking the remote node ID. 0 when unknown,
// non-zero to check with remote node. This is constant throughout
// the lifetime of a Connection object.
remoteNodeID roachpb.NodeID

initOnce sync.Once
validatedOnce sync.Once
}

func newConnection(stopper *stop.Stopper) *Connection {
func newConnectionToNodeID(stopper *stop.Stopper, remoteNodeID roachpb.NodeID) *Connection {
c := &Connection{
initialHeartbeatDone: make(chan struct{}),
stopper: stopper,
remoteNodeID: remoteNodeID,
}
c.heartbeatResult.Store(heartbeatResult{err: ErrNotHeartbeated})
return c
Expand Down Expand Up @@ -372,11 +380,23 @@ type Context struct {
stats StatsHandler

ClusterID base.ClusterIDContainer
NodeID base.NodeIDContainer
version *cluster.ExposedClusterVersion

// For unittesting.
BreakerFactory func() *circuit.Breaker
testingDialOpts []grpc.DialOption

// For testing. See the comment on the same field in HeartbeatService.
TestingAllowNamedRPCToAnonymousServer bool
}

// connKey is used as key in the Context.conns map. Different remote
// node IDs get different *Connection objects, to ensure that we don't
// mis-route RPC requests.
type connKey struct {
targetAddr string
nodeID roachpb.NodeID
}

// NewContext creates an rpc Context with the supplied values.
Expand Down Expand Up @@ -422,7 +442,7 @@ func NewContext(
conn.dialErr = &roachpb.NodeUnavailableError{}
}
})
ctx.removeConn(k.(string), conn)
ctx.removeConn(k.(connKey), conn)
return true
})
})
Expand All @@ -439,8 +459,10 @@ func (ctx *Context) GetStatsMap() *syncmap.Map {

// GetLocalInternalClientForAddr returns the context's internal batch client
// for target, if it exists.
func (ctx *Context) GetLocalInternalClientForAddr(target string) roachpb.InternalClient {
if target == ctx.AdvertiseAddr {
func (ctx *Context) GetLocalInternalClientForAddr(
target string, nodeID roachpb.NodeID,
) roachpb.InternalClient {
if target == ctx.AdvertiseAddr && nodeID == ctx.NodeID.Get() {
return ctx.localInternalClient
}
return nil
Expand Down Expand Up @@ -544,15 +566,15 @@ func (ctx *Context) SetLocalInternalServer(internalServer roachpb.InternalServer
ctx.localInternalClient = internalClientAdapter{internalServer}
}

func (ctx *Context) removeConn(key string, conn *Connection) {
func (ctx *Context) removeConn(key connKey, conn *Connection) {
ctx.conns.Delete(key)
if log.V(1) {
log.Infof(ctx.masterCtx, "closing %s", key)
log.Infof(ctx.masterCtx, "closing %+v", key)
}
if grpcConn := conn.grpcConn; grpcConn != nil {
if err := grpcConn.Close(); err != nil && !grpcutil.IsClosedConnection(err) {
if log.V(1) {
log.Errorf(ctx.masterCtx, "failed to close client connection: %s", err)
log.Errorf(ctx.masterCtx, "failed to close client connection: %v", err)
}
}
}
Expand Down Expand Up @@ -675,11 +697,43 @@ func (ctx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, <-chan struct{
return conn, dialer.redialChan, err
}

// GRPCDial calls grpc.Dial with options appropriate for the context.
func (ctx *Context) GRPCDial(target string) *Connection {
value, ok := ctx.conns.Load(target)
// GRPCUnvalidatedDial uses GRPCDialNode and disables validation of the
// node ID between client and server. This function should only be
// used with the gossip client and CLI commands which can talk to any
// node.
func (ctx *Context) GRPCUnvalidatedDial(target string) *Connection {
return ctx.grpcDialNodeInternal(target, 0)
}

// GRPCDialNode calls grpc.Dial with options appropriate for the context.
//
// The remoteNodeID becomes a constraint on the expected node ID of
// the remote node; this is checked during heartbeats. The caller is
// responsible for ensuring the remote node ID is known prior to using
// this function.
func (ctx *Context) GRPCDialNode(target string, remoteNodeID roachpb.NodeID) *Connection {
if remoteNodeID == 0 && !ctx.TestingAllowNamedRPCToAnonymousServer {
log.Fatalf(context.TODO(), "invalid node ID 0 in GRPCDialNode()")
}
return ctx.grpcDialNodeInternal(target, remoteNodeID)
}

func (ctx *Context) grpcDialNodeInternal(target string, remoteNodeID roachpb.NodeID) *Connection {
thisConnKey := connKey{target, remoteNodeID}
value, ok := ctx.conns.Load(thisConnKey)
if !ok {
value, _ = ctx.conns.LoadOrStore(target, newConnection(ctx.Stopper))
value, _ = ctx.conns.LoadOrStore(thisConnKey, newConnectionToNodeID(ctx.Stopper, remoteNodeID))
if remoteNodeID != 0 {
// If the first connection established at a target address is
// for a specific node ID, then we want to reuse that connection
// also for other dials (eg for gossip) which don't require a
// specific node ID. (We do this as an optimization to reduce
// the number of TCP connections alive between nodes. This is
// not strictly required for correctness.) This LoadOrStore will
// ensure we're registering the connection we just created for
// future use by these other dials.
_, _ = ctx.conns.LoadOrStore(connKey{target, 0}, value)
}
}

conn := value.(*Connection)
Expand All @@ -694,11 +748,11 @@ func (ctx *Context) GRPCDial(target string) *Connection {
if err != nil && !grpcutil.IsClosedConnection(err) {
log.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err)
}
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
})
}); err != nil {
conn.dialErr = err
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
}
}
})
Expand All @@ -720,7 +774,7 @@ func (ctx *Context) NewBreaker(name string) *circuit.Breaker {
// the first heartbeat.
var ErrNotHeartbeated = errors.New("not yet heartbeated")

// ConnHealth returns nil if we have an open connection to the given
// TestingConnHealth returns nil if we have an open connection to the given
// target that succeeded on its most recent heartbeat. Otherwise, it
// kicks off a connection attempt (unless one is already in progress
// or we are in a backoff state) and returns an error (typically
Expand All @@ -729,27 +783,26 @@ var ErrNotHeartbeated = errors.New("not yet heartbeated")
// error will be returned. This method should therefore be used to
// prioritize among a list of candidate nodes, but not to filter out
// "unhealthy" nodes.
func (ctx *Context) ConnHealth(target string) error {
if ctx.GetLocalInternalClientForAddr(target) != nil {
//
// This is used in tests only; in clusters use (*Dialer).ConnHealth()
// instead which automates the address resolution.
//
// TODO(knz): remove this altogether. Use the dialer in all cases.
func (ctx *Context) TestingConnHealth(target string, nodeID roachpb.NodeID) error {
if ctx.GetLocalInternalClientForAddr(target, nodeID) != nil {
// The local server is always considered healthy.
return nil
}
conn := ctx.GRPCDial(target)
conn := ctx.GRPCDialNode(target, nodeID)
return conn.Health()
}

func (ctx *Context) runHeartbeat(
conn *Connection, target string, redialChan <-chan struct{},
) error {
maxOffset := ctx.LocalClock.MaxOffset()
clusterID := ctx.ClusterID.Get()
maxOffsetNanos := maxOffset.Nanoseconds()

request := PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffset.Nanoseconds(),
ClusterID: &clusterID,
ServerVersion: ctx.version.ServerVersion,
}
heartbeatClient := NewHeartbeatClient(conn.grpcConn)

var heartbeatTimer timeutil.Timer
Expand All @@ -768,14 +821,24 @@ func (ctx *Context) runHeartbeat(
heartbeatTimer.Read = true
}

// We re-mint the PingRequest to pick up any asynchronous update to clusterID.
clusterID := ctx.ClusterID.Get()
request := &PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffsetNanos,
ClusterID: &clusterID,
NodeID: conn.remoteNodeID,
ServerVersion: ctx.version.ServerVersion,
}

var response *PingResponse
sendTime := ctx.LocalClock.PhysicalTime()
err := contextutil.RunWithTimeout(ctx.masterCtx, "rpc heartbeat", ctx.heartbeatTimeout,
func(goCtx context.Context) error {
// NB: We want the request to fail-fast (the default), otherwise we won't
// be notified of transport failures.
var err error
response, err = heartbeatClient.Ping(goCtx, &request)
response, err = heartbeatClient.Ping(goCtx, request)
return err
})

Expand Down
Loading