|
1 | 1 | package consul
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bufio" |
4 | 5 | "bytes"
|
| 6 | + "crypto/x509" |
5 | 7 | "encoding/binary"
|
6 | 8 | "errors"
|
| 9 | + "fmt" |
| 10 | + "io" |
| 11 | + "io/ioutil" |
7 | 12 | "math"
|
8 | 13 | "net"
|
9 | 14 | "os"
|
| 15 | + "path/filepath" |
10 | 16 | "strings"
|
11 | 17 | "sync"
|
12 | 18 | "testing"
|
13 | 19 | "time"
|
14 | 20 |
|
| 21 | + "github.com/hashicorp/go-hclog" |
| 22 | + "github.com/hashicorp/go-memdb" |
| 23 | + "github.com/hashicorp/go-msgpack/codec" |
| 24 | + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" |
| 25 | + "github.com/hashicorp/raft" |
| 26 | + "github.com/stretchr/testify/assert" |
| 27 | + "github.com/stretchr/testify/require" |
| 28 | + |
| 29 | + "github.com/hashicorp/consul/agent/connect" |
| 30 | + |
15 | 31 | "github.com/hashicorp/consul/acl"
|
16 | 32 | "github.com/hashicorp/consul/agent/consul/state"
|
17 | 33 | "github.com/hashicorp/consul/agent/pool"
|
18 | 34 | "github.com/hashicorp/consul/agent/structs"
|
19 | 35 | tokenStore "github.com/hashicorp/consul/agent/token"
|
20 | 36 | "github.com/hashicorp/consul/api"
|
| 37 | + "github.com/hashicorp/consul/sdk/testutil" |
21 | 38 | "github.com/hashicorp/consul/sdk/testutil/retry"
|
22 | 39 | "github.com/hashicorp/consul/testrpc"
|
23 |
| - "github.com/hashicorp/go-memdb" |
24 |
| - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" |
25 |
| - "github.com/stretchr/testify/assert" |
26 |
| - "github.com/stretchr/testify/require" |
| 40 | + "github.com/hashicorp/consul/tlsutil" |
27 | 41 | )
|
28 | 42 |
|
29 | 43 | func TestRPC_NoLeader_Fail(t *testing.T) {
|
@@ -648,10 +662,10 @@ func TestRPC_RPCMaxConnsPerClient(t *testing.T) {
|
648 | 662 | magicByte pool.RPCType
|
649 | 663 | tlsEnabled bool
|
650 | 664 | }{
|
651 |
| - {"RPC", pool.RPCMultiplexV2, false}, |
652 |
| - {"RPC TLS", pool.RPCMultiplexV2, true}, |
653 |
| - {"Raft", pool.RPCRaft, false}, |
654 |
| - {"Raft TLS", pool.RPCRaft, true}, |
| 665 | + {"RPC v2", pool.RPCMultiplexV2, false}, |
| 666 | + {"RPC v2 TLS", pool.RPCMultiplexV2, true}, |
| 667 | + {"RPC", pool.RPCConsul, false}, |
| 668 | + {"RPC TLS", pool.RPCConsul, true}, |
655 | 669 | }
|
656 | 670 |
|
657 | 671 | for _, tc := range cases {
|
@@ -913,3 +927,262 @@ func TestRPC_LocalTokenStrippedOnForward(t *testing.T) {
|
913 | 927 | require.NoError(t, err)
|
914 | 928 | require.Equal(t, localToken2.SecretID, arg.WriteRequest.Token, "token should not be stripped")
|
915 | 929 | }
|
| 930 | + |
| 931 | +func TestRPC_AuthorizeRaftRPC(t *testing.T) { |
| 932 | + caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "consul"}) |
| 933 | + require.NoError(t, err) |
| 934 | + |
| 935 | + dir := testutil.TempDir(t, "certs") |
| 936 | + err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600) |
| 937 | + require.NoError(t, err) |
| 938 | + |
| 939 | + newCert := func(t *testing.T, caPEM, pk, node, name string) { |
| 940 | + t.Helper() |
| 941 | + |
| 942 | + signer, err := tlsutil.ParseSigner(pk) |
| 943 | + require.NoError(t, err) |
| 944 | + |
| 945 | + pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{ |
| 946 | + Signer: signer, |
| 947 | + CA: caPEM, |
| 948 | + Name: name, |
| 949 | + Days: 5, |
| 950 | + DNSNames: []string{node + "." + name, name, "localhost"}, |
| 951 | + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, |
| 952 | + }) |
| 953 | + require.NoError(t, err) |
| 954 | + |
| 955 | + err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600) |
| 956 | + require.NoError(t, err) |
| 957 | + err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600) |
| 958 | + require.NoError(t, err) |
| 959 | + } |
| 960 | + |
| 961 | + newCert(t, caPEM, pk, "srv1", "server.dc1.consul") |
| 962 | + |
| 963 | + _, connectCApk, err := connect.GeneratePrivateKey() |
| 964 | + require.NoError(t, err) |
| 965 | + |
| 966 | + _, srv := testServerWithConfig(t, func(c *Config) { |
| 967 | + c.Domain = "consul." // consul. is the default value in agent/config |
| 968 | + c.CAFile = filepath.Join(dir, "ca.pem") |
| 969 | + c.CertFile = filepath.Join(dir, "srv1-server.dc1.consul.pem") |
| 970 | + c.KeyFile = filepath.Join(dir, "srv1-server.dc1.consul.key") |
| 971 | + c.VerifyIncoming = true |
| 972 | + c.VerifyServerHostname = true |
| 973 | + // Enable Auto-Encrypt so that Conenct CA roots are added to the |
| 974 | + // tlsutil.Configurator. |
| 975 | + c.AutoEncryptAllowTLS = true |
| 976 | + c.CAConfig = &structs.CAConfiguration{ |
| 977 | + ClusterID: connect.TestClusterID, |
| 978 | + Provider: structs.ConsulCAProvider, |
| 979 | + Config: map[string]interface{}{"PrivateKey": connectCApk}, |
| 980 | + } |
| 981 | + }) |
| 982 | + defer srv.Shutdown() |
| 983 | + |
| 984 | + // Wait for ConnectCA initiation to complete. |
| 985 | + retry.Run(t, func(r *retry.R) { |
| 986 | + _, root := srv.caManager.getCAProvider() |
| 987 | + if root == nil { |
| 988 | + r.Fatal("ConnectCA root is still nil") |
| 989 | + } |
| 990 | + }) |
| 991 | + |
| 992 | + useTLSByte := func(t *testing.T, c *tlsutil.Configurator) net.Conn { |
| 993 | + wrapper := tlsutil.SpecificDC("dc1", c.OutgoingRPCWrapper()) |
| 994 | + tlsEnabled := func(_ raft.ServerAddress) bool { |
| 995 | + return true |
| 996 | + } |
| 997 | + |
| 998 | + rl := NewRaftLayer(nil, nil, wrapper, tlsEnabled) |
| 999 | + conn, err := rl.Dial(raft.ServerAddress(srv.Listener.Addr().String()), 100*time.Millisecond) |
| 1000 | + require.NoError(t, err) |
| 1001 | + return conn |
| 1002 | + } |
| 1003 | + |
| 1004 | + useNativeTLS := func(t *testing.T, c *tlsutil.Configurator) net.Conn { |
| 1005 | + wrapper := c.OutgoingALPNRPCWrapper() |
| 1006 | + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} |
| 1007 | + |
| 1008 | + rawConn, err := dialer.Dial("tcp", srv.Listener.Addr().String()) |
| 1009 | + require.NoError(t, err) |
| 1010 | + |
| 1011 | + tlsConn, err := wrapper("dc1", "srv1", pool.ALPN_RPCRaft, rawConn) |
| 1012 | + require.NoError(t, err) |
| 1013 | + return tlsConn |
| 1014 | + } |
| 1015 | + |
| 1016 | + setupAgentTLSCert := func(name string) func(t *testing.T) string { |
| 1017 | + return func(t *testing.T) string { |
| 1018 | + newCert(t, caPEM, pk, "node1", name) |
| 1019 | + return filepath.Join(dir, "node1-"+name) |
| 1020 | + } |
| 1021 | + } |
| 1022 | + |
| 1023 | + setupConnectCACert := func(name string) func(t *testing.T) string { |
| 1024 | + return func(t *testing.T) string { |
| 1025 | + _, caRoot := srv.caManager.getCAProvider() |
| 1026 | + newCert(t, caRoot.RootCert, connectCApk, "node1", name) |
| 1027 | + return filepath.Join(dir, "node1-"+name) |
| 1028 | + } |
| 1029 | + } |
| 1030 | + |
| 1031 | + type testCase struct { |
| 1032 | + name string |
| 1033 | + conn func(t *testing.T, c *tlsutil.Configurator) net.Conn |
| 1034 | + setupCert func(t *testing.T) string |
| 1035 | + expectError bool |
| 1036 | + } |
| 1037 | + |
| 1038 | + run := func(t *testing.T, tc testCase) { |
| 1039 | + certPath := tc.setupCert(t) |
| 1040 | + |
| 1041 | + cfg := tlsutil.Config{ |
| 1042 | + VerifyOutgoing: true, |
| 1043 | + VerifyServerHostname: true, |
| 1044 | + CAFile: filepath.Join(dir, "ca.pem"), |
| 1045 | + CertFile: certPath + ".pem", |
| 1046 | + KeyFile: certPath + ".key", |
| 1047 | + Domain: "consul", |
| 1048 | + } |
| 1049 | + c, err := tlsutil.NewConfigurator(cfg, hclog.New(nil)) |
| 1050 | + require.NoError(t, err) |
| 1051 | + |
| 1052 | + _, err = doRaftRPC(tc.conn(t, c), srv.config.NodeName) |
| 1053 | + if tc.expectError { |
| 1054 | + if !isConnectionClosedError(err) { |
| 1055 | + t.Fatalf("expected a connection closed error, got: %v", err) |
| 1056 | + } |
| 1057 | + return |
| 1058 | + } |
| 1059 | + require.NoError(t, err) |
| 1060 | + } |
| 1061 | + |
| 1062 | + var testCases = []testCase{ |
| 1063 | + { |
| 1064 | + name: "TLS byte with client cert", |
| 1065 | + setupCert: setupAgentTLSCert("client.dc1.consul"), |
| 1066 | + conn: useTLSByte, |
| 1067 | + expectError: true, |
| 1068 | + }, |
| 1069 | + { |
| 1070 | + name: "TLS byte with server cert in different DC", |
| 1071 | + setupCert: setupAgentTLSCert("server.dc2.consul"), |
| 1072 | + conn: useTLSByte, |
| 1073 | + expectError: true, |
| 1074 | + }, |
| 1075 | + { |
| 1076 | + name: "TLS byte with server cert in same DC", |
| 1077 | + setupCert: setupAgentTLSCert("server.dc1.consul"), |
| 1078 | + conn: useTLSByte, |
| 1079 | + }, |
| 1080 | + { |
| 1081 | + name: "TLS byte with ConnectCA leaf cert", |
| 1082 | + setupCert: setupConnectCACert("server.dc1.consul"), |
| 1083 | + conn: useTLSByte, |
| 1084 | + expectError: true, |
| 1085 | + }, |
| 1086 | + { |
| 1087 | + name: "native TLS with client cert", |
| 1088 | + setupCert: setupAgentTLSCert("client.dc1.consul"), |
| 1089 | + conn: useNativeTLS, |
| 1090 | + expectError: true, |
| 1091 | + }, |
| 1092 | + { |
| 1093 | + name: "native TLS with server cert in different DC", |
| 1094 | + setupCert: setupAgentTLSCert("server.dc2.consul"), |
| 1095 | + conn: useNativeTLS, |
| 1096 | + expectError: true, |
| 1097 | + }, |
| 1098 | + { |
| 1099 | + name: "native TLS with server cert in same DC", |
| 1100 | + setupCert: setupAgentTLSCert("server.dc1.consul"), |
| 1101 | + conn: useNativeTLS, |
| 1102 | + }, |
| 1103 | + { |
| 1104 | + name: "native TLS with ConnectCA leaf cert", |
| 1105 | + setupCert: setupConnectCACert("server.dc1.consul"), |
| 1106 | + conn: useNativeTLS, |
| 1107 | + expectError: true, |
| 1108 | + }, |
| 1109 | + } |
| 1110 | + |
| 1111 | + for _, tc := range testCases { |
| 1112 | + t.Run(tc.name, func(t *testing.T) { |
| 1113 | + run(t, tc) |
| 1114 | + }) |
| 1115 | + } |
| 1116 | +} |
| 1117 | + |
| 1118 | +func doRaftRPC(conn net.Conn, leader string) (raft.AppendEntriesResponse, error) { |
| 1119 | + var resp raft.AppendEntriesResponse |
| 1120 | + |
| 1121 | + var term uint64 = 0xc |
| 1122 | + a := raft.AppendEntriesRequest{ |
| 1123 | + RPCHeader: raft.RPCHeader{ProtocolVersion: 3}, |
| 1124 | + Term: 0, |
| 1125 | + Leader: []byte(leader), |
| 1126 | + PrevLogEntry: 0, |
| 1127 | + PrevLogTerm: term, |
| 1128 | + LeaderCommitIndex: 50, |
| 1129 | + } |
| 1130 | + |
| 1131 | + if err := appendEntries(conn, a, &resp); err != nil { |
| 1132 | + return resp, err |
| 1133 | + } |
| 1134 | + return resp, nil |
| 1135 | +} |
| 1136 | + |
| 1137 | +func appendEntries(conn net.Conn, req raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error { |
| 1138 | + w := bufio.NewWriter(conn) |
| 1139 | + enc := codec.NewEncoder(w, &codec.MsgpackHandle{}) |
| 1140 | + |
| 1141 | + const rpcAppendEntries = 0 |
| 1142 | + if err := w.WriteByte(rpcAppendEntries); err != nil { |
| 1143 | + return fmt.Errorf("failed to write raft-RPC byte: %w", err) |
| 1144 | + } |
| 1145 | + |
| 1146 | + if err := enc.Encode(req); err != nil { |
| 1147 | + return fmt.Errorf("failed to send append entries RPC: %w", err) |
| 1148 | + } |
| 1149 | + if err := w.Flush(); err != nil { |
| 1150 | + return fmt.Errorf("failed to flush RPC: %w", err) |
| 1151 | + } |
| 1152 | + |
| 1153 | + if err := decodeRaftRPCResponse(conn, resp); err != nil { |
| 1154 | + return fmt.Errorf("response error: %w", err) |
| 1155 | + } |
| 1156 | + return nil |
| 1157 | +} |
| 1158 | + |
| 1159 | +// copied and modified from raft/net_transport.go |
| 1160 | +func decodeRaftRPCResponse(conn net.Conn, resp *raft.AppendEntriesResponse) error { |
| 1161 | + r := bufio.NewReader(conn) |
| 1162 | + dec := codec.NewDecoder(r, &codec.MsgpackHandle{}) |
| 1163 | + |
| 1164 | + var rpcError string |
| 1165 | + if err := dec.Decode(&rpcError); err != nil { |
| 1166 | + return fmt.Errorf("failed to decode response error: %w", err) |
| 1167 | + } |
| 1168 | + if err := dec.Decode(resp); err != nil { |
| 1169 | + return fmt.Errorf("failed to decode response: %w", err) |
| 1170 | + } |
| 1171 | + if rpcError != "" { |
| 1172 | + return fmt.Errorf("rpc error: %v", rpcError) |
| 1173 | + } |
| 1174 | + return nil |
| 1175 | +} |
| 1176 | + |
| 1177 | +func isConnectionClosedError(err error) bool { |
| 1178 | + switch { |
| 1179 | + case err == nil: |
| 1180 | + return false |
| 1181 | + case errors.Is(err, io.EOF): |
| 1182 | + return true |
| 1183 | + case strings.Contains(err.Error(), "connection reset by peer"): |
| 1184 | + return true |
| 1185 | + default: |
| 1186 | + return false |
| 1187 | + } |
| 1188 | +} |
0 commit comments