From 42f28ab8e374137fe3f5d25424489d879d4724f8 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Wed, 21 Jun 2017 19:06:17 -0700 Subject: Updating server dependancies (#6712) --- vendor/github.com/go-sql-driver/mysql/packets.go | 77 +++++++++++++----------- 1 file changed, 41 insertions(+), 36 deletions(-) (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go') diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index cb21397a2..aafe9793e 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -486,24 +486,23 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) { plugin := string(data[1:pluginEndIndex]) cipher := data[pluginEndIndex+1 : len(data)-1] - switch plugin { - case "mysql_old_password": + if plugin == "mysql_old_password" { // using old_passwords return cipher, ErrOldPassword - case "mysql_clear_password": + } else if plugin == "mysql_clear_password" { // using clear text password return cipher, ErrCleartextPassword - case "mysql_native_password": + } else if plugin == "mysql_native_password" { // using mysql default authentication method return cipher, ErrNativePassword - default: + } else { return cipher, ErrUnknownPlugin } + } else { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, ErrOldPassword } - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, ErrOldPassword - default: // Error otherwise return nil, mc.handleErrorPacket(data) } @@ -585,8 +584,8 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if mc.status&statusMoreResultsExists != 0 { - return nil + if err := mc.discardResults(); err != nil { + return err } // warning count [2 bytes] @@ -699,10 +698,6 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc - if rows.rs.done { - return io.EOF - } - data, err := mc.readPacket() if err != nil { return err @@ -712,11 +707,15 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - rows.rs.done = true - if !rows.HasNextResultSet() { - rows.mc = nil + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() } - return io.EOF + rows.mc = nil + return err } if data[0] == iERR { rows.mc = nil @@ -737,7 +736,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.rs.columns[i].fieldType { + switch rows.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -1098,6 +1097,8 @@ func (mc *mysqlConn) discardResults() error { if err := mc.readUntilEOF(); err != nil { return err } + } else { + mc.status &^= statusMoreResultsExists } } return nil @@ -1115,11 +1116,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - rows.rs.done = true - if !rows.HasNextResultSet() { - rows.mc = nil + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() } - return io.EOF + rows.mc = nil + return err } rows.mc = nil @@ -1140,14 +1145,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.rs.columns[i].fieldType { + switch rows.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.rs.columns[i].flags&flagUnsigned != 0 { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1156,7 +1161,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.rs.columns[i].flags&flagUnsigned != 0 { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1165,7 +1170,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.rs.columns[i].flags&flagUnsigned != 0 { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1174,7 +1179,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.rs.columns[i].flags&flagUnsigned != 0 { + if rows.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1188,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) + dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) pos += 4 continue @@ -1228,10 +1233,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.rs.columns[i].fieldType == fieldTypeTime: + case rows.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.rs.columns[i].decimals; decimals { + switch decimals := rows.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1239,7 +1244,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.rs.columns[i].decimals, + rows.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) @@ -1247,10 +1252,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.rs.columns[i].fieldType == fieldTypeDate { + if rows.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.rs.columns[i].decimals; decimals { + switch decimals := rows.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1258,7 +1263,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.rs.columns[i].decimals, + rows.columns[i].decimals, ) } } @@ -1274,7 +1279,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) } } -- cgit v1.2.3-1-g7c22