diff options
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/statement.go')
-rw-r--r-- | vendor/github.com/go-sql-driver/mysql/statement.go | 55 |
1 files changed, 25 insertions, 30 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index e5071276a..7f9b04585 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -11,7 +11,6 @@ package mysql import ( "database/sql/driver" "fmt" - "io" "reflect" "strconv" ) @@ -20,6 +19,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { @@ -62,30 +62,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err != nil { - return nil, err - } - - if resLen > 0 { - // Columns - if err = mc.readUntilEOF(); err != nil { - return nil, err + if err == nil { + if resLen > 0 { + // Columns + err = mc.readUntilEOF() + if err != nil { + return nil, err + } + + // Rows + err = mc.readUntilEOF() } - - // Rows - if err := mc.readUntilEOF(); err != nil { - return nil, err + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } } - if err := mc.discardResults(); err != nil { - return nil, err - } - - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + return nil, err } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -111,15 +107,14 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) - } else { - rows.rs.done = true - - switch err := rows.NextResultSet(); err { - case nil, io.EOF: - return rows, nil - default: - return nil, err + // Columns + // If not cached, read them and cache them + if stmt.columns == nil { + rows.columns, err = mc.readColumns(resLen) + stmt.columns = rows.columns + } else { + rows.columns = stmt.columns + err = mc.readUntilEOF() } } |