summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/connection.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/connection.go277
1 files changed, 194 insertions, 83 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go
index 04607296e..d82c728f3 100644
--- a/vendor/github.com/go-sql-driver/mysql/connection.go
+++ b/vendor/github.com/go-sql-driver/mysql/connection.go
@@ -9,10 +9,9 @@
package mysql
import (
- "crypto/tls"
"database/sql/driver"
- "errors"
"net"
+ "strconv"
"strings"
"time"
)
@@ -22,34 +21,20 @@ type mysqlConn struct {
netConn net.Conn
affectedRows uint64
insertId uint64
- cfg *config
- maxPacketAllowed int
+ cfg *Config
+ maxAllowedPacket int
maxWriteSize int
+ writeTimeout time.Duration
flags clientFlag
+ status statusFlag
sequence uint8
parseTime bool
strict bool
}
-type config struct {
- user string
- passwd string
- net string
- addr string
- dbname string
- params map[string]string
- loc *time.Location
- tls *tls.Config
- timeout time.Duration
- collation uint8
- allowAllFiles bool
- allowOldPasswords bool
- clientFoundRows bool
-}
-
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
- for param, val := range mc.cfg.params {
+ for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
@@ -65,27 +50,6 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}
- // time.Time parsing
- case "parseTime":
- var isBool bool
- mc.parseTime, isBool = readBool(val)
- if !isBool {
- return errors.New("Invalid Bool value: " + val)
- }
-
- // Strict mode
- case "strict":
- var isBool bool
- mc.strict, isBool = readBool(val)
- if !isBool {
- return errors.New("Invalid Bool value: " + val)
- }
-
- // Compression
- case "compress":
- err = errors.New("Compression not implemented yet")
- return
-
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
@@ -115,20 +79,29 @@ func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
err = mc.writeCommandPacket(comQuit)
- if err == nil {
- err = mc.netConn.Close()
- } else {
- mc.netConn.Close()
- }
- mc.netConn = nil
}
- mc.cfg = nil
- mc.buf.rd = nil
+ mc.cleanup()
return
}
+// Closes the network connection and unsets internal variables. Do not call this
+// function after successfully authentication, call Close instead. This function
+// is called before auth or on auth failure because MySQL will have already
+// closed the network connection.
+func (mc *mysqlConn) cleanup() {
+ // Makes cleanup idempotent
+ if mc.netConn != nil {
+ if err := mc.netConn.Close(); err != nil {
+ errLog.Print(err)
+ }
+ mc.netConn = nil
+ }
+ mc.cfg = nil
+ mc.buf.nc = nil
+}
+
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
@@ -161,28 +134,156 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return stmt, err
}
+func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
+ // Number of ? should be same to len(args)
+ if strings.Count(query, "?") != len(args) {
+ return "", driver.ErrSkip
+ }
+
+ buf := mc.buf.takeCompleteBuffer()
+ if buf == nil {
+ // can not take the buffer. Something must be wrong with the connection
+ errLog.Print(ErrBusyBuffer)
+ return "", driver.ErrBadConn
+ }
+ buf = buf[:0]
+ argPos := 0
+
+ for i := 0; i < len(query); i++ {
+ q := strings.IndexByte(query[i:], '?')
+ if q == -1 {
+ buf = append(buf, query[i:]...)
+ break
+ }
+ buf = append(buf, query[i:i+q]...)
+ i += q
+
+ arg := args[argPos]
+ argPos++
+
+ if arg == nil {
+ buf = append(buf, "NULL"...)
+ continue
+ }
+
+ switch v := arg.(type) {
+ case int64:
+ buf = strconv.AppendInt(buf, v, 10)
+ case float64:
+ buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+ case bool:
+ if v {
+ buf = append(buf, '1')
+ } else {
+ buf = append(buf, '0')
+ }
+ case time.Time:
+ if v.IsZero() {
+ buf = append(buf, "'0000-00-00'"...)
+ } else {
+ v := v.In(mc.cfg.Loc)
+ v = v.Add(time.Nanosecond * 500) // To round under microsecond
+ year := v.Year()
+ year100 := year / 100
+ year1 := year % 100
+ month := v.Month()
+ day := v.Day()
+ hour := v.Hour()
+ minute := v.Minute()
+ second := v.Second()
+ micro := v.Nanosecond() / 1000
+
+ buf = append(buf, []byte{
+ '\'',
+ digits10[year100], digits01[year100],
+ digits10[year1], digits01[year1],
+ '-',
+ digits10[month], digits01[month],
+ '-',
+ digits10[day], digits01[day],
+ ' ',
+ digits10[hour], digits01[hour],
+ ':',
+ digits10[minute], digits01[minute],
+ ':',
+ digits10[second], digits01[second],
+ }...)
+
+ if micro != 0 {
+ micro10000 := micro / 10000
+ micro100 := micro / 100 % 100
+ micro1 := micro % 100
+ buf = append(buf, []byte{
+ '.',
+ digits10[micro10000], digits01[micro10000],
+ digits10[micro100], digits01[micro100],
+ digits10[micro1], digits01[micro1],
+ }...)
+ }
+ buf = append(buf, '\'')
+ }
+ case []byte:
+ if v == nil {
+ buf = append(buf, "NULL"...)
+ } else {
+ buf = append(buf, "_binary'"...)
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeBytesBackslash(buf, v)
+ } else {
+ buf = escapeBytesQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ }
+ case string:
+ buf = append(buf, '\'')
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeStringBackslash(buf, v)
+ } else {
+ buf = escapeStringQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ default:
+ return "", driver.ErrSkip
+ }
+
+ if len(buf)+4 > mc.maxAllowedPacket {
+ return "", driver.ErrSkip
+ }
+ }
+ if argPos != len(args) {
+ return "", driver.ErrSkip
+ }
+ return string(buf), nil
+}
+
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
- if len(args) == 0 { // no args, fastpath
- mc.affectedRows = 0
- mc.insertId = 0
-
- err := mc.exec(query)
- if err == nil {
- return &mysqlResult{
- affectedRows: int64(mc.affectedRows),
- insertId: int64(mc.insertId),
- }, err
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
}
- return nil, err
+ // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ args = nil
}
+ mc.affectedRows = 0
+ mc.insertId = 0
- // with args, must use prepared stmt
- return nil, driver.ErrSkip
-
+ err := mc.exec(query)
+ if err == nil {
+ return &mysqlResult{
+ affectedRows: int64(mc.affectedRows),
+ insertId: int64(mc.insertId),
+ }, err
+ }
+ return nil, err
}
// Internal function to execute commands
@@ -211,29 +312,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
- if len(args) == 0 { // no args, fastpath
- // Send command
- err := mc.writeCommandPacketStr(comQuery, query)
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
+ }
+ // try client-side prepare to reduce roundtrip
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ args = nil
+ }
+ // Send command
+ err := mc.writeCommandPacketStr(comQuery, query)
+ if err == nil {
+ // Read Result
+ var resLen int
+ resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
- // Read Result
- var resLen int
- resLen, err = mc.readResultSetHeaderPacket()
- if err == nil {
- rows := new(textRows)
- rows.mc = mc
-
- if resLen > 0 {
- // Columns
- rows.columns, err = mc.readColumns(resLen)
- }
- return rows, err
+ rows := new(textRows)
+ rows.mc = mc
+
+ if resLen == 0 {
+ // no columns, no more data
+ return emptyRows{}, nil
}
+ // Columns
+ rows.columns, err = mc.readColumns(resLen)
+ return rows, err
}
- return nil, err
}
-
- // with args, must use prepared stmt
- return nil, driver.ErrSkip
+ return nil, err
}
// Gets the value of the given MySQL System Variable
@@ -249,6 +359,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if err == nil {
rows := new(textRows)
rows.mc = mc
+ rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns