Skip to content

Commit 18a7c56

Browse files
committed
Add WithTLSServerCurvePreferences, WithTLSServerCipherSuites and WithTLSServerVerifyPeerCertificate TLSServerConfigOption
1 parent 9abe851 commit 18a7c56

File tree

4 files changed

+232
-5
lines changed

4 files changed

+232
-5
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module github.com/grepplabs/cert-source
22

33
go 1.21
44

5-
require github.com/stretchr/testify v1.8.4
5+
require github.com/stretchr/testify v1.10.0
66

77
require (
88
github.com/davecgh/go-spew v1.1.1 // indirect

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
22
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
33
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
44
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5-
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
6-
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
5+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
6+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
77
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
88
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
99
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

tls/server/config/config_test.go

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package config
22

33
import (
44
"crypto/tls"
5+
"crypto/x509"
6+
"errors"
57
"log/slog"
68
"testing"
79

@@ -24,11 +26,188 @@ func TestGetServerTLSConfig(t *testing.T) {
2426
ClientCAs: bundle.CACert.Name(),
2527
ClientCRL: bundle.ClientCRL.Name(),
2628
},
27-
}, tlsserver.WithTLSServerNextProtos([]string{"h2"}))
29+
})
2830
require.NoError(t, err)
2931
require.NotNil(t, tlsConfig.ClientCAs)
3032
require.Equal(t, tlsConfig.ClientAuth, tls.RequireAndVerifyClientCert)
3133
require.NotEmpty(t, tlsConfig.Certificates)
34+
// clientCRL verification
3235
require.NotNil(t, tlsConfig.VerifyPeerCertificate)
36+
require.Nil(t, tlsConfig.NextProtos)
37+
require.Nil(t, tlsConfig.CipherSuites)
38+
require.Nil(t, tlsConfig.CurvePreferences)
39+
}
40+
41+
func TestGetServerTLSOptionsConfig(t *testing.T) {
42+
bundle := testutil.NewCertsBundle()
43+
defer bundle.Close()
44+
45+
tlsConfig, err := GetServerTLSConfig(slog.Default(), &config.TLSServerConfig{
46+
Enable: true,
47+
Refresh: 0,
48+
File: config.TLSServerFiles{
49+
Key: bundle.ServerKey.Name(),
50+
Cert: bundle.ServerCert.Name(),
51+
},
52+
}, tlsserver.WithTLSServerNextProtos([]string{"h2"}),
53+
tlsserver.WithTLSServerCipherSuites([]uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}),
54+
tlsserver.WithTLSServerCurvePreferences([]tls.CurveID{tls.CurveP256, tls.CurveP384}),
55+
)
56+
require.NoError(t, err)
57+
require.Nil(t, tlsConfig.ClientCAs)
58+
require.Equal(t, tlsConfig.ClientAuth, tls.NoClientCert)
59+
require.NotEmpty(t, tlsConfig.Certificates)
60+
require.Nil(t, tlsConfig.VerifyPeerCertificate)
3361
require.Equal(t, tlsConfig.NextProtos, []string{"h2"})
62+
require.Equal(t, tlsConfig.CipherSuites, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256})
63+
require.Equal(t, tlsConfig.CurvePreferences, []tls.CurveID{tls.CurveP256, tls.CurveP384})
64+
}
65+
66+
func TestGetServerTLSVerifyPeerCertificateConfig(t *testing.T) {
67+
bundle := testutil.NewCertsBundle()
68+
defer bundle.Close()
69+
70+
tests := []struct {
71+
name string
72+
clientCAs string
73+
verifyFuncs []tlsserver.VerifyPeerCertificateFunc
74+
verifyError error
75+
}{
76+
{
77+
name: "no peer verification",
78+
},
79+
{
80+
name: "default client CA/CLR verification",
81+
clientCAs: bundle.CACert.Name(),
82+
verifyError: nil, // CRLs are not set, verification is successful
83+
},
84+
{
85+
name: "client CA/CLR verify success, second verify success",
86+
clientCAs: bundle.CACert.Name(),
87+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
88+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
89+
return nil
90+
},
91+
},
92+
},
93+
{
94+
name: "client CA/CLR verify success, third verify success",
95+
clientCAs: bundle.CACert.Name(),
96+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
97+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
98+
return nil
99+
},
100+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
101+
return nil
102+
},
103+
},
104+
},
105+
{
106+
name: "client CA/CLR verify success, third verify failure",
107+
clientCAs: bundle.CACert.Name(),
108+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
109+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
110+
return nil
111+
},
112+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
113+
return errors.New("3 function failed")
114+
},
115+
},
116+
verifyError: errors.New("3 function failed"),
117+
},
118+
{
119+
name: "client CA/CLR verify success, second verify failure",
120+
clientCAs: bundle.CACert.Name(),
121+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
122+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
123+
return errors.New("2 function failed")
124+
},
125+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
126+
return errors.New("3 function would also fail")
127+
},
128+
},
129+
verifyError: errors.New("2 function failed"),
130+
},
131+
{
132+
name: "first verify success",
133+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
134+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
135+
return nil
136+
},
137+
},
138+
},
139+
{
140+
name: "second verify success",
141+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
142+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
143+
return nil
144+
},
145+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
146+
return nil
147+
},
148+
},
149+
},
150+
{
151+
name: "second verify failure",
152+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
153+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
154+
return nil
155+
},
156+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
157+
return errors.New("2 function failed")
158+
},
159+
},
160+
verifyError: errors.New("2 function failed"),
161+
},
162+
{
163+
name: "first verify failure",
164+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
165+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
166+
return errors.New("1 function failed")
167+
},
168+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
169+
return errors.New("2 function would also fail")
170+
},
171+
},
172+
verifyError: errors.New("1 function failed"),
173+
},
174+
{
175+
name: "unset verify function",
176+
verifyFuncs: []tlsserver.VerifyPeerCertificateFunc{
177+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
178+
return errors.New("1 function failed")
179+
},
180+
nil, // unset chain of verify functions
181+
func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
182+
return nil
183+
},
184+
},
185+
},
186+
}
187+
for _, tc := range tests {
188+
t.Run(tc.name, func(t *testing.T) {
189+
190+
opts := make([]tlsserver.TLSServerConfigOption, 0, len(tc.verifyFuncs))
191+
for _, f := range tc.verifyFuncs {
192+
opts = append(opts, tlsserver.WithTLSServerVerifyPeerCertificate(f))
193+
}
194+
tlsConfig, err := GetServerTLSConfig(slog.Default(), &config.TLSServerConfig{
195+
Enable: true,
196+
Refresh: 0,
197+
File: config.TLSServerFiles{
198+
Key: bundle.ServerKey.Name(),
199+
Cert: bundle.ServerCert.Name(),
200+
ClientCAs: tc.clientCAs,
201+
},
202+
}, opts...)
203+
require.NoError(t, err)
204+
if tc.clientCAs == "" && len(tc.verifyFuncs) == 0 {
205+
require.Nil(t, tlsConfig.VerifyPeerCertificate)
206+
} else {
207+
require.NotNil(t, tlsConfig.VerifyPeerCertificate)
208+
err = tlsConfig.VerifyPeerCertificate(nil, nil)
209+
require.Equal(t, tc.verifyError, err)
210+
}
211+
})
212+
}
34213
}

