summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/statement.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/statement.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/statement.go55
1 files changed, 25 insertions, 30 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go
index e5071276a..7f9b04585 100644
--- a/vendor/github.com/go-sql-driver/mysql/statement.go
+++ b/vendor/github.com/go-sql-driver/mysql/statement.go
@@ -11,7 +11,6 @@ package mysql
import (
"database/sql/driver"
"fmt"
- "io"
"reflect"
"strconv"
)
@@ -20,6 +19,7 @@ type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
+ columns []mysqlField // cached from the first query
}
func (stmt *mysqlStmt) Close() error {
@@ -62,30 +62,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
- if err != nil {
- return nil, err
- }
-
- if resLen > 0 {
- // Columns
- if err = mc.readUntilEOF(); err != nil {
- return nil, err
+ if err == nil {
+ if resLen > 0 {
+ // Columns
+ err = mc.readUntilEOF()
+ if err != nil {
+ return nil, err
+ }
+
+ // Rows
+ err = mc.readUntilEOF()
}
-
- // Rows
- if err := mc.readUntilEOF(); err != nil {
- return nil, err
+ if err == nil {
+ return &mysqlResult{
+ affectedRows: int64(mc.affectedRows),
+ insertId: int64(mc.insertId),
+ }, nil
}
}
- if err := mc.discardResults(); err != nil {
- return nil, err
- }
-
- return &mysqlResult{
- affectedRows: int64(mc.affectedRows),
- insertId: int64(mc.insertId),
- }, nil
+ return nil, err
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
@@ -111,15 +107,14 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
if resLen > 0 {
rows.mc = mc
- rows.rs.columns, err = mc.readColumns(resLen)
- } else {
- rows.rs.done = true
-
- switch err := rows.NextResultSet(); err {
- case nil, io.EOF:
- return rows, nil
- default:
- return nil, err
+ // Columns
+ // If not cached, read them and cache them
+ if stmt.columns == nil {
+ rows.columns, err = mc.readColumns(resLen)
+ stmt.columns = rows.columns
+ } else {
+ rows.columns = stmt.columns
+ err = mc.readUntilEOF()
}
}