parent
							
								
									280ebcbf7c
								
							
						
					
					
						commit
						9d4c1ddfa1
					
				|  | @ -294,7 +294,7 @@ | ||||||
| [[projects]] | [[projects]] | ||||||
|   name = "github.com/go-sql-driver/mysql" |   name = "github.com/go-sql-driver/mysql" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|   revision = "ce924a41eea897745442daaa1739089b0f3f561d" |   revision = "d523deb1b23d913de5bdada721a6071e71283618" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   name = "github.com/go-xorm/builder" |   name = "github.com/go-xorm/builder" | ||||||
|  | @ -873,6 +873,6 @@ | ||||||
| [solve-meta] | [solve-meta] | ||||||
|   analyzer-name = "dep" |   analyzer-name = "dep" | ||||||
|   analyzer-version = 1 |   analyzer-version = 1 | ||||||
|   inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" |   inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8" | ||||||
|   solver-name = "gps-cdcl" |   solver-name = "gps-cdcl" | ||||||
|   solver-version = 1 |   solver-version = 1 | ||||||
|  |  | ||||||
|  | @ -40,6 +40,10 @@ ignored = ["google.golang.org/appengine*"] | ||||||
|   #version = "0.6.5" |   #version = "0.6.5" | ||||||
|   revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" |   revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | ||||||
| 
 | 
 | ||||||
|  | [[override]] | ||||||
|  |   name = "github.com/go-sql-driver/mysql" | ||||||
|  |   revision = "d523deb1b23d913de5bdada721a6071e71283618" | ||||||
|  | 
 | ||||||
| [[override]] | [[override]] | ||||||
|   name = "github.com/gorilla/mux" |   name = "github.com/gorilla/mux" | ||||||
|   revision = "757bef944d0f21880861c2dd9c871ca543023cba" |   revision = "757bef944d0f21880861c2dd9c871ca543023cba" | ||||||
|  |  | ||||||
|  | @ -12,34 +12,63 @@ | ||||||
| # Individual Persons | # Individual Persons | ||||||
| 
 | 
 | ||||||
| Aaron Hopkins <go-sql-driver at die.net> | Aaron Hopkins <go-sql-driver at die.net> | ||||||
|  | Achille Roussel <achille.roussel at gmail.com> | ||||||
|  | Alexey Palazhchenko <alexey.palazhchenko at gmail.com> | ||||||
|  | Andrew Reid <andrew.reid at tixtrack.com> | ||||||
| Arne Hormann <arnehormann at gmail.com> | Arne Hormann <arnehormann at gmail.com> | ||||||
|  | Asta Xie <xiemengjun at gmail.com> | ||||||
|  | Bulat Gaifullin <gaifullinbf at gmail.com> | ||||||
| Carlos Nieto <jose.carlos at menteslibres.net> | Carlos Nieto <jose.carlos at menteslibres.net> | ||||||
| Chris Moos <chris at tech9computers.com> | Chris Moos <chris at tech9computers.com> | ||||||
|  | Craig Wilson <craiggwilson at gmail.com> | ||||||
|  | Daniel Montoya <dsmontoyam at gmail.com> | ||||||
| Daniel Nichter <nil at codenode.com> | Daniel Nichter <nil at codenode.com> | ||||||
| Daniël van Eeden <git at myname.nl> | Daniël van Eeden <git at myname.nl> | ||||||
|  | Dave Protasowski <dprotaso at gmail.com> | ||||||
| DisposaBoy <disposaboy at dby.me> | DisposaBoy <disposaboy at dby.me> | ||||||
|  | Egor Smolyakov <egorsmkv at gmail.com> | ||||||
|  | Evan Shaw <evan at vendhq.com> | ||||||
| Frederick Mayle <frederickmayle at gmail.com> | Frederick Mayle <frederickmayle at gmail.com> | ||||||
| Gustavo Kristic <gkristic at gmail.com> | Gustavo Kristic <gkristic at gmail.com> | ||||||
|  | Hajime Nakagami <nakagami at gmail.com> | ||||||
| Hanno Braun <mail at hannobraun.com> | Hanno Braun <mail at hannobraun.com> | ||||||
| Henri Yandell <flamefew at gmail.com> | Henri Yandell <flamefew at gmail.com> | ||||||
| Hirotaka Yamamoto <ymmt2005 at gmail.com> | Hirotaka Yamamoto <ymmt2005 at gmail.com> | ||||||
|  | ICHINOSE Shogo <shogo82148 at gmail.com> | ||||||
| INADA Naoki <songofacandy at gmail.com> | INADA Naoki <songofacandy at gmail.com> | ||||||
|  | Jacek Szwec <szwec.jacek at gmail.com> | ||||||
| James Harr <james.harr at gmail.com> | James Harr <james.harr at gmail.com> | ||||||
|  | Jeff Hodges <jeff at somethingsimilar.com> | ||||||
|  | Jeffrey Charles <jeffreycharles at gmail.com> | ||||||
| Jian Zhen <zhenjl at gmail.com> | Jian Zhen <zhenjl at gmail.com> | ||||||
| Joshua Prunier <joshua.prunier at gmail.com> | Joshua Prunier <joshua.prunier at gmail.com> | ||||||
| Julien Lefevre <julien.lefevr at gmail.com> | Julien Lefevre <julien.lefevr at gmail.com> | ||||||
| Julien Schmidt <go-sql-driver at julienschmidt.com> | Julien Schmidt <go-sql-driver at julienschmidt.com> | ||||||
|  | Justin Li <jli at j-li.net> | ||||||
|  | Justin Nuß <nuss.justin at gmail.com> | ||||||
| Kamil Dziedzic <kamil at klecza.pl> | Kamil Dziedzic <kamil at klecza.pl> | ||||||
| Kevin Malachowski <kevin at chowski.com> | Kevin Malachowski <kevin at chowski.com> | ||||||
|  | Kieron Woodhouse <kieron.woodhouse at infosum.com> | ||||||
| Lennart Rudolph <lrudolph at hmc.edu> | Lennart Rudolph <lrudolph at hmc.edu> | ||||||
| Leonardo YongUk Kim <dalinaum at gmail.com> | Leonardo YongUk Kim <dalinaum at gmail.com> | ||||||
|  | Linh Tran Tuan <linhduonggnu at gmail.com> | ||||||
|  | Lion Yang <lion at aosc.xyz> | ||||||
| Luca Looz <luca.looz92 at gmail.com> | Luca Looz <luca.looz92 at gmail.com> | ||||||
| Lucas Liu <extrafliu at gmail.com> | Lucas Liu <extrafliu at gmail.com> | ||||||
| Luke Scott <luke at webconnex.com> | Luke Scott <luke at webconnex.com> | ||||||
|  | Maciej Zimnoch <maciej.zimnoch at codilime.com> | ||||||
| Michael Woolnough <michael.woolnough at gmail.com> | Michael Woolnough <michael.woolnough at gmail.com> | ||||||
| Nicola Peduzzi <thenikso at gmail.com> | Nicola Peduzzi <thenikso at gmail.com> | ||||||
|  | Olivier Mengué <dolmen at cpan.org> | ||||||
|  | oscarzhao <oscarzhaosl at gmail.com> | ||||||
| Paul Bonser <misterpib at gmail.com> | Paul Bonser <misterpib at gmail.com> | ||||||
|  | Peter Schultz <peter.schultz at classmarkets.com> | ||||||
|  | Rebecca Chin <rchin at pivotal.io> | ||||||
|  | Reed Allman <rdallman10 at gmail.com> | ||||||
|  | Richard Wilkes <wilkes at me.com> | ||||||
|  | Robert Russell <robert at rrbrussell.com> | ||||||
| Runrioter Wung <runrioter at gmail.com> | Runrioter Wung <runrioter at gmail.com> | ||||||
|  | Shuode Li <elemount at qq.com> | ||||||
| Soroush Pour <me at soroushjp.com> | Soroush Pour <me at soroushjp.com> | ||||||
| Stan Putrya <root.vagner at gmail.com> | Stan Putrya <root.vagner at gmail.com> | ||||||
| Stanley Gunawan <gunawan.stanley at gmail.com> | Stanley Gunawan <gunawan.stanley at gmail.com> | ||||||
|  | @ -51,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com> | ||||||
| # Organizations | # Organizations | ||||||
| 
 | 
 | ||||||