tls/server/option.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package tlsserver
22

3-
import "crypto/tls"
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
)
47

58
type TLSServerConfigOption func(*tls.Config)
69

@@ -9,3 +12,48 @@ func WithTLSServerNextProtos(nextProto []string) TLSServerConfigOption {
912
c.NextProtos = nextProto
1013
}
1114
}
15+
16+
func WithTLSServerCurvePreferences(curvePreferences []tls.CurveID) TLSServerConfigOption {
17+
return func(c *tls.Config) {
18+
if len(curvePreferences) != 0 {
19+
c.CurvePreferences = curvePreferences
20+
} else {
21+
c.CurvePreferences = nil
22+
}
23+
}
24+
}
25+
26+
func WithTLSServerCipherSuites(cipherSuites []uint16) TLSServerConfigOption {
27+
return func(c *tls.Config) {
28+
if len(cipherSuites) != 0 {
29+
c.CipherSuites = cipherSuites
30+
} else {
31+
c.CipherSuites = nil
32+
}
33+
}
34+
}
35+
36+
type VerifyPeerCertificateFunc func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
37+
38+
// WithTLSServerVerifyPeerCertificate sets or chains a custom VerifyPeerCertificate function on a *tls.Config.
39+
// If a nil function is provided, it unsets the certificate verification function (including the standard verification).
40+
// If an existing verification function is present, the new function is chained so that it is invoked only if the existing one succeeds.
41+
func WithTLSServerVerifyPeerCertificate(verifyFunc VerifyPeerCertificateFunc) TLSServerConfigOption {
42+
return func(c *tls.Config) {
43+
if verifyFunc == nil {
44+
c.VerifyPeerCertificate = nil
45+
return
46+
}
47+
prevFunc := c.VerifyPeerCertificate
48+
if prevFunc == nil {
49+
c.VerifyPeerCertificate = verifyFunc
50+
} else {
51+
c.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
52+
if err := prevFunc(rawCerts, verifiedChains); err != nil {
53+
return err
54+
}
55+
return verifyFunc(rawCerts, verifiedChains)
56+
}
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)