Skip to content

Cassandra - Add username customization #10906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions helper/testhelpers/cassandra/cassandrahelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"context"
"errors"
"fmt"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/testhelpers/docker"
"os"
"testing"
"time"

"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/testhelpers/docker"
)

func PrepareTestContainer(t *testing.T, version string) (func(), string) {
Expand Down
47 changes: 38 additions & 9 deletions plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"fmt"
"strings"

"github.com/hashicorp/vault/sdk/helper/template"

"github.com/gocql/gocql"
multierror "github.com/hashicorp/go-multierror"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
)
Expand All @@ -18,13 +19,17 @@ const (
defaultUserDeletionCQL = `DROP USER '{{username}}';`
defaultChangePasswordCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';`
cassandraTypeName = "cassandra"

defaultUserNameTemplate = `{{ printf "v_%s_%s_%s_%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | truncate 100 | replace "-" "_" | lowercase }}`
)

var _ dbplugin.Database = &Cassandra{}

// Cassandra is an implementation of Database interface
type Cassandra struct {
*cassandraConnectionProducer

usernameProducer template.StringTemplate
}

// New returns a new Cassandra instance
Expand Down Expand Up @@ -58,6 +63,37 @@ func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
return session.(*gocql.Session), nil
}

func (c *Cassandra) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
usernameTemplate, err := strutil.GetString(req.Config, "username_template")
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
}
if usernameTemplate == "" {
usernameTemplate = defaultUserNameTemplate
}

up, err := template.NewTemplate(template.Template(usernameTemplate))
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
}
c.usernameProducer = up

_, err = c.usernameProducer.Generate(dbplugin.UsernameMetadata{})
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
}

err = c.cassandraConnectionProducer.Initialize(ctx, req)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to initialize: %w", err)
}

resp := dbplugin.InitializeResponse{
Config: req.Config,
}
return resp, nil
}

// NewUser generates the username/password on the underlying Cassandra secret backend as instructed by
// the statements provided.
func (c *Cassandra) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
Expand All @@ -79,17 +115,10 @@ func (c *Cassandra) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (d
rollbackCQL = []string{defaultUserDeletionCQL}
}

username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 15),
credsutil.RoleName(req.UsernameConfig.RoleName, 15),
credsutil.Separator("_"),
credsutil.MaxLength(100),
credsutil.ToLower(),
)
username, err := c.usernameProducer.Generate(req.UsernameConfig)
if err != nil {
return dbplugin.NewUserResponse{}, err
}
username = strings.ReplaceAll(username, "-", "_")

for _, stmt := range creationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
Expand Down
112 changes: 91 additions & 21 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package cassandra

import (
"context"
"reflect"
"regexp"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"

backoff "github.com/cenkalti/backoff/v3"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
Expand Down Expand Up @@ -65,31 +67,99 @@ func TestCassandra_Initialize(t *testing.T) {
}

func TestCassandra_CreateUser(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()
type testCase struct {
// Config will have the hosts & port added to it during the test
config map[string]interface{}
newUserReq dbplugin.NewUserRequest
expectErr bool
expectedUsernameRegex string
assertCreds func(t testing.TB, address string, port int, username, password string, timeout time.Duration)
}

password := "myreallysecurepassword"
createReq := dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
tests := map[string]testCase{
"default username_template": {
config: map[string]interface{}{
"username": "cassandra",
"password": "cassandra",
"protocol_version": "4",
"connect_timeout": "20s",
},
newUserReq: dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "token",
RoleName: "mylongrolenamewithmanycharacters",
},
Statements: dbplugin.Statements{
Commands: []string{createUserStatements},
},
Password: "bfn985wjAHIh6t",
Expiration: time.Now().Add(1 * time.Minute),
},
expectErr: false,
expectedUsernameRegex: `^v_token_mylongrolenamew_[a-z0-9]{20}_[0-9]{10}$`,
assertCreds: assertCreds,
},
Statements: dbplugin.Statements{
Commands: []string{createUserStatements},
"custom username_template": {
config: map[string]interface{}{
"username": "cassandra",
"password": "cassandra",
"protocol_version": "4",
"connect_timeout": "20s",
"username_template": `foo_{{random 20}}_{{.RoleName | replace "e" "3"}}_{{unix_time}}`,
},
newUserReq: dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "token",
RoleName: "mylongrolenamewithmanycharacters",
},
Statements: dbplugin.Statements{
Commands: []string{createUserStatements},
},
Password: "bfn985wjAHIh6t",
Expiration: time.Now().Add(1 * time.Minute),
},
expectErr: false,
expectedUsernameRegex: `^foo_[a-zA-Z0-9]{20}_mylongrol3nam3withmanycharact3rs_[0-9]{10}$`,
assertCreds: assertCreds,
},
Password: password,
Expiration: time.Now().Add(1 * time.Minute),
}

createResp := dbtesting.AssertNewUser(t, db, createReq)

expectedRegex := "^v_test_test_[a-zA-Z0-9]{20}_[0-9]{10}$"
re := regexp.MustCompile(expectedRegex)
if !re.MatchString(createResp.Username) {
t.Fatalf("Generated username %q did not match regexp %q", createResp.Username, expectedRegex)
for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
pieces := strings.Split(connURL, ":")
defer cleanup()

db := new()

config := test.config
config["hosts"] = connURL
config["port"] = pieces[1]

initReq := dbplugin.InitializeRequest{
Config: config,
VerifyConnection: true,
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
dbtesting.AssertInitialize(t, db, initReq)

require.True(t, db.Initialized, "Database is not initialized")

ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
newUserResp, err := db.NewUser(ctx, test.newUserReq)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
require.Regexp(t, test.expectedUsernameRegex, newUserResp.Username)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, 5*time.Second)
})
}

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
}

func TestMyCassandra_UpdateUserPassword(t *testing.T) {
Expand Down Expand Up @@ -217,4 +287,4 @@ func assertNoCreds(t testing.TB, address string, port int, username, password st
}

const createUserStatements = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};`
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO '{{username}}';`
46 changes: 21 additions & 25 deletions plugins/database/cassandra/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"sync"
"time"

"github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"

"github.com/gocql/gocql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/mitchellh/mapstructure"
)

Expand Down Expand Up @@ -51,40 +51,40 @@ type cassandraConnectionProducer struct {
sync.Mutex
}

func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplugin.InitializeRequest) error {
c.Lock()
defer c.Unlock()

c.rawConfig = req.Config

err := mapstructure.WeakDecode(req.Config, c)
if err != nil {
return dbplugin.InitializeResponse{}, err
return err
}

if c.ConnectTimeoutRaw == nil {
c.ConnectTimeoutRaw = "0s"
}
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
return fmt.Errorf("invalid connect_timeout: %w", err)
}

if c.SocketKeepAliveRaw == nil {
c.SocketKeepAliveRaw = "0s"
}
c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw)
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err)
return fmt.Errorf("invalid socket_keep_alive: %w", err)
}

switch {
case len(c.Hosts) == 0:
return dbplugin.InitializeResponse{}, fmt.Errorf("hosts cannot be empty")
return fmt.Errorf("hosts cannot be empty")
case len(c.Username) == 0:
return dbplugin.InitializeResponse{}, fmt.Errorf("username cannot be empty")
return fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return dbplugin.InitializeResponse{}, fmt.Errorf("password cannot be empty")
return fmt.Errorf("password cannot be empty")
}

var certBundle *certutil.CertBundle
Expand All @@ -93,11 +93,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
return fmt.Errorf("error marshaling PEM information: %w", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
Expand All @@ -107,11 +107,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
return fmt.Errorf("error parsing the given PEM information: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
return fmt.Errorf("error marshaling PEM information: %w", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
Expand All @@ -125,15 +125,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug

if req.VerifyConnection {
if _, err := c.Connection(ctx); err != nil {
return dbplugin.InitializeResponse{}, errwrap.Wrapf("error verifying connection: {{err}}", err)
return fmt.Errorf("error verifying connection: %w", err)
}
}

resp := dbplugin.InitializeResponse{
Config: req.Config,
}

return resp, nil
return nil
}

func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
Expand Down Expand Up @@ -207,12 +203,12 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}

tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = c.InsecureTLS

Expand Down Expand Up @@ -240,7 +236,7 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

session, err := clusterConfig.CreateSession()
if err != nil {
return nil, errwrap.Wrapf("error creating session: {{err}}", err)
return nil, fmt.Errorf("error creating session: %w", err)
}

if c.Consistency != "" {
Expand All @@ -262,11 +258,11 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

if rowNum < 1 {
session.Close()
return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err)
return nil, fmt.Errorf("error validating connection info: No role create permissions found, previous error: %w", err)
}
} else if err != nil {
session.Close()
return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
return nil, fmt.Errorf("error validating connection info: %w", err)
}
}

Expand Down