mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
refract gdb package, add complete unit test cases, almost there
This commit is contained in:
@ -4,6 +4,7 @@ go:
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- master
|
||||
|
||||
before_install:
|
||||
|
||||
@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
|
||||
Henri Yandell <flamefew at gmail.com>
|
||||
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
||||
ICHINOSE Shogo <shogo82148 at gmail.com>
|
||||
Ilia Cimpoes <ichimpoesh at gmail.com>
|
||||
INADA Naoki <songofacandy at gmail.com>
|
||||
Jacek Szwec <szwec.jacek at gmail.com>
|
||||
James Harr <james.harr at gmail.com>
|
||||
@ -72,7 +73,9 @@ Shuode Li <elemount at qq.com>
|
||||
Soroush Pour <me at soroushjp.com>
|
||||
Stan Putrya <root.vagner at gmail.com>
|
||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||
Steven Hartland <steven.hartland at multiplay.co.uk>
|
||||
Thomas Wodarek <wodarekwebpage at gmail.com>
|
||||
Tom Jenkinson <tom at tjenkinson.me>
|
||||
Xiangyu Hu <xiangyu.hu at outlook.com>
|
||||
Xiaobing Jiang <s7v7nislands at gmail.com>
|
||||
Xiuming Chen <cc at cxm.cc>
|
||||
@ -88,3 +91,4 @@ Keybase Inc.
|
||||
Percona LLC
|
||||
Pivotal Inc.
|
||||
Stripe Inc.
|
||||
Multiplay Ltd.
|
||||
|
||||
@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface.
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
```go
|
||||
import "database/sql"
|
||||
import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
|
||||
db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
```
|
||||
@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci
|
||||
|
||||
```
|
||||
Type: bool / string
|
||||
Valid Values: true, false, skip-verify, <name>
|
||||
Valid Values: true, false, skip-verify, preferred, <name>
|
||||
Default: false
|
||||
```
|
||||
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
|
||||
|
||||
##### `writeTimeout`
|
||||
@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d
|
||||
### `LOAD DATA LOCAL INFILE` support
|
||||
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
|
||||
```go
|
||||
import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
|
||||
|
||||
@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.writeAuthSwitchPacket(enc, false)
|
||||
return mc.writeAuthSwitchPacket(enc)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
|
||||
switch plugin {
|
||||
case "caching_sha2_password":
|
||||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
|
||||
return authResp, false, nil
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_old_password":
|
||||
if !mc.cfg.AllowOldPasswords {
|
||||
return nil, false, ErrOldPassword
|
||||
return nil, ErrOldPassword
|
||||
}
|
||||
// Note: there are edge cases where this should work but doesn't;
|
||||
// this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
|
||||
return authResp, true, nil
|
||||
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_clear_password":
|
||||
if !mc.cfg.AllowCleartextPasswords {
|
||||
return nil, false, ErrCleartextPassword
|
||||
return nil, ErrCleartextPassword
|
||||
}
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
|
||||
case "mysql_native_password":
|
||||
if !mc.cfg.AllowNativePasswords {
|
||||
return nil, false, ErrNativePassword
|
||||
return nil, ErrNativePassword
|
||||
}
|
||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
// Native password authentication only need and will need 20-byte challenge.
|
||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||
return authResp, false, nil
|
||||
return authResp, nil
|
||||
|
||||
case "sha256_password":
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return nil, true, nil
|
||||
return []byte{0}, nil
|
||||
}
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
}
|
||||
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
return []byte{1}, false, nil
|
||||
return []byte{1}, nil
|
||||
}
|
||||
|
||||
// encrypted password
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
|
||||
return enc, false, err
|
||||
return enc, err
|
||||
|
||||
default:
|
||||
errLog.Print("unknown auth plugin:", plugin)
|
||||
return nil, false, ErrUnknownPlugin
|
||||
return nil, ErrUnknownPlugin
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
|
||||
plugin = newPlugin
|
||||
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
|
||||
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
case cachingSha2PasswordPerformFullAuthentication:
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
|
||||
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[4] = cachingSha2PasswordRequestPublicKey
|
||||
mc.writePacket(data)
|
||||
|
||||
// parse public key
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
if data, err = mc.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
_, _, err := mc.auth(authData, plugin)
|
||||
_, err := mc.auth(authData, plugin)
|
||||
if err != ErrCleartextPassword {
|
||||
t.Errorf("expected ErrCleartextPassword, got %v", err)
|
||||
}
|
||||
@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
|
||||
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
|
||||
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
if writtenAuthRespLen != 0 {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v",
|
||||
writtenAuthRespLen, writtenAuthResp)
|
||||
expectedAuthResp := []byte{0}
|
||||
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
|
||||
@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
_, _, err := mc.auth(authData, plugin)
|
||||
_, err := mc.auth(authData, plugin)
|
||||
if err != ErrNativePassword {
|
||||
t.Errorf("expected ErrNativePassword, got %v", err)
|
||||
}
|
||||
@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
if writtenAuthRespLen != 0 {
|
||||
expectedAuthResp := []byte{0}
|
||||
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
|
||||
// unset TLS config to prevent the actual establishment of a TLS wrapper
|
||||
mc.cfg.tls = nil
|
||||
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check written auth response
|
||||
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
|
||||
authRespEnd := authRespStart + 1 + len(authResp) + 1
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
|
||||
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitchNotAllowed(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
|
||||
// OldAuthSwitch request
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
conn.maxReads = 1
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
err := mc.handleAuthResult(authData, plugin)
|
||||
if err != ErrOldPassword {
|
||||
t.Errorf("expected ErrOldPassword, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSwitchOldPassword(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitch(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
mc.cfg.Passwd = "secret"
|
||||
|
||||
// OldAuthSwitch request
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
|
||||
// auth response
|
||||
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
|
||||
conn.maxReads = 2
|
||||
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
if err := mc.handleAuthResult(authData, plugin); err != nil {
|
||||
t.Errorf("got error: %v", err)
|
||||
}
|
||||
|
||||
expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
|
||||
if !bytes.Equal(conn.written, expectedReply) {
|
||||
t.Errorf("got unexpected data: %v", conn.written)
|
||||
}
|
||||
}
|
||||
func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitchPasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
mc.cfg.Passwd = ""
|
||||
|
||||
// OldAuthSwitch request.
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
|
||||
// auth response
|
||||
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
|
||||
conn.maxReads = 2
|
||||
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
if err := mc.handleAuthResult(authData, plugin); err != nil {
|
||||
t.Errorf("got error: %v", err)
|
||||
}
|
||||
|
||||
expectedReply := []byte{1, 0, 0, 3, 0}
|
||||
if !bytes.Equal(conn.written, expectedReply) {
|
||||
t.Errorf("got unexpected data: %v", conn.written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.Passwd = ""
|
||||
|
||||
@ -22,17 +22,17 @@ const defaultBufSize = 4096
|
||||
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
|
||||
// Also highly optimized for this particular use case.
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
buf []byte // buf is a byte buffer who's length and capacity are equal.
|
||||
nc net.Conn
|
||||
idx int
|
||||
length int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newBuffer allocates and returns a new buffer.
|
||||
func newBuffer(nc net.Conn) buffer {
|
||||
var b [defaultBufSize]byte
|
||||
return buffer{
|
||||
buf: b[:],
|
||||
buf: make([]byte, defaultBufSize),
|
||||
nc: nc,
|
||||
}
|
||||
}
|
||||
@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
|
||||
return b.buf[offset:b.idx], nil
|
||||
}
|
||||
|
||||
// returns a buffer with the requested size.
|
||||
// takeBuffer returns a buffer with the requested size.
|
||||
// If possible, a slice from the existing buffer is returned.
|
||||
// Otherwise a bigger buffer is made.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeBuffer(length int) []byte {
|
||||
func (b *buffer) takeBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
|
||||
// test (cheap) general case first
|
||||
if length <= defaultBufSize || length <= cap(b.buf) {
|
||||
return b.buf[:length]
|
||||
if length <= cap(b.buf) {
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
if length < maxPacketSize {
|
||||
b.buf = make([]byte, length)
|
||||
return b.buf
|
||||
return b.buf, nil
|
||||
}
|
||||
return make([]byte, length)
|
||||
|
||||
// buffer is larger than we want to store.
|
||||
return make([]byte, length), nil
|
||||
}
|
||||
|
||||
// shortcut which can be used if the requested buffer is guaranteed to be
|
||||
// smaller than defaultBufSize
|
||||
// takeSmallBuffer is shortcut which can be used if length is
|
||||
// known to be smaller than defaultBufSize.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeSmallBuffer(length int) []byte {
|
||||
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf[:length]
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
// takeCompleteBuffer returns the complete existing buffer.
|
||||
// This can be used if the necessary buffer size is unknown.
|
||||
// cap and len of the returned buffer will be equal.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeCompleteBuffer() []byte {
|
||||
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
// store stores buf, an updated buffer, if its suitable to do so.
|
||||
func (b *buffer) store(buf []byte) error {
|
||||
if b.length > 0 {
|
||||
return ErrBusyBuffer
|
||||
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
|
||||
b.buf = buf[:cap(buf)]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -19,16 +19,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// a copy of context.Context for Go 1.7 and earlier
|
||||
type mysqlContext interface {
|
||||
Done() <-chan struct{}
|
||||
Err() error
|
||||
|
||||
// defined in context.Context, but not used in this driver:
|
||||
// Deadline() (deadline time.Time, ok bool)
|
||||
// Value(key interface{}) interface{}
|
||||
}
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
@ -45,7 +35,7 @@ type mysqlConn struct {
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- mysqlContext
|
||||
watcher chan<- context.Context
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
buf := mc.buf.takeCompleteBuffer()
|
||||
if buf == nil {
|
||||
buf, err := mc.buf.takeCompleteBuffer()
|
||||
if err != nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return "", ErrInvalidConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
// When ctx is already cancelled, don't watch it.
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
// When ctx is not cancellable, don't watch it.
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
// When watcher is not alive, can't watch it.
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
mc.watcher <- ctx
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan mysqlContext, 1)
|
||||
watcher := make(chan context.Context, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx mysqlContext
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
|
||||
@ -9,7 +9,10 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanCancel tests passed context is cancelled at start.
|
||||
// No packet should be sent. Connection should keep current status.
|
||||
func TestCleanCancel(t *testing.T) {
|
||||
mc := &mysqlConn{
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
mc.startWatcher()
|
||||
defer mc.cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
for i := 0; i < 3; i++ { // Repeat same behavior
|
||||
err := mc.Ping(ctx)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled, got %#v", err)
|
||||
}
|
||||
|
||||
if mc.closed.IsSet() {
|
||||
t.Error("expected mc is not closed, closed actually")
|
||||
}
|
||||
|
||||
if mc.watching {
|
||||
t.Error("expected watching is false, but true")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingMarkBadConnection(t *testing.T) {
|
||||
nc := badConnection{err: errors.New("boom")}
|
||||
ms := &mysqlConn{
|
||||
netConn: nc,
|
||||
buf: newBuffer(nc),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
}
|
||||
|
||||
err := ms.Ping(context.Background())
|
||||
|
||||
if err != driver.ErrBadConn {
|
||||
t.Errorf("expected driver.ErrBadConn, got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingErrInvalidConn(t *testing.T) {
|
||||
nc := badConnection{err: errors.New("failed to write"), n: 10}
|
||||
ms := &mysqlConn{
|
||||
netConn: nc,
|
||||
buf: newBuffer(nc),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
err := ms.Ping(context.Background())
|
||||
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type badConnection struct {
|
||||
n int
|
||||
err error
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (bc badConnection) Write(b []byte) (n int, err error) {
|
||||
return bc.n, bc.err
|
||||
}
|
||||
|
||||
func (bc badConnection) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
// The driver should be used via the database/sql package:
|
||||
//
|
||||
// import "database/sql"
|
||||
// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
//
|
||||
// db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
//
|
||||
@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
|
||||
|
||||
// Open new Connection.
|
||||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
|
||||
// the DSN string is formated
|
||||
// the DSN string is formatted
|
||||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
var err error
|
||||
|
||||
@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
|
||||
}
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
errLog.Print("net.Error from Dial()': ", nerr.Error())
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||
authResp, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -85,6 +85,23 @@ type DBTest struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type netErrorMock struct {
|
||||
temporary bool
|
||||
timeout bool
|
||||
}
|
||||
|
||||
func (e netErrorMock) Temporary() bool {
|
||||
return e.temporary
|
||||
}
|
||||
|
||||
func (e netErrorMock) Timeout() bool {
|
||||
return e.timeout
|
||||
}
|
||||
|
||||
func (e netErrorMock) Error() string {
|
||||
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
|
||||
}
|
||||
|
||||
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
|
||||
if !available {
|
||||
t.Skipf("MySQL server not running on %s", netAddr)
|
||||
@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
tlsTest := func(dbt *DBTest) {
|
||||
tlsTestReq := func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
if err == ErrNoTLS {
|
||||
dbt.Skip("server does not support TLS")
|
||||
@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) {
|
||||
dbt.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
dbt.Fatal("no Cipher")
|
||||
if (*value == nil) || (len(*value) == 0) {
|
||||
dbt.Fatalf("no Cipher")
|
||||
} else {
|
||||
dbt.Logf("Cipher: %s", *value)
|
||||
}
|
||||
}
|
||||
}
|
||||
tlsTestOpt := func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
dbt.Fatalf("error on Ping: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
runTests(t, dsn+"&tls=skip-verify", tlsTest)
|
||||
runTests(t, dsn+"&tls=preferred", tlsTestOpt)
|
||||
runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
|
||||
|
||||
// Verify that registering / using a custom cfg works
|
||||
RegisterTLSConfig("custom-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
|
||||
runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
|
||||
}
|
||||
|
||||
func TestReuseClosedConnection(t *testing.T) {
|
||||
@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func testDialError(t *testing.T, dialErr error, expectErr error) {
|
||||
RegisterDial("mydial", func(addr string) (net.Conn, error) {
|
||||
return nil, dialErr
|
||||
})
|
||||
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting: %s", err.Error())
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("DO 1")
|
||||
if err != expectErr {
|
||||
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialUnknownError(t *testing.T) {
|
||||
testErr := fmt.Errorf("test")
|
||||
testDialError(t, testErr, testErr)
|
||||
}
|
||||
|
||||
func TestDialNonRetryableNetErr(t *testing.T) {
|
||||
testErr := netErrorMock{}
|
||||
testDialError(t, testErr, testErr)
|
||||
}
|
||||
|
||||
func TestDialTemporaryNetErr(t *testing.T) {
|
||||
testErr := netErrorMock{temporary: true}
|
||||
testDialError(t, testErr, driver.ErrBadConn)
|
||||
}
|
||||
|
||||
// Tests custom dial functions
|
||||
func TestCustomDial(t *testing.T) {
|
||||
if !available {
|
||||
|
||||
@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||
} else {
|
||||
cfg.TLSConfig = "false"
|
||||
}
|
||||
} else if vl := strings.ToLower(value); vl == "skip-verify" {
|
||||
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
|
||||
cfg.TLSConfig = vl
|
||||
cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
||||
} else {
|
||||
|
||||
@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
||||
mc.sequence++
|
||||
|
||||
// packets with length 0 terminate a previous packet which is a
|
||||
// multiple of (2^24)−1 bytes long
|
||||
// multiple of (2^24)-1 bytes long
|
||||
if pktLen == 0 {
|
||||
// there was no previous packet
|
||||
if prevData == nil {
|
||||
@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
return nil, "", ErrOldProtocol
|
||||
}
|
||||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
||||
return nil, "", ErrNoTLS
|
||||
if mc.cfg.TLSConfig == "preferred" {
|
||||
mc.cfg.tls = nil
|
||||
} else {
|
||||
return nil, "", ErrNoTLS
|
||||
}
|
||||
}
|
||||
pos += 2
|
||||
|
||||
@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
|
||||
// Client Authentication Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
|
||||
// Adjust client flags based on server support
|
||||
clientFlags := clientProtocol41 |
|
||||
clientSecureConn |
|
||||
@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
|
||||
// encode length of the auth plugin data
|
||||
var authRespLEIBuf [9]byte
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
|
||||
authRespLen := len(authResp)
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
|
||||
if len(authRespLEI) > 1 {
|
||||
// if the length can not be written in 1 byte, it must be written as a
|
||||
// length encoded integer
|
||||
@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
}
|
||||
|
||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
|
||||
// To specify a db name
|
||||
if n := len(mc.cfg.DBName); n > 0 {
|
||||
@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
}
|
||||
|
||||
// Calculate packet length and get buffer with that size
|
||||
data := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
// Auth Data [length encoded integer]
|
||||
pos += copy(data[pos:], authRespLEI)
|
||||
pos += copy(data[pos:], authResp)
|
||||
if addNUL {
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
}
|
||||
|
||||
// Databasename [null terminated string]
|
||||
if len(mc.cfg.DBName) > 0 {
|
||||
@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
|
||||
pos += copy(data[pos:], plugin)
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
|
||||
// Send Auth packet
|
||||
return mc.writePacket(data)
|
||||
return mc.writePacket(data[:pos])
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
|
||||
pktLen := 4 + len(authData)
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
data := mc.buf.takeSmallBuffer(pktLen)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(pktLen)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add the auth data [EOF]
|
||||
copy(data[4:], authData)
|
||||
if addNUL {
|
||||
data[pktLen-1] = 0x00
|
||||
}
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
||||
// Reset Packet Sequence
|
||||
mc.sequence = 0
|
||||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
||||
mc.sequence = 0
|
||||
|
||||
pktLen := 1 + len(arg)
|
||||
data := mc.buf.takeBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeBuffer(pktLen + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
||||
// Reset Packet Sequence
|
||||
mc.sequence = 0
|
||||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
||||
return data[1:], "", err
|
||||
|
||||
case iEOF:
|
||||
if len(data) < 1 {
|
||||
if len(data) == 1 {
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||
return nil, "mysql_old_password", nil
|
||||
}
|
||||
@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
const minPktLen = 4 + 1 + 4 + 1 + 4
|
||||
mc := stmt.mc
|
||||
|
||||
// Determine threshould dynamically to avoid packet size shortage.
|
||||
// Determine threshold dynamically to avoid packet size shortage.
|
||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||
if longDataSize < 64 {
|
||||
longDataSize = 64
|
||||
@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
mc.sequence = 0
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
if len(args) == 0 {
|
||||
data = mc.buf.takeBuffer(minPktLen)
|
||||
data, err = mc.buf.takeBuffer(minPktLen)
|
||||
} else {
|
||||
data = mc.buf.takeCompleteBuffer()
|
||||
data, err = mc.buf.takeCompleteBuffer()
|
||||
// In this case the len(data) == cap(data) which is used to optimise the flow below.
|
||||
}
|
||||
if data == nil {
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
pos := minPktLen
|
||||
|
||||
var nullMask []byte
|
||||
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
|
||||
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
|
||||
// buffer has to be extended but we don't know by how much so
|
||||
// we depend on append after all data with known sizes fit.
|
||||
// We stop at that because we deal with a lot of columns here
|
||||
@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
copy(tmp[:pos], data[:pos])
|
||||
data = tmp
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
// No need to clean nullMask as make ensures that.
|
||||
pos += maskLen
|
||||
} else {
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
for i := 0; i < maskLen; i++ {
|
||||
for i := range nullMask {
|
||||
nullMask[i] = 0
|
||||
}
|
||||
pos += maskLen
|
||||
@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
// In that case we must build the data packet with the new values buffer
|
||||
if valuesCap != cap(paramValues) {
|
||||
data = append(data[:pos], paramValues...)
|
||||
mc.buf.buf = data
|
||||
if err = mc.buf.store(data); err != nil {
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
}
|
||||
|
||||
pos += len(paramValues)
|
||||
|
||||
Reference in New Issue
Block a user