| Barracuda Networks, Inc. | Barracuda Networks, Inc. | ||||||
|  | Counting Ltd. | ||||||
| Google Inc. | Google Inc. | ||||||
|  | InfoSum Ltd. | ||||||
|  | Keybase Inc. | ||||||
|  | Percona LLC | ||||||
|  | Pivotal Inc. | ||||||
| Stripe Inc. | Stripe Inc. | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ | ||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"appengine/cloudsql" | 	"google.golang.org/appengine/cloudsql" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func init() { | func init() { | ||||||
|  |  | ||||||
|  | @ -0,0 +1,420 @@ | ||||||
|  | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // This Source Code Form is subject to the terms of the Mozilla Public
 | ||||||
|  | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
|  | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
|  | 
 | ||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/rsa" | ||||||
|  | 	"crypto/sha1" | ||||||
|  | 	"crypto/sha256" | ||||||
|  | 	"crypto/x509" | ||||||
|  | 	"encoding/pem" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // server pub keys registry
 | ||||||
|  | var ( | ||||||
|  | 	serverPubKeyLock     sync.RWMutex | ||||||
|  | 	serverPubKeyRegistry map[string]*rsa.PublicKey | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // RegisterServerPubKey registers a server RSA public key which can be used to
 | ||||||
|  | // send data in a secure manner to the server without receiving the public key
 | ||||||
|  | // in a potentially insecure way from the server first.
 | ||||||
|  | // Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
 | ||||||
|  | //
 | ||||||
|  | // Note: The provided rsa.PublicKey instance is exclusively owned by the driver
 | ||||||
|  | // after registering it and may not be modified.
 | ||||||
|  | //
 | ||||||
|  | //  data, err := ioutil.ReadFile("mykey.pem")
 | ||||||
|  | //  if err != nil {
 | ||||||
|  | //  	log.Fatal(err)
 | ||||||
|  | //  }
 | ||||||
|  | //
 | ||||||
|  | //  block, _ := pem.Decode(data)
 | ||||||
|  | //  if block == nil || block.Type != "PUBLIC KEY" {
 | ||||||
|  | //  	log.Fatal("failed to decode PEM block containing public key")
 | ||||||
|  | //  }
 | ||||||
|  | //
 | ||||||
|  | //  pub, err := x509.ParsePKIXPublicKey(block.Bytes)
 | ||||||
|  | //  if err != nil {
 | ||||||
|  | //  	log.Fatal(err)
 | ||||||
|  | //  }
 | ||||||
|  | //
 | ||||||
|  | //  if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
 | ||||||
|  | //  	mysql.RegisterServerPubKey("mykey", rsaPubKey)
 | ||||||
|  | //  } else {
 | ||||||
|  | //  	log.Fatal("not a RSA public key")
 | ||||||
|  | //  }
 | ||||||
|  | //
 | ||||||
|  | func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { | ||||||
|  | 	serverPubKeyLock.Lock() | ||||||
|  | 	if serverPubKeyRegistry == nil { | ||||||
|  | 		serverPubKeyRegistry = make(map[string]*rsa.PublicKey) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	serverPubKeyRegistry[name] = pubKey | ||||||
|  | 	serverPubKeyLock.Unlock() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DeregisterServerPubKey removes the public key registered with the given name.
 | ||||||
|  | func DeregisterServerPubKey(name string) { | ||||||
|  | 	serverPubKeyLock.Lock() | ||||||
|  | 	if serverPubKeyRegistry != nil { | ||||||
|  | 		delete(serverPubKeyRegistry, name) | ||||||
|  | 	} | ||||||
|  | 	serverPubKeyLock.Unlock() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getServerPubKey(name string) (pubKey *rsa.PublicKey) { | ||||||
|  | 	serverPubKeyLock.RLock() | ||||||
|  | 	if v, ok := serverPubKeyRegistry[name]; ok { | ||||||
|  | 		pubKey = v | ||||||
|  | 	} | ||||||
|  | 	serverPubKeyLock.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Hash password using pre 4.1 (old password) method
 | ||||||
|  | // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
 | ||||||
|  | type myRnd struct { | ||||||
|  | 	seed1, seed2 uint32 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const myRndMaxVal = 0x3FFFFFFF | ||||||
|  | 
 | ||||||
|  | // Pseudo random number generator
 | ||||||
|  | func newMyRnd(seed1, seed2 uint32) *myRnd { | ||||||
|  | 	return &myRnd{ | ||||||
|  | 		seed1: seed1 % myRndMaxVal, | ||||||
|  | 		seed2: seed2 % myRndMaxVal, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Tested to be equivalent to MariaDB's floating point variant
 | ||||||
|  | // http://play.golang.org/p/QHvhd4qved
 | ||||||
|  | // http://play.golang.org/p/RG0q4ElWDx
 | ||||||
|  | func (r *myRnd) NextByte() byte { | ||||||
|  | 	r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | ||||||
|  | 	r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | ||||||
|  | 
 | ||||||
|  | 	return byte(uint64(r.seed1) * 31 / myRndMaxVal) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Generate binary hash from byte string using insecure pre 4.1 method
 | ||||||
|  | func pwHash(password []byte) (result [2]uint32) { | ||||||
|  | 	var add uint32 = 7 | ||||||
|  | 	var tmp uint32 | ||||||
|  | 
 | ||||||
|  | 	result[0] = 1345345333 | ||||||
|  | 	result[1] = 0x12345671 | ||||||
|  | 
 | ||||||
|  | 	for _, c := range password { | ||||||
|  | 		// skip spaces and tabs in password
 | ||||||
|  | 		if c == ' ' || c == '\t' { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		tmp = uint32(c) | ||||||
|  | 		result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | ||||||
|  | 		result[1] += (result[1] << 8) ^ result[0] | ||||||
|  | 		add += tmp | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Remove sign bit (1<<31)-1)
 | ||||||
|  | 	result[0] &= 0x7FFFFFFF | ||||||
|  | 	result[1] &= 0x7FFFFFFF | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Hash password using insecure pre 4.1 method
 | ||||||
|  | func scrambleOldPassword(scramble []byte, password string) []byte { | ||||||
|  | 	if len(password) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	scramble = scramble[:8] | ||||||
|  | 
 | ||||||
|  | 	hashPw := pwHash([]byte(password)) | ||||||
|  | 	hashSc := pwHash(scramble) | ||||||
|  | 
 | ||||||
|  | 	r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | ||||||
|  | 
 | ||||||
|  | 	var out [8]byte | ||||||
|  | 	for i := range out { | ||||||
|  | 		out[i] = r.NextByte() + 64 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	mask := r.NextByte() | ||||||
|  | 	for i := range out { | ||||||
|  | 		out[i] ^= mask | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return out[:] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Hash password using 4.1+ method (SHA1)
 | ||||||
|  | func scramblePassword(scramble []byte, password string) []byte { | ||||||
|  | 	if len(password) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// stage1Hash = SHA1(password)
 | ||||||
|  | 	crypt := sha1.New() | ||||||
|  | 	crypt.Write([]byte(password)) | ||||||
|  | 	stage1 := crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
 | ||||||
|  | 	// inner Hash
 | ||||||
|  | 	crypt.Reset() | ||||||
|  | 	crypt.Write(stage1) | ||||||
|  | 	hash := crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	// outer Hash
 | ||||||
|  | 	crypt.Reset() | ||||||
|  | 	crypt.Write(scramble) | ||||||
|  | 	crypt.Write(hash) | ||||||
|  | 	scramble = crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	// token = scrambleHash XOR stage1Hash
 | ||||||
|  | 	for i := range scramble { | ||||||
|  | 		scramble[i] ^= stage1[i] | ||||||
|  | 	} | ||||||
|  | 	return scramble | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Hash password using MySQL 8+ method (SHA256)
 | ||||||
|  | func scrambleSHA256Password(scramble []byte, password string) []byte { | ||||||
|  | 	if len(password) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
 | ||||||
|  | 
 | ||||||
|  | 	crypt := sha256.New() | ||||||
|  | 	crypt.Write([]byte(password)) | ||||||
|  | 	message1 := crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	crypt.Reset() | ||||||
|  | 	crypt.Write(message1) | ||||||
|  | 	message1Hash := crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	crypt.Reset() | ||||||
|  | 	crypt.Write(message1Hash) | ||||||
|  | 	crypt.Write(scramble) | ||||||
|  | 	message2 := crypt.Sum(nil) | ||||||
|  | 
 | ||||||
|  | 	for i := range message1 { | ||||||
|  | 		message1[i] ^= message2[i] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return message1 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { | ||||||
|  | 	plain := make([]byte, len(password)+1) | ||||||
|  | 	copy(plain, password) | ||||||
|  | 	for i := range plain { | ||||||
|  | 		j := i % len(seed) | ||||||
|  | 		plain[i] ^= seed[j] | ||||||
|  | 	} | ||||||
|  | 	sha1 := sha1.New() | ||||||
|  | 	return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { | ||||||
|  | 	enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return mc.writeAuthSwitchPacket(enc, false) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { | ||||||
|  | 	switch plugin { | ||||||
|  | 	case "caching_sha2_password": | ||||||
|  | 		authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) | ||||||
|  | 		return authResp, (authResp == nil), nil | ||||||
|  | 
 | ||||||
|  | 	case "mysql_old_password": | ||||||
|  | 		if !mc.cfg.AllowOldPasswords { | ||||||
|  | 			return nil, false, ErrOldPassword | ||||||
|  | 		} | ||||||
|  | 		// Note: there are edge cases where this should work but doesn't;
 | ||||||
|  | 		// this is currently "wontfix":
 | ||||||
|  | 		// https://github.com/go-sql-driver/mysql/issues/184
 | ||||||
|  | 		authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) | ||||||
|  | 		return authResp, true, nil | ||||||
|  | 
 | ||||||
|  | 	case "mysql_clear_password": | ||||||
|  | 		if !mc.cfg.AllowCleartextPasswords { | ||||||
|  | 			return nil, false, ErrCleartextPassword | ||||||
|  | 		} | ||||||
|  | 		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
 | ||||||
|  | 		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
 | ||||||
|  | 		return []byte(mc.cfg.Passwd), true, nil | ||||||
|  | 
 | ||||||
|  | 	case "mysql_native_password": | ||||||
|  | 		if !mc.cfg.AllowNativePasswords { | ||||||
|  | 			return nil, false, ErrNativePassword | ||||||
|  | 		} | ||||||
|  | 		// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
 | ||||||
|  | 		// Native password authentication only need and will need 20-byte challenge.
 | ||||||
|  | 		authResp := scramblePassword(authData[:20], mc.cfg.Passwd) | ||||||
|  | 		return authResp, false, nil | ||||||
|  | 
 | ||||||
|  | 	case "sha256_password": | ||||||
|  | 		if len(mc.cfg.Passwd) == 0 { | ||||||
|  | 			return nil, true, nil | ||||||
|  | 		} | ||||||
|  | 		if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | ||||||
|  | 			// write cleartext auth packet
 | ||||||
|  | 			return []byte(mc.cfg.Passwd), true, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		pubKey := mc.cfg.pubKey | ||||||
|  | 		if pubKey == nil { | ||||||
|  | 			// request public key from server
 | ||||||
|  | 			return []byte{1}, false, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// encrypted password
 | ||||||
|  | 		enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) | ||||||
|  | 		return enc, false, err | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		errLog.Print("unknown auth plugin:", plugin) | ||||||
|  | 		return nil, false, ErrUnknownPlugin | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { | ||||||
|  | 	// Read Result Packet
 | ||||||
|  | 	authData, newPlugin, err := mc.readAuthResult() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// handle auth plugin switch, if requested
 | ||||||
|  | 	if newPlugin != "" { | ||||||
|  | 		// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
 | ||||||
|  | 		// sent and we have to keep using the cipher sent in the init packet.
 | ||||||
|  | 		if authData == nil { | ||||||
|  | 			authData = oldAuthData | ||||||
|  | 		} else { | ||||||
|  | 			// copy data from read buffer to owned slice
 | ||||||
|  | 			copy(oldAuthData, authData) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		plugin = newPlugin | ||||||
|  | 
 | ||||||
|  | 		authResp, addNUL, err := mc.auth(authData, plugin) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Read Result Packet
 | ||||||
|  | 		authData, newPlugin, err = mc.readAuthResult() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Do not allow to change the auth plugin more than once
 | ||||||
|  | 		if newPlugin != "" { | ||||||
|  | 			return ErrMalformPkt | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch plugin { | ||||||
|  | 
 | ||||||
|  | 	// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
 | ||||||
|  | 	case "caching_sha2_password": | ||||||
|  | 		switch len(authData) { | ||||||
|  | 		case 0: | ||||||
|  | 			return nil // auth successful
 | ||||||
|  | 		case 1: | ||||||
|  | 			switch authData[0] { | ||||||
|  | 			case cachingSha2PasswordFastAuthSuccess: | ||||||
|  | 				if err = mc.readResultOK(); err == nil { | ||||||
|  | 					return nil // auth successful
 | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 			case cachingSha2PasswordPerformFullAuthentication: | ||||||
|  | 				if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | ||||||
|  | 					// write cleartext auth packet
 | ||||||
|  | 					err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return err | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					pubKey := mc.cfg.pubKey | ||||||
|  | 					if pubKey == nil { | ||||||
|  | 						// request public key from server
 | ||||||
|  | 						data := mc.buf.takeSmallBuffer(4 + 1) | ||||||
|  | 						data[4] = cachingSha2PasswordRequestPublicKey | ||||||
|  | 						mc.writePacket(data) | ||||||
|  | 
 | ||||||
|  | 						// parse public key
 | ||||||
|  | 						data, err := mc.readPacket() | ||||||
|  | 						if err != nil { | ||||||
|  | 							return err | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						block, _ := pem.Decode(data[1:]) | ||||||
|  | 						pkix, err := x509.ParsePKIXPublicKey(block.Bytes) | ||||||
|  | 						if err != nil { | ||||||
|  | 							return err | ||||||
|  | 						} | ||||||
|  | 						pubKey = pkix.(*rsa.PublicKey) | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					// send encrypted password
 | ||||||
|  | 					err = mc.sendEncryptedPassword(oldAuthData, pubKey) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return err | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				return mc.readResultOK() | ||||||
|  | 
 | ||||||
|  | 			default: | ||||||
|  | 				return ErrMalformPkt | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			return ErrMalformPkt | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	case "sha256_password": | ||||||
|  | 		switch len(authData) { | ||||||
|  | 		case 0: | ||||||
|  | 			return nil // auth successful
 | ||||||
|  | 		default: | ||||||
|  | 			block, _ := pem.Decode(authData) | ||||||
|  | 			pub, err := x509.ParsePKIXPublicKey(block.Bytes) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// send encrypted password
 | ||||||
|  | 			err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			return mc.readResultOK() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return nil // auth successful
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | @ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { | ||||||
| // smaller than defaultBufSize
 | // smaller than defaultBufSize
 | ||||||
| // Only one buffer (total) can be used at a time.
 | // Only one buffer (total) can be used at a time.
 | ||||||
| func (b *buffer) takeSmallBuffer(length int) []byte { | func (b *buffer) takeSmallBuffer(length int) []byte { | ||||||
| 	if b.length == 0 { | 	if b.length > 0 { | ||||||
| 		return b.buf[:length] | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return b.buf[:length] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // takeCompleteBuffer returns the complete existing buffer.
 | // takeCompleteBuffer returns the complete existing buffer.
 | ||||||
| // This can be used if the necessary buffer size is unknown.
 | // This can be used if the necessary buffer size is unknown.
 | ||||||
| // Only one buffer (total) can be used at a time.
 | // Only one buffer (total) can be used at a time.
 | ||||||
| func (b *buffer) takeCompleteBuffer() []byte { | func (b *buffer) takeCompleteBuffer() []byte { | ||||||
| 	if b.length == 0 { | 	if b.length > 0 { | ||||||
| 		return b.buf | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return b.buf | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ | ||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| const defaultCollation = "utf8_general_ci" | const defaultCollation = "utf8_general_ci" | ||||||
|  | const binaryCollation = "binary" | ||||||
| 
 | 
 | ||||||
| // A list of available collations mapped to the internal ID.
 | // A list of available collations mapped to the internal ID.
 | ||||||
| // To update this map use the following MySQL query:
 | // To update this map use the following MySQL query:
 | ||||||
|  |  | ||||||
|  | @ -10,12 +10,23 @@ package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
|  | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // a copy of context.Context for Go 1.7 and earlier
 | ||||||
|  | type mysqlContext interface { | ||||||
|  | 	Done() <-chan struct{} | ||||||
|  | 	Err() error | ||||||
|  | 
 | ||||||
|  | 	// defined in context.Context, but not used in this driver:
 | ||||||
|  | 	// Deadline() (deadline time.Time, ok bool)
 | ||||||
|  | 	// Value(key interface{}) interface{}
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type mysqlConn struct { | type mysqlConn struct { | ||||||
| 	buf              buffer | 	buf              buffer | ||||||
| 	netConn          net.Conn | 	netConn          net.Conn | ||||||
|  | @ -29,7 +40,14 @@ type mysqlConn struct { | ||||||
| 	status           statusFlag | 	status           statusFlag | ||||||
| 	sequence         uint8 | 	sequence         uint8 | ||||||
| 	parseTime        bool | 	parseTime        bool | ||||||
| 	strict           bool | 
 | ||||||
|  | 	// for context support (Go 1.8+)
 | ||||||
|  | 	watching bool | ||||||
|  | 	watcher  chan<- mysqlContext | ||||||
|  | 	closech  chan struct{} | ||||||
|  | 	finished chan<- struct{} | ||||||
|  | 	canceled atomicError // set non-nil if conn is canceled
 | ||||||
|  | 	closed   atomicBool  // set when conn is closed, before closech is closed
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Handles parameters set in DSN after the connection is established
 | // Handles parameters set in DSN after the connection is established
 | ||||||
|  | @ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (mc *mysqlConn) markBadConn(err error) error { | ||||||
|  | 	if mc == nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err != errBadConnNoWrite { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return driver.ErrBadConn | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (mc *mysqlConn) Begin() (driver.Tx, error) { | func (mc *mysqlConn) Begin() (driver.Tx, error) { | ||||||
| 	if mc.netConn == nil { | 	return mc.begin(false) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { | ||||||
|  | 	if mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
| 	err := mc.exec("START TRANSACTION") | 	var q string | ||||||
|  | 	if readOnly { | ||||||
|  | 		q = "START TRANSACTION READ ONLY" | ||||||
|  | 	} else { | ||||||
|  | 		q = "START TRANSACTION" | ||||||
|  | 	} | ||||||
|  | 	err := mc.exec(q) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		return &mysqlTx{mc}, err | 		return &mysqlTx{mc}, err | ||||||
| 	} | 	} | ||||||
| 
 | 	return nil, mc.markBadConn(err) | ||||||
| 	return nil, err |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (mc *mysqlConn) Close() (err error) { | func (mc *mysqlConn) Close() (err error) { | ||||||
| 	// Makes Close idempotent
 | 	// Makes Close idempotent
 | ||||||
| 	if mc.netConn != nil { | 	if !mc.closed.IsSet() { | ||||||
| 		err = mc.writeCommandPacket(comQuit) | 		err = mc.writeCommandPacket(comQuit) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { | ||||||
| // is called before auth or on auth failure because MySQL will have already
 | // is called before auth or on auth failure because MySQL will have already
 | ||||||
| // closed the network connection.
 | // closed the network connection.
 | ||||||
| func (mc *mysqlConn) cleanup() { | func (mc *mysqlConn) cleanup() { | ||||||
| 	// Makes cleanup idempotent
 | 	if !mc.closed.TrySet(true) { | ||||||
| 	if mc.netConn != nil { | 		return | ||||||
| 		if err := mc.netConn.Close(); err != nil { |  | ||||||
| 			errLog.Print(err) |  | ||||||
| 		} |  | ||||||
| 		mc.netConn = nil |  | ||||||
| 	} | 	} | ||||||
| 	mc.cfg = nil | 
 | ||||||
| 	mc.buf.nc = nil | 	// Makes cleanup idempotent
 | ||||||
|  | 	close(mc.closech) | ||||||
|  | 	if mc.netConn == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if err := mc.netConn.Close(); err != nil { | ||||||
|  | 		errLog.Print(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) error() error { | ||||||
|  | 	if mc.closed.IsSet() { | ||||||
|  | 		if err := mc.canceled.Value(); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return ErrInvalidConn | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | ||||||
| 	if mc.netConn == nil { | 	if mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
| 	// Send command
 | 	// Send command
 | ||||||
| 	err := mc.writeCommandPacketStr(comStmtPrepare, query) | 	err := mc.writeCommandPacketStr(comStmtPrepare, query) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, mc.markBadConn(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	stmt := &mysqlStmt{ | 	stmt := &mysqlStmt{ | ||||||
|  | @ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | ||||||
| 	if buf == nil { | 	if buf == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// can not take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return "", driver.ErrBadConn | 		return "", ErrInvalidConn | ||||||
| 	} | 	} | ||||||
| 	buf = buf[:0] | 	buf = buf[:0] | ||||||
| 	argPos := 0 | 	argPos := 0 | ||||||
|  | @ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||||
| 	if mc.netConn == nil { | 	if mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
|  | @ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		query = prepared | 		query = prepared | ||||||
| 		args = nil |  | ||||||
| 	} | 	} | ||||||
| 	mc.affectedRows = 0 | 	mc.affectedRows = 0 | ||||||
| 	mc.insertId = 0 | 	mc.insertId = 0 | ||||||
|  | @ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | ||||||
| 			insertId:     int64(mc.insertId), | 			insertId:     int64(mc.insertId), | ||||||
| 		}, err | 		}, err | ||||||
| 	} | 	} | ||||||
| 	return nil, err | 	return nil, mc.markBadConn(err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Internal function to execute commands
 | // Internal function to execute commands
 | ||||||
| func (mc *mysqlConn) exec(query string) error { | func (mc *mysqlConn) exec(query string) error { | ||||||
| 	// Send command
 | 	// Send command
 | ||||||
| 	err := mc.writeCommandPacketStr(comQuery, query) | 	if err := mc.writeCommandPacketStr(comQuery, query); err != nil { | ||||||
| 	if err != nil { | 		return mc.markBadConn(err) | ||||||
| 		return err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Read Result
 | 	// Read Result
 | ||||||
| 	resLen, err := mc.readResultSetHeaderPacket() | 	resLen, err := mc.readResultSetHeaderPacket() | ||||||
| 	if err == nil && resLen > 0 { | 	if err != nil { | ||||||
| 		if err = mc.readUntilEOF(); err != nil { | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if resLen > 0 { | ||||||
|  | 		// columns
 | ||||||
|  | 		if err := mc.readUntilEOF(); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		err = mc.readUntilEOF() | 		// rows
 | ||||||
|  | 		if err := mc.readUntilEOF(); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return err | 	return mc.discardResults() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||||
| 	if mc.netConn == nil { | 	return mc.query(query, args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { | ||||||
|  | 	if mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
|  | @ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		query = prepared | 		query = prepared | ||||||
| 		args = nil |  | ||||||
| 	} | 	} | ||||||
| 	// Send command
 | 	// Send command
 | ||||||
| 	err := mc.writeCommandPacketStr(comQuery, query) | 	err := mc.writeCommandPacketStr(comQuery, query) | ||||||
|  | @ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | ||||||
| 			rows.mc = mc | 			rows.mc = mc | ||||||
| 
 | 
 | ||||||
| 			if resLen == 0 { | 			if resLen == 0 { | ||||||
| 				// no columns, no more data
 | 				rows.rs.done = true | ||||||
| 				return emptyRows{}, nil | 
 | ||||||
|  | 				switch err := rows.NextResultSet(); err { | ||||||
|  | 				case nil, io.EOF: | ||||||
|  | 					return rows, nil | ||||||
|  | 				default: | ||||||
|  | 					return nil, err | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
| 			// Columns
 | 			// Columns
 | ||||||
| 			rows.columns, err = mc.readColumns(resLen) | 			rows.rs.columns, err = mc.readColumns(resLen) | ||||||
| 			return rows, err | 			return rows, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil, err | 	return nil, mc.markBadConn(err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Gets the value of the given MySQL System Variable
 | // Gets the value of the given MySQL System Variable
 | ||||||
|  | @ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		rows := new(textRows) | 		rows := new(textRows) | ||||||
| 		rows.mc = mc | 		rows.mc = mc | ||||||
| 		rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | 		rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | ||||||
| 
 | 
 | ||||||
| 		if resLen > 0 { | 		if resLen > 0 { | ||||||
| 			// Columns
 | 			// Columns
 | ||||||
|  | @ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | ||||||
| 	} | 	} | ||||||
| 	return nil, err | 	return nil, err | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // finish is called when the query has canceled.
 | ||||||
|  | func (mc *mysqlConn) cancel(err error) { | ||||||
|  | 	mc.canceled.Set(err) | ||||||
|  | 	mc.cleanup() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // finish is called when the query has succeeded.
 | ||||||
|  | func (mc *mysqlConn) finish() { | ||||||
|  | 	if !mc.watching || mc.finished == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	select { | ||||||
|  | 	case mc.finished <- struct{}{}: | ||||||
|  | 		mc.watching = false | ||||||
|  | 	case <-mc.closech: | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,208 @@ | ||||||
|  | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // This Source Code Form is subject to the terms of the Mozilla Public
 | ||||||
|  | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
|  | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
|  | 
 | ||||||
|  | // +build go1.8
 | ||||||
|  | 
 | ||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Ping implements driver.Pinger interface
 | ||||||
|  | func (mc *mysqlConn) Ping(ctx context.Context) (err error) { | ||||||
|  | 	if mc.closed.IsSet() { | ||||||
|  | 		errLog.Print(ErrInvalidConn) | ||||||
|  | 		return driver.ErrBadConn | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err = mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	defer mc.finish() | ||||||
|  | 
 | ||||||
|  | 	if err = mc.writeCommandPacket(comPing); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return mc.readResultOK() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // BeginTx implements driver.ConnBeginTx interface
 | ||||||
|  | func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||||
|  | 	if err := mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer mc.finish() | ||||||
|  | 
 | ||||||
|  | 	if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { | ||||||
|  | 		level, err := mapIsolationLevel(opts.Isolation) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return mc.begin(opts.ReadOnly) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	dargs, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows, err := mc.query(query, dargs) | ||||||
|  | 	if err != nil { | ||||||
|  | 		mc.finish() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	rows.finish = mc.finish | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	dargs, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer mc.finish() | ||||||
|  | 
 | ||||||
|  | 	return mc.Exec(query, dargs) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||||
|  | 	if err := mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stmt, err := mc.Prepare(query) | ||||||
|  | 	mc.finish() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	select { | ||||||
|  | 	default: | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		stmt.Close() | ||||||
|  | 		return nil, ctx.Err() | ||||||
|  | 	} | ||||||
|  | 	return stmt, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	dargs, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := stmt.mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows, err := stmt.query(dargs) | ||||||
|  | 	if err != nil { | ||||||
|  | 		stmt.mc.finish() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	rows.finish = stmt.mc.finish | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	dargs, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := stmt.mc.watchCancel(ctx); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer stmt.mc.finish() | ||||||
|  | 
 | ||||||
|  | 	return stmt.Exec(dargs) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) watchCancel(ctx context.Context) error { | ||||||
|  | 	if mc.watching { | ||||||
|  | 		// Reach here if canceled,
 | ||||||
|  | 		// so the connection is already invalid
 | ||||||
|  | 		mc.cleanup() | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	if ctx.Done() == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	mc.watching = true | ||||||
|  | 	select { | ||||||
|  | 	default: | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return ctx.Err() | ||||||
|  | 	} | ||||||
|  | 	if mc.watcher == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	mc.watcher <- ctx | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) startWatcher() { | ||||||
|  | 	watcher := make(chan mysqlContext, 1) | ||||||
|  | 	mc.watcher = watcher | ||||||
|  | 	finished := make(chan struct{}) | ||||||
|  | 	mc.finished = finished | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			var ctx mysqlContext | ||||||
|  | 			select { | ||||||
|  | 			case ctx = <-watcher: | ||||||
|  | 			case <-mc.closech: | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			select { | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				mc.cancel(ctx.Err()) | ||||||
|  | 			case <-finished: | ||||||
|  | 			case <-mc.closech: | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { | ||||||
|  | 	nv.Value, err = converter{}.ConvertValue(nv.Value) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ResetSession implements driver.SessionResetter.
 | ||||||
|  | // (From Go 1.10)
 | ||||||
|  | func (mc *mysqlConn) ResetSession(ctx context.Context) error { | ||||||
|  | 	if mc.closed.IsSet() { | ||||||
|  | 		return driver.ErrBadConn | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | @ -9,7 +9,9 @@ | ||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	minProtocolVersion byte = 10 | 	defaultAuthPlugin       = "mysql_native_password" | ||||||
|  | 	defaultMaxAllowedPacket = 4 << 20 // 4 MiB
 | ||||||
|  | 	minProtocolVersion      = 10 | ||||||
| 	maxPacketSize           = 1<<24 - 1 | 	maxPacketSize           = 1<<24 - 1 | ||||||
| 	timeFormat              = "2006-01-02 15:04:05.999999" | 	timeFormat              = "2006-01-02 15:04:05.999999" | ||||||
| ) | ) | ||||||
|  | @ -18,10 +20,11 @@ const ( | ||||||
| // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
 | // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
 | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	iOK          byte = 0x00 | 	iOK           byte = 0x00 | ||||||
| 	iLocalInFile byte = 0xfb | 	iAuthMoreData byte = 0x01 | ||||||
| 	iEOF         byte = 0xfe | 	iLocalInFile  byte = 0xfb | ||||||
| 	iERR         byte = 0xff | 	iEOF          byte = 0xfe | ||||||
|  | 	iERR          byte = 0xff | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
 | // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
 | ||||||
|  | @ -87,8 +90,10 @@ const ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
 | // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
 | ||||||
|  | type fieldType byte | ||||||
|  | 
 | ||||||
| const ( | const ( | ||||||
| 	fieldTypeDecimal byte = iota | 	fieldTypeDecimal fieldType = iota | ||||||
| 	fieldTypeTiny | 	fieldTypeTiny | ||||||
| 	fieldTypeShort | 	fieldTypeShort | ||||||
| 	fieldTypeLong | 	fieldTypeLong | ||||||
|  | @ -107,7 +112,7 @@ const ( | ||||||
| 	fieldTypeBit | 	fieldTypeBit | ||||||
| ) | ) | ||||||
| const ( | const ( | ||||||
| 	fieldTypeJSON byte = iota + 0xf5 | 	fieldTypeJSON fieldType = iota + 0xf5 | ||||||
| 	fieldTypeNewDecimal | 	fieldTypeNewDecimal | ||||||
| 	fieldTypeEnum | 	fieldTypeEnum | ||||||
| 	fieldTypeSet | 	fieldTypeSet | ||||||
|  | @ -161,3 +166,9 @@ const ( | ||||||
| 	statusInTransReadonly | 	statusInTransReadonly | ||||||
| 	statusSessionStateChanged | 	statusSessionStateChanged | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	cachingSha2PasswordRequestPublicKey          = 2 | ||||||
|  | 	cachingSha2PasswordFastAuthSuccess           = 3 | ||||||
|  | 	cachingSha2PasswordPerformFullAuthentication = 4 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ | ||||||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
| // You can obtain one at http://mozilla.org/MPL/2.0/.
 | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
| 
 | 
 | ||||||
| // Package mysql provides a MySQL driver for Go's database/sql package
 | // Package mysql provides a MySQL driver for Go's database/sql package.
 | ||||||
| //
 | //
 | ||||||
| // The driver should be used via the database/sql package:
 | // The driver should be used via the database/sql package:
 | ||||||
| //
 | //
 | ||||||
|  | @ -20,8 +20,14 @@ import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"sync" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // watcher interface is used for context support (From Go 1.8)
 | ||||||
|  | type watcher interface { | ||||||
|  | 	startWatcher() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // MySQLDriver is exported to make the driver directly accessible.
 | // MySQLDriver is exported to make the driver directly accessible.
 | ||||||
| // In general the driver is used via the database/sql package.
 | // In general the driver is used via the database/sql package.
 | ||||||
| type MySQLDriver struct{} | type MySQLDriver struct{} | ||||||
|  | @ -30,12 +36,17 @@ type MySQLDriver struct{} | ||||||
| // Custom dial functions must be registered with RegisterDial
 | // Custom dial functions must be registered with RegisterDial
 | ||||||
| type DialFunc func(addr string) (net.Conn, error) | type DialFunc func(addr string) (net.Conn, error) | ||||||
| 
 | 
 | ||||||
| var dials map[string]DialFunc | var ( | ||||||
|  | 	dialsLock sync.RWMutex | ||||||
|  | 	dials     map[string]DialFunc | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| // RegisterDial registers a custom dial function. It can then be used by the
 | // RegisterDial registers a custom dial function. It can then be used by the
 | ||||||
| // network address mynet(addr), where mynet is the registered new network.
 | // network address mynet(addr), where mynet is the registered new network.
 | ||||||
| // addr is passed as a parameter to the dial function.
 | // addr is passed as a parameter to the dial function.
 | ||||||
| func RegisterDial(net string, dial DialFunc) { | func RegisterDial(net string, dial DialFunc) { | ||||||
|  | 	dialsLock.Lock() | ||||||
|  | 	defer dialsLock.Unlock() | ||||||
| 	if dials == nil { | 	if dials == nil { | ||||||
| 		dials = make(map[string]DialFunc) | 		dials = make(map[string]DialFunc) | ||||||
| 	} | 	} | ||||||
|  | @ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||||||
| 	mc := &mysqlConn{ | 	mc := &mysqlConn{ | ||||||
| 		maxAllowedPacket: maxPacketSize, | 		maxAllowedPacket: maxPacketSize, | ||||||
| 		maxWriteSize:     maxPacketSize - 1, | 		maxWriteSize:     maxPacketSize - 1, | ||||||
|  | 		closech:          make(chan struct{}), | ||||||
| 	} | 	} | ||||||
| 	mc.cfg, err = ParseDSN(dsn) | 	mc.cfg, err = ParseDSN(dsn) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	mc.parseTime = mc.cfg.ParseTime | 	mc.parseTime = mc.cfg.ParseTime | ||||||
| 	mc.strict = mc.cfg.Strict |  | ||||||
| 
 | 
 | ||||||
| 	// Connect to Server
 | 	// Connect to Server
 | ||||||
| 	if dial, ok := dials[mc.cfg.Net]; ok { | 	dialsLock.RLock() | ||||||
|  | 	dial, ok := dials[mc.cfg.Net] | ||||||
|  | 	dialsLock.RUnlock() | ||||||
|  | 	if ok { | ||||||
| 		mc.netConn, err = dial(mc.cfg.Addr) | 		mc.netConn, err = dial(mc.cfg.Addr) | ||||||
| 	} else { | 	} else { | ||||||
| 		nd := net.Dialer{Timeout: mc.cfg.Timeout} | 		nd := net.Dialer{Timeout: mc.cfg.Timeout} | ||||||
|  | @ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Call startWatcher for context support (From Go 1.8)
 | ||||||
|  | 	if s, ok := interface{}(mc).(watcher); ok { | ||||||
|  | 		s.startWatcher() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	mc.buf = newBuffer(mc.netConn) | 	mc.buf = newBuffer(mc.netConn) | ||||||
| 
 | 
 | ||||||
| 	// Set I/O timeouts
 | 	// Set I/O timeouts
 | ||||||
|  | @ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||||||
| 	mc.writeTimeout = mc.cfg.WriteTimeout | 	mc.writeTimeout = mc.cfg.WriteTimeout | ||||||
| 
 | 
 | ||||||
| 	// Reading Handshake Initialization Packet
 | 	// Reading Handshake Initialization Packet
 | ||||||
| 	cipher, err := mc.readInitPacket() | 	authData, plugin, err := mc.readHandshakePacket() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		mc.cleanup() | 		mc.cleanup() | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Send Client Authentication Packet
 | 	// Send Client Authentication Packet
 | ||||||
| 	if err = mc.writeAuthPacket(cipher); err != nil { | 	authResp, addNUL, err := mc.auth(authData, plugin) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// try the default auth plugin, if using the requested plugin failed
 | ||||||
|  | 		errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) | ||||||
|  | 		plugin = defaultAuthPlugin | ||||||
|  | 		authResp, addNUL, err = mc.auth(authData, plugin) | ||||||
|  | 		if err != nil { | ||||||
|  | 			mc.cleanup() | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { | ||||||
| 		mc.cleanup() | 		mc.cleanup() | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Handle response to auth packet, switch methods if possible
 | 	// Handle response to auth packet, switch methods if possible
 | ||||||
| 	if err = handleAuthResult(mc); err != nil { | 	if err = mc.handleAuthResult(authData, plugin); err != nil { | ||||||
| 		// Authentication failed and MySQL has already closed the connection
 | 		// Authentication failed and MySQL has already closed the connection
 | ||||||
| 		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
 | 		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
 | ||||||
| 		// Do not send COM_QUIT, just cleanup and return the error.
 | 		// Do not send COM_QUIT, just cleanup and return the error.
 | ||||||
|  | @ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||||||
| 	return mc, nil | 	return mc, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func handleAuthResult(mc *mysqlConn) error { |  | ||||||
| 	// Read Result Packet
 |  | ||||||
| 	cipher, err := mc.readResultOK() |  | ||||||
| 	if err == nil { |  | ||||||
| 		return nil // auth successful
 |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if mc.cfg == nil { |  | ||||||
| 		return err // auth failed and retry not possible
 |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Retry auth if configured to do so.
 |  | ||||||
| 	if mc.cfg.AllowOldPasswords && err == ErrOldPassword { |  | ||||||
| 		// Retry with old authentication method. Note: there are edge cases
 |  | ||||||
| 		// where this should work but doesn't; this is currently "wontfix":
 |  | ||||||
| 		// https://github.com/go-sql-driver/mysql/issues/184
 |  | ||||||
| 		if err = mc.writeOldAuthPacket(cipher); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		_, err = mc.readResultOK() |  | ||||||
| 	} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { |  | ||||||
| 		// Retry with clear text password for
 |  | ||||||
| 		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
 |  | ||||||
| 		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
 |  | ||||||
| 		if err = mc.writeClearAuthPacket(); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		_, err = mc.readResultOK() |  | ||||||
| 	} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { |  | ||||||
| 		if err = mc.writeNativeAuthPacket(cipher); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		_, err = mc.readResultOK() |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func init() { | func init() { | ||||||
| 	sql.Register("mysql", &MySQLDriver{}) | 	sql.Register("mysql", &MySQLDriver{}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -10,11 +10,13 @@ package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"crypto/rsa" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -27,7 +29,9 @@ var ( | ||||||
| 	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | 	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Config is a configuration parsed from a DSN string
 | // Config is a configuration parsed from a DSN string.
 | ||||||
|  | // If a new Config is created instead of being parsed from a DSN string,
 | ||||||
|  | // the NewConfig function should be used, which sets default values.
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	User             string            // Username
 | 	User             string            // Username
 | ||||||
| 	Passwd           string            // Password (requires User)
 | 	Passwd           string            // Password (requires User)
 | ||||||
|  | @ -38,6 +42,8 @@ type Config struct { | ||||||
| 	Collation        string            // Connection collation
 | 	Collation        string            // Connection collation
 | ||||||
| 	Loc              *time.Location    // Location for time.Time values
 | 	Loc              *time.Location    // Location for time.Time values
 | ||||||
| 	MaxAllowedPacket int               // Max packet size allowed
 | 	MaxAllowedPacket int               // Max packet size allowed
 | ||||||
|  | 	ServerPubKey     string            // Server public key name
 | ||||||
|  | 	pubKey           *rsa.PublicKey    // Server public key
 | ||||||
| 	TLSConfig        string            // TLS configuration name
 | 	TLSConfig        string            // TLS configuration name
 | ||||||
| 	tls              *tls.Config       // TLS configuration
 | 	tls              *tls.Config       // TLS configuration
 | ||||||
| 	Timeout          time.Duration     // Dial timeout
 | 	Timeout          time.Duration     // Dial timeout
 | ||||||
|  | @ -53,7 +59,54 @@ type Config struct { | ||||||
| 	InterpolateParams       bool // Interpolate placeholders into query string
 | 	InterpolateParams       bool // Interpolate placeholders into query string
 | ||||||
| 	MultiStatements         bool // Allow multiple statements in one query
 | 	MultiStatements         bool // Allow multiple statements in one query
 | ||||||
| 	ParseTime               bool // Parse time values to time.Time
 | 	ParseTime               bool // Parse time values to time.Time
 | ||||||
| 	Strict                  bool // Return warnings as errors
 | 	RejectReadOnly          bool // Reject read-only connections
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewConfig creates a new Config and sets default values.
 | ||||||
|  | func NewConfig() *Config { | ||||||
|  | 	return &Config{ | ||||||
|  | 		Collation:            defaultCollation, | ||||||
|  | 		Loc:                  time.UTC, | ||||||
|  | 		MaxAllowedPacket:     defaultMaxAllowedPacket, | ||||||
|  | 		AllowNativePasswords: true, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (cfg *Config) normalize() error { | ||||||
|  | 	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | ||||||
|  | 		return errInvalidDSNUnsafeCollation | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set default network if empty
 | ||||||
|  | 	if cfg.Net == "" { | ||||||
|  | 		cfg.Net = "tcp" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set default address if empty
 | ||||||
|  | 	if cfg.Addr == "" { | ||||||
|  | 		switch cfg.Net { | ||||||
|  | 		case "tcp": | ||||||
|  | 			cfg.Addr = "127.0.0.1:3306" | ||||||
|  | 		case "unix": | ||||||
|  | 			cfg.Addr = "/tmp/mysql.sock" | ||||||
|  | 		default: | ||||||
|  | 			return errors.New("default addr for network '" + cfg.Net + "' unknown") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	} else if cfg.Net == "tcp" { | ||||||
|  | 		cfg.Addr = ensureHavePort(cfg.Addr) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if cfg.tls != nil { | ||||||
|  | 		if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { | ||||||
|  | 			host, _, err := net.SplitHostPort(cfg.Addr) | ||||||
|  | 			if err == nil { | ||||||
|  | 				cfg.tls.ServerName = host | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // FormatDSN formats the given Config into a DSN string which can be passed to
 | // FormatDSN formats the given Config into a DSN string which can be passed to
 | ||||||
|  | @ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.AllowNativePasswords { | 	if !cfg.AllowNativePasswords { | ||||||
| 		if hasParam { | 		if hasParam { | ||||||
| 			buf.WriteString("&allowNativePasswords=true") | 			buf.WriteString("&allowNativePasswords=false") | ||||||
| 		} else { | 		} else { | ||||||
| 			hasParam = true | 			hasParam = true | ||||||
| 			buf.WriteString("?allowNativePasswords=true") | 			buf.WriteString("?allowNativePasswords=false") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { | ||||||
| 		buf.WriteString(cfg.ReadTimeout.String()) | 		buf.WriteString(cfg.ReadTimeout.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.Strict { | 	if cfg.RejectReadOnly { | ||||||
| 		if hasParam { | 		if hasParam { | ||||||
| 			buf.WriteString("&strict=true") | 			buf.WriteString("&rejectReadOnly=true") | ||||||
| 		} else { | 		} else { | ||||||
| 			hasParam = true | 			hasParam = true | ||||||
| 			buf.WriteString("?strict=true") | 			buf.WriteString("?rejectReadOnly=true") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if len(cfg.ServerPubKey) > 0 { | ||||||
|  | 		if hasParam { | ||||||
|  | 			buf.WriteString("&serverPubKey=") | ||||||
|  | 		} else { | ||||||
|  | 			hasParam = true | ||||||
|  | 			buf.WriteString("?serverPubKey=") | ||||||
|  | 		} | ||||||
|  | 		buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if cfg.Timeout > 0 { | 	if cfg.Timeout > 0 { | ||||||
| 		if hasParam { | 		if hasParam { | ||||||
| 			buf.WriteString("&timeout=") | 			buf.WriteString("&timeout=") | ||||||
|  | @ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { | ||||||
| 		buf.WriteString(cfg.WriteTimeout.String()) | 		buf.WriteString(cfg.WriteTimeout.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.MaxAllowedPacket > 0 { | 	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { | ||||||
| 		if hasParam { | 		if hasParam { | ||||||
| 			buf.WriteString("&maxAllowedPacket=") | 			buf.WriteString("&maxAllowedPacket=") | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { | ||||||
| 
 | 
 | ||||||
| 	// other params
 | 	// other params
 | ||||||
| 	if cfg.Params != nil { | 	if cfg.Params != nil { | ||||||
| 		for param, value := range cfg.Params { | 		var params []string | ||||||
|  | 		for param := range cfg.Params { | ||||||
|  | 			params = append(params, param) | ||||||
|  | 		} | ||||||
|  | 		sort.Strings(params) | ||||||
|  | 		for _, param := range params { | ||||||
| 			if hasParam { | 			if hasParam { | ||||||
| 				buf.WriteByte('&') | 				buf.WriteByte('&') | ||||||
| 			} else { | 			} else { | ||||||
|  | @ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { | ||||||
| 
 | 
 | ||||||
| 			buf.WriteString(param) | 			buf.WriteString(param) | ||||||
| 			buf.WriteByte('=') | 			buf.WriteByte('=') | ||||||
| 			buf.WriteString(url.QueryEscape(value)) | 			buf.WriteString(url.QueryEscape(cfg.Params[param])) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { | ||||||
| // ParseDSN parses the DSN string to a Config
 | // ParseDSN parses the DSN string to a Config
 | ||||||
| func ParseDSN(dsn string) (cfg *Config, err error) { | func ParseDSN(dsn string) (cfg *Config, err error) { | ||||||
| 	// New config with some default values
 | 	// New config with some default values
 | ||||||
| 	cfg = &Config{ | 	cfg = NewConfig() | ||||||
| 		Loc:       time.UTC, |  | ||||||
| 		Collation: defaultCollation, |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
 | 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
 | ||||||
| 	// Find the last '/' (since the password or the net addr might contain a '/')
 | 	// Find the last '/' (since the password or the net addr might contain a '/')
 | ||||||
|  | @ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { | ||||||
| 		return nil, errInvalidDSNNoSlash | 		return nil, errInvalidDSNNoSlash | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | 	if err = cfg.normalize(); err != nil { | ||||||
| 		return nil, errInvalidDSNUnsafeCollation | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	// Set default network if empty
 |  | ||||||
| 	if cfg.Net == "" { |  | ||||||
| 		cfg.Net = "tcp" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Set default address if empty
 |  | ||||||
| 	if cfg.Addr == "" { |  | ||||||
| 		switch cfg.Net { |  | ||||||
| 		case "tcp": |  | ||||||
| 			cfg.Addr = "127.0.0.1:3306" |  | ||||||
| 		case "unix": |  | ||||||
| 			cfg.Addr = "/tmp/mysql.sock" |  | ||||||
| 		default: |  | ||||||
| 			return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { | ||||||
| 
 | 
 | ||||||
| 		// cfg params
 | 		// cfg params
 | ||||||
| 		switch value := param[1]; param[0] { | 		switch value := param[1]; param[0] { | ||||||
| 
 |  | ||||||
| 		// Disable INFILE whitelist / enable all files
 | 		// Disable INFILE whitelist / enable all files
 | ||||||
| 		case "allowAllFiles": | 		case "allowAllFiles": | ||||||
| 			var isBool bool | 			var isBool bool | ||||||
|  | @ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 		// Strict mode
 | 		// Reject read-only connections
 | ||||||
| 		case "strict": | 		case "rejectReadOnly": | ||||||
| 			var isBool bool | 			var isBool bool | ||||||
| 			cfg.Strict, isBool = readBool(value) | 			cfg.RejectReadOnly, isBool = readBool(value) | ||||||
| 			if !isBool { | 			if !isBool { | ||||||
| 				return errors.New("invalid bool value: " + value) | 				return errors.New("invalid bool value: " + value) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | 		// Server public key
 | ||||||
|  | 		case "serverPubKey": | ||||||
|  | 			name, err := url.QueryUnescape(value) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return fmt.Errorf("invalid value for server pub key name: %v", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if pubKey := getServerPubKey(name); pubKey != nil { | ||||||
|  | 				cfg.ServerPubKey = name | ||||||
|  | 				cfg.pubKey = pubKey | ||||||
|  | 			} else { | ||||||
|  | 				return errors.New("invalid value / unknown server pub key name: " + name) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 		// Strict mode
 | ||||||
|  | 		case "strict": | ||||||
|  | 			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") | ||||||
|  | 
 | ||||||
| 		// Dial Timeout
 | 		// Dial Timeout
 | ||||||
| 		case "timeout": | 		case "timeout": | ||||||
| 			cfg.Timeout, err = time.ParseDuration(value) | 			cfg.Timeout, err = time.ParseDuration(value) | ||||||
|  | @ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { | ||||||
| 					return fmt.Errorf("invalid value for TLS config name: %v", err) | 					return fmt.Errorf("invalid value for TLS config name: %v", err) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				if tlsConfig, ok := tlsConfigRegister[name]; ok { | 				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { | ||||||
| 					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { |  | ||||||
| 						host, _, err := net.SplitHostPort(cfg.Addr) |  | ||||||
| 						if err == nil { |  | ||||||
| 							tlsConfig.ServerName = host |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 
 |  | ||||||
| 					cfg.TLSConfig = name | 					cfg.TLSConfig = name | ||||||
| 					cfg.tls = tlsConfig | 					cfg.tls = tlsConfig | ||||||
| 				} else { | 				} else { | ||||||
|  | @ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { | ||||||
| 
 | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func ensureHavePort(addr string) string { | ||||||
|  | 	if _, _, err := net.SplitHostPort(addr); err != nil { | ||||||
|  | 		return net.JoinHostPort(addr, "3306") | ||||||
|  | 	} | ||||||
|  | 	return addr | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -9,10 +9,8 @@ | ||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" |  | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| ) | ) | ||||||
|  | @ -31,6 +29,12 @@ var ( | ||||||
| 	ErrPktSyncMul        = errors.New("commands out of sync. Did you run multiple statements at once?") | 	ErrPktSyncMul        = errors.New("commands out of sync. Did you run multiple statements at once?") | ||||||
| 	ErrPktTooLarge       = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | 	ErrPktTooLarge       = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | ||||||
| 	ErrBusyBuffer        = errors.New("busy buffer") | 	ErrBusyBuffer        = errors.New("busy buffer") | ||||||
|  | 
 | ||||||
|  | 	// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
 | ||||||
|  | 	// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
 | ||||||
|  | 	// to trigger a resend.
 | ||||||
|  | 	// See https://github.com/go-sql-driver/mysql/pull/302
 | ||||||
|  | 	errBadConnNoWrite = errors.New("bad connection") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | ||||||
|  | @ -59,74 +63,3 @@ type MySQLError struct { | ||||||
| func (me *MySQLError) Error() string { | func (me *MySQLError) Error() string { | ||||||
| 	return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | 	return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // MySQLWarnings is an error type which represents a group of one or more MySQL
 |  | ||||||
| // warnings
 |  | ||||||
| type MySQLWarnings []MySQLWarning |  | ||||||
| 
 |  | ||||||
| func (mws MySQLWarnings) Error() string { |  | ||||||
| 	var msg string |  | ||||||
| 	for i, warning := range mws { |  | ||||||
| 		if i > 0 { |  | ||||||
| 			msg += "\r\n" |  | ||||||
| 		} |  | ||||||
| 		msg += fmt.Sprintf( |  | ||||||
| 			"%s %s: %s", |  | ||||||
| 			warning.Level, |  | ||||||
| 			warning.Code, |  | ||||||
| 			warning.Message, |  | ||||||
| 		) |  | ||||||
| 	} |  | ||||||
| 	return msg |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // MySQLWarning is an error type which represents a single MySQL warning.
 |  | ||||||
| // Warnings are returned in groups only. See MySQLWarnings
 |  | ||||||
| type MySQLWarning struct { |  | ||||||
| 	Level   string |  | ||||||
| 	Code    string |  | ||||||
| 	Message string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (mc *mysqlConn) getWarnings() (err error) { |  | ||||||
| 	rows, err := mc.Query("SHOW WARNINGS", nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var warnings = MySQLWarnings{} |  | ||||||
| 	var values = make([]driver.Value, 3) |  | ||||||
| 
 |  | ||||||
| 	for { |  | ||||||
| 		err = rows.Next(values) |  | ||||||
| 		switch err { |  | ||||||
| 		case nil: |  | ||||||
| 			warning := MySQLWarning{} |  | ||||||
| 
 |  | ||||||
| 			if raw, ok := values[0].([]byte); ok { |  | ||||||
| 				warning.Level = string(raw) |  | ||||||
| 			} else { |  | ||||||
| 				warning.Level = fmt.Sprintf("%s", values[0]) |  | ||||||
| 			} |  | ||||||
| 			if raw, ok := values[1].([]byte); ok { |  | ||||||
| 				warning.Code = string(raw) |  | ||||||
| 			} else { |  | ||||||
| 				warning.Code = fmt.Sprintf("%s", values[1]) |  | ||||||
| 			} |  | ||||||
| 			if raw, ok := values[2].([]byte); ok { |  | ||||||
| 				warning.Message = string(raw) |  | ||||||
| 			} else { |  | ||||||
| 				warning.Message = fmt.Sprintf("%s", values[0]) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			warnings = append(warnings, warning) |  | ||||||
| 
 |  | ||||||
| 		case io.EOF: |  | ||||||
| 			return warnings |  | ||||||
| 
 |  | ||||||
| 		default: |  | ||||||
| 			rows.Close() |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -0,0 +1,194 @@ | ||||||
|  | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // This Source Code Form is subject to the terms of the Mozilla Public
 | ||||||
|  | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
|  | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
|  | 
 | ||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (mf *mysqlField) typeDatabaseName() string { | ||||||
|  | 	switch mf.fieldType { | ||||||
|  | 	case fieldTypeBit: | ||||||
|  | 		return "BIT" | ||||||
|  | 	case fieldTypeBLOB: | ||||||
|  | 		if mf.charSet != collations[binaryCollation] { | ||||||
|  | 			return "TEXT" | ||||||
|  | 		} | ||||||
|  | 		return "BLOB" | ||||||
|  | 	case fieldTypeDate: | ||||||
|  | 		return "DATE" | ||||||
|  | 	case fieldTypeDateTime: | ||||||
|  | 		return "DATETIME" | ||||||
|  | 	case fieldTypeDecimal: | ||||||
|  | 		return "DECIMAL" | ||||||
|  | 	case fieldTypeDouble: | ||||||
|  | 		return "DOUBLE" | ||||||
|  | 	case fieldTypeEnum: | ||||||
|  | 		return "ENUM" | ||||||
|  | 	case fieldTypeFloat: | ||||||
|  | 		return "FLOAT" | ||||||
|  | 	case fieldTypeGeometry: | ||||||
|  | 		return "GEOMETRY" | ||||||
|  | 	case fieldTypeInt24: | ||||||
|  | 		return "MEDIUMINT" | ||||||
|  | 	case fieldTypeJSON: | ||||||
|  | 		return "JSON" | ||||||
|  | 	case fieldTypeLong: | ||||||
|  | 		return "INT" | ||||||
|  | 	case fieldTypeLongBLOB: | ||||||
|  | 		if mf.charSet != collations[binaryCollation] { | ||||||
|  | 			return "LONGTEXT" | ||||||
|  | 		} | ||||||
|  | 		return "LONGBLOB" | ||||||
|  | 	case fieldTypeLongLong: | ||||||
|  | 		return "BIGINT" | ||||||
|  | 	case fieldTypeMediumBLOB: | ||||||
|  | 		if mf.charSet != collations[binaryCollation] { | ||||||
|  | 			return "MEDIUMTEXT" | ||||||
|  | 		} | ||||||
|  | 		return "MEDIUMBLOB" | ||||||
|  | 	case fieldTypeNewDate: | ||||||
|  | 		return "DATE" | ||||||
|  | 	case fieldTypeNewDecimal: | ||||||
|  | 		return "DECIMAL" | ||||||
|  | 	case fieldTypeNULL: | ||||||
|  | 		return "NULL" | ||||||
|  | 	case fieldTypeSet: | ||||||
|  | 		return "SET" | ||||||
|  | 	case fieldTypeShort: | ||||||
|  | 		return "SMALLINT" | ||||||
|  | 	case fieldTypeString: | ||||||
|  | 		if mf.charSet == collations[binaryCollation] { | ||||||
|  | 			return "BINARY" | ||||||
|  | 		} | ||||||
|  | 		return "CHAR" | ||||||
|  | 	case fieldTypeTime: | ||||||
|  | 		return "TIME" | ||||||
|  | 	case fieldTypeTimestamp: | ||||||
|  | 		return "TIMESTAMP" | ||||||
|  | 	case fieldTypeTiny: | ||||||
|  | 		return "TINYINT" | ||||||
|  | 	case fieldTypeTinyBLOB: | ||||||
|  | 		if mf.charSet != collations[binaryCollation] { | ||||||
|  | 			return "TINYTEXT" | ||||||
|  | 		} | ||||||
|  | 		return "TINYBLOB" | ||||||
|  | 	case fieldTypeVarChar: | ||||||
|  | 		if mf.charSet == collations[binaryCollation] { | ||||||
|  | 			return "VARBINARY" | ||||||
|  | 		} | ||||||
|  | 		return "VARCHAR" | ||||||
|  | 	case fieldTypeVarString: | ||||||
|  | 		if mf.charSet == collations[binaryCollation] { | ||||||
|  | 			return "VARBINARY" | ||||||
|  | 		} | ||||||
|  | 		return "VARCHAR" | ||||||
|  | 	case fieldTypeYear: | ||||||
|  | 		return "YEAR" | ||||||
|  | 	default: | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	scanTypeFloat32   = reflect.TypeOf(float32(0)) | ||||||
|  | 	scanTypeFloat64   = reflect.TypeOf(float64(0)) | ||||||
|  | 	scanTypeInt8      = reflect.TypeOf(int8(0)) | ||||||
|  | 	scanTypeInt16     = reflect.TypeOf(int16(0)) | ||||||
|  | 	scanTypeInt32     = reflect.TypeOf(int32(0)) | ||||||
|  | 	scanTypeInt64     = reflect.TypeOf(int64(0)) | ||||||
|  | 	scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | ||||||
|  | 	scanTypeNullInt   = reflect.TypeOf(sql.NullInt64{}) | ||||||
|  | 	scanTypeNullTime  = reflect.TypeOf(NullTime{}) | ||||||
|  | 	scanTypeUint8     = reflect.TypeOf(uint8(0)) | ||||||
|  | 	scanTypeUint16    = reflect.TypeOf(uint16(0)) | ||||||
|  | 	scanTypeUint32    = reflect.TypeOf(uint32(0)) | ||||||
|  | 	scanTypeUint64    = reflect.TypeOf(uint64(0)) | ||||||
|  | 	scanTypeRawBytes  = reflect.TypeOf(sql.RawBytes{}) | ||||||
|  | 	scanTypeUnknown   = reflect.TypeOf(new(interface{})) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type mysqlField struct { | ||||||
|  | 	tableName string | ||||||
|  | 	name      string | ||||||
|  | 	length    uint32 | ||||||
|  | 	flags     fieldFlag | ||||||
|  | 	fieldType fieldType | ||||||
|  | 	decimals  byte | ||||||
|  | 	charSet   uint8 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (mf *mysqlField) scanType() reflect.Type { | ||||||
|  | 	switch mf.fieldType { | ||||||
|  | 	case fieldTypeTiny: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			if mf.flags&flagUnsigned != 0 { | ||||||
|  | 				return scanTypeUint8 | ||||||
|  | 			} | ||||||
|  | 			return scanTypeInt8 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullInt | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeShort, fieldTypeYear: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			if mf.flags&flagUnsigned != 0 { | ||||||
|  | 				return scanTypeUint16 | ||||||
|  | 			} | ||||||
|  | 			return scanTypeInt16 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullInt | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeInt24, fieldTypeLong: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			if mf.flags&flagUnsigned != 0 { | ||||||
|  | 				return scanTypeUint32 | ||||||
|  | 			} | ||||||
|  | 			return scanTypeInt32 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullInt | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeLongLong: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			if mf.flags&flagUnsigned != 0 { | ||||||
|  | 				return scanTypeUint64 | ||||||
|  | 			} | ||||||
|  | 			return scanTypeInt64 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullInt | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeFloat: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			return scanTypeFloat32 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullFloat | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeDouble: | ||||||
|  | 		if mf.flags&flagNotNULL != 0 { | ||||||
|  | 			return scanTypeFloat64 | ||||||
|  | 		} | ||||||
|  | 		return scanTypeNullFloat | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, | ||||||
|  | 		fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, | ||||||
|  | 		fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, | ||||||
|  | 		fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, | ||||||
|  | 		fieldTypeTime: | ||||||
|  | 		return scanTypeRawBytes | ||||||
|  | 
 | ||||||
|  | 	case fieldTypeDate, fieldTypeNewDate, | ||||||
|  | 		fieldTypeTimestamp, fieldTypeDateTime: | ||||||
|  | 		// NullTime is always returned for more consistent behavior as it can
 | ||||||
|  | 		// handle both cases of parseTime regardless if the field is nullable.
 | ||||||
|  | 		return scanTypeNullTime | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return scanTypeUnknown | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// send content packets
 | 	// send content packets
 | ||||||
| 	if err == nil { | 	// if packetSize == 0, the Reader contains no data
 | ||||||
|  | 	if err == nil && packetSize > 0 { | ||||||
| 		data := make([]byte, 4+packetSize) | 		data := make([]byte, 4+packetSize) | ||||||
| 		var n int | 		var n int | ||||||
| 		for err == nil { | 		for err == nil { | ||||||
|  | @ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | ||||||
| 
 | 
 | ||||||
| 	// read OK packet
 | 	// read OK packet
 | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		_, err = mc.readResultOK() | 		return mc.readResultOK() | ||||||
| 		return err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mc.readPacket() | 	mc.readPacket() | ||||||
|  |  | ||||||
|  | @ -25,26 +25,23 @@ import ( | ||||||
| 
 | 
 | ||||||
| // Read packet to buffer 'data'
 | // Read packet to buffer 'data'
 | ||||||
| func (mc *mysqlConn) readPacket() ([]byte, error) { | func (mc *mysqlConn) readPacket() ([]byte, error) { | ||||||
| 	var payload []byte | 	var prevData []byte | ||||||
| 	for { | 	for { | ||||||
| 		// Read packet header
 | 		// read packet header
 | ||||||
| 		data, err := mc.buf.readNext(4) | 		data, err := mc.buf.readNext(4) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			if cerr := mc.canceled.Value(); cerr != nil { | ||||||
|  | 				return nil, cerr | ||||||
|  | 			} | ||||||
| 			errLog.Print(err) | 			errLog.Print(err) | ||||||
| 			mc.Close() | 			mc.Close() | ||||||
| 			return nil, driver.ErrBadConn | 			return nil, ErrInvalidConn | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Packet Length [24 bit]
 | 		// packet length [24 bit]
 | ||||||
| 		pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | 		pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | ||||||
| 
 | 
 | ||||||
| 		if pktLen < 1 { | 		// check packet sync [8 bit]
 | ||||||
| 			errLog.Print(ErrMalformPkt) |  | ||||||
| 			mc.Close() |  | ||||||
| 			return nil, driver.ErrBadConn |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Check Packet Sync [8 bit]
 |  | ||||||
| 		if data[3] != mc.sequence { | 		if data[3] != mc.sequence { | ||||||
| 			if data[3] > mc.sequence { | 			if data[3] > mc.sequence { | ||||||
| 				return nil, ErrPktSyncMul | 				return nil, ErrPktSyncMul | ||||||
|  | @ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { | ||||||
| 		} | 		} | ||||||
| 		mc.sequence++ | 		mc.sequence++ | ||||||
| 
 | 
 | ||||||
| 		// Read packet body [pktLen bytes]
 | 		// packets with length 0 terminate a previous packet which is a
 | ||||||
|  | 		// multiple of (2^24)−1 bytes long
 | ||||||
|  | 		if pktLen == 0 { | ||||||
|  | 			// there was no previous packet
 | ||||||
|  | 			if prevData == nil { | ||||||
|  | 				errLog.Print(ErrMalformPkt) | ||||||
|  | 				mc.Close() | ||||||
|  | 				return nil, ErrInvalidConn | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return prevData, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// read packet body [pktLen bytes]
 | ||||||
| 		data, err = mc.buf.readNext(pktLen) | 		data, err = mc.buf.readNext(pktLen) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			if cerr := mc.canceled.Value(); cerr != nil { | ||||||
|  | 				return nil, cerr | ||||||
|  | 			} | ||||||
| 			errLog.Print(err) | 			errLog.Print(err) | ||||||
| 			mc.Close() | 			mc.Close() | ||||||
| 			return nil, driver.ErrBadConn | 			return nil, ErrInvalidConn | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		isLastPacket := (pktLen < maxPacketSize) | 		// return data if this was the last packet
 | ||||||
|  | 		if pktLen < maxPacketSize { | ||||||
|  | 			// zero allocations for non-split packets
 | ||||||
|  | 			if prevData == nil { | ||||||
|  | 				return data, nil | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 		// Zero allocations for non-splitting packets
 | 			return append(prevData, data...), nil | ||||||
| 		if isLastPacket && payload == nil { |  | ||||||
| 			return data, nil |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		payload = append(payload, data...) | 		prevData = append(prevData, data...) | ||||||
| 
 |  | ||||||
| 		if isLastPacket { |  | ||||||
| 			return payload, nil |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { | ||||||
| 
 | 
 | ||||||
| 		// Handle error
 | 		// Handle error
 | ||||||
| 		if err == nil { // n != len(data)
 | 		if err == nil { // n != len(data)
 | ||||||
|  | 			mc.cleanup() | ||||||
| 			errLog.Print(ErrMalformPkt) | 			errLog.Print(ErrMalformPkt) | ||||||
| 		} else { | 		} else { | ||||||
|  | 			if cerr := mc.canceled.Value(); cerr != nil { | ||||||
|  | 				return cerr | ||||||
|  | 			} | ||||||
|  | 			if n == 0 && pktLen == len(data)-4 { | ||||||
|  | 				// only for the first loop iteration when nothing was written yet
 | ||||||
|  | 				return errBadConnNoWrite | ||||||
|  | 			} | ||||||
|  | 			mc.cleanup() | ||||||
| 			errLog.Print(err) | 			errLog.Print(err) | ||||||
| 		} | 		} | ||||||
| 		return driver.ErrBadConn | 		return ErrInvalidConn | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /****************************************************************************** | /****************************************************************************** | ||||||
| *                           Initialisation Process                            * | *                           Initialization Process                            * | ||||||
| ******************************************************************************/ | ******************************************************************************/ | ||||||
| 
 | 
 | ||||||
| // Handshake Initialization Packet
 | // Handshake Initialization Packet
 | ||||||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
 | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
 | ||||||
| func (mc *mysqlConn) readInitPacket() ([]byte, error) { | func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { | ||||||
| 	data, err := mc.readPacket() | 	data, err := mc.readPacket() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
 | ||||||
|  | 		// in connection initialization we don't risk retrying non-idempotent actions.
 | ||||||
|  | 		if err == ErrInvalidConn { | ||||||
|  | 			return nil, "", driver.ErrBadConn | ||||||
|  | 		} | ||||||
|  | 		return nil, "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if data[0] == iERR { | 	if data[0] == iERR { | ||||||
| 		return nil, mc.handleErrorPacket(data) | 		return nil, "", mc.handleErrorPacket(data) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// protocol version [1 byte]
 | 	// protocol version [1 byte]
 | ||||||
| 	if data[0] < minProtocolVersion { | 	if data[0] < minProtocolVersion { | ||||||
| 		return nil, fmt.Errorf( | 		return nil, "", fmt.Errorf( | ||||||
| 			"unsupported protocol version %d. Version %d or higher is required", | 			"unsupported protocol version %d. Version %d or higher is required", | ||||||
| 			data[0], | 			data[0], | ||||||
| 			minProtocolVersion, | 			minProtocolVersion, | ||||||
|  | @ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | ||||||
| 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | ||||||
| 
 | 
 | ||||||
| 	// first part of the password cipher [8 bytes]
 | 	// first part of the password cipher [8 bytes]
 | ||||||
| 	cipher := data[pos : pos+8] | 	authData := data[pos : pos+8] | ||||||
| 
 | 
 | ||||||
| 	// (filler) always 0x00 [1 byte]
 | 	// (filler) always 0x00 [1 byte]
 | ||||||
| 	pos += 8 + 1 | 	pos += 8 + 1 | ||||||
|  | @ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | ||||||
| 	// capability flags (lower 2 bytes) [2 bytes]
 | 	// capability flags (lower 2 bytes) [2 bytes]
 | ||||||
| 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | ||||||
| 	if mc.flags&clientProtocol41 == 0 { | 	if mc.flags&clientProtocol41 == 0 { | ||||||
| 		return nil, ErrOldProtocol | 		return nil, "", ErrOldProtocol | ||||||
| 	} | 	} | ||||||
| 	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | 	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | ||||||
| 		return nil, ErrNoTLS | 		return nil, "", ErrNoTLS | ||||||
| 	} | 	} | ||||||
| 	pos += 2 | 	pos += 2 | ||||||
| 
 | 
 | ||||||
|  | 	plugin := "" | ||||||
| 	if len(data) > pos { | 	if len(data) > pos { | ||||||
| 		// character set [1 byte]
 | 		// character set [1 byte]
 | ||||||
| 		// status flags [2 bytes]
 | 		// status flags [2 bytes]
 | ||||||
|  | @ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | ||||||
| 		//
 | 		//
 | ||||||
| 		// The official Python library uses the fixed length 12
 | 		// The official Python library uses the fixed length 12
 | ||||||
| 		// which seems to work but technically could have a hidden bug.
 | 		// which seems to work but technically could have a hidden bug.
 | ||||||
| 		cipher = append(cipher, data[pos:pos+12]...) | 		authData = append(authData, data[pos:pos+12]...) | ||||||
|  | 		pos += 13 | ||||||
| 
 | 
 | ||||||
| 		// TODO: Verify string termination
 |  | ||||||
| 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
 | 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
 | ||||||
| 		// \NUL otherwise
 | 		// \NUL otherwise
 | ||||||
| 		//
 | 		if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { | ||||||
| 		//if data[len(data)-1] == 0 {
 | 			plugin = string(data[pos : pos+end]) | ||||||
| 		//	return
 | 		} else { | ||||||
| 		//}
 | 			plugin = string(data[pos:]) | ||||||
| 		//return ErrMalformPkt
 | 		} | ||||||
| 
 | 
 | ||||||
| 		// make a memory safe copy of the cipher slice
 | 		// make a memory safe copy of the cipher slice
 | ||||||
| 		var b [20]byte | 		var b [20]byte | ||||||
| 		copy(b[:], cipher) | 		copy(b[:], authData) | ||||||
| 		return b[:], nil | 		return b[:], plugin, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	plugin = defaultAuthPlugin | ||||||
|  | 
 | ||||||
| 	// make a memory safe copy of the cipher slice
 | 	// make a memory safe copy of the cipher slice
 | ||||||
| 	var b [8]byte | 	var b [8]byte | ||||||
| 	copy(b[:], cipher) | 	copy(b[:], authData) | ||||||
| 	return b[:], nil | 	return b[:], plugin, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Client Authentication Packet
 | // Client Authentication Packet
 | ||||||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
 | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
 | ||||||
| func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { | ||||||
| 	// Adjust client flags based on server support
 | 	// Adjust client flags based on server support
 | ||||||
| 	clientFlags := clientProtocol41 | | 	clientFlags := clientProtocol41 | | ||||||
| 		clientSecureConn | | 		clientSecureConn | | ||||||
|  | @ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | ||||||
| 		clientFlags |= clientMultiStatements | 		clientFlags |= clientMultiStatements | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// User Password
 | 	// encode length of the auth plugin data
 | ||||||
| 	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | 	var authRespLEIBuf [9]byte | ||||||
|  | 	authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) | ||||||
|  | 	if len(authRespLEI) > 1 { | ||||||
|  | 		// if the length can not be written in 1 byte, it must be written as a
 | ||||||
|  | 		// length encoded integer
 | ||||||
|  | 		clientFlags |= clientPluginAuthLenEncClientData | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 | 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 | ||||||
|  | 	if addNUL { | ||||||
|  | 		pktLen++ | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// To specify a db name
 | 	// To specify a db name
 | ||||||
| 	if n := len(mc.cfg.DBName); n > 0 { | 	if n := len(mc.cfg.DBName); n > 0 { | ||||||
|  | @ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | ||||||
| 	// Calculate packet length and get buffer with that size
 | 	// Calculate packet length and get buffer with that size
 | ||||||
| 	data := mc.buf.takeSmallBuffer(pktLen + 4) | 	data := mc.buf.takeSmallBuffer(pktLen + 4) | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// ClientFlags [32 bit]
 | 	// ClientFlags [32 bit]
 | ||||||
|  | @ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | ||||||
| 	data[pos] = 0x00 | 	data[pos] = 0x00 | ||||||
| 	pos++ | 	pos++ | ||||||
| 
 | 
 | ||||||
| 	// ScrambleBuffer [length encoded integer]
 | 	// Auth Data [length encoded integer]
 | ||||||
| 	data[pos] = byte(len(scrambleBuff)) | 	pos += copy(data[pos:], authRespLEI) | ||||||
| 	pos += 1 + copy(data[pos+1:], scrambleBuff) | 	pos += copy(data[pos:], authResp) | ||||||
|  | 	if addNUL { | ||||||
|  | 		data[pos] = 0x00 | ||||||
|  | 		pos++ | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Databasename [null terminated string]
 | 	// Databasename [null terminated string]
 | ||||||
| 	if len(mc.cfg.DBName) > 0 { | 	if len(mc.cfg.DBName) > 0 { | ||||||
|  | @ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | ||||||
| 		pos++ | 		pos++ | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Assume native client during response
 | 	pos += copy(data[pos:], plugin) | ||||||
| 	pos += copy(data[pos:], "mysql_native_password") |  | ||||||
| 	data[pos] = 0x00 | 	data[pos] = 0x00 | ||||||
| 
 | 
 | ||||||
| 	// Send Auth packet
 | 	// Send Auth packet
 | ||||||
| 	return mc.writePacket(data) | 	return mc.writePacket(data) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //  Client old authentication packet
 |  | ||||||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 | ||||||
| func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { | func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { | ||||||
| 	// User password
 | 	pktLen := 4 + len(authData) | ||||||
| 	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) | 	if addNUL { | ||||||
| 
 | 		pktLen++ | ||||||
| 	// Calculate the packet length and add a tailing 0
 | 	} | ||||||
| 	pktLen := len(scrambleBuff) + 1 | 	data := mc.buf.takeSmallBuffer(pktLen) | ||||||
| 	data := mc.buf.takeSmallBuffer(4 + pktLen) |  | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add the scrambled password [null terminated string]
 | 	// Add the auth data [EOF]
 | ||||||
| 	copy(data[4:], scrambleBuff) | 	copy(data[4:], authData) | ||||||
| 	data[4+pktLen-1] = 0x00 | 	if addNUL { | ||||||
| 
 | 		data[pktLen-1] = 0x00 | ||||||
| 	return mc.writePacket(data) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //  Client clear text authentication packet
 |  | ||||||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 |  | ||||||
| func (mc *mysqlConn) writeClearAuthPacket() error { |  | ||||||
| 	// Calculate the packet length and add a tailing 0
 |  | ||||||
| 	pktLen := len(mc.cfg.Passwd) + 1 |  | ||||||
| 	data := mc.buf.takeSmallBuffer(4 + pktLen) |  | ||||||
| 	if data == nil { |  | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 |  | ||||||
| 		errLog.Print(ErrBusyBuffer) |  | ||||||
| 		return driver.ErrBadConn |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add the clear password [null terminated string]
 |  | ||||||
| 	copy(data[4:], mc.cfg.Passwd) |  | ||||||
| 	data[4+pktLen-1] = 0x00 |  | ||||||
| 
 |  | ||||||
| 	return mc.writePacket(data) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //  Native password authentication method
 |  | ||||||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 |  | ||||||
| func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { |  | ||||||
| 	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) |  | ||||||
| 
 |  | ||||||
| 	// Calculate the packet length and add a tailing 0
 |  | ||||||
| 	pktLen := len(scrambleBuff) |  | ||||||
| 	data := mc.buf.takeSmallBuffer(4 + pktLen) |  | ||||||
| 	if data == nil { |  | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 |  | ||||||
| 		errLog.Print(ErrBusyBuffer) |  | ||||||
| 		return driver.ErrBadConn |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Add the scramble
 |  | ||||||
| 	copy(data[4:], scrambleBuff) |  | ||||||
| 
 |  | ||||||
| 	return mc.writePacket(data) | 	return mc.writePacket(data) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { | ||||||
| 
 | 
 | ||||||
| 	data := mc.buf.takeSmallBuffer(4 + 1) | 	data := mc.buf.takeSmallBuffer(4 + 1) | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add command byte
 | 	// Add command byte
 | ||||||
|  | @ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { | ||||||
| 	pktLen := 1 + len(arg) | 	pktLen := 1 + len(arg) | ||||||
| 	data := mc.buf.takeBuffer(pktLen + 4) | 	data := mc.buf.takeBuffer(pktLen + 4) | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add command byte
 | 	// Add command byte
 | ||||||
|  | @ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | ||||||
| 
 | 
 | ||||||
| 	data := mc.buf.takeSmallBuffer(4 + 1 + 4) | 	data := mc.buf.takeSmallBuffer(4 + 1 + 4) | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add command byte
 | 	// Add command byte
 | ||||||
|  | @ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | ||||||
| *                              Result Packets                                 * | *                              Result Packets                                 * | ||||||
| ******************************************************************************/ | ******************************************************************************/ | ||||||
| 
 | 
 | ||||||
| // Returns error if Packet is not an 'Result OK'-Packet
 | func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { | ||||||
| func (mc *mysqlConn) readResultOK() ([]byte, error) { |  | ||||||
| 	data, err := mc.readPacket() | 	data, err := mc.readPacket() | ||||||
| 	if err == nil { | 	if err != nil { | ||||||
| 		// packet indicator
 | 		return nil, "", err | ||||||
| 		switch data[0] { |  | ||||||
| 
 |  | ||||||
| 		case iOK: |  | ||||||
| 			return nil, mc.handleOkPacket(data) |  | ||||||
| 
 |  | ||||||
| 		case iEOF: |  | ||||||
| 			if len(data) > 1 { |  | ||||||
| 				pluginEndIndex := bytes.IndexByte(data, 0x00) |  | ||||||
| 				plugin := string(data[1:pluginEndIndex]) |  | ||||||
| 				cipher := data[pluginEndIndex+1 : len(data)-1] |  | ||||||
| 
 |  | ||||||
| 				if plugin == "mysql_old_password" { |  | ||||||
| 					// using old_passwords
 |  | ||||||
| 					return cipher, ErrOldPassword |  | ||||||
| 				} else if plugin == "mysql_clear_password" { |  | ||||||
| 					// using clear text password
 |  | ||||||
| 					return cipher, ErrCleartextPassword |  | ||||||
| 				} else if plugin == "mysql_native_password" { |  | ||||||
| 					// using mysql default authentication method
 |  | ||||||
| 					return cipher, ErrNativePassword |  | ||||||
| 				} else { |  | ||||||
| 					return cipher, ErrUnknownPlugin |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				return nil, ErrOldPassword |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 		default: // Error otherwise
 |  | ||||||
| 			return nil, mc.handleErrorPacket(data) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	return nil, err | 
 | ||||||
|  | 	// packet indicator
 | ||||||
|  | 	switch data[0] { | ||||||
|  | 
 | ||||||
|  | 	case iOK: | ||||||
|  | 		return nil, "", mc.handleOkPacket(data) | ||||||
|  | 
 | ||||||
|  | 	case iAuthMoreData: | ||||||
|  | 		return data[1:], "", err | ||||||
|  | 
 | ||||||
|  | 	case iEOF: | ||||||
|  | 		if len(data) < 1 { | ||||||
|  | 			// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
 | ||||||
|  | 			return nil, "mysql_old_password", nil | ||||||
|  | 		} | ||||||
|  | 		pluginEndIndex := bytes.IndexByte(data, 0x00) | ||||||
|  | 		if pluginEndIndex < 0 { | ||||||
|  | 			return nil, "", ErrMalformPkt | ||||||
|  | 		} | ||||||
|  | 		plugin := string(data[1:pluginEndIndex]) | ||||||
|  | 		authData := data[pluginEndIndex+1:] | ||||||
|  | 		return authData, plugin, nil | ||||||
|  | 
 | ||||||
|  | 	default: // Error otherwise
 | ||||||
|  | 		return nil, "", mc.handleErrorPacket(data) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns error if Packet is not an 'Result OK'-Packet
 | ||||||
|  | func (mc *mysqlConn) readResultOK() error { | ||||||
|  | 	data, err := mc.readPacket() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if data[0] == iOK { | ||||||
|  | 		return mc.handleOkPacket(data) | ||||||
|  | 	} | ||||||
|  | 	return mc.handleErrorPacket(data) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Result Set Header Packet
 | // Result Set Header Packet
 | ||||||
|  | @ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { | ||||||
| 	// Error Number [16 bit uint]
 | 	// Error Number [16 bit uint]
 | ||||||
| 	errno := binary.LittleEndian.Uint16(data[1:3]) | 	errno := binary.LittleEndian.Uint16(data[1:3]) | ||||||
| 
 | 
 | ||||||
|  | 	// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
 | ||||||
|  | 	// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
 | ||||||
|  | 	if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { | ||||||
|  | 		// Oops; we are connected to a read-only connection, and won't be able
 | ||||||
|  | 		// to issue any write statements. Since RejectReadOnly is configured,
 | ||||||
|  | 		// we throw away this connection hoping this one would have write
 | ||||||
|  | 		// permission. This is specifically for a possible race condition
 | ||||||
|  | 		// during failover (e.g. on AWS Aurora). See README.md for more.
 | ||||||
|  | 		//
 | ||||||
|  | 		// We explicitly close the connection before returning
 | ||||||
|  | 		// driver.ErrBadConn to ensure that `database/sql` purges this
 | ||||||
|  | 		// connection and initiates a new one for next statement next time.
 | ||||||
|  | 		mc.Close() | ||||||
|  | 		return driver.ErrBadConn | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	pos := 3 | 	pos := 3 | ||||||
| 
 | 
 | ||||||
| 	// SQL State [optional: # + 5bytes string]
 | 	// SQL State [optional: # + 5bytes string]
 | ||||||
|  | @ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { | ||||||
| 
 | 
 | ||||||
| 	// server_status [2 bytes]
 | 	// server_status [2 bytes]
 | ||||||
| 	mc.status = readStatus(data[1+n+m : 1+n+m+2]) | 	mc.status = readStatus(data[1+n+m : 1+n+m+2]) | ||||||
| 	if err := mc.discardResults(); err != nil { | 	if mc.status&statusMoreResultsExists != 0 { | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// warning count [2 bytes]
 |  | ||||||
| 	if !mc.strict { |  | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	pos := 1 + n + m + 2 | 	// warning count [2 bytes]
 | ||||||
| 	if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { | 
 | ||||||
| 		return mc.getWarnings() |  | ||||||
| 	} |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | 		pos += n | ||||||
| 
 | 
 | ||||||
| 		// Filler [uint8]
 | 		// Filler [uint8]
 | ||||||
|  | 		pos++ | ||||||
|  | 
 | ||||||
| 		// Charset [charset, collation uint8]
 | 		// Charset [charset, collation uint8]
 | ||||||
|  | 		columns[i].charSet = data[pos] | ||||||
|  | 		pos += 2 | ||||||
|  | 
 | ||||||
| 		// Length [uint32]
 | 		// Length [uint32]
 | ||||||
| 		pos += n + 1 + 2 + 4 | 		columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) | ||||||
|  | 		pos += 4 | ||||||
| 
 | 
 | ||||||
| 		// Field type [uint8]
 | 		// Field type [uint8]
 | ||||||
| 		columns[i].fieldType = data[pos] | 		columns[i].fieldType = fieldType(data[pos]) | ||||||
| 		pos++ | 		pos++ | ||||||
| 
 | 
 | ||||||
| 		// Flags [uint16]
 | 		// Flags [uint16]
 | ||||||
|  | @ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | ||||||
| func (rows *textRows) readRow(dest []driver.Value) error { | func (rows *textRows) readRow(dest []driver.Value) error { | ||||||
| 	mc := rows.mc | 	mc := rows.mc | ||||||
| 
 | 
 | ||||||
|  | 	if rows.rs.done { | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	data, err := mc.readPacket() | 	data, err := mc.readPacket() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|  | @ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { | ||||||
| 	if data[0] == iEOF && len(data) == 5 { | 	if data[0] == iEOF && len(data) == 5 { | ||||||
| 		// server_status [2 bytes]
 | 		// server_status [2 bytes]
 | ||||||
| 		rows.mc.status = readStatus(data[3:]) | 		rows.mc.status = readStatus(data[3:]) | ||||||
| 		if err := rows.mc.discardResults(); err != nil { | 		rows.rs.done = true | ||||||
| 			return err | 		if !rows.HasNextResultSet() { | ||||||
|  | 			rows.mc = nil | ||||||
| 		} | 		} | ||||||
| 		rows.mc = nil |  | ||||||
| 		return io.EOF | 		return io.EOF | ||||||
| 	} | 	} | ||||||
| 	if data[0] == iERR { | 	if data[0] == iERR { | ||||||
|  | @ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { | ||||||
| 				if !mc.parseTime { | 				if !mc.parseTime { | ||||||
| 					continue | 					continue | ||||||
| 				} else { | 				} else { | ||||||
| 					switch rows.columns[i].fieldType { | 					switch rows.rs.columns[i].fieldType { | ||||||
| 					case fieldTypeTimestamp, fieldTypeDateTime, | 					case fieldTypeTimestamp, fieldTypeDateTime, | ||||||
| 						fieldTypeDate, fieldTypeNewDate: | 						fieldTypeDate, fieldTypeNewDate: | ||||||
| 						dest[i], err = parseDateTime( | 						dest[i], err = parseDateTime( | ||||||
|  | @ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { | ||||||
| 		// Reserved [8 bit]
 | 		// Reserved [8 bit]
 | ||||||
| 
 | 
 | ||||||
| 		// Warning count [16 bit uint]
 | 		// Warning count [16 bit uint]
 | ||||||
| 		if !stmt.mc.strict { |  | ||||||
| 			return columnCount, nil |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		// Check for warnings count > 0, only available in MySQL > 4.1
 |  | ||||||
| 		if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { |  | ||||||
| 			return columnCount, stmt.mc.getWarnings() |  | ||||||
| 		} |  | ||||||
| 		return columnCount, nil | 		return columnCount, nil | ||||||
| 	} | 	} | ||||||
| 	return 0, err | 	return 0, err | ||||||
|  | @ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { | ||||||
| 	// 2 bytes paramID
 | 	// 2 bytes paramID
 | ||||||
| 	const dataOffset = 1 + 4 + 2 | 	const dataOffset = 1 + 4 + 2 | ||||||
| 
 | 
 | ||||||
| 	// Can not use the write buffer since
 | 	// Cannot use the write buffer since
 | ||||||
| 	// a) the buffer is too small
 | 	// a) the buffer is too small
 | ||||||
| 	// b) it is in use
 | 	// b) it is in use
 | ||||||
| 	data := make([]byte, 4+1+4+2+len(arg)) | 	data := make([]byte, 4+1+4+2+len(arg)) | ||||||
|  | @ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 	const minPktLen = 4 + 1 + 4 + 1 + 4 | 	const minPktLen = 4 + 1 + 4 + 1 + 4 | ||||||
| 	mc := stmt.mc | 	mc := stmt.mc | ||||||
| 
 | 
 | ||||||
|  | 	// Determine threshould dynamically to avoid packet size shortage.
 | ||||||
|  | 	longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) | ||||||
|  | 	if longDataSize < 64 { | ||||||
|  | 		longDataSize = 64 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Reset packet-sequence
 | 	// Reset packet-sequence
 | ||||||
| 	mc.sequence = 0 | 	mc.sequence = 0 | ||||||
| 
 | 
 | ||||||
|  | @ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 		data = mc.buf.takeCompleteBuffer() | 		data = mc.buf.takeCompleteBuffer() | ||||||
| 	} | 	} | ||||||
| 	if data == nil { | 	if data == nil { | ||||||
| 		// can not take the buffer. Something must be wrong with the connection
 | 		// cannot take the buffer. Something must be wrong with the connection
 | ||||||
| 		errLog.Print(ErrBusyBuffer) | 		errLog.Print(ErrBusyBuffer) | ||||||
| 		return driver.ErrBadConn | 		return errBadConnNoWrite | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// command [1 byte]
 | 	// command [1 byte]
 | ||||||
|  | @ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 			// build NULL-bitmap
 | 			// build NULL-bitmap
 | ||||||
| 			if arg == nil { | 			if arg == nil { | ||||||
| 				nullMask[i/8] |= 1 << (uint(i) & 7) | 				nullMask[i/8] |= 1 << (uint(i) & 7) | ||||||
| 				paramTypes[i+i] = fieldTypeNULL | 				paramTypes[i+i] = byte(fieldTypeNULL) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  | @ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 			// cache types and values
 | 			// cache types and values
 | ||||||
| 			switch v := arg.(type) { | 			switch v := arg.(type) { | ||||||
| 			case int64: | 			case int64: | ||||||
| 				paramTypes[i+i] = fieldTypeLongLong | 				paramTypes[i+i] = byte(fieldTypeLongLong) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 				if cap(paramValues)-len(paramValues)-8 >= 0 { | 				if cap(paramValues)-len(paramValues)-8 >= 0 { | ||||||
|  | @ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 			case float64: | 			case float64: | ||||||
| 				paramTypes[i+i] = fieldTypeDouble | 				paramTypes[i+i] = byte(fieldTypeDouble) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 				if cap(paramValues)-len(paramValues)-8 >= 0 { | 				if cap(paramValues)-len(paramValues)-8 >= 0 { | ||||||
|  | @ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 			case bool: | 			case bool: | ||||||
| 				paramTypes[i+i] = fieldTypeTiny | 				paramTypes[i+i] = byte(fieldTypeTiny) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 				if v { | 				if v { | ||||||
|  | @ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 			case []byte: | 			case []byte: | ||||||
| 				// Common case (non-nil value) first
 | 				// Common case (non-nil value) first
 | ||||||
| 				if v != nil { | 				if v != nil { | ||||||
| 					paramTypes[i+i] = fieldTypeString | 					paramTypes[i+i] = byte(fieldTypeString) | ||||||
| 					paramTypes[i+i+1] = 0x00 | 					paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 					if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | 					if len(v) < longDataSize { | ||||||
| 						paramValues = appendLengthEncodedInteger(paramValues, | 						paramValues = appendLengthEncodedInteger(paramValues, | ||||||
| 							uint64(len(v)), | 							uint64(len(v)), | ||||||
| 						) | 						) | ||||||
|  | @ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 
 | 
 | ||||||
| 				// Handle []byte(nil) as a NULL value
 | 				// Handle []byte(nil) as a NULL value
 | ||||||
| 				nullMask[i/8] |= 1 << (uint(i) & 7) | 				nullMask[i/8] |= 1 << (uint(i) & 7) | ||||||
| 				paramTypes[i+i] = fieldTypeNULL | 				paramTypes[i+i] = byte(fieldTypeNULL) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 			case string: | 			case string: | ||||||
| 				paramTypes[i+i] = fieldTypeString | 				paramTypes[i+i] = byte(fieldTypeString) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 				if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | 				if len(v) < longDataSize { | ||||||
| 					paramValues = appendLengthEncodedInteger(paramValues, | 					paramValues = appendLengthEncodedInteger(paramValues, | ||||||
| 						uint64(len(v)), | 						uint64(len(v)), | ||||||
| 					) | 					) | ||||||
|  | @ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 			case time.Time: | 			case time.Time: | ||||||
| 				paramTypes[i+i] = fieldTypeString | 				paramTypes[i+i] = byte(fieldTypeString) | ||||||
| 				paramTypes[i+i+1] = 0x00 | 				paramTypes[i+i+1] = 0x00 | ||||||
| 
 | 
 | ||||||
| 				var val []byte | 				var a [64]byte | ||||||
|  | 				var b = a[:0] | ||||||
|  | 
 | ||||||
| 				if v.IsZero() { | 				if v.IsZero() { | ||||||
| 					val = []byte("0000-00-00") | 					b = append(b, "0000-00-00"...) | ||||||
| 				} else { | 				} else { | ||||||
| 					val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) | 					b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				paramValues = appendLengthEncodedInteger(paramValues, | 				paramValues = appendLengthEncodedInteger(paramValues, | ||||||
| 					uint64(len(val)), | 					uint64(len(b)), | ||||||
| 				) | 				) | ||||||
| 				paramValues = append(paramValues, val...) | 				paramValues = append(paramValues, b...) | ||||||
| 
 | 
 | ||||||
| 			default: | 			default: | ||||||
| 				return fmt.Errorf("can not convert type: %T", arg) | 				return fmt.Errorf("cannot convert type: %T", arg) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error { | ||||||
| 			if err := mc.readUntilEOF(); err != nil { | 			if err := mc.readUntilEOF(); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} else { |  | ||||||
| 			mc.status &^= statusMoreResultsExists |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 		// EOF Packet
 | 		// EOF Packet
 | ||||||
| 		if data[0] == iEOF && len(data) == 5 { | 		if data[0] == iEOF && len(data) == 5 { | ||||||
| 			rows.mc.status = readStatus(data[3:]) | 			rows.mc.status = readStatus(data[3:]) | ||||||
| 			if err := rows.mc.discardResults(); err != nil { | 			rows.rs.done = true | ||||||
| 				return err | 			if !rows.HasNextResultSet() { | ||||||
|  | 				rows.mc = nil | ||||||
| 			} | 			} | ||||||
| 			rows.mc = nil |  | ||||||
| 			return io.EOF | 			return io.EOF | ||||||
| 		} | 		} | ||||||
|  | 		mc := rows.mc | ||||||
| 		rows.mc = nil | 		rows.mc = nil | ||||||
| 
 | 
 | ||||||
| 		// Error otherwise
 | 		// Error otherwise
 | ||||||
| 		return rows.mc.handleErrorPacket(data) | 		return mc.handleErrorPacket(data) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
 | 	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
 | ||||||
|  | @ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Convert to byte-coded string
 | 		// Convert to byte-coded string
 | ||||||
| 		switch rows.columns[i].fieldType { | 		switch rows.rs.columns[i].fieldType { | ||||||
| 		case fieldTypeNULL: | 		case fieldTypeNULL: | ||||||
| 			dest[i] = nil | 			dest[i] = nil | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
| 		// Numeric Types
 | 		// Numeric Types
 | ||||||
| 		case fieldTypeTiny: | 		case fieldTypeTiny: | ||||||
| 			if rows.columns[i].flags&flagUnsigned != 0 { | 			if rows.rs.columns[i].flags&flagUnsigned != 0 { | ||||||
| 				dest[i] = int64(data[pos]) | 				dest[i] = int64(data[pos]) | ||||||
| 			} else { | 			} else { | ||||||
| 				dest[i] = int64(int8(data[pos])) | 				dest[i] = int64(int8(data[pos])) | ||||||
|  | @ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
| 		case fieldTypeShort, fieldTypeYear: | 		case fieldTypeShort, fieldTypeYear: | ||||||
| 			if rows.columns[i].flags&flagUnsigned != 0 { | 			if rows.rs.columns[i].flags&flagUnsigned != 0 { | ||||||
| 				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | 				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | ||||||
| 			} else { | 			} else { | ||||||
| 				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | 				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | ||||||
|  | @ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
| 		case fieldTypeInt24, fieldTypeLong: | 		case fieldTypeInt24, fieldTypeLong: | ||||||
| 			if rows.columns[i].flags&flagUnsigned != 0 { | 			if rows.rs.columns[i].flags&flagUnsigned != 0 { | ||||||
| 				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | 				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | ||||||
| 			} else { | 			} else { | ||||||
| 				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | 				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | ||||||
|  | @ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
| 		case fieldTypeLongLong: | 		case fieldTypeLongLong: | ||||||
| 			if rows.columns[i].flags&flagUnsigned != 0 { | 			if rows.rs.columns[i].flags&flagUnsigned != 0 { | ||||||
| 				val := binary.LittleEndian.Uint64(data[pos : pos+8]) | 				val := binary.LittleEndian.Uint64(data[pos : pos+8]) | ||||||
| 				if val > math.MaxInt64 { | 				if val > math.MaxInt64 { | ||||||
| 					dest[i] = uint64ToString(val) | 					dest[i] = uint64ToString(val) | ||||||
|  | @ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
| 		case fieldTypeFloat: | 		case fieldTypeFloat: | ||||||
| 			dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) | 			dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) | ||||||
| 			pos += 4 | 			pos += 4 | ||||||
| 			continue | 			continue | ||||||
| 
 | 
 | ||||||
|  | @ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 			case isNull: | 			case isNull: | ||||||
| 				dest[i] = nil | 				dest[i] = nil | ||||||
| 				continue | 				continue | ||||||
| 			case rows.columns[i].fieldType == fieldTypeTime: | 			case rows.rs.columns[i].fieldType == fieldTypeTime: | ||||||
| 				// database/sql does not support an equivalent to TIME, return a string
 | 				// database/sql does not support an equivalent to TIME, return a string
 | ||||||
| 				var dstlen uint8 | 				var dstlen uint8 | ||||||
| 				switch decimals := rows.columns[i].decimals; decimals { | 				switch decimals := rows.rs.columns[i].decimals; decimals { | ||||||
| 				case 0x00, 0x1f: | 				case 0x00, 0x1f: | ||||||
| 					dstlen = 8 | 					dstlen = 8 | ||||||
| 				case 1, 2, 3, 4, 5, 6: | 				case 1, 2, 3, 4, 5, 6: | ||||||
|  | @ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 				default: | 				default: | ||||||
| 					return fmt.Errorf( | 					return fmt.Errorf( | ||||||
| 						"protocol error, illegal decimals value %d", | 						"protocol error, illegal decimals value %d", | ||||||
| 						rows.columns[i].decimals, | 						rows.rs.columns[i].decimals, | ||||||
| 					) | 					) | ||||||
| 				} | 				} | ||||||
| 				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | 				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | ||||||
|  | @ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | 				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | ||||||
| 			default: | 			default: | ||||||
| 				var dstlen uint8 | 				var dstlen uint8 | ||||||
| 				if rows.columns[i].fieldType == fieldTypeDate { | 				if rows.rs.columns[i].fieldType == fieldTypeDate { | ||||||
| 					dstlen = 10 | 					dstlen = 10 | ||||||
| 				} else { | 				} else { | ||||||
| 					switch decimals := rows.columns[i].decimals; decimals { | 					switch decimals := rows.rs.columns[i].decimals; decimals { | ||||||
| 					case 0x00, 0x1f: | 					case 0x00, 0x1f: | ||||||
| 						dstlen = 19 | 						dstlen = 19 | ||||||
| 					case 1, 2, 3, 4, 5, 6: | 					case 1, 2, 3, 4, 5, 6: | ||||||
|  | @ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 					default: | 					default: | ||||||
| 						return fmt.Errorf( | 						return fmt.Errorf( | ||||||
| 							"protocol error, illegal decimals value %d", | 							"protocol error, illegal decimals value %d", | ||||||
| 							rows.columns[i].decimals, | 							rows.rs.columns[i].decimals, | ||||||
| 						) | 						) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | @ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | ||||||
| 
 | 
 | ||||||
| 		// Please report if this happens!
 | 		// Please report if this happens!
 | ||||||
| 		default: | 		default: | ||||||
| 			return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) | 			return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -11,19 +11,20 @@ package mysql | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"math" | ||||||
|  | 	"reflect" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type mysqlField struct { | type resultSet struct { | ||||||
| 	tableName string | 	columns     []mysqlField | ||||||
| 	name      string | 	columnNames []string | ||||||
| 	flags     fieldFlag | 	done        bool | ||||||
| 	fieldType byte |  | ||||||
| 	decimals  byte |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type mysqlRows struct { | type mysqlRows struct { | ||||||
| 	mc      *mysqlConn | 	mc     *mysqlConn | ||||||
| 	columns []mysqlField | 	rs     resultSet | ||||||
|  | 	finish func() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type binaryRows struct { | type binaryRows struct { | ||||||
|  | @ -34,37 +35,86 @@ type textRows struct { | ||||||
| 	mysqlRows | 	mysqlRows | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type emptyRows struct{} |  | ||||||
| 
 |  | ||||||
| func (rows *mysqlRows) Columns() []string { | func (rows *mysqlRows) Columns() []string { | ||||||
| 	columns := make([]string, len(rows.columns)) | 	if rows.rs.columnNames != nil { | ||||||
|  | 		return rows.rs.columnNames | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	columns := make([]string, len(rows.rs.columns)) | ||||||
| 	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | 	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | ||||||
| 		for i := range columns { | 		for i := range columns { | ||||||
| 			if tableName := rows.columns[i].tableName; len(tableName) > 0 { | 			if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { | ||||||
| 				columns[i] = tableName + "." + rows.columns[i].name | 				columns[i] = tableName + "." + rows.rs.columns[i].name | ||||||
| 			} else { | 			} else { | ||||||
| 				columns[i] = rows.columns[i].name | 				columns[i] = rows.rs.columns[i].name | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		for i := range columns { | 		for i := range columns { | ||||||
| 			columns[i] = rows.columns[i].name | 			columns[i] = rows.rs.columns[i].name | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	rows.rs.columnNames = columns | ||||||
| 	return columns | 	return columns | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (rows *mysqlRows) Close() error { | func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { | ||||||
|  | 	return rows.rs.columns[i].typeDatabaseName() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
 | ||||||
|  | // 	return int64(rows.rs.columns[i].length), true
 | ||||||
|  | // }
 | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { | ||||||
|  | 	return rows.rs.columns[i].flags&flagNotNULL == 0, true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { | ||||||
|  | 	column := rows.rs.columns[i] | ||||||
|  | 	decimals := int64(column.decimals) | ||||||
|  | 
 | ||||||
|  | 	switch column.fieldType { | ||||||
|  | 	case fieldTypeDecimal, fieldTypeNewDecimal: | ||||||
|  | 		if decimals > 0 { | ||||||
|  | 			return int64(column.length) - 2, decimals, true | ||||||
|  | 		} | ||||||
|  | 		return int64(column.length) - 1, decimals, true | ||||||
|  | 	case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: | ||||||
|  | 		return decimals, decimals, true | ||||||
|  | 	case fieldTypeFloat, fieldTypeDouble: | ||||||
|  | 		if decimals == 0x1f { | ||||||
|  | 			return math.MaxInt64, math.MaxInt64, true | ||||||
|  | 		} | ||||||
|  | 		return math.MaxInt64, decimals, true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return 0, 0, false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { | ||||||
|  | 	return rows.rs.columns[i].scanType() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) Close() (err error) { | ||||||
|  | 	if f := rows.finish; f != nil { | ||||||
|  | 		f() | ||||||
|  | 		rows.finish = nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	mc := rows.mc | 	mc := rows.mc | ||||||
| 	if mc == nil { | 	if mc == nil { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	if mc.netConn == nil { | 	if err := mc.error(); err != nil { | ||||||
| 		return ErrInvalidConn | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Remove unread packets from stream
 | 	// Remove unread packets from stream
 | ||||||
| 	err := mc.readUntilEOF() | 	if !rows.rs.done { | ||||||
|  | 		err = mc.readUntilEOF() | ||||||
|  | 	} | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		if err = mc.discardResults(); err != nil { | 		if err = mc.discardResults(); err != nil { | ||||||
| 			return err | 			return err | ||||||
|  | @ -75,10 +125,66 @@ func (rows *mysqlRows) Close() error { | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (rows *mysqlRows) HasNextResultSet() (b bool) { | ||||||
|  | 	if rows.mc == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return rows.mc.status&statusMoreResultsExists != 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) nextResultSet() (int, error) { | ||||||
|  | 	if rows.mc == nil { | ||||||
|  | 		return 0, io.EOF | ||||||
|  | 	} | ||||||
|  | 	if err := rows.mc.error(); err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Remove unread packets from stream
 | ||||||
|  | 	if !rows.rs.done { | ||||||
|  | 		if err := rows.mc.readUntilEOF(); err != nil { | ||||||
|  | 			return 0, err | ||||||
|  | 		} | ||||||
|  | 		rows.rs.done = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !rows.HasNextResultSet() { | ||||||
|  | 		rows.mc = nil | ||||||
|  | 		return 0, io.EOF | ||||||
|  | 	} | ||||||
|  | 	rows.rs = resultSet{} | ||||||
|  | 	return rows.mc.readResultSetHeaderPacket() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { | ||||||
|  | 	for { | ||||||
|  | 		resLen, err := rows.nextResultSet() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return 0, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if resLen > 0 { | ||||||
|  | 			return resLen, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		rows.rs.done = true | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rows *binaryRows) NextResultSet() error { | ||||||
|  | 	resLen, err := rows.nextNotEmptyResultSet() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows.rs.columns, err = rows.mc.readColumns(resLen) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (rows *binaryRows) Next(dest []driver.Value) error { | func (rows *binaryRows) Next(dest []driver.Value) error { | ||||||
| 	if mc := rows.mc; mc != nil { | 	if mc := rows.mc; mc != nil { | ||||||
| 		if mc.netConn == nil { | 		if err := mc.error(); err != nil { | ||||||
| 			return ErrInvalidConn | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Fetch next row from stream
 | 		// Fetch next row from stream
 | ||||||
|  | @ -87,10 +193,20 @@ func (rows *binaryRows) Next(dest []driver.Value) error { | ||||||
| 	return io.EOF | 	return io.EOF | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (rows *textRows) NextResultSet() (err error) { | ||||||
|  | 	resLen, err := rows.nextNotEmptyResultSet() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows.rs.columns, err = rows.mc.readColumns(resLen) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (rows *textRows) Next(dest []driver.Value) error { | func (rows *textRows) Next(dest []driver.Value) error { | ||||||
| 	if mc := rows.mc; mc != nil { | 	if mc := rows.mc; mc != nil { | ||||||
| 		if mc.netConn == nil { | 		if err := mc.error(); err != nil { | ||||||
| 			return ErrInvalidConn | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Fetch next row from stream
 | 		// Fetch next row from stream
 | ||||||
|  | @ -98,15 +214,3 @@ func (rows *textRows) Next(dest []driver.Value) error { | ||||||
| 	} | 	} | ||||||
| 	return io.EOF | 	return io.EOF | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (rows emptyRows) Columns() []string { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (rows emptyRows) Close() error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (rows emptyRows) Next(dest []driver.Value) error { |  | ||||||
| 	return io.EOF |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -11,6 +11,7 @@ package mysql | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
|  | @ -19,12 +20,14 @@ type mysqlStmt struct { | ||||||
| 	mc         *mysqlConn | 	mc         *mysqlConn | ||||||
| 	id         uint32 | 	id         uint32 | ||||||
| 	paramCount int | 	paramCount int | ||||||
| 	columns    []mysqlField // cached from the first query
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *mysqlStmt) Close() error { | func (stmt *mysqlStmt) Close() error { | ||||||
| 	if stmt.mc == nil || stmt.mc.netConn == nil { | 	if stmt.mc == nil || stmt.mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		// driver.Stmt.Close can be called more than once, thus this function
 | ||||||
|  | 		// has to be idempotent.
 | ||||||
|  | 		// See also Issue #450 and golang/go#16019.
 | ||||||
|  | 		//errLog.Print(ErrInvalidConn)
 | ||||||
| 		return driver.ErrBadConn | 		return driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||||
| 	if stmt.mc.netConn == nil { | 	if stmt.mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
| 	// Send command
 | 	// Send command
 | ||||||
| 	err := stmt.writeExecutePacket(args) | 	err := stmt.writeExecutePacket(args) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, stmt.mc.markBadConn(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mc := stmt.mc | 	mc := stmt.mc | ||||||
|  | @ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||||
| 
 | 
 | ||||||
| 	// Read Result
 | 	// Read Result
 | ||||||
| 	resLen, err := mc.readResultSetHeaderPacket() | 	resLen, err := mc.readResultSetHeaderPacket() | ||||||
| 	if err == nil { | 	if err != nil { | ||||||
| 		if resLen > 0 { | 		return nil, err | ||||||
| 			// Columns
 | 	} | ||||||
| 			err = mc.readUntilEOF() |  | ||||||
| 			if err != nil { |  | ||||||
| 				return nil, err |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			// Rows
 | 	if resLen > 0 { | ||||||
| 			err = mc.readUntilEOF() | 		// Columns
 | ||||||
|  | 		if err = mc.readUntilEOF(); err != nil { | ||||||
|  | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		if err == nil { | 
 | ||||||
| 			return &mysqlResult{ | 		// Rows
 | ||||||
| 				affectedRows: int64(mc.affectedRows), | 		if err := mc.readUntilEOF(); err != nil { | ||||||
| 				insertId:     int64(mc.insertId), | 			return nil, err | ||||||
| 			}, nil |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil, err | 	if err := mc.discardResults(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &mysqlResult{ | ||||||
|  | 		affectedRows: int64(mc.affectedRows), | ||||||
|  | 		insertId:     int64(mc.insertId), | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||||
| 	if stmt.mc.netConn == nil { | 	return stmt.query(args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { | ||||||
|  | 	if stmt.mc.closed.IsSet() { | ||||||
| 		errLog.Print(ErrInvalidConn) | 		errLog.Print(ErrInvalidConn) | ||||||
| 		return nil, driver.ErrBadConn | 		return nil, driver.ErrBadConn | ||||||
| 	} | 	} | ||||||
| 	// Send command
 | 	// Send command
 | ||||||
| 	err := stmt.writeExecutePacket(args) | 	err := stmt.writeExecutePacket(args) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, stmt.mc.markBadConn(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mc := stmt.mc | 	mc := stmt.mc | ||||||
|  | @ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||||
| 
 | 
 | ||||||
| 	if resLen > 0 { | 	if resLen > 0 { | ||||||
| 		rows.mc = mc | 		rows.mc = mc | ||||||
| 		// Columns
 | 		rows.rs.columns, err = mc.readColumns(resLen) | ||||||
| 		// If not cached, read them and cache them
 | 	} else { | ||||||
| 		if stmt.columns == nil { | 		rows.rs.done = true | ||||||
| 			rows.columns, err = mc.readColumns(resLen) | 
 | ||||||
| 			stmt.columns = rows.columns | 		switch err := rows.NextResultSet(); err { | ||||||
| 		} else { | 		case nil, io.EOF: | ||||||
| 			rows.columns = stmt.columns | 			return rows, nil | ||||||
| 			err = mc.readUntilEOF() | 		default: | ||||||
|  | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||||
| 
 | 
 | ||||||
| type converter struct{} | type converter struct{} | ||||||
| 
 | 
 | ||||||
|  | // ConvertValue mirrors the reference/default converter in database/sql/driver
 | ||||||
|  | // with _one_ exception.  We support uint64 with their high bit and the default
 | ||||||
|  | // implementation does not.  This function should be kept in sync with
 | ||||||
|  | // database/sql/driver defaultConverter.ConvertValue() except for that
 | ||||||
|  | // deliberate difference.
 | ||||||
| func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | ||||||
| 	if driver.IsValue(v) { | 	if driver.IsValue(v) { | ||||||
| 		return v, nil | 		return v, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if vr, ok := v.(driver.Valuer); ok { | ||||||
|  | 		sv, err := callValuerValue(vr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if !driver.IsValue(sv) { | ||||||
|  | 			return nil, fmt.Errorf("non-Value type %T returned from Value", sv) | ||||||
|  | 		} | ||||||
|  | 		return sv, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	rv := reflect.ValueOf(v) | 	rv := reflect.ValueOf(v) | ||||||
| 	switch rv.Kind() { | 	switch rv.Kind() { | ||||||
| 	case reflect.Ptr: | 	case reflect.Ptr: | ||||||
| 		// indirect pointers
 | 		// indirect pointers
 | ||||||
| 		if rv.IsNil() { | 		if rv.IsNil() { | ||||||
| 			return nil, nil | 			return nil, nil | ||||||
|  | 		} else { | ||||||
|  | 			return c.ConvertValue(rv.Elem().Interface()) | ||||||
| 		} | 		} | ||||||
| 		return c.ConvertValue(rv.Elem().Interface()) |  | ||||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||||
| 		return rv.Int(), nil | 		return rv.Int(), nil | ||||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | ||||||
|  | @ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | ||||||
| 		return int64(u64), nil | 		return int64(u64), nil | ||||||
| 	case reflect.Float32, reflect.Float64: | 	case reflect.Float32, reflect.Float64: | ||||||
| 		return rv.Float(), nil | 		return rv.Float(), nil | ||||||
|  | 	case reflect.Bool: | ||||||
|  | 		return rv.Bool(), nil | ||||||
|  | 	case reflect.Slice: | ||||||
|  | 		ek := rv.Type().Elem().Kind() | ||||||
|  | 		if ek == reflect.Uint8 { | ||||||
|  | 			return rv.Bytes(), nil | ||||||
|  | 		} | ||||||
|  | 		return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) | ||||||
|  | 	case reflect.String: | ||||||
|  | 		return rv.String(), nil | ||||||
| 	} | 	} | ||||||
| 	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | 	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | ||||||
|  | 
 | ||||||
|  | // callValuerValue returns vr.Value(), with one exception:
 | ||||||
|  | // If vr.Value is an auto-generated method on a pointer type and the
 | ||||||
|  | // pointer is nil, it would panic at runtime in the panicwrap
 | ||||||
|  | // method. Treat it like nil instead.
 | ||||||
|  | //
 | ||||||
|  | // This is so people can implement driver.Value on value types and
 | ||||||
|  | // still use nil pointers to those types to mean nil/NULL, just like
 | ||||||
|  | // string/*string.
 | ||||||
|  | //
 | ||||||
|  | // This is an exact copy of the same-named unexported function from the
 | ||||||
|  | // database/sql package.
 | ||||||
|  | func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { | ||||||
|  | 	if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && | ||||||
|  | 		rv.IsNil() && | ||||||
|  | 		rv.Type().Elem().Implements(valuerReflectType) { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 	return vr.Value() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ type mysqlTx struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tx *mysqlTx) Commit() (err error) { | func (tx *mysqlTx) Commit() (err error) { | ||||||
| 	if tx.mc == nil || tx.mc.netConn == nil { | 	if tx.mc == nil || tx.mc.closed.IsSet() { | ||||||
| 		return ErrInvalidConn | 		return ErrInvalidConn | ||||||
| 	} | 	} | ||||||
| 	err = tx.mc.exec("COMMIT") | 	err = tx.mc.exec("COMMIT") | ||||||
|  | @ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tx *mysqlTx) Rollback() (err error) { | func (tx *mysqlTx) Rollback() (err error) { | ||||||
| 	if tx.mc == nil || tx.mc.netConn == nil { | 	if tx.mc == nil || tx.mc.closed.IsSet() { | ||||||
| 		return ErrInvalidConn | 		return ErrInvalidConn | ||||||
| 	} | 	} | ||||||
| 	err = tx.mc.exec("ROLLBACK") | 	err = tx.mc.exec("ROLLBACK") | ||||||
|  |  | ||||||
|  | @ -9,23 +9,29 @@ | ||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/sha1" |  | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Registry for custom tls.Configs
 | ||||||
| var ( | var ( | ||||||
| 	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
 | 	tlsConfigLock     sync.RWMutex | ||||||
|  | 	tlsConfigRegistry map[string]*tls.Config | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
 | // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
 | ||||||
| // Use the key as a value in the DSN where tls=value.
 | // Use the key as a value in the DSN where tls=value.
 | ||||||
| //
 | //
 | ||||||
|  | // Note: The provided tls.Config is exclusively owned by the driver after
 | ||||||
|  | // registering it.
 | ||||||
|  | //
 | ||||||
| //  rootCertPool := x509.NewCertPool()
 | //  rootCertPool := x509.NewCertPool()
 | ||||||
| //  pem, err := ioutil.ReadFile("/path/ca-cert.pem")
 | //  pem, err := ioutil.ReadFile("/path/ca-cert.pem")
 | ||||||
| //  if err != nil {
 | //  if err != nil {
 | ||||||
|  | @ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { | ||||||
| 		return fmt.Errorf("key '%s' is reserved", key) | 		return fmt.Errorf("key '%s' is reserved", key) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if tlsConfigRegister == nil { | 	tlsConfigLock.Lock() | ||||||
| 		tlsConfigRegister = make(map[string]*tls.Config) | 	if tlsConfigRegistry == nil { | ||||||
|  | 		tlsConfigRegistry = make(map[string]*tls.Config) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tlsConfigRegister[key] = config | 	tlsConfigRegistry[key] = config | ||||||
|  | 	tlsConfigLock.Unlock() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // DeregisterTLSConfig removes the tls.Config associated with key.
 | // DeregisterTLSConfig removes the tls.Config associated with key.
 | ||||||
| func DeregisterTLSConfig(key string) { | func DeregisterTLSConfig(key string) { | ||||||
| 	if tlsConfigRegister != nil { | 	tlsConfigLock.Lock() | ||||||
| 		delete(tlsConfigRegister, key) | 	if tlsConfigRegistry != nil { | ||||||
|  | 		delete(tlsConfigRegistry, key) | ||||||
| 	} | 	} | ||||||
|  | 	tlsConfigLock.Unlock() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getTLSConfigClone(key string) (config *tls.Config) { | ||||||
|  | 	tlsConfigLock.RLock() | ||||||
|  | 	if v, ok := tlsConfigRegistry[key]; ok { | ||||||
|  | 		config = cloneTLSConfig(v) | ||||||
|  | 	} | ||||||
|  | 	tlsConfigLock.RUnlock() | ||||||
|  | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Returns the bool value of the input.
 | // Returns the bool value of the input.
 | ||||||
|  | @ -80,119 +99,6 @@ func readBool(input string) (value bool, valid bool) { | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /****************************************************************************** |  | ||||||
| *                             Authentication                                  * |  | ||||||
| ******************************************************************************/ |  | ||||||
| 
 |  | ||||||
| // Encrypt password using 4.1+ method
 |  | ||||||
| func scramblePassword(scramble, password []byte) []byte { |  | ||||||
| 	if len(password) == 0 { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// stage1Hash = SHA1(password)
 |  | ||||||
| 	crypt := sha1.New() |  | ||||||
| 	crypt.Write(password) |  | ||||||
| 	stage1 := crypt.Sum(nil) |  | ||||||
| 
 |  | ||||||
| 	// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
 |  | ||||||
| 	// inner Hash
 |  | ||||||
| 	crypt.Reset() |  | ||||||
| 	crypt.Write(stage1) |  | ||||||
| 	hash := crypt.Sum(nil) |  | ||||||
| 
 |  | ||||||
| 	// outer Hash
 |  | ||||||
| 	crypt.Reset() |  | ||||||
| 	crypt.Write(scramble) |  | ||||||
| 	crypt.Write(hash) |  | ||||||
| 	scramble = crypt.Sum(nil) |  | ||||||
| 
 |  | ||||||
| 	// token = scrambleHash XOR stage1Hash
 |  | ||||||
| 	for i := range scramble { |  | ||||||
| 		scramble[i] ^= stage1[i] |  | ||||||
| 	} |  | ||||||
| 	return scramble |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Encrypt password using pre 4.1 (old password) method
 |  | ||||||
| // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
 |  | ||||||
| type myRnd struct { |  | ||||||
| 	seed1, seed2 uint32 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| const myRndMaxVal = 0x3FFFFFFF |  | ||||||
| 
 |  | ||||||
| // Pseudo random number generator
 |  | ||||||
| func newMyRnd(seed1, seed2 uint32) *myRnd { |  | ||||||
| 	return &myRnd{ |  | ||||||
| 		seed1: seed1 % myRndMaxVal, |  | ||||||
| 		seed2: seed2 % myRndMaxVal, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Tested to be equivalent to MariaDB's floating point variant
 |  | ||||||
| // http://play.golang.org/p/QHvhd4qved
 |  | ||||||
| // http://play.golang.org/p/RG0q4ElWDx
 |  | ||||||
| func (r *myRnd) NextByte() byte { |  | ||||||
| 	r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal |  | ||||||
| 	r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal |  | ||||||
| 
 |  | ||||||
| 	return byte(uint64(r.seed1) * 31 / myRndMaxVal) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Generate binary hash from byte string using insecure pre 4.1 method
 |  | ||||||
| func pwHash(password []byte) (result [2]uint32) { |  | ||||||
| 	var add uint32 = 7 |  | ||||||
| 	var tmp uint32 |  | ||||||
| 
 |  | ||||||
| 	result[0] = 1345345333 |  | ||||||
| 	result[1] = 0x12345671 |  | ||||||
| 
 |  | ||||||
| 	for _, c := range password { |  | ||||||
| 		// skip spaces and tabs in password
 |  | ||||||
| 		if c == ' ' || c == '\t' { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		tmp = uint32(c) |  | ||||||
| 		result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) |  | ||||||
| 		result[1] += (result[1] << 8) ^ result[0] |  | ||||||
| 		add += tmp |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Remove sign bit (1<<31)-1)
 |  | ||||||
| 	result[0] &= 0x7FFFFFFF |  | ||||||
| 	result[1] &= 0x7FFFFFFF |  | ||||||
| 
 |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Encrypt password using insecure pre 4.1 method
 |  | ||||||
| func scrambleOldPassword(scramble, password []byte) []byte { |  | ||||||
| 	if len(password) == 0 { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	scramble = scramble[:8] |  | ||||||
| 
 |  | ||||||
| 	hashPw := pwHash(password) |  | ||||||
| 	hashSc := pwHash(scramble) |  | ||||||
| 
 |  | ||||||
| 	r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) |  | ||||||
| 
 |  | ||||||
| 	var out [8]byte |  | ||||||
| 	for i := range out { |  | ||||||
| 		out[i] = r.NextByte() + 64 |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	mask := r.NextByte() |  | ||||||
| 	for i := range out { |  | ||||||
| 		out[i] ^= mask |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return out[:] |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /****************************************************************************** | /****************************************************************************** | ||||||
| *                           Time related utils                                * | *                           Time related utils                                * | ||||||
| ******************************************************************************/ | ******************************************************************************/ | ||||||
|  | @ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { | ||||||
| 
 | 
 | ||||||
| 	// Check data length
 | 	// Check data length
 | ||||||
| 	if len(b) >= n { | 	if len(b) >= n { | ||||||
| 		return b[n-int(num) : n], false, n, nil | 		return b[n-int(num) : n : n], false, n, nil | ||||||
| 	} | 	} | ||||||
| 	return nil, false, n, io.EOF | 	return nil, false, n, io.EOF | ||||||
| } | } | ||||||
|  | @ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { | ||||||
| 	if len(b) == 0 { | 	if len(b) == 0 { | ||||||
| 		return 0, true, 1 | 		return 0, true, 1 | ||||||
| 	} | 	} | ||||||
| 	switch b[0] { |  | ||||||
| 
 | 
 | ||||||
|  | 	switch b[0] { | ||||||
| 	// 251: NULL
 | 	// 251: NULL
 | ||||||
| 	case 0xfb: | 	case 0xfb: | ||||||
| 		return 0, true, 1 | 		return 0, true, 1 | ||||||
|  | @ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte { | ||||||
| 
 | 
 | ||||||
| 	return buf[:pos] | 	return buf[:pos] | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | /****************************************************************************** | ||||||
|  | *                               Sync utils                                    * | ||||||
|  | ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // noCopy may be embedded into structs which must not be copied
 | ||||||
|  | // after the first use.
 | ||||||
|  | //
 | ||||||
|  | // See https://github.com/golang/go/issues/8005#issuecomment-190753527
 | ||||||
|  | // for details.
 | ||||||
|  | type noCopy struct{} | ||||||
|  | 
 | ||||||
|  | // Lock is a no-op used by -copylocks checker from `go vet`.
 | ||||||
|  | func (*noCopy) Lock() {} | ||||||
|  | 
 | ||||||
|  | // atomicBool is a wrapper around uint32 for usage as a boolean value with
 | ||||||
|  | // atomic access.
 | ||||||
|  | type atomicBool struct { | ||||||
|  | 	_noCopy noCopy | ||||||
|  | 	value   uint32 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsSet returns wether the current boolean value is true
 | ||||||
|  | func (ab *atomicBool) IsSet() bool { | ||||||
|  | 	return atomic.LoadUint32(&ab.value) > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Set sets the value of the bool regardless of the previous value
 | ||||||
|  | func (ab *atomicBool) Set(value bool) { | ||||||
|  | 	if value { | ||||||
|  | 		atomic.StoreUint32(&ab.value, 1) | ||||||
|  | 	} else { | ||||||
|  | 		atomic.StoreUint32(&ab.value, 0) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TrySet sets the value of the bool and returns wether the value changed
 | ||||||
|  | func (ab *atomicBool) TrySet(value bool) bool { | ||||||
|  | 	if value { | ||||||
|  | 		return atomic.SwapUint32(&ab.value, 1) == 0 | ||||||
|  | 	} | ||||||
|  | 	return atomic.SwapUint32(&ab.value, 0) > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // atomicError is a wrapper for atomically accessed error values
 | ||||||
|  | type atomicError struct { | ||||||
|  | 	_noCopy noCopy | ||||||
|  | 	value   atomic.Value | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Set sets the error value regardless of the previous value.
 | ||||||
|  | // The value must not be nil
 | ||||||
|  | func (ae *atomicError) Set(value error) { | ||||||
|  | 	ae.value.Store(value) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value returns the current error value
 | ||||||
|  | func (ae *atomicError) Value() error { | ||||||
|  | 	if v := ae.value.Load(); v != nil { | ||||||
|  | 		// this will panic if the value doesn't implement the error interface
 | ||||||
|  | 		return v.(error) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,40 @@ | ||||||
|  | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // This Source Code Form is subject to the terms of the Mozilla Public
 | ||||||
|  | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
|  | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
|  | 
 | ||||||
|  | // +build go1.7
 | ||||||
|  | // +build !go1.8
 | ||||||
|  | 
 | ||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import "crypto/tls" | ||||||
|  | 
 | ||||||
|  | func cloneTLSConfig(c *tls.Config) *tls.Config { | ||||||
|  | 	return &tls.Config{ | ||||||
|  | 		Rand:                        c.Rand, | ||||||
|  | 		Time:                        c.Time, | ||||||
|  | 		Certificates:                c.Certificates, | ||||||
|  | 		NameToCertificate:           c.NameToCertificate, | ||||||
|  | 		GetCertificate:              c.GetCertificate, | ||||||
|  | 		RootCAs:                     c.RootCAs, | ||||||
|  | 		NextProtos:                  c.NextProtos, | ||||||
|  | 		ServerName:                  c.ServerName, | ||||||
|  | 		ClientAuth:                  c.ClientAuth, | ||||||
|  | 		ClientCAs:                   c.ClientCAs, | ||||||
|  | 		InsecureSkipVerify:          c.InsecureSkipVerify, | ||||||
|  | 		CipherSuites:                c.CipherSuites, | ||||||
|  | 		PreferServerCipherSuites:    c.PreferServerCipherSuites, | ||||||
|  | 		SessionTicketsDisabled:      c.SessionTicketsDisabled, | ||||||
|  | 		SessionTicketKey:            c.SessionTicketKey, | ||||||
|  | 		ClientSessionCache:          c.ClientSessionCache, | ||||||
|  | 		MinVersion:                  c.MinVersion, | ||||||
|  | 		MaxVersion:                  c.MaxVersion, | ||||||
|  | 		CurvePreferences:            c.CurvePreferences, | ||||||
|  | 		DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, | ||||||
|  | 		Renegotiation:               c.Renegotiation, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,50 @@ | ||||||
|  | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // This Source Code Form is subject to the terms of the Mozilla Public
 | ||||||
|  | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | ||||||
|  | // You can obtain one at http://mozilla.org/MPL/2.0/.
 | ||||||
|  | 
 | ||||||
|  | // +build go1.8
 | ||||||
|  | 
 | ||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func cloneTLSConfig(c *tls.Config) *tls.Config { | ||||||
|  | 	return c.Clone() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | ||||||
|  | 	dargs := make([]driver.Value, len(named)) | ||||||
|  | 	for n, param := range named { | ||||||
|  | 		if len(param.Name) > 0 { | ||||||
|  | 			// TODO: support the use of Named Parameters #561
 | ||||||
|  | 			return nil, errors.New("mysql: driver does not support the use of Named Parameters") | ||||||
|  | 		} | ||||||
|  | 		dargs[n] = param.Value | ||||||
|  | 	} | ||||||
|  | 	return dargs, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func mapIsolationLevel(level driver.IsolationLevel) (string, error) { | ||||||
|  | 	switch sql.IsolationLevel(level) { | ||||||
|  | 	case sql.LevelRepeatableRead: | ||||||
|  | 		return "REPEATABLE READ", nil | ||||||
|  | 	case sql.LevelReadCommitted: | ||||||
|  | 		return "READ COMMITTED", nil | ||||||
|  | 	case sql.LevelReadUncommitted: | ||||||
|  | 		return "READ UNCOMMITTED", nil | ||||||
|  | 	case sql.LevelSerializable: | ||||||
|  | 		return "SERIALIZABLE", nil | ||||||
|  | 	default: | ||||||
|  | 		return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue