Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
; TestConcurrent fails if max_connections is too large
max_connections=50
local_infile=1
performance_schema=on
- name: setup database
run: |
mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;'
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ Default: 0

I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.

##### `connectionAttributes`

```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```

[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.

##### System Variables

Expand Down
12 changes: 12 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,24 @@

package mysql

import "runtime"

const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"

// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "GO-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
)

// MySQL constants documentation:
Expand Down
47 changes: 47 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3209,3 +3209,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
t.Errorf("connection not closed")
}
}

func TestConnectionAttributes(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}

dbt := &DBTest{t, db}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
38 changes: 22 additions & 16 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -554,6 +555,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

// Connection attributes
case "connectionAttributes":
cfg.ConnectionAttributes = value

default:
// lazy init
if cfg.Params == nil {
Expand Down
35 changes: 35 additions & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"fmt"
"io"
"math"
"os"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -285,6 +288,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -318,6 +322,32 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

connAttrsBuf := make([]byte, 0, 100)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
attr := strings.Split(connAttr, ":")
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

// 1 byte to store length of all key-values
pktLen += len(connAttrsBuf) + 1

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
Expand Down Expand Up @@ -394,6 +424,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Connection Attributes
data[pos] = byte(len(connAttrsBuf))
pos++
pos += copy(data[pos:], connAttrsBuf)

// Send Auth packet
return mc.writePacket(data[:pos])
}
Expand Down
5 changes: 5 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}

func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
Expand Down