Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/certificator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ func main() {
}
logger.Infof("checking certificate for %s", mainDomain)

needsRenewing, err := certificate.NeedsRenewing(cert, mainDomain, cfg.RenewBeforeDays, logger)
needsReissuing, err := certificate.NeedsReissuing(cert, allDomains, cfg.RenewBeforeDays, logger)
if err != nil {
failedDomains = append(failedDomains, mainDomain)
logger.Error(err)
continue
}

if needsRenewing {
if needsReissuing {
logger.Infof("obtaining certificate for %s", mainDomain)
err := certificate.ObtainCertificate(acmeClient, vaultClient, allDomains,
cfg.DNSAddress, cfg.Acme.DNSChallengeProvider, cfg.Acme.DNSPropagationRequirement)
Expand Down
40 changes: 36 additions & 4 deletions pkg/certificate/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,59 @@ func GetCertificate(domain string, vault *vault.VaultClient) (*x509.Certificate,
return nil, nil
}

// NeedsRenewing checks if certificate expiration date is earlier than configured in config.Cfg.RenewBeforeDays
func NeedsRenewing(certificate *x509.Certificate, domain string, days int, logger *logrus.Logger) (bool, error) {
// NeedsReissuing checks if certificate domains and required domains match
// and if certificate expiration date is earlier than configured in config.Cfg.RenewBeforeDays
func NeedsReissuing(certificate *x509.Certificate, domains []string, days int, logger *logrus.Logger) (bool, error) {
if certificate == nil {
return true, nil
}

if certificate.IsCA {
return true, fmt.Errorf("certificate bundle for %s starts with a CA certificate", domain)
return true, fmt.Errorf("certificate bundle for %s starts with a CA certificate", domains[0])
}

// Check if all domains are in certificate DNS names
if !arraysEqual(domains, certificate.DNSNames) {
logger.Printf("certificate %s domains changed, it needs reissuing", domains[0])
logger.Printf("certificate domains: %v", certificate.DNSNames)
logger.Printf("required domains: %v", domains)
return true, nil
}

notAfter := int(time.Until(certificate.NotAfter).Hours() / 24.0)
logger.Printf("certificate is valid for %v more days", notAfter)
if notAfter > days {
logger.Printf("certificate for %s does not need renewing", domain)
logger.Printf("certificate for %s does not need renewing", domains[0])

return false, nil
}

return true, nil
}

func arraysEqual(array1 []string, array2 []string) bool {
if len(array1) != len(array2) {
return false
}

for _, v := range array1 {
if !arrayContains(array2, v) {
return false
}
}

return true
}

func arrayContains(array []string, element string) bool {
for _, a := range array {
if a == element {
return true
}
}
return false
}

func vaultCertLocation(domain string) string {
return "certificates/" + domain
}
Expand Down
62 changes: 44 additions & 18 deletions pkg/certificate/certificate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,70 @@ import (
"github.com/thanos-io/thanos/pkg/testutil"
)

func TestNeedsRenewing(t *testing.T) {
func TestNeedsReissuing(t *testing.T) {
template := &x509.Certificate{
IsCA: false,
SerialNumber: big.NewInt(1234),
NotBefore: time.Now(),
DNSNames: []string{"test.com", "www.test.com", "*.test.com"},
NotAfter: time.Now().AddDate(0 /* years */, 3 /* months */, 0 /* days */),
}
logger := logrus.New()

certificate := generateCert(t, template)

for _, tcase := range []struct {
tcaseName string
certificate *x509.Certificate
renewDays int
expectedResult bool
tcaseName string
requiredDomains []string
certificate *x509.Certificate
renewDays int
expectedResult bool
}{
{
tcaseName: "certificate expires after three months (90 days), renewDays = 30",
certificate: certificate,
renewDays: 30,
expectedResult: false,
tcaseName: "certificate expires after three months (90 days), renewDays = 30, required domains correct",
requiredDomains: []string{"test.com", "www.test.com", "*.test.com"},
certificate: certificate,
renewDays: 30,
expectedResult: false,
},
{
tcaseName: "certificate expires after three months (90 days), renewDays = 100",
certificate: certificate,
renewDays: 100,
expectedResult: true,
tcaseName: "certificate expires after three months (90 days), renewDays = 100, required domains correct",
requiredDomains: []string{"test.com", "www.test.com", "*.test.com"},
certificate: certificate,
renewDays: 100,
expectedResult: true,
},
{
tcaseName: "nil certificate, renew days 30",
certificate: nil,
renewDays: 30,
expectedResult: true,
tcaseName: "nil certificate, renew days 30, required domains correct",
requiredDomains: []string{"test.com", "www.test.com", "*.test.com"},
certificate: nil,
renewDays: 30,
expectedResult: true,
},
{
tcaseName: "certificate expires after three months (90 days), renewDays = 30, fewer required domains than certificate has",
requiredDomains: []string{"www.test.com", "*.test.com"},
certificate: certificate,
renewDays: 30,
expectedResult: true,
},
{
tcaseName: "certificate expires after three months (90 days), renewDays = 30, more required domains than certificate has",
requiredDomains: []string{"test.com", "www.test.com", "*.test.com", "additional.test.com"},
certificate: certificate,
renewDays: 30,
expectedResult: true,
},
{
tcaseName: "certificate expires after three months (90 days), renewDays = 30, different required domains than certificate has",
requiredDomains: []string{"test.com", "www.test.com", "different.test.com"},
certificate: certificate,
renewDays: 30,
expectedResult: true,
},
} {
t.Run(tcase.tcaseName, func(t *testing.T) {
result, err := NeedsRenewing(tcase.certificate, "test.com", tcase.renewDays, logger)
result, err := NeedsReissuing(tcase.certificate, tcase.requiredDomains, tcase.renewDays, logger)
testutil.Ok(t, err)
testutil.Equals(t, tcase.expectedResult, result)
})
Expand Down