summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/packets.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/packets.go77
1 files changed, 41 insertions, 36 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go
index cb21397a2..aafe9793e 100644
--- a/vendor/github.com/go-sql-driver/mysql/packets.go
+++ b/vendor/github.com/go-sql-driver/mysql/packets.go
@@ -486,24 +486,23 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
plugin := string(data[1:pluginEndIndex])
cipher := data[pluginEndIndex+1 : len(data)-1]
- switch plugin {
- case "mysql_old_password":
+ if plugin == "mysql_old_password" {
// using old_passwords
return cipher, ErrOldPassword
- case "mysql_clear_password":
+ } else if plugin == "mysql_clear_password" {
// using clear text password
return cipher, ErrCleartextPassword
- case "mysql_native_password":
+ } else if plugin == "mysql_native_password" {
// using mysql default authentication method
return cipher, ErrNativePassword
- default:
+ } else {
return cipher, ErrUnknownPlugin
}
+ } else {
+ // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+ return nil, ErrOldPassword
}
- // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
- return nil, ErrOldPassword
-
default: // Error otherwise
return nil, mc.handleErrorPacket(data)
}
@@ -585,8 +584,8 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
- if mc.status&statusMoreResultsExists != 0 {
- return nil
+ if err := mc.discardResults(); err != nil {
+ return err
}
// warning count [2 bytes]
@@ -699,10 +698,6 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
- if rows.rs.done {
- return io.EOF
- }
-
data, err := mc.readPacket()
if err != nil {
return err
@@ -712,11 +707,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
- rows.rs.done = true
- if !rows.HasNextResultSet() {
- rows.mc = nil
+ err = rows.mc.discardResults()
+ if err == nil {
+ err = io.EOF
+ } else {
+ // connection unusable
+ rows.mc.Close()
}
- return io.EOF
+ rows.mc = nil
+ return err
}
if data[0] == iERR {
rows.mc = nil
@@ -737,7 +736,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if !mc.parseTime {
continue
} else {
- switch rows.rs.columns[i].fieldType {
+ switch rows.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
@@ -1098,6 +1097,8 @@ func (mc *mysqlConn) discardResults() error {
if err := mc.readUntilEOF(); err != nil {
return err
}
+ } else {
+ mc.status &^= statusMoreResultsExists
}
}
return nil
@@ -1115,11 +1116,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
- rows.rs.done = true
- if !rows.HasNextResultSet() {
- rows.mc = nil
+ err = rows.mc.discardResults()
+ if err == nil {
+ err = io.EOF
+ } else {
+ // connection unusable
+ rows.mc.Close()
}
- return io.EOF
+ rows.mc = nil
+ return err
}
rows.mc = nil
@@ -1140,14 +1145,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
}
// Convert to byte-coded string
- switch rows.rs.columns[i].fieldType {
+ switch rows.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue
// Numeric Types
case fieldTypeTiny:
- if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
@@ -1156,7 +1161,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeShort, fieldTypeYear:
- if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1165,7 +1170,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeInt24, fieldTypeLong:
- if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1174,7 +1179,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeLongLong:
- if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ if rows.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
@@ -1188,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeFloat:
- dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
+ dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
pos += 4
continue
@@ -1228,10 +1233,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case isNull:
dest[i] = nil
continue
- case rows.rs.columns[i].fieldType == fieldTypeTime:
+ case rows.columns[i].fieldType == fieldTypeTime:
// database/sql does not support an equivalent to TIME, return a string
var dstlen uint8
- switch decimals := rows.rs.columns[i].decimals; decimals {
+ switch decimals := rows.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 8
case 1, 2, 3, 4, 5, 6:
@@ -1239,7 +1244,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
- rows.rs.columns[i].decimals,
+ rows.columns[i].decimals,
)
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
@@ -1247,10 +1252,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
var dstlen uint8
- if rows.rs.columns[i].fieldType == fieldTypeDate {
+ if rows.columns[i].fieldType == fieldTypeDate {
dstlen = 10
} else {
- switch decimals := rows.rs.columns[i].decimals; decimals {
+ switch decimals := rows.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 19
case 1, 2, 3, 4, 5, 6:
@@ -1258,7 +1263,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
- rows.rs.columns[i].decimals,
+ rows.columns[i].decimals,
)
}
}
@@ -1274,7 +1279,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// Please report if this happens!
default:
- return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
+ return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
}
}