refract gdb package, add complete unit test cases, almost there

This commit is contained in:
John
2018-12-15 15:50:39 +08:00
parent d5e46f2b42
commit e67aa63a50
36 changed files with 1530 additions and 745 deletions

View File

@ -4,6 +4,7 @@ go:
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- master
before_install:

View File

@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
@ -72,7 +73,9 @@ Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
@ -88,3 +91,4 @@ Keybase Inc.
Percona LLC
Pivotal Inc.
Stripe Inc.
Multiplay Ltd.

View File

@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface.
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
```go
import "database/sql"
import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
import _ "github.com/go-sql-driver/mysql"
db, err := sql.Open("mysql", "user:password@/dbname")
```
@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci
```
Type: bool / string
Valid Values: true, false, skip-verify, <name>
Valid Values: true, false, skip-verify, preferred, <name>
Default: false
```
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### `writeTimeout`
@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d
### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
```go
import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
import "github.com/go-sql-driver/mysql"
```
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).

View File

@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc, false)
return mc.writeAuthSwitchPacket(enc)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, false, nil
return authResp, nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, false, ErrOldPassword
return nil, ErrOldPassword
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
return authResp, true, nil
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
return authResp, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, false, ErrCleartextPassword
return nil, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return []byte(mc.cfg.Passwd), true, nil
return append([]byte(mc.cfg.Passwd), 0), nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, false, ErrNativePassword
return nil, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, false, nil
return authResp, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return nil, true, nil
return []byte{0}, nil
}
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
return []byte(mc.cfg.Passwd), true, nil
return append([]byte(mc.cfg.Passwd), 0), nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, false, nil
return []byte{1}, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, false, err
return enc, err
default:
errLog.Print("unknown auth plugin:", plugin)
return nil, false, ErrUnknownPlugin
return nil, ErrUnknownPlugin
}
}
@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
plugin = newPlugin
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
return err
}
@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil {
return err
}
@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data := mc.buf.takeSmallBuffer(4 + 1)
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)
// parse public key
data, err := mc.readPacket()
if err != nil {
if data, err = mc.readPacket(); err != nil {
return err
}

View File

@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
_, _, err := mc.auth(authData, plugin)
_, err := mc.auth(authData, plugin)
if err != ErrCleartextPassword {
t.Errorf("expected ErrCleartextPassword, got %v", err)
}
@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
if writtenAuthRespLen != 0 {
t.Fatalf("unexpected written auth response (%d bytes): %v",
writtenAuthRespLen, writtenAuthResp)
expectedAuthResp := []byte{0}
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
_, _, err := mc.auth(authData, plugin)
_, err := mc.auth(authData, plugin)
if err != ErrNativePassword {
t.Errorf("expected ErrNativePassword, got %v", err)
}
@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
if writtenAuthRespLen != 0 {
expectedAuthResp := []byte{0}
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
plugin := "sha256_password"
// send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
// unset TLS config to prevent the actual establishment of a TLS wrapper
mc.cfg.tls = nil
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
authRespEnd := authRespStart + 1 + len(authResp) + 1
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request.
func TestOldAuthSwitchNotAllowed(t *testing.T) {
conn, mc := newRWMockConn(2)
// OldAuthSwitch request
conn.data = []byte{1, 0, 0, 2, 0xfe}
conn.maxReads = 1
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
err := mc.handleAuthResult(authData, plugin)
if err != ErrOldPassword {
t.Errorf("expected ErrOldPassword, got %v", err)
}
}
func TestAuthSwitchOldPassword(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request.
func TestOldAuthSwitch(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
mc.cfg.Passwd = "secret"
// OldAuthSwitch request
conn.data = []byte{1, 0, 0, 2, 0xfe}
// auth response
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
conn.maxReads = 2
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
if err := mc.handleAuthResult(authData, plugin); err != nil {
t.Errorf("got error: %v", err)
}
expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
if !bytes.Equal(conn.written, expectedReply) {
t.Errorf("got unexpected data: %v", conn.written)
}
}
func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request.
func TestOldAuthSwitchPasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
mc.cfg.Passwd = ""
// OldAuthSwitch request.
conn.data = []byte{1, 0, 0, 2, 0xfe}
// auth response
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
conn.maxReads = 2
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
if err := mc.handleAuthResult(authData, plugin); err != nil {
t.Errorf("got error: %v", err)
}
expectedReply := []byte{1, 0, 0, 3, 0}
if !bytes.Equal(conn.written, expectedReply) {
t.Errorf("got unexpected data: %v", conn.written)
}
}
func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.Passwd = ""

View File

@ -22,17 +22,17 @@ const defaultBufSize = 4096
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
buf: make([]byte, defaultBufSize),
nc: nc,
}
}
@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
if length <= cap(b.buf) {
return b.buf[:length], nil
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
return b.buf, nil
}
return make([]byte, length)
// buffer is larger than we want to store.
return make([]byte, length), nil
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf[:length]
return b.buf[:length], nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf
return b.buf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return nil
}

