summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/statement.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/statement.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/statement.go45
1 files changed, 43 insertions, 2 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go
index 142ef5416..7f9b04585 100644
--- a/vendor/github.com/go-sql-driver/mysql/statement.go
+++ b/vendor/github.com/go-sql-driver/mysql/statement.go
@@ -10,6 +10,9 @@ package mysql
import (
"database/sql/driver"
+ "fmt"
+ "reflect"
+ "strconv"
)
type mysqlStmt struct {
@@ -21,7 +24,10 @@ type mysqlStmt struct {
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.netConn == nil {
- errLog.Print(ErrInvalidConn)
+ // driver.Stmt.Close can be called more than once, thus this function
+ // has to be idempotent.
+ // See also Issue #450 and golang/go#16019.
+ //errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
@@ -34,6 +40,10 @@ func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
+func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
+ return converter{}
+}
+
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
@@ -94,9 +104,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
}
rows := new(binaryRows)
- rows.mc = mc
if resLen > 0 {
+ rows.mc = mc
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
@@ -110,3 +120,34 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return rows, err
}
+
+type converter struct{}
+
+func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
+ if driver.IsValue(v) {
+ return v, nil
+ }
+
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Ptr:
+ // indirect pointers
+ if rv.IsNil() {
+ return nil, nil
+ }
+ return c.ConvertValue(rv.Elem().Interface())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return rv.Int(), nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
+ return int64(rv.Uint()), nil
+ case reflect.Uint64:
+ u64 := rv.Uint()
+ if u64 >= 1<<63 {
+ return strconv.FormatUint(u64, 10), nil
+ }
+ return int64(u64), nil
+ case reflect.Float32, reflect.Float64:
+ return rv.Float(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
+}