summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/packets.go
diff options
context:
space:
mode:
authorChristopher Speller <crspeller@gmail.com>2016-09-23 10:17:51 -0400
committerGitHub <noreply@github.com>2016-09-23 10:17:51 -0400
commit2ca0e8f9a0f9863555a26e984cde15efff9ef8f8 (patch)
treedaae1ee67b14a3d0a84424f2a304885d9e75ce2b /vendor/github.com/go-sql-driver/mysql/packets.go
parent6d62d65b2dc85855aabea036cbd44f6059e19d13 (diff)
downloadchat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.tar.gz
chat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.tar.bz2
chat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.zip
Updating golang dependancies (#4075)
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/packets.go337
1 files changed, 125 insertions, 212 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go
index 8d9166578..618098146 100644
--- a/vendor/github.com/go-sql-driver/mysql/packets.go
+++ b/vendor/github.com/go-sql-driver/mysql/packets.go
@@ -13,7 +13,6 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
- "errors"
"fmt"
"io"
"math"
@@ -48,8 +47,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
+ } else {
+ return nil, ErrPktSync
}
- return nil, ErrPktSync
}
mc.sequence++
@@ -100,12 +100,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
data[3] = mc.sequence
// Write packet
- if mc.writeTimeout > 0 {
- if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
- return err
- }
- }
-
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
@@ -146,7 +140,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
- "unsupported protocol version %d. Version %d or higher is required",
+ "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
data[0],
minProtocolVersion,
)
@@ -202,11 +196,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// return
//}
//return ErrMalformPkt
-
- // make a memory safe copy of the cipher slice
- var b [20]byte
- copy(b[:], cipher)
- return b[:], nil
+ return cipher, nil
}
// make a memory safe copy of the cipher slice
@@ -224,11 +214,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientLongPassword |
clientTransactions |
clientLocalFiles |
- clientPluginAuth |
- clientMultiResults |
mc.flags&clientLongFlag
- if mc.cfg.ClientFoundRows {
+ if mc.cfg.clientFoundRows {
clientFlags |= clientFoundRows
}
@@ -237,17 +225,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientFlags |= clientSSL
}
- if mc.cfg.MultiStatements {
- clientFlags |= clientMultiStatements
- }
-
// User Password
- scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
+ scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
- pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
+ pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
// To specify a db name
- if n := len(mc.cfg.DBName); n > 0 {
+ if n := len(mc.cfg.dbname); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
@@ -273,14 +257,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
data[11] = 0x00
// Charset [1 byte]
- var found bool
- data[12], found = collations[mc.cfg.Collation]
- if !found {
- // Note possibility for false negatives:
- // could be triggered although the collation is valid if the
- // collations map does not contain entries the server supports.
- return errors.New("unknown collation")
- }
+ data[12] = mc.cfg.collation
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
@@ -296,18 +273,15 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
return err
}
mc.netConn = tlsConn
- mc.buf.nc = tlsConn
+ mc.buf.rd = tlsConn
}
// Filler [23 bytes] (all 0x00)
- pos := 13
- for ; pos < 13+23; pos++ {
- data[pos] = 0
- }
+ pos := 13 + 23
// User [null terminated string]
- if len(mc.cfg.User) > 0 {
- pos += copy(data[pos:], mc.cfg.User)
+ if len(mc.cfg.user) > 0 {
+ pos += copy(data[pos:], mc.cfg.user)
}
data[pos] = 0x00
pos++
@@ -317,16 +291,11 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
pos += 1 + copy(data[pos+1:], scrambleBuff)
// Databasename [null terminated string]
- if len(mc.cfg.DBName) > 0 {
- pos += copy(data[pos:], mc.cfg.DBName)
+ if len(mc.cfg.dbname) > 0 {
+ pos += copy(data[pos:], mc.cfg.dbname)
data[pos] = 0x00
- pos++
}
- // Assume native client during response
- pos += copy(data[pos:], "mysql_native_password")
- data[pos] = 0x00
-
// Send Auth packet
return mc.writePacket(data)
}
@@ -335,9 +304,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
- scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
+ scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
- // Calculate the packet length and add a tailing 0
+ // Calculate the packet lenght and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
@@ -353,25 +322,6 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
return mc.writePacket(data)
}
-// Client clear text authentication packet
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeClearAuthPacket() error {
- // Calculate the packet length and add a tailing 0
- pktLen := len(mc.cfg.Passwd) + 1
- data := mc.buf.takeSmallBuffer(4 + pktLen)
- if data == nil {
- // can not take the buffer. Something must be wrong with the connection
- errLog.Print(ErrBusyBuffer)
- return driver.ErrBadConn
- }
-
- // Add the clear password [null terminated string]
- copy(data[4:], mc.cfg.Passwd)
- data[4+pktLen-1] = 0x00
-
- return mc.writePacket(data)
-}
-
/******************************************************************************
* Command Packets *
******************************************************************************/
@@ -455,20 +405,8 @@ func (mc *mysqlConn) readResultOK() error {
return mc.handleOkPacket(data)
case iEOF:
- if len(data) > 1 {
- plugin := string(data[1:bytes.IndexByte(data, 0x00)])
- if plugin == "mysql_old_password" {
- // using old_passwords
- return ErrOldPassword
- } else if plugin == "mysql_clear_password" {
- // using clear text password
- return ErrCleartextPassword
- } else {
- return ErrUnknownPlugin
- }
- } else {
- return ErrOldPassword
- }
+ // someone is using old_passwords
+ return ErrOldPassword
default: // Error otherwise
return mc.handleErrorPacket(data)
@@ -532,10 +470,6 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
}
}
-func readStatus(b []byte) statusFlag {
- return statusFlag(b[0]) | statusFlag(b[1])<<8
-}
-
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -550,21 +484,17 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
// server_status [2 bytes]
- mc.status = readStatus(data[1+n+m : 1+n+m+2])
- if err := mc.discardResults(); err != nil {
- return err
- }
// warning count [2 bytes]
if !mc.strict {
return nil
+ } else {
+ pos := 1 + n + m + 2
+ if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+ return mc.getWarnings()
+ }
+ return nil
}
-
- pos := 1 + n + m + 2
- if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
- return mc.getWarnings()
- }
- return nil
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -583,7 +513,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
if i == count {
return columns, nil
}
- return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
+ return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
}
// Catalog
@@ -600,20 +530,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
pos += n
// Table [len coded string]
- if mc.cfg.ColumnsWithAlias {
- tableName, _, n, err := readLengthEncodedString(data[pos:])
- if err != nil {
- return nil, err
- }
- pos += n
- columns[i].tableName = string(tableName)
- } else {
- n, err = skipLengthEncodedString(data[pos:])
- if err != nil {
- return nil, err
- }
- pos += n
+ n, err = skipLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
}
+ pos += n
// Original table [len coded string]
n, err = skipLengthEncodedString(data[pos:])
@@ -636,21 +557,20 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
return nil, err
}
- // Filler [uint8]
- // Charset [charset, collation uint8]
- // Length [uint32]
+ // Filler [1 byte]
+ // Charset [16 bit uint]
+ // Length [32 bit uint]
pos += n + 1 + 2 + 4
- // Field type [uint8]
+ // Field type [byte]
columns[i].fieldType = data[pos]
pos++
- // Flags [uint16]
+ // Flags [16 bit uint]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
- pos += 2
+ //pos += 2
- // Decimals [uint8]
- columns[i].decimals = data[pos]
+ // Decimals [8 bit uint]
//pos++
// Default value [len coded binary]
@@ -672,18 +592,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
- // server_status [2 bytes]
- rows.mc.status = readStatus(data[3:])
- if err := rows.mc.discardResults(); err != nil {
- return err
- }
- rows.mc = nil
return io.EOF
}
- if data[0] == iERR {
- rows.mc = nil
- return mc.handleErrorPacket(data)
- }
// RowSet Packet
var n int
@@ -704,7 +614,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
- mc.cfg.Loc,
+ mc.cfg.loc,
)
if err == nil {
continue
@@ -734,10 +644,6 @@ func (mc *mysqlConn) readUntilEOF() error {
if err == nil && data[0] != iEOF {
continue
}
- if err == nil && data[0] == iEOF && len(data) == 5 {
- mc.status = readStatus(data[3:])
- }
-
return err // Err or EOF
}
}
@@ -770,13 +676,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
+ } else {
+ // Check for warnings count > 0, only available in MySQL > 4.1
+ if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
+ return columnCount, stmt.mc.getWarnings()
+ }
+ return columnCount, nil
}
-
- // Check for warnings count > 0, only available in MySQL > 4.1
- if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
- return columnCount, stmt.mc.getWarnings()
- }
- return columnCount, nil
}
return 0, err
}
@@ -838,7 +744,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
- "argument count mismatch (got: %d; has: %d)",
+ "Arguments count mismatch (Got: %d Has: %d)",
len(args),
stmt.paramCount,
)
@@ -1015,7 +921,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
val = []byte("0000-00-00")
} else {
- val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
+ val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
}
paramValues = appendLengthEncodedInteger(paramValues,
@@ -1024,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
paramValues = append(paramValues, val...)
default:
- return fmt.Errorf("can not convert type: %T", arg)
+ return fmt.Errorf("Can't convert type: %T", arg)
}
}
@@ -1042,28 +948,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
return mc.writePacket(data)
}
-func (mc *mysqlConn) discardResults() error {
- for mc.status&statusMoreResultsExists != 0 {
- resLen, err := mc.readResultSetHeaderPacket()
- if err != nil {
- return err
- }
- if resLen > 0 {
- // columns
- if err := mc.readUntilEOF(); err != nil {
- return err
- }
- // rows
- if err := mc.readUntilEOF(); err != nil {
- return err
- }
- } else {
- mc.status &^= statusMoreResultsExists
- }
- }
- return nil
-}
-
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
@@ -1075,14 +959,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
- rows.mc.status = readStatus(data[3:])
- if err := rows.mc.discardResults(); err != nil {
- return err
- }
- rows.mc = nil
return io.EOF
}
- rows.mc = nil
// Error otherwise
return rows.mc.handleErrorPacket(data)
@@ -1149,7 +1027,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeFloat:
- dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
+ dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
pos += 4
continue
@@ -1162,7 +1040,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
- fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
+ fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
@@ -1177,53 +1055,88 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
}
return err
- case
- fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
- fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
- fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
-
+ // Date YYYY-MM-DD
+ case fieldTypeDate, fieldTypeNewDate:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
- switch {
- case isNull:
+ if isNull {
dest[i] = nil
continue
- case rows.columns[i].fieldType == fieldTypeTime:
- // database/sql does not support an equivalent to TIME, return a string
- var dstlen uint8
- switch decimals := rows.columns[i].decimals; decimals {
- case 0x00, 0x1f:
- dstlen = 8
- case 1, 2, 3, 4, 5, 6:
- dstlen = 8 + 1 + decimals
- default:
- return fmt.Errorf(
- "protocol error, illegal decimals value %d",
- rows.columns[i].decimals,
- )
- }
- dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
- case rows.mc.parseTime:
- dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
- default:
- var dstlen uint8
- if rows.columns[i].fieldType == fieldTypeDate {
- dstlen = 10
+ }
+
+ if rows.mc.parseTime {
+ dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
+ } else {
+ dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false)
+ }
+
+ if err == nil {
+ pos += int(num)
+ continue
+ } else {
+ return err
+ }
+
+ // Time [-][H]HH:MM:SS[.fractal]
+ case fieldTypeTime:
+ num, isNull, n := readLengthEncodedInteger(data[pos:])
+ pos += n
+
+ if num == 0 {
+ if isNull {
+ dest[i] = nil
+ continue
} else {
- switch decimals := rows.columns[i].decimals; decimals {
- case 0x00, 0x1f:
- dstlen = 19
- case 1, 2, 3, 4, 5, 6:
- dstlen = 19 + 1 + decimals
- default:
- return fmt.Errorf(
- "protocol error, illegal decimals value %d",
- rows.columns[i].decimals,
- )
- }
+ dest[i] = []byte("00:00:00")
+ continue
}
- dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
+ }
+
+ var sign string
+ if data[pos] == 1 {
+ sign = "-"
+ }
+
+ switch num {
+ case 8:
+ dest[i] = []byte(fmt.Sprintf(
+ sign+"%02d:%02d:%02d",
+ uint16(data[pos+1])*24+uint16(data[pos+5]),
+ data[pos+6],
+ data[pos+7],
+ ))
+ pos += 8
+ continue
+ case 12:
+ dest[i] = []byte(fmt.Sprintf(
+ sign+"%02d:%02d:%02d.%06d",
+ uint16(data[pos+1])*24+uint16(data[pos+5]),
+ data[pos+6],
+ data[pos+7],
+ binary.LittleEndian.Uint32(data[pos+8:pos+12]),
+ ))
+ pos += 12
+ continue
+ default:
+ return fmt.Errorf("Invalid TIME-packet length %d", num)
+ }
+
+ // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
+ case fieldTypeTimestamp, fieldTypeDateTime:
+ num, isNull, n := readLengthEncodedInteger(data[pos:])
+
+ pos += n
+
+ if isNull {
+ dest[i] = nil
+ continue
+ }
+
+ if rows.mc.parseTime {
+ dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
+ } else {
+ dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true)
}
if err == nil {
@@ -1235,7 +1148,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// Please report if this happens!
default:
- return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
+ return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType)
}
}