View File

@ -19,16 +19,6 @@ import (
"time"
)
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
@ -45,7 +35,7 @@ type mysqlConn struct {
// for context support (Go 1.8+)
watching bool
watcher chan<- mysqlContext
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
return mc.markBadConn(err)
}
return mc.readResultOK()
@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}
mc.watching = true
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:

View File

@ -9,7 +9,10 @@
package mysql
import (
"context"
"database/sql/driver"
"errors"
"net"
"testing"
)
@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) {
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
}
}
// TestCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < 3; i++ { // Repeat same behavior
err := mc.Ping(ctx)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}
if mc.closed.IsSet() {
t.Error("expected mc is not closed, closed actually")
}
if mc.watching {
t.Error("expected watching is false, but true")
}
}
}
func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
}
err := ms.Ping(context.Background())
if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
}
func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
}
err := ms.Ping(context.Background())
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
}
}
type badConnection struct {
n int
err error
net.Conn
}
func (bc badConnection) Write(b []byte) (n int, err error) {
return bc.n, bc.err
}
func (bc badConnection) Close() error {
return nil
}

View File

@ -9,7 +9,7 @@
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
// the DSN string is formatted
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
errLog.Print("net.Error from Dial()': ", nerr.Error())
return nil, driver.ErrBadConn
}
return nil, err
}
@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}

View File

@ -85,6 +85,23 @@ type DBTest struct {
db *sql.DB
}
type netErrorMock struct {
temporary bool
timeout bool
}
func (e netErrorMock) Temporary() bool {
return e.temporary
}
func (e netErrorMock) Timeout() bool {
return e.timeout
}
func (e netErrorMock) Error() string {
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
}
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) {
}
func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
tlsTestReq := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
if err == ErrNoTLS {
dbt.Skip("server does not support TLS")
@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) {
dbt.Fatal(err.Error())
}
if value == nil {
dbt.Fatal("no Cipher")
if (*value == nil) || (len(*value) == 0) {
dbt.Fatalf("no Cipher")
} else {
dbt.Logf("Cipher: %s", *value)
}
}
}
tlsTestOpt := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.Fatalf("error on Ping: %s", err.Error())
}
}
runTests(t, dsn+"&tls=skip-verify", tlsTest)
runTests(t, dsn+"&tls=preferred", tlsTestOpt)
runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
// Verify that registering / using a custom cfg works
RegisterTLSConfig("custom-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
}
func TestReuseClosedConnection(t *testing.T) {
@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) {
})
}
func testDialError(t *testing.T, dialErr error, expectErr error) {
RegisterDial("mydial", func(addr string) (net.Conn, error) {
return nil, dialErr
})
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
_, err = db.Exec("DO 1")
if err != expectErr {
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
}
}
func TestDialUnknownError(t *testing.T) {
testErr := fmt.Errorf("test")
testDialError(t, testErr, testErr)
}
func TestDialNonRetryableNetErr(t *testing.T) {
testErr := netErrorMock{}
testDialError(t, testErr, testErr)
}
func TestDialTemporaryNetErr(t *testing.T) {
testErr := netErrorMock{temporary: true}
testDialError(t, testErr, driver.ErrBadConn)
}
// Tests custom dial functions
func TestCustomDial(t *testing.T) {
if !available {

View File

@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" {
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {

View File

@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.sequence++
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)1 bytes long
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, "", ErrNoTLS
if mc.cfg.TLSConfig == "preferred" {
mc.cfg.tls = nil
} else {
return nil, "", ErrNoTLS
}
}
pos += 2
@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
authRespLen := len(authResp)
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
}
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
if addNUL {
pktLen++
}
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
}
// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], authResp)
if addNUL {
data[pos] = 0x00
pos++
}
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
// Send Auth packet
return mc.writePacket(data)
return mc.writePacket(data[:pos])
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
if addNUL {
pktLen++
}
data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
// Add the auth data [EOF]
copy(data[4:], authData)
if addNUL {
data[pktLen-1] = 0x00
}
return mc.writePacket(data)
}
@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0
pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
return data[1:], "", err
case iEOF:
if len(data) < 1 {
if len(data) == 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
return nil, "mysql_old_password", nil
}
@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Determine threshould dynamically to avoid packet size shortage.
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0
var data []byte
var err error
if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if data == nil {
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
for i := range nullMask {
nullMask[i] = 0
}
pos += maskLen
@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
}
pos += len(paramValues)