From 56e74239d6b34df8f30ef046f0b0ff4ff0866a71 Mon Sep 17 00:00:00 2001 From: =Corey Hulen Date: Sun, 14 Jun 2015 23:53:32 -0800 Subject: first commit --- .../src/github.com/go-gorp/gorp/.gitignore | 8 + .../src/github.com/go-gorp/gorp/.travis.yml | 23 + .../_workspace/src/github.com/go-gorp/gorp/LICENSE | 22 + .../src/github.com/go-gorp/gorp/Makefile | 6 + .../src/github.com/go-gorp/gorp/dialect.go | 696 +++++++ .../src/github.com/go-gorp/gorp/errors.go | 26 + .../_workspace/src/github.com/go-gorp/gorp/gorp.go | 2178 ++++++++++++++++++++ .../src/github.com/go-gorp/gorp/gorp_test.go | 2170 +++++++++++++++++++ .../src/github.com/go-gorp/gorp/test_all.sh | 22 + 9 files changed, 5151 insertions(+) create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/.gitignore create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/.travis.yml create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/LICENSE create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/Makefile create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/dialect.go create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/errors.go create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/gorp.go create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/gorp_test.go create mode 100644 Godeps/_workspace/src/github.com/go-gorp/gorp/test_all.sh (limited to 'Godeps/_workspace/src/github.com/go-gorp') diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/.gitignore b/Godeps/_workspace/src/github.com/go-gorp/gorp/.gitignore new file mode 100644 index 000000000..8a06adea5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/.gitignore @@ -0,0 +1,8 @@ +_test +_testmain.go +_obj +*~ +*.6 +6.out +gorptest.bin +tmp diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/.travis.yml b/Godeps/_workspace/src/github.com/go-gorp/gorp/.travis.yml new file mode 100644 index 000000000..6df5edf1c --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/.travis.yml @@ -0,0 +1,23 @@ +language: go +go: + - 1.2 + - 1.3 + - 1.4 + - tip + +services: + - mysql + - postgres + - sqlite3 + +before_script: + - mysql -e "CREATE DATABASE gorptest;" + - mysql -u root -e "GRANT ALL ON gorptest.* TO gorptest@localhost IDENTIFIED BY 'gorptest'" + - psql -c "CREATE DATABASE gorptest;" -U postgres + - psql -c "CREATE USER "gorptest" WITH SUPERUSER PASSWORD 'gorptest';" -U postgres + - go get github.com/lib/pq + - go get github.com/mattn/go-sqlite3 + - go get github.com/ziutek/mymysql/godrv + - go get github.com/go-sql-driver/mysql + +script: ./test_all.sh diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/LICENSE b/Godeps/_workspace/src/github.com/go-gorp/gorp/LICENSE new file mode 100644 index 000000000..b661111d0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/LICENSE @@ -0,0 +1,22 @@ +(The MIT License) + +Copyright (c) 2012 James Cooper + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +'Software'), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/Makefile b/Godeps/_workspace/src/github.com/go-gorp/gorp/Makefile new file mode 100644 index 000000000..3a27ae194 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/Makefile @@ -0,0 +1,6 @@ +include $(GOROOT)/src/Make.inc + +TARG = github.com/go-gorp/gorp +GOFILES = gorp.go dialect.go + +include $(GOROOT)/src/Make.pkg \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/dialect.go b/Godeps/_workspace/src/github.com/go-gorp/gorp/dialect.go new file mode 100644 index 000000000..8277a965e --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/dialect.go @@ -0,0 +1,696 @@ +package gorp + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// The Dialect interface encapsulates behaviors that differ across +// SQL databases. At present the Dialect is only used by CreateTables() +// but this could change in the future +type Dialect interface { + + // adds a suffix to any query, usually ";" + QuerySuffix() string + + // ToSqlType returns the SQL column type to use when creating a + // table of the given Go Type. maxsize can be used to switch based on + // size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB, + // or LONGBLOB depending on the maxsize + ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string + + // string to append to primary key column definitions + AutoIncrStr() string + + // string to bind autoincrement columns to. Empty string will + // remove reference to those columns in the INSERT statement. + AutoIncrBindValue() string + + AutoIncrInsertSuffix(col *ColumnMap) string + + // string to append to "create table" statement for vendor specific + // table attributes + CreateTableSuffix() string + + // string to truncate tables + TruncateClause() string + + // bind variable string to use when forming SQL statements + // in many dbs it is "?", but Postgres appears to use $1 + // + // i is a zero based index of the bind variable in this statement + // + BindVar(i int) string + + // Handles quoting of a field name to ensure that it doesn't raise any + // SQL parsing exceptions by using a reserved word as a field name. + QuoteField(field string) string + + // Handles building up of a schema.database string that is compatible with + // the given dialect + // + // schema - The schema that lives in + // table - The table name + QuotedTableForQuery(schema string, table string) string + + // Existance clause for table creation / deletion + IfSchemaNotExists(command, schema string) string + IfTableExists(command, schema, table string) string + IfTableNotExists(command, schema, table string) string +} + +// IntegerAutoIncrInserter is implemented by dialects that can perform +// inserts with automatically incremented integer primary keys. If +// the dialect can handle automatic assignment of more than just +// integers, see TargetedAutoIncrInserter. +type IntegerAutoIncrInserter interface { + InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) +} + +// TargetedAutoIncrInserter is implemented by dialects that can +// perform automatic assignment of any primary key type (i.e. strings +// for uuids, integers for serials, etc). +type TargetedAutoIncrInserter interface { + // InsertAutoIncrToTarget runs an insert operation and assigns the + // automatically generated primary key directly to the passed in + // target. The target should be a pointer to the primary key + // field of the value being inserted. + InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error +} + +func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + res, err := exec.Exec(insertSql, params...) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +/////////////////////////////////////////////////////// +// sqlite3 // +///////////// + +type SqliteDialect struct { + suffix string +} + +func (d SqliteDialect) QuerySuffix() string { return ";" } + +func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "integer" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return "integer" + case reflect.Float64, reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "blob" + } + } + + switch val.Name() { + case "NullInt64": + return "integer" + case "NullFloat64": + return "real" + case "NullBool": + return "integer" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns autoincrement +func (d SqliteDialect) AutoIncrStr() string { + return "autoincrement" +} + +func (d SqliteDialect) AutoIncrBindValue() string { + return "null" +} + +func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d SqliteDialect) CreateTableSuffix() string { + return d.suffix +} + +// With sqlite, there technically isn't a TRUNCATE statement, +// but a DELETE FROM uses a truncate optimization: +// http://www.sqlite.org/lang_delete.html +func (d SqliteDialect) TruncateClause() string { + return "delete from" +} + +// Returns "?" +func (d SqliteDialect) BindVar(i int) string { + return "?" +} + +func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d SqliteDialect) QuoteField(f string) string { + return `"` + f + `"` +} + +// sqlite does not have schemas like PostgreSQL does, so just escape it like normal +func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string { + return d.QuoteField(table) +} + +func (d SqliteDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d SqliteDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d SqliteDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} + +/////////////////////////////////////////////////////// +// PostgreSQL // +//////////////// + +type PostgresDialect struct { + suffix string +} + +func (d PostgresDialect) QuerySuffix() string { return ";" } + +func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "Time", "NullTime": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d PostgresDialect) AutoIncrStr() string { + return "" +} + +func (d PostgresDialect) AutoIncrBindValue() string { + return "default" +} + +func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + col.ColumnName +} + +// Returns suffix +func (d PostgresDialect) CreateTableSuffix() string { + return d.suffix +} + +func (d PostgresDialect) TruncateClause() string { + return "truncate" +} + +// Returns "$(i+1)" +func (d PostgresDialect) BindVar(i int) string { + return fmt.Sprintf("$%d", i+1) +} + +func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { + rows, err := exec.query(insertSql, params...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err()) + } + if err := rows.Scan(target); err != nil { + return err + } + if rows.Next() { + return fmt.Errorf("more than two serial value returned for insert: %s", insertSql) + } + return rows.Err() +} + +func (d PostgresDialect) QuoteField(f string) string { + return `"` + strings.ToLower(f) + `"` +} + +func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d PostgresDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d PostgresDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} + +/////////////////////////////////////////////////////// +// MySQL // +/////////// + +// Implementation of Dialect for MySQL databases. +type MySQLDialect struct { + + // Engine is the storage engine to use "InnoDB" vs "MyISAM" for example + Engine string + + // Encoding is the character encoding to use for created tables + Encoding string +} + +func (d MySQLDialect) QuerySuffix() string { return ";" } + +func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int8: + return "tinyint" + case reflect.Uint8: + return "tinyint unsigned" + case reflect.Int16: + return "smallint" + case reflect.Uint16: + return "smallint unsigned" + case reflect.Int, reflect.Int32: + return "int" + case reflect.Uint, reflect.Uint32: + return "int unsigned" + case reflect.Int64: + return "bigint" + case reflect.Uint64: + return "bigint unsigned" + case reflect.Float64, reflect.Float32: + return "double" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "mediumblob" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double" + case "NullBool": + return "tinyint" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns auto_increment +func (d MySQLDialect) AutoIncrStr() string { + return "auto_increment" +} + +func (d MySQLDialect) AutoIncrBindValue() string { + return "null" +} + +func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns engine=%s charset=%s based on values stored on struct +func (d MySQLDialect) CreateTableSuffix() string { + if d.Engine == "" || d.Encoding == "" { + msg := "gorp - undefined" + + if d.Engine == "" { + msg += " MySQLDialect.Engine" + } + if d.Engine == "" && d.Encoding == "" { + msg += "," + } + if d.Encoding == "" { + msg += " MySQLDialect.Encoding" + } + msg += ". Check that your MySQLDialect was correctly initialized when declared." + panic(msg) + } + + return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) +} + +func (d MySQLDialect) TruncateClause() string { + return "truncate" +} + +// Returns "?" +func (d MySQLDialect) BindVar(i int) string { + return "?" +} + +func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d MySQLDialect) QuoteField(f string) string { + return "`" + f + "`" +} + +func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d MySQLDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d MySQLDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} + +/////////////////////////////////////////////////////// +// Sql Server // +//////////////// + +// Implementation of Dialect for Microsoft SQL Server databases. +// Tested on SQL Server 2008 with driver: github.com/denisenkom/go-mssqldb + +type SqlServerDialect struct { + suffix string +} + +func (d SqlServerDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "bit" + case reflect.Int8: + return "tinyint" + case reflect.Uint8: + return "smallint" + case reflect.Int16: + return "smallint" + case reflect.Uint16: + return "int" + case reflect.Int, reflect.Int32: + return "int" + case reflect.Uint, reflect.Uint32: + return "bigint" + case reflect.Int64: + return "bigint" + case reflect.Uint64: + return "bigint" + case reflect.Float32: + return "real" + case reflect.Float64: + return "float(53)" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "varbinary" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "float(53)" + case "NullBool": + return "tinyint" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns auto_increment +func (d SqlServerDialect) AutoIncrStr() string { + return "identity(0,1)" +} + +// Empty string removes autoincrement columns from the INSERT statements. +func (d SqlServerDialect) AutoIncrBindValue() string { + return "" +} + +func (d SqlServerDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d SqlServerDialect) CreateTableSuffix() string { + + return d.suffix +} + +func (d SqlServerDialect) TruncateClause() string { + return "delete from" +} + +// Returns "?" +func (d SqlServerDialect) BindVar(i int) string { + return "?" +} + +func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d SqlServerDialect) QuoteField(f string) string { + return `"` + f + `"` +} + +func (d SqlServerDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return table + } + return schema + "." + table +} + +func (d SqlServerDialect) QuerySuffix() string { return ";" } + +func (d SqlServerDialect) IfSchemaNotExists(command, schema string) string { + s := fmt.Sprintf("if not exists (select name from sys.schemas where name = '%s') %s", schema, command) + return s +} + +func (d SqlServerDialect) IfTableExists(command, schema, table string) string { + var schema_clause string + if strings.TrimSpace(schema) != "" { + schema_clause = fmt.Sprintf("table_schema = '%s' and ", schema) + } + s := fmt.Sprintf("if exists (select * from information_schema.tables where %stable_name = '%s') %s", schema_clause, table, command) + return s +} + +func (d SqlServerDialect) IfTableNotExists(command, schema, table string) string { + var schema_clause string + if strings.TrimSpace(schema) != "" { + schema_clause = fmt.Sprintf("table_schema = '%s' and ", schema) + } + s := fmt.Sprintf("if not exists (select * from information_schema.tables where %stable_name = '%s') %s", schema_clause, table, command) + return s +} + +/////////////////////////////////////////////////////// +// Oracle // +/////////// + +// Implementation of Dialect for Oracle databases. +type OracleDialect struct{} + +func (d OracleDialect) QuerySuffix() string { return "" } + +func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "NullTime", "Time": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d OracleDialect) AutoIncrStr() string { + return "" +} + +func (d OracleDialect) AutoIncrBindValue() string { + return "default" +} + +func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + col.ColumnName +} + +// Returns suffix +func (d OracleDialect) CreateTableSuffix() string { + return "" +} + +func (d OracleDialect) TruncateClause() string { + return "truncate" +} + +// Returns "$(i+1)" +func (d OracleDialect) BindVar(i int) string { + return fmt.Sprintf(":%d", i+1) +} + +func (d OracleDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + rows, err := exec.query(insertSql, params...) + if err != nil { + return 0, err + } + defer rows.Close() + + if rows.Next() { + var id int64 + err := rows.Scan(&id) + return id, err + } + + return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) +} + +func (d OracleDialect) QuoteField(f string) string { + return `"` + strings.ToUpper(f) + `"` +} + +func (d OracleDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d OracleDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d OracleDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d OracleDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/errors.go b/Godeps/_workspace/src/github.com/go-gorp/gorp/errors.go new file mode 100644 index 000000000..356d68475 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/errors.go @@ -0,0 +1,26 @@ +package gorp + +import ( + "fmt" +) + +// A non-fatal error, when a select query returns columns that do not exist +// as fields in the struct it is being mapped to +type NoFieldInTypeError struct { + TypeName string + MissingColNames []string +} + +func (err *NoFieldInTypeError) Error() string { + return fmt.Sprintf("gorp: No fields %+v in type %s", err.MissingColNames, err.TypeName) +} + +// returns true if the error is non-fatal (ie, we shouldn't immediately return) +func NonFatalError(err error) bool { + switch err.(type) { + case *NoFieldInTypeError: + return true + default: + return false + } +} diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp.go b/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp.go new file mode 100644 index 000000000..4c91b6f78 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp.go @@ -0,0 +1,2178 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package gorp provides a simple way to marshal Go structs to and from +// SQL databases. It uses the database/sql package, and should work with any +// compliant database/sql driver. +// +// Source code and project home: +// https://github.com/go-gorp/gorp +// +package gorp + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "regexp" + "strings" + "time" +) + +// Oracle String (empty string is null) +type OracleString struct { + sql.NullString +} + +// Scan implements the Scanner interface. +func (os *OracleString) Scan(value interface{}) error { + if value == nil { + os.String, os.Valid = "", false + return nil + } + os.Valid = true + return os.NullString.Scan(value) +} + +// Value implements the driver Valuer interface. +func (os OracleString) Value() (driver.Value, error) { + if !os.Valid || os.String == "" { + return nil, nil + } + return os.String, nil +} + +// A nullable Time value +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + switch t := value.(type) { + case time.Time: + nt.Time, nt.Valid = t, true + case []byte: + nt.Valid = false + for _, dtfmt := range []string{ + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04", + "2006-01-02T15:04", + "2006-01-02", + "2006-01-02 15:04:05-07:00", + } { + var err error + if nt.Time, err = time.Parse(dtfmt, string(t)); err == nil { + nt.Valid = true + break + } + } + } + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +var zeroVal reflect.Value +var versFieldConst = "[gorp_ver_field]" + +// OptimisticLockError is returned by Update() or Delete() if the +// struct being modified has a Version field and the value is not equal to +// the current value in the database +type OptimisticLockError struct { + // Table name where the lock error occurred + TableName string + + // Primary key values of the row being updated/deleted + Keys []interface{} + + // true if a row was found with those keys, indicating the + // LocalVersion is stale. false if no value was found with those + // keys, suggesting the row has been deleted since loaded, or + // was never inserted to begin with + RowExists bool + + // Version value on the struct passed to Update/Delete. This value is + // out of sync with the database. + LocalVersion int64 +} + +// Error returns a description of the cause of the lock error +func (e OptimisticLockError) Error() string { + if e.RowExists { + return fmt.Sprintf("gorp: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion) + } + + return fmt.Sprintf("gorp: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys) +} + +// The TypeConverter interface provides a way to map a value of one +// type to another type when persisting to, or loading from, a database. +// +// Example use cases: Implement type converter to convert bool types to "y"/"n" strings, +// or serialize a struct member as a JSON blob. +type TypeConverter interface { + // ToDb converts val to another type. Called before INSERT/UPDATE operations + ToDb(val interface{}) (interface{}, error) + + // FromDb returns a CustomScanner appropriate for this type. This will be used + // to hold values returned from SELECT queries. + // + // In particular the CustomScanner returned should implement a Binder + // function appropriate for the Go type you wish to convert the db value to + // + // If bool==false, then no custom scanner will be used for this field. + FromDb(target interface{}) (CustomScanner, bool) +} + +// CustomScanner binds a database column value to a Go type +type CustomScanner struct { + // After a row is scanned, Holder will contain the value from the database column. + // Initialize the CustomScanner with the concrete Go type you wish the database + // driver to scan the raw column into. + Holder interface{} + // Target typically holds a pointer to the target struct field to bind the Holder + // value to. + Target interface{} + // Binder is a custom function that converts the holder value to the target type + // and sets target accordingly. This function should return error if a problem + // occurs converting the holder to the target. + Binder func(holder interface{}, target interface{}) error +} + +// Bind is called automatically by gorp after Scan() +func (me CustomScanner) Bind() error { + return me.Binder(me.Holder, me.Target) +} + +// DbMap is the root gorp mapping object. Create one of these for each +// database schema you wish to map. Each DbMap contains a list of +// mapped tables. +// +// Example: +// +// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"} +// dbmap := &gorp.DbMap{Db: db, Dialect: dialect} +// +type DbMap struct { + // Db handle to use with this map + Db *sql.DB + + // Dialect implementation to use with this map + Dialect Dialect + + TypeConverter TypeConverter + + tables []*TableMap + logger GorpLogger + logPrefix string +} + +// TableMap represents a mapping between a Go struct and a database table +// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these +type TableMap struct { + // Name of database table. + TableName string + SchemaName string + gotype reflect.Type + Columns []*ColumnMap + keys []*ColumnMap + uniqueTogether [][]string + version *ColumnMap + insertPlan bindPlan + updatePlan bindPlan + deletePlan bindPlan + getPlan bindPlan + dbmap *DbMap +} + +// ResetSql removes cached insert/update/select/delete SQL strings +// associated with this TableMap. Call this if you've modified +// any column names or the table name itself. +func (t *TableMap) ResetSql() { + t.insertPlan = bindPlan{} + t.updatePlan = bindPlan{} + t.deletePlan = bindPlan{} + t.getPlan = bindPlan{} +} + +// SetKeys lets you specify the fields on a struct that map to primary +// key columns on the table. If isAutoIncr is set, result.LastInsertId() +// will be used after INSERT to bind the generated id to the Go struct. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if isAutoIncr is true, and fieldNames length != 1 +// +func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap { + if isAutoIncr && len(fieldNames) != 1 { + panic(fmt.Sprintf( + "gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)", + len(fieldNames))) + } + t.keys = make([]*ColumnMap, 0) + for _, name := range fieldNames { + colmap := t.ColMap(name) + colmap.isPK = true + colmap.isAutoIncr = isAutoIncr + t.keys = append(t.keys, colmap) + } + t.ResetSql() + + return t +} + +// SetUniqueTogether lets you specify uniqueness constraints across multiple +// columns on the table. Each call adds an additional constraint for the +// specified columns. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if fieldNames length < 2. +// +func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap { + if len(fieldNames) < 2 { + panic(fmt.Sprintf( + "gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint.")) + } + + columns := make([]string, 0) + for _, name := range fieldNames { + columns = append(columns, name) + } + t.uniqueTogether = append(t.uniqueTogether, columns) + t.ResetSql() + + return t +} + +// ColMap returns the ColumnMap pointer matching the given struct field +// name. It panics if the struct does not contain a field matching this +// name. +func (t *TableMap) ColMap(field string) *ColumnMap { + col := colMapOrNil(t, field) + if col == nil { + e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s", + t.TableName, t.gotype.Name(), field) + + panic(e) + } + return col +} + +func colMapOrNil(t *TableMap, field string) *ColumnMap { + for _, col := range t.Columns { + if col.fieldName == field || col.ColumnName == field { + return col + } + } + return nil +} + +// SetVersionCol sets the column to use as the Version field. By default +// the "Version" field is used. Returns the column found, or panics +// if the struct does not contain a field matching this name. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +func (t *TableMap) SetVersionCol(field string) *ColumnMap { + c := t.ColMap(field) + t.version = c + t.ResetSql() + return c +} + +type bindPlan struct { + query string + argFields []string + keyFields []string + versField string + autoIncrIdx int + autoIncrFieldName string +} + +func (plan bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) { + bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField} + if plan.versField != "" { + bi.existingVersion = elem.FieldByName(plan.versField).Int() + } + + var err error + + for i := 0; i < len(plan.argFields); i++ { + k := plan.argFields[i] + if k == versFieldConst { + newVer := bi.existingVersion + 1 + bi.args = append(bi.args, newVer) + if bi.existingVersion == 0 { + elem.FieldByName(plan.versField).SetInt(int64(newVer)) + } + } else { + val := elem.FieldByName(k).Interface() + if conv != nil { + val, err = conv.ToDb(val) + if err != nil { + return bindInstance{}, err + } + } + bi.args = append(bi.args, val) + } + } + + for i := 0; i < len(plan.keyFields); i++ { + k := plan.keyFields[i] + val := elem.FieldByName(k).Interface() + if conv != nil { + val, err = conv.ToDb(val) + if err != nil { + return bindInstance{}, err + } + } + bi.keys = append(bi.keys, val) + } + + return bi, nil +} + +type bindInstance struct { + query string + args []interface{} + keys []interface{} + existingVersion int64 + versField string + autoIncrIdx int + autoIncrFieldName string +} + +func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { + plan := t.insertPlan + if plan.query == "" { + plan.autoIncrIdx = -1 + + s := bytes.Buffer{} + s2 := bytes.Buffer{} + s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + + x := 0 + first := true + for y := range t.Columns { + col := t.Columns[y] + if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") { + if !col.Transient { + if !first { + s.WriteString(",") + s2.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + + if col.isAutoIncr { + s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName + } else { + s2.WriteString(t.dbmap.Dialect.BindVar(x)) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + + x++ + } + first = false + } + } else { + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName + } + } + s.WriteString(") values (") + s.WriteString(s2.String()) + s.WriteString(")") + if plan.autoIncrIdx > -1 { + s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + t.insertPlan = plan + } + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) { + plan := t.updatePlan + if plan.query == "" { + + s := bytes.Buffer{} + s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + x := 0 + + for y := range t.Columns { + col := t.Columns[y] + if !col.isAutoIncr && !col.Transient { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + x++ + } + } + + s.WriteString(" where ") + for y := range t.keys { + col := t.keys[y] + if y > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.argFields = append(plan.argFields, col.fieldName) + plan.keyFields = append(plan.keyFields, col.fieldName) + x++ + } + if plan.versField != "" { + s.WriteString(" and ") + s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + plan.argFields = append(plan.argFields, plan.versField) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + t.updatePlan = plan + } + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { + plan := t.deletePlan + if plan.query == "" { + + s := bytes.Buffer{} + s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + + for y := range t.Columns { + col := t.Columns[y] + if !col.Transient { + if col == t.version { + plan.versField = col.fieldName + } + } + } + + s.WriteString(" where ") + for x := range t.keys { + k := t.keys[x] + if x > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.keyFields = append(plan.keyFields, k.fieldName) + plan.argFields = append(plan.argFields, k.fieldName) + } + if plan.versField != "" { + s.WriteString(" and ") + s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields))) + + plan.argFields = append(plan.argFields, plan.versField) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + t.deletePlan = plan + } + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindGet() bindPlan { + plan := t.getPlan + if plan.query == "" { + + s := bytes.Buffer{} + s.WriteString("select ") + + x := 0 + for _, col := range t.Columns { + if !col.Transient { + if x > 0 { + s.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + plan.argFields = append(plan.argFields, col.fieldName) + x++ + } + } + s.WriteString(" from ") + s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)) + s.WriteString(" where ") + for x := range t.keys { + col := t.keys[x] + if x > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.keyFields = append(plan.keyFields, col.fieldName) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + t.getPlan = plan + } + + return plan +} + +// ColumnMap represents a mapping between a Go struct field and a single +// column in a table. +// Unique and MaxSize only inform the +// CreateTables() function and are not used by Insert/Update/Delete/Get. +type ColumnMap struct { + // Column name in db table + ColumnName string + + // If true, this column is skipped in generated SQL statements + Transient bool + + // If true, " unique" is added to create table statements. + // Not used elsewhere + Unique bool + + // Passed to Dialect.ToSqlType() to assist in informing the + // correct column type to map to in CreateTables() + // Not used elsewhere + MaxSize int + + fieldName string + gotype reflect.Type + isPK bool + isAutoIncr bool + isNotNull bool +} + +// Rename allows you to specify the column name in the table +// +// Example: table.ColMap("Updated").Rename("date_updated") +// +func (c *ColumnMap) Rename(colname string) *ColumnMap { + c.ColumnName = colname + return c +} + +// SetTransient allows you to mark the column as transient. If true +// this column will be skipped when SQL statements are generated +func (c *ColumnMap) SetTransient(b bool) *ColumnMap { + c.Transient = b + return c +} + +// SetUnique adds "unique" to the create table statements for this +// column, if b is true. +func (c *ColumnMap) SetUnique(b bool) *ColumnMap { + c.Unique = b + return c +} + +// SetNotNull adds "not null" to the create table statements for this +// column, if nn is true. +func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap { + c.isNotNull = nn + return c +} + +// SetMaxSize specifies the max length of values of this column. This is +// passed to the dialect.ToSqlType() function, which can use the value +// to alter the generated type for "create table" statements +func (c *ColumnMap) SetMaxSize(size int) *ColumnMap { + c.MaxSize = size + return c +} + +// Transaction represents a database transaction. +// Insert/Update/Delete/Get/Exec operations will be run in the context +// of that transaction. Transactions should be terminated with +// a call to Commit() or Rollback() +type Transaction struct { + dbmap *DbMap + tx *sql.Tx + closed bool +} + +// Executor exposes the sql.DB and sql.Tx Exec function so that it can be used +// on internal functions that convert named parameters for the Exec function. +type executor interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// SqlExecutor exposes gorp operations that can be run from Pre/Post +// hooks. This hides whether the current operation that triggered the +// hook is in a transaction. +// +// See the DbMap function docs for each of the functions below for more +// information. +type SqlExecutor interface { + Get(i interface{}, keys ...interface{}) (interface{}, error) + Insert(list ...interface{}) error + Update(list ...interface{}) (int64, error) + Delete(list ...interface{}) (int64, error) + Exec(query string, args ...interface{}) (sql.Result, error) + Select(i interface{}, query string, + args ...interface{}) ([]interface{}, error) + SelectInt(query string, args ...interface{}) (int64, error) + SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) + SelectFloat(query string, args ...interface{}) (float64, error) + SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) + SelectStr(query string, args ...interface{}) (string, error) + SelectNullStr(query string, args ...interface{}) (sql.NullString, error) + SelectOne(holder interface{}, query string, args ...interface{}) error + query(query string, args ...interface{}) (*sql.Rows, error) + queryRow(query string, args ...interface{}) *sql.Row +} + +// Compile-time check that DbMap and Transaction implement the SqlExecutor +// interface. +var _, _ SqlExecutor = &DbMap{}, &Transaction{} + +type GorpLogger interface { + Printf(format string, v ...interface{}) +} + +// TraceOn turns on SQL statement logging for this DbMap. After this is +// called, all SQL statements will be sent to the logger. If prefix is +// a non-empty string, it will be written to the front of all logged +// strings, which can aid in filtering log lines. +// +// Use TraceOn if you want to spy on the SQL statements that gorp +// generates. +// +// Note that the base log.Logger type satisfies GorpLogger, but adapters can +// easily be written for other logging packages (e.g., the golang-sanctioned +// glog framework). +func (m *DbMap) TraceOn(prefix string, logger GorpLogger) { + m.logger = logger + if prefix == "" { + m.logPrefix = prefix + } else { + m.logPrefix = fmt.Sprintf("%s ", prefix) + } +} + +// TraceOff turns off tracing. It is idempotent. +func (m *DbMap) TraceOff() { + m.logger = nil + m.logPrefix = "" +} + +// AddTable registers the given interface type with gorp. The table name +// will be given the name of the TypeOf(i). You must call this function, +// or AddTableWithName, for any struct type you wish to persist with +// the given DbMap. +// +// This operation is idempotent. If i's type is already mapped, the +// existing *TableMap is returned +func (m *DbMap) AddTable(i interface{}) *TableMap { + return m.AddTableWithName(i, "") +} + +// AddTableWithName has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { + return m.AddTableWithNameAndSchema(i, "", name) +} + +// AddTableWithNameAndSchema has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap { + t := reflect.TypeOf(i) + if name == "" { + name = t.Name() + } + + // check if we have a table for this type already + // if so, update the name and return the existing pointer + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + table.TableName = name + return table + } + } + + tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} + tmap.Columns = m.readStructColumns(t) + m.tables = append(m.tables, tmap) + + return tmap +} + +func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap) { + n := t.NumField() + for i := 0; i < n; i++ { + f := t.Field(i) + if f.Anonymous && f.Type.Kind() == reflect.Struct { + // Recursively add nested fields in embedded structs. + subcols := m.readStructColumns(f.Type) + // Don't append nested fields that have the same field + // name as an already-mapped field. + for _, subcol := range subcols { + shouldAppend := true + for _, col := range cols { + if !subcol.Transient && subcol.fieldName == col.fieldName { + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, subcol) + } + } + } else { + columnName := f.Tag.Get("db") + if columnName == "" { + columnName = f.Name + } + gotype := f.Type + if m.TypeConverter != nil { + // Make a new pointer to a value of type gotype and + // pass it to the TypeConverter's FromDb method to see + // if a different type should be used for the column + // type during table creation. + value := reflect.New(gotype).Interface() + scanner, useHolder := m.TypeConverter.FromDb(value) + if useHolder { + gotype = reflect.TypeOf(scanner.Holder) + } + } + cm := &ColumnMap{ + ColumnName: columnName, + Transient: columnName == "-", + fieldName: f.Name, + gotype: gotype, + } + // Check for nested fields of the same field name and + // override them. + shouldAppend := true + for index, col := range cols { + if !col.Transient && col.fieldName == cm.fieldName { + cols[index] = cm + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, cm) + } + } + } + return +} + +// CreateTables iterates through TableMaps registered to this DbMap and +// executes "create table" statements against the database for each. +// +// This is particularly useful in unit tests where you want to create +// and destroy the schema automatically. +func (m *DbMap) CreateTables() error { + return m.createTables(false) +} + +// CreateTablesIfNotExists is similar to CreateTables, but starts +// each statement with "create table if not exists" so that existing +// tables do not raise errors +func (m *DbMap) CreateTablesIfNotExists() error { + return m.createTables(true) +} + +func (m *DbMap) createTables(ifNotExists bool) error { + var err error + for i := range m.tables { + table := m.tables[i] + + s := bytes.Buffer{} + + if strings.TrimSpace(table.SchemaName) != "" { + schemaCreate := "create schema" + if ifNotExists { + s.WriteString(m.Dialect.IfSchemaNotExists(schemaCreate, table.SchemaName)) + } else { + s.WriteString(schemaCreate) + } + s.WriteString(fmt.Sprintf(" %s;", table.SchemaName)) + } + + tableCreate := "create table" + if ifNotExists { + s.WriteString(m.Dialect.IfTableNotExists(tableCreate, table.SchemaName, table.TableName)) + } else { + s.WriteString(tableCreate) + } + s.WriteString(fmt.Sprintf(" %s (", m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + + x := 0 + for _, col := range table.Columns { + if !col.Transient { + if x > 0 { + s.WriteString(", ") + } + stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr) + s.WriteString(fmt.Sprintf("%s %s", m.Dialect.QuoteField(col.ColumnName), stype)) + + if col.isPK || col.isNotNull { + s.WriteString(" not null") + } + if col.isPK && len(table.keys) == 1 { + s.WriteString(" primary key") + } + if col.Unique { + s.WriteString(" unique") + } + if col.isAutoIncr { + s.WriteString(fmt.Sprintf(" %s", m.Dialect.AutoIncrStr())) + } + + x++ + } + } + if len(table.keys) > 1 { + s.WriteString(", primary key (") + for x := range table.keys { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName)) + } + s.WriteString(")") + } + if len(table.uniqueTogether) > 0 { + for _, columns := range table.uniqueTogether { + s.WriteString(", unique (") + for i, column := range columns { + if i > 0 { + s.WriteString(", ") + } + s.WriteString(m.Dialect.QuoteField(column)) + } + s.WriteString(")") + } + } + s.WriteString(") ") + s.WriteString(m.Dialect.CreateTableSuffix()) + s.WriteString(m.Dialect.QuerySuffix()) + _, err = m.Exec(s.String()) + if err != nil { + break + } + } + return err +} + +// DropTable drops an individual table. Will throw an error +// if the table does not exist. +func (m *DbMap) DropTable(table interface{}) error { + t := reflect.TypeOf(table) + return m.dropTable(t, false) +} + +// DropTable drops an individual table. Will NOT throw an error +// if the table does not exist. +func (m *DbMap) DropTableIfExists(table interface{}) error { + t := reflect.TypeOf(table) + return m.dropTable(t, true) +} + +// DropTables iterates through TableMaps registered to this DbMap and +// executes "drop table" statements against the database for each. +func (m *DbMap) DropTables() error { + return m.dropTables(false) +} + +// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to +// avoid errors for tables that do not exist. +func (m *DbMap) DropTablesIfExists() error { + return m.dropTables(true) +} + +// Goes through all the registered tables, dropping them one by one. +// If an error is encountered, then it is returned and the rest of +// the tables are not dropped. +func (m *DbMap) dropTables(addIfExists bool) (err error) { + for _, table := range m.tables { + err = m.dropTableImpl(table, addIfExists) + if err != nil { + return + } + } + return err +} + +// Implementation of dropping a single table. +func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error { + table := tableOrNil(m, t) + if table == nil { + return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName)) + } + + return m.dropTableImpl(table, addIfExists) +} + +func (m *DbMap) dropTableImpl(table *TableMap, ifExists bool) (err error) { + tableDrop := "drop table" + if ifExists { + tableDrop = m.Dialect.IfTableExists(tableDrop, table.SchemaName, table.TableName) + } + _, err = m.Exec(fmt.Sprintf("%s %s;", tableDrop, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + return err +} + +// TruncateTables iterates through TableMaps registered to this DbMap and +// executes "truncate table" statements against the database for each, or in the case of +// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization +// (http://www.sqlite.org/lang_delete.html) +func (m *DbMap) TruncateTables() error { + var err error + for i := range m.tables { + table := m.tables[i] + _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + if e != nil { + err = e + } + } + return err +} + +// Insert runs a SQL INSERT statement for each element in list. List +// items must be pointers. +// +// Any interface whose TableMap has an auto-increment primary key will +// have its last insert id bound to the PK field on the struct. +// +// The hook functions PreInsert() and/or PostInsert() will be executed +// before/after the INSERT statement if the interface defines them. +// +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Insert(list ...interface{}) error { + return insert(m, m, list...) +} + +// Update runs a SQL UPDATE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreUpdate() and/or PostUpdate() will be executed +// before/after the UPDATE statement if the interface defines them. +// +// Returns the number of rows updated. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Update(list ...interface{}) (int64, error) { + return update(m, m, list...) +} + +// Delete runs a SQL DELETE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreDelete() and/or PostDelete() will be executed +// before/after the DELETE statement if the interface defines them. +// +// Returns the number of rows deleted. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Delete(list ...interface{}) (int64, error) { + return delete(m, m, list...) +} + +// Get runs a SQL SELECT to fetch a single row from the table based on the +// primary key(s) +// +// i should be an empty value for the struct to load. keys should be +// the primary key value(s) for the row to load. If multiple keys +// exist on the table, the order should match the column order +// specified in SetKeys() when the table mapping was defined. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Returns a pointer to a struct that matches or nil if no row is found. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) { + return get(m, m, i, keys...) +} + +// Select runs an arbitrary SQL query, binding the columns in the result +// to fields on the struct specified by i. args represent the bind +// parameters for the SQL statement. +// +// Column names on the SELECT statement should be aliased to the field names +// on the struct i. Returns an error if one or more columns in the result +// do not match. It is OK if fields on i are not part of the SQL +// statement. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Values are returned in one of two ways: +// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to +// matching rows of type i. +// 2. If i is a pointer to a slice, the results will be appended to that slice +// and nil returned. +// +// i does NOT need to be registered with AddTable() +func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return hookedselect(m, m, i, query, args...) +} + +// Exec runs an arbitrary SQL statement. args represent the bind parameters. +// This is equivalent to running: Exec() using database/sql +func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, args...) + } + return exec(m, query, args...) +} + +// SelectInt is a convenience wrapper around the gorp.SelectInt function +func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) { + return SelectInt(m, query, args...) +} + +// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function +func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + return SelectNullInt(m, query, args...) +} + +// SelectFloat is a convenience wrapper around the gorp.SelectFlot function +func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) { + return SelectFloat(m, query, args...) +} + +// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function +func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + return SelectNullFloat(m, query, args...) +} + +// SelectStr is a convenience wrapper around the gorp.SelectStr function +func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) { + return SelectStr(m, query, args...) +} + +// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function +func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + return SelectNullStr(m, query, args...) +} + +// SelectOne is a convenience wrapper around the gorp.SelectOne function +func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error { + return SelectOne(m, m, holder, query, args...) +} + +// Begin starts a gorp Transaction +func (m *DbMap) Begin() (*Transaction, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, "begin;") + } + tx, err := m.Db.Begin() + if err != nil { + return nil, err + } + return &Transaction{m, tx, false}, nil +} + +// TableFor returns the *TableMap corresponding to the given Go Type +// If no table is mapped to that type an error is returned. +// If checkPK is true and the mapped table has no registered PKs, an error is returned. +func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) { + table := tableOrNil(m, t) + if table == nil { + return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name())) + } + + if checkPK && len(table.keys) < 1 { + e := fmt.Sprintf("gorp: No keys defined for table: %s", + table.TableName) + return nil, errors.New(e) + } + + return table, nil +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the returned statement. +// This is equivalent to running: Prepare() using database/sql +func (m *DbMap) Prepare(query string) (*sql.Stmt, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, nil) + } + return m.Db.Prepare(query) +} + +func tableOrNil(m *DbMap, t reflect.Type) *TableMap { + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + return table + } + } + return nil +} + +func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) { + ptrv := reflect.ValueOf(ptr) + if ptrv.Kind() != reflect.Ptr { + e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr, + ptrv.Kind()) + return nil, reflect.Value{}, errors.New(e) + } + elem := ptrv.Elem() + etype := reflect.TypeOf(elem.Interface()) + t, err := m.TableFor(etype, checkPK) + if err != nil { + return nil, reflect.Value{}, err + } + + return t, elem, nil +} + +func (m *DbMap) queryRow(query string, args ...interface{}) *sql.Row { + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, args...) + } + return m.Db.QueryRow(query, args...) +} + +func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, args...) + } + return m.Db.Query(query, args...) +} + +func (m *DbMap) trace(started time.Time, query string, args ...interface{}) { + if m.logger != nil { + var margs = argsString(args...) + m.logger.Printf("%s%s [%s] (%v)", m.logPrefix, query, margs, (time.Now().Sub(started))) + } +} + +func argsString(args ...interface{}) string { + var margs string + for i, a := range args { + var v interface{} = a + if x, ok := v.(driver.Valuer); ok { + y, err := x.Value() + if err == nil { + v = y + } + } + switch v.(type) { + case string: + v = fmt.Sprintf("%q", v) + default: + v = fmt.Sprintf("%v", v) + } + margs += fmt.Sprintf("%d:%s", i+1, v) + if i+1 < len(args) { + margs += " " + } + } + return margs +} + +/////////////// + +// Insert has the same behavior as DbMap.Insert(), but runs in a transaction. +func (t *Transaction) Insert(list ...interface{}) error { + return insert(t.dbmap, t, list...) +} + +// Update had the same behavior as DbMap.Update(), but runs in a transaction. +func (t *Transaction) Update(list ...interface{}) (int64, error) { + return update(t.dbmap, t, list...) +} + +// Delete has the same behavior as DbMap.Delete(), but runs in a transaction. +func (t *Transaction) Delete(list ...interface{}) (int64, error) { + return delete(t.dbmap, t, list...) +} + +// Get has the same behavior as DbMap.Get(), but runs in a transaction. +func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) { + return get(t.dbmap, t, i, keys...) +} + +// Select has the same behavior as DbMap.Select(), but runs in a transaction. +func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return hookedselect(t.dbmap, t, i, query, args...) +} + +// Exec has the same behavior as DbMap.Exec(), but runs in a transaction. +func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) { + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, args...) + } + return exec(t, query, args...) +} + +// SelectInt is a convenience wrapper around the gorp.SelectInt function. +func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) { + return SelectInt(t, query, args...) +} + +// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function. +func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + return SelectNullInt(t, query, args...) +} + +// SelectFloat is a convenience wrapper around the gorp.SelectFloat function. +func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) { + return SelectFloat(t, query, args...) +} + +// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function. +func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + return SelectNullFloat(t, query, args...) +} + +// SelectStr is a convenience wrapper around the gorp.SelectStr function. +func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) { + return SelectStr(t, query, args...) +} + +// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function. +func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + return SelectNullStr(t, query, args...) +} + +// SelectOne is a convenience wrapper around the gorp.SelectOne function. +func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error { + return SelectOne(t.dbmap, t, holder, query, args...) +} + +// Commit commits the underlying database transaction. +func (t *Transaction) Commit() error { + if !t.closed { + t.closed = true + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, "commit;") + } + return t.tx.Commit() + } + + return sql.ErrTxDone +} + +// Rollback rolls back the underlying database transaction. +func (t *Transaction) Rollback() error { + if !t.closed { + t.closed = true + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, "rollback;") + } + return t.tx.Rollback() + } + + return sql.ErrTxDone +} + +// Savepoint creates a savepoint with the given name. The name is interpolated +// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is +// derived from user input. +func (t *Transaction) Savepoint(name string) error { + query := "savepoint " + t.dbmap.Dialect.QuoteField(name) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := t.tx.Exec(query) + return err +} + +// RollbackToSavepoint rolls back to the savepoint with the given name. The +// name is interpolated directly into the SQL SAVEPOINT statement, so you must +// sanitize it if it is derived from user input. +func (t *Transaction) RollbackToSavepoint(savepoint string) error { + query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := t.tx.Exec(query) + return err +} + +// ReleaseSavepint releases the savepoint with the given name. The name is +// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize +// it if it is derived from user input. +func (t *Transaction) ReleaseSavepoint(savepoint string) error { + query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := t.tx.Exec(query) + return err +} + +// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction. +func (t *Transaction) Prepare(query string) (*sql.Stmt, error) { + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + return t.tx.Prepare(query) +} + +func (t *Transaction) queryRow(query string, args ...interface{}) *sql.Row { + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, args...) + } + return t.tx.QueryRow(query, args...) +} + +func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error) { + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, args...) + } + return t.tx.Query(query, args...) +} + +/////////////// + +// SelectInt executes the given query, which should be a SELECT statement for a single +// integer column, and returns the value of the first row returned. If no rows are +// found, zero is returned. +func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) { + var h int64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return 0, err + } + return h, nil +} + +// SelectNullInt executes the given query, which should be a SELECT statement for a single +// integer column, and returns the value of the first row returned. If no rows are +// found, the empty sql.NullInt64 value is returned. +func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) { + var h sql.NullInt64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectFloat executes the given query, which should be a SELECT statement for a single +// float column, and returns the value of the first row returned. If no rows are +// found, zero is returned. +func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) { + var h float64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return 0, err + } + return h, nil +} + +// SelectNullFloat executes the given query, which should be a SELECT statement for a single +// float column, and returns the value of the first row returned. If no rows are +// found, the empty sql.NullInt64 value is returned. +func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) { + var h sql.NullFloat64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectStr executes the given query, which should be a SELECT statement for a single +// char/varchar column, and returns the value of the first row returned. If no rows are +// found, an empty string is returned. +func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) { + var h string + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return "", err + } + return h, nil +} + +// SelectNullStr executes the given query, which should be a SELECT +// statement for a single char/varchar column, and returns the value +// of the first row returned. If no rows are found, the empty +// sql.NullString is returned. +func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) { + var h sql.NullString + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectOne executes the given query (which should be a SELECT statement) +// and binds the result to holder, which must be a pointer. +// +// If no row is found, an error (sql.ErrNoRows specifically) will be returned +// +// If more than one row is found, an error will be returned. +// +func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error { + t := reflect.TypeOf(holder) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } else { + return fmt.Errorf("gorp: SelectOne holder must be a pointer, but got: %t", holder) + } + + // Handle pointer to pointer + isptr := false + if t.Kind() == reflect.Ptr { + isptr = true + t = t.Elem() + } + + if t.Kind() == reflect.Struct { + var nonFatalErr error + + list, err := hookedselect(m, e, holder, query, args...) + if err != nil { + if !NonFatalError(err) { + return err + } + nonFatalErr = err + } + + dest := reflect.ValueOf(holder) + if isptr { + dest = dest.Elem() + } + + if list != nil && len(list) > 0 { + // check for multiple rows + if len(list) > 1 { + return fmt.Errorf("gorp: multiple rows returned for: %s - %v", query, args) + } + + // Initialize if nil + if dest.IsNil() { + dest.Set(reflect.New(t)) + } + + // only one row found + src := reflect.ValueOf(list[0]) + dest.Elem().Set(src.Elem()) + } else { + // No rows found, return a proper error. + return sql.ErrNoRows + } + + return nonFatalErr + } + + return selectVal(e, holder, query, args...) +} + +func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error { + if len(args) == 1 { + switch m := e.(type) { + case *DbMap: + query, args = maybeExpandNamedQuery(m, query, args) + case *Transaction: + query, args = maybeExpandNamedQuery(m.dbmap, query, args) + } + } + rows, err := e.query(query, args...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return sql.ErrNoRows + } + + return rows.Scan(holder) +} + +/////////////// + +func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, + args ...interface{}) ([]interface{}, error) { + + var nonFatalErr error + + list, err := rawselect(m, exec, i, query, args...) + if err != nil { + if !NonFatalError(err) { + return nil, err + } + nonFatalErr = err + } + + // Determine where the results are: written to i, or returned in list + if t, _ := toSliceType(i); t == nil { + for _, v := range list { + if v, ok := v.(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + } + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(i)) + for i := 0; i < resultsValue.Len(); i++ { + if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + } + } + return list, nonFatalErr +} + +func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string, + args ...interface{}) ([]interface{}, error) { + var ( + appendToSlice = false // Write results to i directly? + intoStruct = true // Selecting into a struct? + pointerElements = true // Are the slice elements pointers (vs values)? + ) + + var nonFatalErr error + + // get type for i, verifying it's a supported destination + t, err := toType(i) + if err != nil { + var err2 error + if t, err2 = toSliceType(i); t == nil { + if err2 != nil { + return nil, err2 + } + return nil, err + } + pointerElements = t.Kind() == reflect.Ptr + if pointerElements { + t = t.Elem() + } + appendToSlice = true + intoStruct = t.Kind() == reflect.Struct + } + + // If the caller supplied a single struct/map argument, assume a "named + // parameter" query. Extract the named arguments from the struct/map, create + // the flat arg slice, and rewrite the query to use the dialect's placeholder. + if len(args) == 1 { + query, args = maybeExpandNamedQuery(m, query, args) + } + + // Run the query + rows, err := exec.query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Fetch the column names as returned from db + cols, err := rows.Columns() + if err != nil { + return nil, err + } + + if !intoStruct && len(cols) > 1 { + return nil, fmt.Errorf("gorp: select into non-struct slice requires 1 column, got %d", len(cols)) + } + + var colToFieldIndex [][]int + if intoStruct { + if colToFieldIndex, err = columnToFieldIndex(m, t, cols); err != nil { + if !NonFatalError(err) { + return nil, err + } + nonFatalErr = err + } + } + + conv := m.TypeConverter + + // Add results to one of these two slices. + var ( + list = make([]interface{}, 0) + sliceValue = reflect.Indirect(reflect.ValueOf(i)) + ) + + for { + if !rows.Next() { + // if error occured return rawselect + if rows.Err() != nil { + return nil, rows.Err() + } + // time to exit from outer "for" loop + break + } + v := reflect.New(t) + dest := make([]interface{}, len(cols)) + + custScan := make([]CustomScanner, 0) + + for x := range cols { + f := v.Elem() + if intoStruct { + index := colToFieldIndex[x] + if index == nil { + // this field is not present in the struct, so create a dummy + // value for rows.Scan to scan into + var dummy sql.RawBytes + dest[x] = &dummy + continue + } + f = f.FieldByIndex(index) + } + target := f.Addr().Interface() + if conv != nil { + scanner, ok := conv.FromDb(target) + if ok { + target = scanner.Holder + custScan = append(custScan, scanner) + } + } + dest[x] = target + } + + err = rows.Scan(dest...) + if err != nil { + return nil, err + } + + for _, c := range custScan { + err = c.Bind() + if err != nil { + return nil, err + } + } + + if appendToSlice { + if !pointerElements { + v = v.Elem() + } + sliceValue.Set(reflect.Append(sliceValue, v)) + } else { + list = append(list, v.Interface()) + } + } + + if appendToSlice && sliceValue.IsNil() { + sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0)) + } + + return list, nonFatalErr +} + +// Calls the Exec function on the executor, but attempts to expand any eligible named +// query arguments first. +func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { + var dbMap *DbMap + var executor executor + switch m := e.(type) { + case *DbMap: + executor = m.Db + dbMap = m + case *Transaction: + executor = m.tx + dbMap = m.dbmap + } + + if len(args) == 1 { + query, args = maybeExpandNamedQuery(dbMap, query, args) + } + + return executor.Exec(query, args...) +} + +// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used +// as input to a named query. If so, it rewrites the query to use +// dialect-dependent bindvars and instantiates the corresponding slice of +// parameters by extracting data from the map / struct. +// If not, returns the input values unchanged. +func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) { + var ( + arg = args[0] + argval = reflect.ValueOf(arg) + ) + if argval.Kind() == reflect.Ptr { + argval = argval.Elem() + } + + if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String { + return expandNamedQuery(m, query, func(key string) reflect.Value { + return argval.MapIndex(reflect.ValueOf(key)) + }) + } + if argval.Kind() != reflect.Struct { + return query, args + } + if _, ok := arg.(time.Time); ok { + // time.Time is driver.Value + return query, args + } + if _, ok := arg.(driver.Valuer); ok { + // driver.Valuer will be converted to driver.Value. + return query, args + } + + return expandNamedQuery(m, query, argval.FieldByName) +} + +var keyRegexp = regexp.MustCompile(`:[[:word:]]+`) + +// expandNamedQuery accepts a query with placeholders of the form ":key", and a +// single arg of Kind Struct or Map[string]. It returns the query with the +// dialect's placeholders, and a slice of args ready for positional insertion +// into the query. +func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) { + var ( + n int + args []interface{} + ) + return keyRegexp.ReplaceAllStringFunc(query, func(key string) string { + val := keyGetter(key[1:]) + if !val.IsValid() { + return key + } + args = append(args, val.Interface()) + newVar := m.Dialect.BindVar(n) + n++ + return newVar + }), args +} + +func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) { + colToFieldIndex := make([][]int, len(cols)) + + // check if type t is a mapped table - if so we'll + // check the table for column aliasing below + tableMapped := false + table := tableOrNil(m, t) + if table != nil { + tableMapped = true + } + + // Loop over column names and find field in i to bind to + // based on column name. all returned columns must match + // a field in the i struct + missingColNames := []string{} + for x := range cols { + colName := strings.ToLower(cols[x]) + field, found := t.FieldByNameFunc(func(fieldName string) bool { + field, _ := t.FieldByName(fieldName) + fieldName = field.Tag.Get("db") + + if fieldName == "-" { + return false + } else if fieldName == "" { + fieldName = field.Name + } + if tableMapped { + colMap := colMapOrNil(table, fieldName) + if colMap != nil { + fieldName = colMap.ColumnName + } + } + return colName == strings.ToLower(fieldName) + }) + if found { + colToFieldIndex[x] = field.Index + } + if colToFieldIndex[x] == nil { + missingColNames = append(missingColNames, colName) + } + } + if len(missingColNames) > 0 { + return colToFieldIndex, &NoFieldInTypeError{ + TypeName: t.Name(), + MissingColNames: missingColNames, + } + } + return colToFieldIndex, nil +} + +func fieldByName(val reflect.Value, fieldName string) *reflect.Value { + // try to find field by exact match + f := val.FieldByName(fieldName) + + if f != zeroVal { + return &f + } + + // try to find by case insensitive match - only the Postgres driver + // seems to require this - in the case where columns are aliased in the sql + fieldNameL := strings.ToLower(fieldName) + fieldCount := val.NumField() + t := val.Type() + for i := 0; i < fieldCount; i++ { + sf := t.Field(i) + if strings.ToLower(sf.Name) == fieldNameL { + f := val.Field(i) + return &f + } + } + + return nil +} + +// toSliceType returns the element type of the given object, if the object is a +// "*[]*Element" or "*[]Element". If not, returns nil. +// err is returned if the user was trying to pass a pointer-to-slice but failed. +func toSliceType(i interface{}) (reflect.Type, error) { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Ptr { + // If it's a slice, return a more helpful error message + if t.Kind() == reflect.Slice { + return nil, fmt.Errorf("gorp: Cannot SELECT into a non-pointer slice: %v", t) + } + return nil, nil + } + if t = t.Elem(); t.Kind() != reflect.Slice { + return nil, nil + } + return t.Elem(), nil +} + +func toType(i interface{}) (reflect.Type, error) { + t := reflect.TypeOf(i) + + // If a Pointer to a type, follow + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("gorp: Cannot SELECT into this type: %v", reflect.TypeOf(i)) + } + return t, nil +} + +func get(m *DbMap, exec SqlExecutor, i interface{}, + keys ...interface{}) (interface{}, error) { + + t, err := toType(i) + if err != nil { + return nil, err + } + + table, err := m.TableFor(t, true) + if err != nil { + return nil, err + } + + plan := table.bindGet() + + v := reflect.New(t) + dest := make([]interface{}, len(plan.argFields)) + + conv := m.TypeConverter + custScan := make([]CustomScanner, 0) + + for x, fieldName := range plan.argFields { + f := v.Elem().FieldByName(fieldName) + target := f.Addr().Interface() + if conv != nil { + scanner, ok := conv.FromDb(target) + if ok { + target = scanner.Holder + custScan = append(custScan, scanner) + } + } + dest[x] = target + } + + row := exec.queryRow(plan.query, keys...) + err = row.Scan(dest...) + if err != nil { + if err == sql.ErrNoRows { + err = nil + } + return nil, err + } + + for _, c := range custScan { + err = c.Bind() + if err != nil { + return nil, err + } + } + + if v, ok := v.Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + + return v.Interface(), nil +} + +func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { + count := int64(0) + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, true) + if err != nil { + return -1, err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreDelete); ok { + err = v.PreDelete(exec) + if err != nil { + return -1, err + } + } + + bi, err := table.bindDelete(elem) + if err != nil { + return -1, err + } + + res, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return -1, err + } + rows, err := res.RowsAffected() + if err != nil { + return -1, err + } + + if rows == 0 && bi.existingVersion > 0 { + return lockError(m, exec, table.TableName, + bi.existingVersion, elem, bi.keys...) + } + + count += rows + + if v, ok := eval.(HasPostDelete); ok { + err := v.PostDelete(exec) + if err != nil { + return -1, err + } + } + } + + return count, nil +} + +func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { + count := int64(0) + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, true) + if err != nil { + return -1, err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreUpdate); ok { + err = v.PreUpdate(exec) + if err != nil { + return -1, err + } + } + + bi, err := table.bindUpdate(elem) + if err != nil { + return -1, err + } + + res, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return -1, err + } + + rows, err := res.RowsAffected() + if err != nil { + return -1, err + } + + if rows == 0 && bi.existingVersion > 0 { + return lockError(m, exec, table.TableName, + bi.existingVersion, elem, bi.keys...) + } + + if bi.versField != "" { + elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1) + } + + count += rows + + if v, ok := eval.(HasPostUpdate); ok { + err = v.PostUpdate(exec) + if err != nil { + return -1, err + } + } + } + return count, nil +} + +func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, false) + if err != nil { + return err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreInsert); ok { + err := v.PreInsert(exec) + if err != nil { + return err + } + } + + bi, err := table.bindInsert(elem) + if err != nil { + return err + } + + if bi.autoIncrIdx > -1 { + f := elem.FieldByName(bi.autoIncrFieldName) + switch inserter := m.Dialect.(type) { + case IntegerAutoIncrInserter: + id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) + if err != nil { + return err + } + k := f.Kind() + if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { + f.SetInt(id) + } else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { + f.SetUint(uint64(id)) + } else { + return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + } + case TargetedAutoIncrInserter: + err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) + if err != nil { + return err + } + default: + return fmt.Errorf("gorp: Cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") + } + } else { + _, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return err + } + } + + if v, ok := eval.(HasPostInsert); ok { + err := v.PostInsert(exec) + if err != nil { + return err + } + } + } + return nil +} + +func lockError(m *DbMap, exec SqlExecutor, tableName string, + existingVer int64, elem reflect.Value, + keys ...interface{}) (int64, error) { + + existing, err := get(m, exec, elem.Interface(), keys...) + if err != nil { + return -1, err + } + + ole := OptimisticLockError{tableName, keys, true, existingVer} + if existing == nil { + ole.RowExists = false + } + return -1, ole +} + +// PostUpdate() will be executed after the GET statement. +type HasPostGet interface { + PostGet(SqlExecutor) error +} + +// PostUpdate() will be executed after the DELETE statement +type HasPostDelete interface { + PostDelete(SqlExecutor) error +} + +// PostUpdate() will be executed after the UPDATE statement +type HasPostUpdate interface { + PostUpdate(SqlExecutor) error +} + +// PostInsert() will be executed after the INSERT statement +type HasPostInsert interface { + PostInsert(SqlExecutor) error +} + +// PreDelete() will be executed before the DELETE statement. +type HasPreDelete interface { + PreDelete(SqlExecutor) error +} + +// PreUpdate() will be executed before UPDATE statement. +type HasPreUpdate interface { + PreUpdate(SqlExecutor) error +} + +// PreInsert() will be executed before INSERT statement. +type HasPreInsert interface { + PreInsert(SqlExecutor) error +} diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp_test.go b/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp_test.go new file mode 100644 index 000000000..6e5618c1f --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/gorp_test.go @@ -0,0 +1,2170 @@ +package gorp + +import ( + "bytes" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "os" + "reflect" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + _ "github.com/ziutek/mymysql/godrv" +) + +// verify interface compliance +var _ Dialect = SqliteDialect{} +var _ Dialect = PostgresDialect{} +var _ Dialect = MySQLDialect{} +var _ Dialect = SqlServerDialect{} +var _ Dialect = OracleDialect{} + +type testable interface { + GetId() int64 + Rand() +} + +type Invoice struct { + Id int64 + Created int64 + Updated int64 + Memo string + PersonId int64 + IsPaid bool +} + +func (me *Invoice) GetId() int64 { return me.Id } +func (me *Invoice) Rand() { + me.Memo = fmt.Sprintf("random %d", rand.Int63()) + me.Created = rand.Int63() + me.Updated = rand.Int63() +} + +type InvoiceTag struct { + Id int64 `db:"myid"` + Created int64 `db:"myCreated"` + Updated int64 `db:"date_updated"` + Memo string + PersonId int64 `db:"person_id"` + IsPaid bool `db:"is_Paid"` +} + +func (me *InvoiceTag) GetId() int64 { return me.Id } +func (me *InvoiceTag) Rand() { + me.Memo = fmt.Sprintf("random %d", rand.Int63()) + me.Created = rand.Int63() + me.Updated = rand.Int63() +} + +// See: https://github.com/go-gorp/gorp/issues/175 +type AliasTransientField struct { + Id int64 `db:"id"` + Bar int64 `db:"-"` + BarStr string `db:"bar"` +} + +func (me *AliasTransientField) GetId() int64 { return me.Id } +func (me *AliasTransientField) Rand() { + me.BarStr = fmt.Sprintf("random %d", rand.Int63()) +} + +type OverriddenInvoice struct { + Invoice + Id string +} + +type Person struct { + Id int64 + Created int64 + Updated int64 + FName string + LName string + Version int64 +} + +type FNameOnly struct { + FName string +} + +type InvoicePersonView struct { + InvoiceId int64 + PersonId int64 + Memo string + FName string + LegacyVersion int64 +} + +type TableWithNull struct { + Id int64 + Str sql.NullString + Int64 sql.NullInt64 + Float64 sql.NullFloat64 + Bool sql.NullBool + Bytes []byte +} + +type WithIgnoredColumn struct { + internal int64 `db:"-"` + Id int64 + Created int64 +} + +type IdCreated struct { + Id int64 + Created int64 +} + +type IdCreatedExternal struct { + IdCreated + External int64 +} + +type WithStringPk struct { + Id string + Name string +} + +type CustomStringType string + +type TypeConversionExample struct { + Id int64 + PersonJSON Person + Name CustomStringType +} + +type PersonUInt32 struct { + Id uint32 + Name string +} + +type PersonUInt64 struct { + Id uint64 + Name string +} + +type PersonUInt16 struct { + Id uint16 + Name string +} + +type WithEmbeddedStruct struct { + Id int64 + Names +} + +type WithEmbeddedStructBeforeAutoincrField struct { + Names + Id int64 +} + +type WithEmbeddedAutoincr struct { + WithEmbeddedStruct + MiddleName string +} + +type Names struct { + FirstName string + LastName string +} + +type UniqueColumns struct { + FirstName string + LastName string + City string + ZipCode int64 +} + +type SingleColumnTable struct { + SomeId string +} + +type CustomDate struct { + time.Time +} + +type WithCustomDate struct { + Id int64 + Added CustomDate +} + +type WithNullTime struct { + Id int64 + Time NullTime +} + +type testTypeConverter struct{} + +func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) { + + switch t := val.(type) { + case Person: + b, err := json.Marshal(t) + if err != nil { + return "", err + } + return string(b), nil + case CustomStringType: + return string(t), nil + case CustomDate: + return t.Time, nil + } + + return val, nil +} + +func (me testTypeConverter) FromDb(target interface{}) (CustomScanner, bool) { + switch target.(type) { + case *Person: + binder := func(holder, target interface{}) error { + s, ok := holder.(*string) + if !ok { + return errors.New("FromDb: Unable to convert Person to *string") + } + b := []byte(*s) + return json.Unmarshal(b, target) + } + return CustomScanner{new(string), target, binder}, true + case *CustomStringType: + binder := func(holder, target interface{}) error { + s, ok := holder.(*string) + if !ok { + return errors.New("FromDb: Unable to convert CustomStringType to *string") + } + st, ok := target.(*CustomStringType) + if !ok { + return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomStringType: ", reflect.TypeOf(target))) + } + *st = CustomStringType(*s) + return nil + } + return CustomScanner{new(string), target, binder}, true + case *CustomDate: + binder := func(holder, target interface{}) error { + t, ok := holder.(*time.Time) + if !ok { + return errors.New("FromDb: Unable to convert CustomDate to *time.Time") + } + dateTarget, ok := target.(*CustomDate) + if !ok { + return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomDate: ", reflect.TypeOf(target))) + } + dateTarget.Time = *t + return nil + } + return CustomScanner{new(time.Time), target, binder}, true + } + + return CustomScanner{}, false +} + +func (p *Person) PreInsert(s SqlExecutor) error { + p.Created = time.Now().UnixNano() + p.Updated = p.Created + if p.FName == "badname" { + return fmt.Errorf("Invalid name: %s", p.FName) + } + return nil +} + +func (p *Person) PostInsert(s SqlExecutor) error { + p.LName = "postinsert" + return nil +} + +func (p *Person) PreUpdate(s SqlExecutor) error { + p.FName = "preupdate" + return nil +} + +func (p *Person) PostUpdate(s SqlExecutor) error { + p.LName = "postupdate" + return nil +} + +func (p *Person) PreDelete(s SqlExecutor) error { + p.FName = "predelete" + return nil +} + +func (p *Person) PostDelete(s SqlExecutor) error { + p.LName = "postdelete" + return nil +} + +func (p *Person) PostGet(s SqlExecutor) error { + p.LName = "postget" + return nil +} + +type PersistentUser struct { + Key int32 + Id string + PassedTraining bool +} + +func TestCreateTablesIfNotExists(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + err := dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error(err) + } +} + +func TestTruncateTables(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + err := dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error(err) + } + + // Insert some data + p1 := &Person{0, 0, 0, "Bob", "Smith", 0} + dbmap.Insert(p1) + inv := &Invoice{0, 0, 1, "my invoice", 0, true} + dbmap.Insert(inv) + + err = dbmap.TruncateTables() + if err != nil { + t.Error(err) + } + + // Make sure all rows are deleted + rows, _ := dbmap.Select(Person{}, "SELECT * FROM person_test") + if len(rows) != 0 { + t.Errorf("Expected 0 person rows, got %d", len(rows)) + } + rows, _ = dbmap.Select(Invoice{}, "SELECT * FROM invoice_test") + if len(rows) != 0 { + t.Errorf("Expected 0 invoice rows, got %d", len(rows)) + } +} + +func TestCustomDateType(t *testing.T) { + dbmap := newDbMap() + dbmap.TypeConverter = testTypeConverter{} + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(WithCustomDate{}).SetKeys(true, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + test1 := &WithCustomDate{Added: CustomDate{Time: time.Now().Truncate(time.Second)}} + err = dbmap.Insert(test1) + if err != nil { + t.Errorf("Could not insert struct with custom date field: %s", err) + t.FailNow() + } + // Unfortunately, the mysql driver doesn't handle time.Time + // values properly during Get(). I can't find a way to work + // around that problem - every other type that I've tried is just + // silently converted. time.Time is the only type that causes + // the issue that this test checks for. As such, if the driver is + // mysql, we'll just skip the rest of this test. + if _, driver := dialectAndDriver(); driver == "mysql" { + t.Skip("TestCustomDateType can't run Get() with the mysql driver; skipping the rest of this test...") + } + result, err := dbmap.Get(new(WithCustomDate), test1.Id) + if err != nil { + t.Errorf("Could not get struct with custom date field: %s", err) + t.FailNow() + } + test2 := result.(*WithCustomDate) + if test2.Added.UTC() != test1.Added.UTC() { + t.Errorf("Custom dates do not match: %v != %v", test2.Added.UTC(), test1.Added.UTC()) + } +} + +func TestUIntPrimaryKey(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(PersonUInt64{}).SetKeys(true, "Id") + dbmap.AddTable(PersonUInt32{}).SetKeys(true, "Id") + dbmap.AddTable(PersonUInt16{}).SetKeys(true, "Id") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + p1 := &PersonUInt64{0, "name1"} + p2 := &PersonUInt32{0, "name2"} + p3 := &PersonUInt16{0, "name3"} + err = dbmap.Insert(p1, p2, p3) + if err != nil { + t.Error(err) + } + if p1.Id != 1 { + t.Errorf("%d != 1", p1.Id) + } + if p2.Id != 1 { + t.Errorf("%d != 1", p2.Id) + } + if p3.Id != 1 { + t.Errorf("%d != 1", p3.Id) + } +} + +func TestSetUniqueTogether(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName").SetUniqueTogether("City", "ZipCode") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014} + err = dbmap.Insert(n1) + if err != nil { + t.Error(err) + } + + // Should fail because of the first constraint + n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085} + err = dbmap.Insert(n2) + if err == nil { + t.Error(err) + } + // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL + errLower := strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { + t.Error(err) + } + + // Should also fail because of the second unique-together + n3 := &UniqueColumns{"Steve", "Wozniak", "Cupertino", 95014} + err = dbmap.Insert(n3) + if err == nil { + t.Error(err) + } + // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL + errLower = strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { + t.Error(err) + } + + // This one should finally succeed + n4 := &UniqueColumns{"Steve", "Wozniak", "Sunnyvale", 94085} + err = dbmap.Insert(n4) + if err != nil { + t.Error(err) + } +} + +func TestPersistentUser(t *testing.T) { + dbmap := newDbMap() + dbmap.Exec("drop table if exists PersistentUser") + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + err = dbmap.Insert(pu) + if err != nil { + panic(err) + } + + // prove we can pass a pointer into Get + pu2, err := dbmap.Get(pu, pu.Key) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(pu, pu2) { + t.Errorf("%v!=%v", pu, pu2) + } + + arr, err := dbmap.Select(pu, "select * from PersistentUser") + if err != nil { + panic(err) + } + if !reflect.DeepEqual(pu, arr[0]) { + t.Errorf("%v!=%v", pu, arr[0]) + } + + // prove we can get the results back in a slice + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, "select * from PersistentUser") + if err != nil { + panic(err) + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // prove we can get the results back in a non-pointer slice + var puValues []PersistentUser + _, err = dbmap.Select(&puValues, "select * from PersistentUser") + if err != nil { + panic(err) + } + if len(puValues) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(*pu, puValues[0]) { + t.Errorf("%v!=%v", *pu, puValues[0]) + } + + // prove we can get the results back in a string slice + var idArr []*string + _, err = dbmap.Select(&idArr, "select Id from PersistentUser") + if err != nil { + panic(err) + } + if len(idArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Id, *idArr[0]) { + t.Errorf("%v!=%v", pu.Id, *idArr[0]) + } + + // prove we can get the results back in an int slice + var keyArr []*int32 + _, err = dbmap.Select(&keyArr, "select mykey from PersistentUser") + if err != nil { + panic(err) + } + if len(keyArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Key, *keyArr[0]) { + t.Errorf("%v!=%v", pu.Key, *keyArr[0]) + } + + // prove we can get the results back in a bool slice + var passedArr []*bool + _, err = dbmap.Select(&passedArr, "select PassedTraining from PersistentUser") + if err != nil { + panic(err) + } + if len(passedArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.PassedTraining, *passedArr[0]) { + t.Errorf("%v!=%v", pu.PassedTraining, *passedArr[0]) + } + + // prove we can get the results back in a non-pointer slice + var stringArr []string + _, err = dbmap.Select(&stringArr, "select Id from PersistentUser") + if err != nil { + panic(err) + } + if len(stringArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Id, stringArr[0]) { + t.Errorf("%v!=%v", pu.Id, stringArr[0]) + } +} + +func TestNamedQueryMap(t *testing.T) { + dbmap := newDbMap() + dbmap.Exec("drop table if exists PersistentUser") + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + pu2 := &PersistentUser{500, "abc", false} + err = dbmap.Insert(pu, pu2) + if err != nil { + panic(err) + } + + // Test simple case + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, "select * from PersistentUser where mykey = :Key", map[string]interface{}{ + "Key": 43, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // Test more specific map value type is ok + puArr = nil + _, err = dbmap.Select(&puArr, "select * from PersistentUser where mykey = :Key", map[string]int{ + "Key": 43, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test multiple parameters set. + puArr = nil + _, err = dbmap.Select(&puArr, ` +select * from PersistentUser + where mykey = :Key + and PassedTraining = :PassedTraining + and Id = :Id`, map[string]interface{}{ + "Key": 43, + "PassedTraining": false, + "Id": "33r", + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test colon within a non-key string + // Test having extra, unused properties in the map. + puArr = nil + _, err = dbmap.Select(&puArr, ` +select * from PersistentUser + where mykey = :Key + and Id != 'abc:def'`, map[string]interface{}{ + "Key": 43, + "PassedTraining": false, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test to delete with Exec and named params. + result, err := dbmap.Exec("delete from PersistentUser where mykey = :Key", map[string]interface{}{ + "Key": 43, + }) + count, err := result.RowsAffected() + if err != nil { + t.Errorf("Failed to exec: %s", err) + t.FailNow() + } + if count != 1 { + t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) + } +} + +func TestNamedQueryStruct(t *testing.T) { + dbmap := newDbMap() + dbmap.Exec("drop table if exists PersistentUser") + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + pu2 := &PersistentUser{500, "abc", false} + err = dbmap.Insert(pu, pu2) + if err != nil { + panic(err) + } + + // Test select self + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, ` +select * from PersistentUser + where mykey = :Key + and PassedTraining = :PassedTraining + and Id = :Id`, pu) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // Test delete self. + result, err := dbmap.Exec(` +delete from PersistentUser + where mykey = :Key + and PassedTraining = :PassedTraining + and Id = :Id`, pu) + count, err := result.RowsAffected() + if err != nil { + t.Errorf("Failed to exec: %s", err) + t.FailNow() + } + if count != 1 { + t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) + } +} + +// Ensure that the slices containing SQL results are non-nil when the result set is empty. +func TestReturnsNonNilSlice(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + noResultsSQL := "select * from invoice_test where id=99999" + var r1 []*Invoice + _rawselect(dbmap, &r1, noResultsSQL) + if r1 == nil { + t.Errorf("r1==nil") + } + + r2 := _rawselect(dbmap, Invoice{}, noResultsSQL) + if r2 == nil { + t.Errorf("r2==nil") + } +} + +func TestOverrideVersionCol(t *testing.T) { + dbmap := newDbMap() + t1 := dbmap.AddTable(InvoicePersonView{}).SetKeys(false, "InvoiceId", "PersonId") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + c1 := t1.SetVersionCol("LegacyVersion") + if c1.ColumnName != "LegacyVersion" { + t.Errorf("Wrong col returned: %v", c1) + } + + ipv := &InvoicePersonView{1, 2, "memo", "fname", 0} + _update(dbmap, ipv) + if ipv.LegacyVersion != 1 { + t.Errorf("LegacyVersion not updated: %d", ipv.LegacyVersion) + } +} + +func TestOptimisticLocking(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "Bob", "Smith", 0} + dbmap.Insert(p1) // Version is now 1 + if p1.Version != 1 { + t.Errorf("Insert didn't incr Version: %d != %d", 1, p1.Version) + return + } + if p1.Id == 0 { + t.Errorf("Insert didn't return a generated PK") + return + } + + obj, err := dbmap.Get(Person{}, p1.Id) + if err != nil { + panic(err) + } + p2 := obj.(*Person) + p2.LName = "Edwards" + dbmap.Update(p2) // Version is now 2 + if p2.Version != 2 { + t.Errorf("Update didn't incr Version: %d != %d", 2, p2.Version) + } + + p1.LName = "Howard" + count, err := dbmap.Update(p1) + if _, ok := err.(OptimisticLockError); !ok { + t.Errorf("update - Expected OptimisticLockError, got: %v", err) + } + if count != -1 { + t.Errorf("update - Expected -1 count, got: %d", count) + } + + count, err = dbmap.Delete(p1) + if _, ok := err.(OptimisticLockError); !ok { + t.Errorf("delete - Expected OptimisticLockError, got: %v", err) + } + if count != -1 { + t.Errorf("delete - Expected -1 count, got: %d", count) + } +} + +// what happens if a legacy table has a null value? +func TestDoubleAddTable(t *testing.T) { + dbmap := newDbMap() + t1 := dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") + t2 := dbmap.AddTable(TableWithNull{}) + if t1 != t2 { + t.Errorf("%v != %v", t1, t2) + } +} + +// what happens if a legacy table has a null value? +func TestNullValues(t *testing.T) { + dbmap := initDbMapNulls() + defer dropAndClose(dbmap) + + // insert a row directly + _rawexec(dbmap, "insert into TableWithNull values (10, null, "+ + "null, null, null, null)") + + // try to load it + expected := &TableWithNull{Id: 10} + obj := _get(dbmap, TableWithNull{}, 10) + t1 := obj.(*TableWithNull) + if !reflect.DeepEqual(expected, t1) { + t.Errorf("%v != %v", expected, t1) + } + + // update it + t1.Str = sql.NullString{"hi", true} + expected.Str = t1.Str + t1.Int64 = sql.NullInt64{999, true} + expected.Int64 = t1.Int64 + t1.Float64 = sql.NullFloat64{53.33, true} + expected.Float64 = t1.Float64 + t1.Bool = sql.NullBool{true, true} + expected.Bool = t1.Bool + t1.Bytes = []byte{1, 30, 31, 33} + expected.Bytes = t1.Bytes + _update(dbmap, t1) + + obj = _get(dbmap, TableWithNull{}, 10) + t1 = obj.(*TableWithNull) + if t1.Str.String != "hi" { + t.Errorf("%s != hi", t1.Str.String) + } + if !reflect.DeepEqual(expected, t1) { + t.Errorf("%v != %v", expected, t1) + } +} + +func TestColumnProps(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + t1.ColMap("Created").Rename("date_created") + t1.ColMap("Updated").SetTransient(true) + t1.ColMap("Memo").SetMaxSize(10) + t1.ColMap("PersonId").SetUnique(true) + + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + // test transient + inv := &Invoice{0, 0, 1, "my invoice", 0, true} + _insert(dbmap, inv) + obj := _get(dbmap, Invoice{}, inv.Id) + inv = obj.(*Invoice) + if inv.Updated != 0 { + t.Errorf("Saved transient column 'Updated'") + } + + // test max size + inv.Memo = "this memo is too long" + err = dbmap.Insert(inv) + if err == nil { + t.Errorf("max size exceeded, but Insert did not fail.") + } + + // test unique - same person id + inv = &Invoice{0, 0, 1, "my invoice2", 0, false} + err = dbmap.Insert(inv) + if err == nil { + t.Errorf("same PersonId inserted, but Insert did not fail.") + } +} + +func TestRawSelect(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + + inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, true} + _insert(dbmap, inv1) + + expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} + + query := "select i.Id InvoiceId, p.Id PersonId, i.Memo, p.FName " + + "from invoice_test i, person_test p " + + "where i.PersonId = p.Id" + list := _rawselect(dbmap, InvoicePersonView{}, query) + if len(list) != 1 { + t.Errorf("len(list) != 1: %d", len(list)) + } else if !reflect.DeepEqual(expected, list[0]) { + t.Errorf("%v != %v", expected, list[0]) + } +} + +func TestHooks(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + if p1.Created == 0 || p1.Updated == 0 { + t.Errorf("p1.PreInsert() didn't run: %v", p1) + } else if p1.LName != "postinsert" { + t.Errorf("p1.PostInsert() didn't run: %v", p1) + } + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + if p1.LName != "postget" { + t.Errorf("p1.PostGet() didn't run: %v", p1) + } + + _update(dbmap, p1) + if p1.FName != "preupdate" { + t.Errorf("p1.PreUpdate() didn't run: %v", p1) + } else if p1.LName != "postupdate" { + t.Errorf("p1.PostUpdate() didn't run: %v", p1) + } + + var persons []*Person + bindVar := dbmap.Dialect.BindVar(0) + _rawselect(dbmap, &persons, "select * from person_test where id = "+bindVar, p1.Id) + if persons[0].LName != "postget" { + t.Errorf("p1.PostGet() didn't run after select: %v", p1) + } + + _del(dbmap, p1) + if p1.FName != "predelete" { + t.Errorf("p1.PreDelete() didn't run: %v", p1) + } else if p1.LName != "postdelete" { + t.Errorf("p1.PostDelete() didn't run: %v", p1) + } + + // Test error case + p2 := &Person{0, 0, 0, "badname", "", 0} + err := dbmap.Insert(p2) + if err == nil { + t.Errorf("p2.PreInsert() didn't return an error") + } +} + +func TestTransaction(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "t1", 0, true} + inv2 := &Invoice{0, 100, 200, "t2", 0, false} + + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + trans.Insert(inv1, inv2) + err = trans.Commit() + if err != nil { + panic(err) + } + + obj, err := dbmap.Get(Invoice{}, inv1.Id) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(inv1, obj) { + t.Errorf("%v != %v", inv1, obj) + } + obj, err = dbmap.Get(Invoice{}, inv2.Id) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(inv2, obj) { + t.Errorf("%v != %v", inv2, obj) + } +} + +func TestSavepoint(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "unpaid", 0, false} + + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + trans.Insert(inv1) + + var checkMemo = func(want string) { + memo, err := trans.SelectStr("select memo from invoice_test") + if err != nil { + panic(err) + } + if memo != want { + t.Errorf("%q != %q", want, memo) + } + } + checkMemo("unpaid") + + err = trans.Savepoint("foo") + if err != nil { + panic(err) + } + checkMemo("unpaid") + + inv1.Memo = "paid" + _, err = trans.Update(inv1) + if err != nil { + panic(err) + } + checkMemo("paid") + + err = trans.RollbackToSavepoint("foo") + if err != nil { + panic(err) + } + checkMemo("unpaid") + + err = trans.Rollback() + if err != nil { + panic(err) + } +} + +func TestMultiple(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "a", 0, false} + inv2 := &Invoice{0, 100, 200, "b", 0, true} + _insert(dbmap, inv1, inv2) + + inv1.Memo = "c" + inv2.Memo = "d" + _update(dbmap, inv1, inv2) + + count := _del(dbmap, inv1, inv2) + if count != 2 { + t.Errorf("%d != 2", count) + } +} + +func TestCrud(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + inv := &Invoice{0, 100, 200, "first order", 0, true} + testCrudInternal(t, dbmap, inv) + + invtag := &InvoiceTag{0, 300, 400, "some order", 33, false} + testCrudInternal(t, dbmap, invtag) + + foo := &AliasTransientField{BarStr: "some bar"} + testCrudInternal(t, dbmap, foo) +} + +func testCrudInternal(t *testing.T, dbmap *DbMap, val testable) { + table, _, err := dbmap.tableForPointer(val, false) + if err != nil { + t.Errorf("couldn't call TableFor: val=%v err=%v", val, err) + } + + _, err = dbmap.Exec("delete from " + table.TableName) + if err != nil { + t.Errorf("couldn't delete rows from: val=%v err=%v", val, err) + } + + // INSERT row + _insert(dbmap, val) + if val.GetId() == 0 { + t.Errorf("val.GetId() was not set on INSERT") + return + } + + // SELECT row + val2 := _get(dbmap, val, val.GetId()) + if !reflect.DeepEqual(val, val2) { + t.Errorf("%v != %v", val, val2) + } + + // UPDATE row and SELECT + val.Rand() + count := _update(dbmap, val) + if count != 1 { + t.Errorf("update 1 != %d", count) + } + val2 = _get(dbmap, val, val.GetId()) + if !reflect.DeepEqual(val, val2) { + t.Errorf("%v != %v", val, val2) + } + + // Select * + rows, err := dbmap.Select(val, "select * from "+table.TableName) + if err != nil { + t.Errorf("couldn't select * from %s err=%v", table.TableName, err) + } else if len(rows) != 1 { + t.Errorf("unexpected row count in %s: %d", table.TableName, len(rows)) + } else if !reflect.DeepEqual(val, rows[0]) { + t.Errorf("select * result: %v != %v", val, rows[0]) + } + + // DELETE row + deleted := _del(dbmap, val) + if deleted != 1 { + t.Errorf("Did not delete row with Id: %d", val.GetId()) + return + } + + // VERIFY deleted + val2 = _get(dbmap, val, val.GetId()) + if val2 != nil { + t.Errorf("Found invoice with id: %d after Delete()", val.GetId()) + } +} + +func TestWithIgnoredColumn(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + ic := &WithIgnoredColumn{-1, 0, 1} + _insert(dbmap, ic) + expected := &WithIgnoredColumn{0, 1, 1} + ic2 := _get(dbmap, WithIgnoredColumn{}, ic.Id).(*WithIgnoredColumn) + + if !reflect.DeepEqual(expected, ic2) { + t.Errorf("%v != %v", expected, ic2) + } + if _del(dbmap, ic) != 1 { + t.Errorf("Did not delete row with Id: %d", ic.Id) + return + } + if _get(dbmap, WithIgnoredColumn{}, ic.Id) != nil { + t.Errorf("Found id: %d after Delete()", ic.Id) + } +} + +func TestTypeConversionExample(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p := Person{FName: "Bob", LName: "Smith"} + tc := &TypeConversionExample{-1, p, CustomStringType("hi")} + _insert(dbmap, tc) + + expected := &TypeConversionExample{1, p, CustomStringType("hi")} + tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc2) { + t.Errorf("tc2 %v != %v", expected, tc2) + } + + tc2.Name = CustomStringType("hi2") + tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"} + _update(dbmap, tc2) + + expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")} + tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc3) { + t.Errorf("tc3 %v != %v", expected, tc3) + } + + if _del(dbmap, tc) != 1 { + t.Errorf("Did not delete row with Id: %d", tc.Id) + } + +} + +func TestWithEmbeddedStruct(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + es := &WithEmbeddedStruct{-1, Names{FirstName: "Alice", LastName: "Smith"}} + _insert(dbmap, es) + expected := &WithEmbeddedStruct{1, Names{FirstName: "Alice", LastName: "Smith"}} + es2 := _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + es2.FirstName = "Bob" + expected.FirstName = "Bob" + _update(dbmap, es2) + es2 = _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + ess := _rawselect(dbmap, WithEmbeddedStruct{}, "select * from embedded_struct_test") + if !reflect.DeepEqual(es2, ess[0]) { + t.Errorf("%v != %v", es2, ess[0]) + } +} + +func TestWithEmbeddedStructBeforeAutoincr(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + esba := &WithEmbeddedStructBeforeAutoincrField{Names: Names{FirstName: "Alice", LastName: "Smith"}} + _insert(dbmap, esba) + var expectedAutoincrId int64 = 1 + if esba.Id != expectedAutoincrId { + t.Errorf("%d != %d", expectedAutoincrId, esba.Id) + } +} + +func TestWithEmbeddedAutoincr(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + esa := &WithEmbeddedAutoincr{ + WithEmbeddedStruct: WithEmbeddedStruct{Names: Names{FirstName: "Alice", LastName: "Smith"}}, + MiddleName: "Rose", + } + _insert(dbmap, esa) + var expectedAutoincrId int64 = 1 + if esa.Id != expectedAutoincrId { + t.Errorf("%d != %d", expectedAutoincrId, esa.Id) + } +} + +func TestSelectVal(t *testing.T) { + dbmap := initDbMapNulls() + defer dropAndClose(dbmap) + + bindVar := dbmap.Dialect.BindVar(0) + + t1 := TableWithNull{Str: sql.NullString{"abc", true}, + Int64: sql.NullInt64{78, true}, + Float64: sql.NullFloat64{32.2, true}, + Bool: sql.NullBool{true, true}, + Bytes: []byte("hi")} + _insert(dbmap, &t1) + + // SelectInt + i64 := selectInt(dbmap, "select Int64 from TableWithNull where Str='abc'") + if i64 != 78 { + t.Errorf("int64 %d != 78", i64) + } + i64 = selectInt(dbmap, "select count(*) from TableWithNull") + if i64 != 1 { + t.Errorf("int64 count %d != 1", i64) + } + i64 = selectInt(dbmap, "select count(*) from TableWithNull where Str="+bindVar, "asdfasdf") + if i64 != 0 { + t.Errorf("int64 no rows %d != 0", i64) + } + + // SelectNullInt + n := selectNullInt(dbmap, "select Int64 from TableWithNull where Str='notfound'") + if !reflect.DeepEqual(n, sql.NullInt64{0, false}) { + t.Errorf("nullint %v != 0,false", n) + } + + n = selectNullInt(dbmap, "select Int64 from TableWithNull where Str='abc'") + if !reflect.DeepEqual(n, sql.NullInt64{78, true}) { + t.Errorf("nullint %v != 78, true", n) + } + + // SelectFloat + f64 := selectFloat(dbmap, "select Float64 from TableWithNull where Str='abc'") + if f64 != 32.2 { + t.Errorf("float64 %d != 32.2", f64) + } + f64 = selectFloat(dbmap, "select min(Float64) from TableWithNull") + if f64 != 32.2 { + t.Errorf("float64 min %d != 32.2", f64) + } + f64 = selectFloat(dbmap, "select count(*) from TableWithNull where Str="+bindVar, "asdfasdf") + if f64 != 0 { + t.Errorf("float64 no rows %d != 0", f64) + } + + // SelectNullFloat + nf := selectNullFloat(dbmap, "select Float64 from TableWithNull where Str='notfound'") + if !reflect.DeepEqual(nf, sql.NullFloat64{0, false}) { + t.Errorf("nullfloat %v != 0,false", nf) + } + + nf = selectNullFloat(dbmap, "select Float64 from TableWithNull where Str='abc'") + if !reflect.DeepEqual(nf, sql.NullFloat64{32.2, true}) { + t.Errorf("nullfloat %v != 32.2, true", nf) + } + + // SelectStr + s := selectStr(dbmap, "select Str from TableWithNull where Int64="+bindVar, 78) + if s != "abc" { + t.Errorf("s %s != abc", s) + } + s = selectStr(dbmap, "select Str from TableWithNull where Str='asdfasdf'") + if s != "" { + t.Errorf("s no rows %s != ''", s) + } + + // SelectNullStr + ns := selectNullStr(dbmap, "select Str from TableWithNull where Int64="+bindVar, 78) + if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { + t.Errorf("nullstr %v != abc,true", ns) + } + ns = selectNullStr(dbmap, "select Str from TableWithNull where Str='asdfasdf'") + if !reflect.DeepEqual(ns, sql.NullString{"", false}) { + t.Errorf("nullstr no rows %v != '',false", ns) + } + + // SelectInt/Str with named parameters + i64 = selectInt(dbmap, "select Int64 from TableWithNull where Str=:abc", map[string]string{"abc": "abc"}) + if i64 != 78 { + t.Errorf("int64 %d != 78", i64) + } + ns = selectNullStr(dbmap, "select Str from TableWithNull where Int64=:num", map[string]int{"num": 78}) + if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { + t.Errorf("nullstr %v != abc,true", ns) + } +} + +func TestVersionMultipleRows(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + persons := []*Person{ + &Person{0, 0, 0, "Bob", "Smith", 0}, + &Person{0, 0, 0, "Jane", "Smith", 0}, + &Person{0, 0, 0, "Mike", "Smith", 0}, + } + + _insert(dbmap, persons[0], persons[1], persons[2]) + + for x, p := range persons { + if p.Version != 1 { + t.Errorf("person[%d].Version != 1: %d", x, p.Version) + } + } +} + +func TestWithStringPk(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTableWithName(WithStringPk{}, "string_pk_test").SetKeys(true, "Id") + _, err := dbmap.Exec("create table string_pk_test (Id varchar(255), Name varchar(255));") + if err != nil { + t.Errorf("couldn't create string_pk_test: %v", err) + } + defer dropAndClose(dbmap) + + row := &WithStringPk{"1", "foo"} + err = dbmap.Insert(row) + if err == nil { + t.Errorf("Expected error when inserting into table w/non Int PK and autoincr set true") + } +} + +// TestSqlExecutorInterfaceSelects ensures that all DbMap methods starting with Select... +// are also exposed in the SqlExecutor interface. Select... functions can always +// run on Pre/Post hooks. +func TestSqlExecutorInterfaceSelects(t *testing.T) { + dbMapType := reflect.TypeOf(&DbMap{}) + sqlExecutorType := reflect.TypeOf((*SqlExecutor)(nil)).Elem() + numDbMapMethods := dbMapType.NumMethod() + for i := 0; i < numDbMapMethods; i += 1 { + dbMapMethod := dbMapType.Method(i) + if !strings.HasPrefix(dbMapMethod.Name, "Select") { + continue + } + if _, found := sqlExecutorType.MethodByName(dbMapMethod.Name); !found { + t.Errorf("Method %s is defined on DbMap but not implemented in SqlExecutor", + dbMapMethod.Name) + } + } +} + +func TestNullTime(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + // if time is null + ent := &WithNullTime{ + Id: 0, + Time: NullTime{ + Valid: false, + }} + err := dbmap.Insert(ent) + if err != nil { + t.Error("failed insert on %s", err.Error()) + } + err = dbmap.SelectOne(ent, `select * from nulltime_test where Id=:Id`, map[string]interface{}{ + "Id": ent.Id, + }) + if err != nil { + t.Error("failed select on %s", err.Error()) + } + if ent.Time.Valid { + t.Error("NullTime returns valid but expected null.") + } + + // if time is not null + ts, err := time.Parse(time.Stamp, "Jan 2 15:04:05") + ent = &WithNullTime{ + Id: 1, + Time: NullTime{ + Valid: true, + Time: ts, + }} + err = dbmap.Insert(ent) + if err != nil { + t.Error("failed insert on %s", err.Error()) + } + err = dbmap.SelectOne(ent, `select * from nulltime_test where Id=:Id`, map[string]interface{}{ + "Id": ent.Id, + }) + if err != nil { + t.Error("failed select on %s", err.Error()) + } + if !ent.Time.Valid { + t.Error("NullTime returns invalid but expected valid.") + } + if ent.Time.Time.UTC() != ts.UTC() { + t.Errorf("expect %v but got %v.", ts, ent.Time.Time) + } + + return +} + +type WithTime struct { + Id int64 + Time time.Time +} + +type Times struct { + One time.Time + Two time.Time +} + +type EmbeddedTime struct { + Id string + Times +} + +func parseTimeOrPanic(format, date string) time.Time { + t1, err := time.Parse(format, date) + if err != nil { + panic(err) + } + return t1 +} + +// TODO: re-enable next two tests when this is merged: +// https://github.com/ziutek/mymysql/pull/77 +// +// This test currently fails w/MySQL b/c tz info is lost +func testWithTime(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + t1 := parseTimeOrPanic("2006-01-02 15:04:05 -0700 MST", + "2013-08-09 21:30:43 +0800 CST") + w1 := WithTime{1, t1} + _insert(dbmap, &w1) + + obj := _get(dbmap, WithTime{}, w1.Id) + w2 := obj.(*WithTime) + if w1.Time.UnixNano() != w2.Time.UnixNano() { + t.Errorf("%v != %v", w1, w2) + } +} + +// See: https://github.com/go-gorp/gorp/issues/86 +func testEmbeddedTime(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(EmbeddedTime{}).SetKeys(false, "Id") + defer dropAndClose(dbmap) + err := dbmap.CreateTables() + if err != nil { + t.Fatal(err) + } + + time1 := parseTimeOrPanic("2006-01-02 15:04:05", "2013-08-09 21:30:43") + + t1 := &EmbeddedTime{Id: "abc", Times: Times{One: time1, Two: time1.Add(10 * time.Second)}} + _insert(dbmap, t1) + + x := _get(dbmap, EmbeddedTime{}, t1.Id) + t2, _ := x.(*EmbeddedTime) + if t1.One.UnixNano() != t2.One.UnixNano() || t1.Two.UnixNano() != t2.Two.UnixNano() { + t.Errorf("%v != %v", t1, t2) + } +} + +func TestWithTimeSelect(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + halfhourago := time.Now().UTC().Add(-30 * time.Minute) + + w1 := WithTime{1, halfhourago.Add(time.Minute * -1)} + w2 := WithTime{2, halfhourago.Add(time.Second)} + _insert(dbmap, &w1, &w2) + + var caseIds []int64 + _, err := dbmap.Select(&caseIds, "SELECT id FROM time_test WHERE Time < "+dbmap.Dialect.BindVar(0), halfhourago) + + if err != nil { + t.Error(err) + } + if len(caseIds) != 1 { + t.Errorf("%d != 1", len(caseIds)) + } + if caseIds[0] != w1.Id { + t.Errorf("%d != %d", caseIds[0], w1.Id) + } +} + +func TestInvoicePersonView(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + // Create some rows + p1 := &Person{0, 0, 0, "bob", "smith", 0} + dbmap.Insert(p1) + + // notice how we can wire up p1.Id to the invoice easily + inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, false} + dbmap.Insert(inv1) + + // Run your query + query := "select i.Id InvoiceId, p.Id PersonId, i.Memo, p.FName " + + "from invoice_test i, person_test p " + + "where i.PersonId = p.Id" + + // pass a slice of pointers to Select() + // this avoids the need to type assert after the query is run + var list []*InvoicePersonView + _, err := dbmap.Select(&list, query) + if err != nil { + panic(err) + } + + // this should test true + expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} + if !reflect.DeepEqual(list[0], expected) { + t.Errorf("%v != %v", list[0], expected) + } +} + +func TestQuoteTableNames(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + quotedTableName := dbmap.Dialect.QuoteField("person_test") + + // Use a buffer to hold the log to check generated queries + logBuffer := &bytes.Buffer{} + dbmap.TraceOn("", log.New(logBuffer, "gorptest:", log.Lmicroseconds)) + + // Create some rows + p1 := &Person{0, 0, 0, "bob", "smith", 0} + errorTemplate := "Expected quoted table name %v in query but didn't find it" + + // Check if Insert quotes the table name + id := dbmap.Insert(p1) + if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { + t.Errorf(errorTemplate, quotedTableName) + } + logBuffer.Reset() + + // Check if Get quotes the table name + dbmap.Get(Person{}, id) + if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { + t.Errorf(errorTemplate, quotedTableName) + } + logBuffer.Reset() +} + +func TestSelectTooManyCols(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + p2 := &Person{0, 0, 0, "jane", "doe", 0} + _insert(dbmap, p1) + _insert(dbmap, p2) + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + obj = _get(dbmap, Person{}, p2.Id) + p2 = obj.(*Person) + + params := map[string]interface{}{ + "Id": p1.Id, + } + + var p3 FNameOnly + err := dbmap.SelectOne(&p3, "select * from person_test where Id=:Id", params) + if err != nil { + if !NonFatalError(err) { + t.Error(err) + } + } else { + t.Errorf("Non-fatal error expected") + } + + if p1.FName != p3.FName { + t.Errorf("%v != %v", p1.FName, p3.FName) + } + + var pSlice []FNameOnly + _, err = dbmap.Select(&pSlice, "select * from person_test order by fname asc") + if err != nil { + if !NonFatalError(err) { + t.Error(err) + } + } else { + t.Errorf("Non-fatal error expected") + } + + if p1.FName != pSlice[0].FName { + t.Errorf("%v != %v", p1.FName, pSlice[0].FName) + } + if p2.FName != pSlice[1].FName { + t.Errorf("%v != %v", p2.FName, pSlice[1].FName) + } +} + +func TestSelectSingleVal(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + + params := map[string]interface{}{ + "Id": p1.Id, + } + + var p2 Person + err := dbmap.SelectOne(&p2, "select * from person_test where Id=:Id", params) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(p1, &p2) { + t.Errorf("%v != %v", p1, &p2) + } + + // verify SelectOne allows non-struct holders + var s string + err = dbmap.SelectOne(&s, "select FName from person_test where Id=:Id", params) + if err != nil { + t.Error(err) + } + if s != "bob" { + t.Error("Expected bob but got: " + s) + } + + // verify SelectOne requires pointer receiver + err = dbmap.SelectOne(s, "select FName from person_test where Id=:Id", params) + if err == nil { + t.Error("SelectOne should have returned error for non-pointer holder") + } + + // verify SelectOne works with uninitialized pointers + var p3 *Person + err = dbmap.SelectOne(&p3, "select * from person_test where Id=:Id", params) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(p1, p3) { + t.Errorf("%v != %v", p1, p3) + } + + // verify that the receiver is still nil if nothing was found + var p4 *Person + dbmap.SelectOne(&p3, "select * from person_test where 2<1 AND Id=:Id", params) + if p4 != nil { + t.Error("SelectOne should not have changed a nil receiver when no rows were found") + } + + // verify that the error is set to sql.ErrNoRows if not found + err = dbmap.SelectOne(&p2, "select * from person_test where Id=:Id", map[string]interface{}{ + "Id": -2222, + }) + if err == nil || err != sql.ErrNoRows { + t.Error("SelectOne should have returned an sql.ErrNoRows") + } + + _insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0}) + err = dbmap.SelectOne(&p2, "select * from person_test where Fname='bob'") + if err == nil { + t.Error("Expected error when two rows found") + } + + // tests for #150 + var tInt int64 + var tStr string + var tBool bool + var tFloat float64 + primVals := []interface{}{tInt, tStr, tBool, tFloat} + for _, prim := range primVals { + err = dbmap.SelectOne(&prim, "select * from person_test where Id=-123") + if err == nil || err != sql.ErrNoRows { + t.Error("primVals: SelectOne should have returned sql.ErrNoRows") + } + } +} + +func TestSelectAlias(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &IdCreatedExternal{IdCreated: IdCreated{Id: 1, Created: 3}, External: 2} + + // Insert using embedded IdCreated, which reflects the structure of the table + _insert(dbmap, &p1.IdCreated) + + // Select into IdCreatedExternal type, which includes some fields not present + // in id_created_test + var p2 IdCreatedExternal + err := dbmap.SelectOne(&p2, "select * from id_created_test where Id=1") + if err != nil { + t.Error(err) + } + if p2.Id != 1 || p2.Created != 3 || p2.External != 0 { + t.Error("Expected ignored field defaults to not set") + } + + // Prove that we can supply an aliased value in the select, and that it will + // automatically map to IdCreatedExternal.External + err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM id_created_test") + if err != nil { + t.Error(err) + } + if p2.External != 1 { + t.Error("Expected select as can map to exported field.") + } + + var rows *sql.Rows + var cols []string + rows, err = dbmap.Db.Query("SELECT * FROM id_created_test") + cols, err = rows.Columns() + if err != nil || len(cols) != 2 { + t.Error("Expected ignored column not created") + } +} + +func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { + _, driver := dialectAndDriver() + // this test only applies to MySQL + if os.Getenv("GORP_TEST_DIALECT") != "mysql" { + return + } + + // The expected behaviour is to catch a panic. + // Here is the deferred function which will check if a panic has indeed occurred : + defer func() { + r := recover() + if r == nil { + t.Error("db.CreateTables() should panic if db is initialized with an incorrect MySQLDialect") + } + }() + + // invalid MySQLDialect : does not contain Engine or Encoding specification + dialect := MySQLDialect{} + db := &DbMap{Db: connect(driver), Dialect: dialect} + db.AddTableWithName(Invoice{}, "invoice") + // the following call should panic : + db.CreateTables() +} + +func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId") + err := dbmap.DropTablesIfExists() + if err != nil { + t.Error("Drop tables failed") + } + err = dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error("Create tables failed") + } + err = dbmap.TruncateTables() + if err != nil { + t.Error("Truncate tables failed") + } + + sct := SingleColumnTable{ + SomeId: "A Unique Id String", + } + + count, err := dbmap.Update(&sct) + if err != nil { + t.Error(err) + } + if count != 0 { + t.Errorf("Expected 0 updated rows, got %d", count) + } + +} + +func TestPrepare(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "prepare-foo", 0, false} + inv2 := &Invoice{0, 100, 200, "prepare-bar", 0, false} + _insert(dbmap, inv1, inv2) + + bindVar0 := dbmap.Dialect.BindVar(0) + bindVar1 := dbmap.Dialect.BindVar(1) + stmt, err := dbmap.Prepare(fmt.Sprintf("UPDATE invoice_test SET Memo=%s WHERE Id=%s", bindVar0, bindVar1)) + if err != nil { + t.Error(err) + } + defer stmt.Close() + _, err = stmt.Exec("prepare-baz", inv1.Id) + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv1, "SELECT * from invoice_test WHERE Memo='prepare-baz'") + if err != nil { + t.Error(err) + } + + trans, err := dbmap.Begin() + if err != nil { + t.Error(err) + } + transStmt, err := trans.Prepare(fmt.Sprintf("UPDATE invoice_test SET IsPaid=%s WHERE Id=%s", bindVar0, bindVar1)) + if err != nil { + t.Error(err) + } + defer transStmt.Close() + _, err = transStmt.Exec(true, inv2.Id) + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true) + if err == nil || err != sql.ErrNoRows { + t.Error("SelectOne should have returned an sql.ErrNoRows") + } + err = trans.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true) + if err != nil { + t.Error(err) + } + err = trans.Commit() + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true) + if err != nil { + t.Error(err) + } +} + +func BenchmarkNativeCrud(b *testing.B) { + b.StopTimer() + dbmap := initDbMapBench() + defer dropAndClose(dbmap) + b.StartTimer() + + insert := "insert into invoice_test (Created, Updated, Memo, PersonId) values (?, ?, ?, ?)" + sel := "select Id, Created, Updated, Memo, PersonId from invoice_test where Id=?" + update := "update invoice_test set Created=?, Updated=?, Memo=?, PersonId=? where Id=?" + delete := "delete from invoice_test where Id=?" + + inv := &Invoice{0, 100, 200, "my memo", 0, false} + + for i := 0; i < b.N; i++ { + res, err := dbmap.Db.Exec(insert, inv.Created, inv.Updated, + inv.Memo, inv.PersonId) + if err != nil { + panic(err) + } + + newid, err := res.LastInsertId() + if err != nil { + panic(err) + } + inv.Id = newid + + row := dbmap.Db.QueryRow(sel, inv.Id) + err = row.Scan(&inv.Id, &inv.Created, &inv.Updated, &inv.Memo, + &inv.PersonId) + if err != nil { + panic(err) + } + + inv.Created = 1000 + inv.Updated = 2000 + inv.Memo = "my memo 2" + inv.PersonId = 3000 + + _, err = dbmap.Db.Exec(update, inv.Created, inv.Updated, inv.Memo, + inv.PersonId, inv.Id) + if err != nil { + panic(err) + } + + _, err = dbmap.Db.Exec(delete, inv.Id) + if err != nil { + panic(err) + } + } + +} + +func BenchmarkGorpCrud(b *testing.B) { + b.StopTimer() + dbmap := initDbMapBench() + defer dropAndClose(dbmap) + b.StartTimer() + + inv := &Invoice{0, 100, 200, "my memo", 0, true} + for i := 0; i < b.N; i++ { + err := dbmap.Insert(inv) + if err != nil { + panic(err) + } + + obj, err := dbmap.Get(Invoice{}, inv.Id) + if err != nil { + panic(err) + } + + inv2, ok := obj.(*Invoice) + if !ok { + panic(fmt.Sprintf("expected *Invoice, got: %v", obj)) + } + + inv2.Created = 1000 + inv2.Updated = 2000 + inv2.Memo = "my memo 2" + inv2.PersonId = 3000 + _, err = dbmap.Update(inv2) + if err != nil { + panic(err) + } + + _, err = dbmap.Delete(inv2) + if err != nil { + panic(err) + } + + } +} + +func initDbMapBench() *DbMap { + dbmap := newDbMap() + dbmap.Db.Exec("drop table if exists invoice_test") + dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + return dbmap +} + +func initDbMap() *DbMap { + dbmap := newDbMap() + dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") + dbmap.AddTableWithName(InvoiceTag{}, "invoice_tag_test").SetKeys(true, "myid") + dbmap.AddTableWithName(AliasTransientField{}, "alias_trans_field_test").SetKeys(true, "id") + dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") + dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id").SetVersionCol("Version") + dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") + dbmap.AddTableWithName(IdCreated{}, "id_created_test").SetKeys(true, "Id") + dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithEmbeddedAutoincr{}, "embedded_autoincr_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithTime{}, "time_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithNullTime{}, "nulltime_test").SetKeys(false, "Id") + dbmap.TypeConverter = testTypeConverter{} + err := dbmap.DropTablesIfExists() + if err != nil { + panic(err) + } + err = dbmap.CreateTables() + if err != nil { + panic(err) + } + + // See #146 and TestSelectAlias - this type is mapped to the same + // table as IdCreated, but includes an extra field that isn't in the table + dbmap.AddTableWithName(IdCreatedExternal{}, "id_created_test").SetKeys(true, "Id") + + return dbmap +} + +func initDbMapNulls() *DbMap { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + return dbmap +} + +func newDbMap() *DbMap { + dialect, driver := dialectAndDriver() + dbmap := &DbMap{Db: connect(driver), Dialect: dialect} + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + return dbmap +} + +func dropAndClose(dbmap *DbMap) { + dbmap.DropTablesIfExists() + dbmap.Db.Close() +} + +func connect(driver string) *sql.DB { + dsn := os.Getenv("GORP_TEST_DSN") + if dsn == "" { + panic("GORP_TEST_DSN env variable is not set. Please see README.md") + } + + db, err := sql.Open(driver, dsn) + if err != nil { + panic("Error connecting to db: " + err.Error()) + } + return db +} + +func dialectAndDriver() (Dialect, string) { + switch os.Getenv("GORP_TEST_DIALECT") { + case "mysql": + return MySQLDialect{"InnoDB", "UTF8"}, "mymysql" + case "gomysql": + return MySQLDialect{"InnoDB", "UTF8"}, "mysql" + case "postgres": + return PostgresDialect{}, "postgres" + case "sqlite": + return SqliteDialect{}, "sqlite3" + } + panic("GORP_TEST_DIALECT env variable is not set or is invalid. Please see README.md") +} + +func _insert(dbmap *DbMap, list ...interface{}) { + err := dbmap.Insert(list...) + if err != nil { + panic(err) + } +} + +func _update(dbmap *DbMap, list ...interface{}) int64 { + count, err := dbmap.Update(list...) + if err != nil { + panic(err) + } + return count +} + +func _del(dbmap *DbMap, list ...interface{}) int64 { + count, err := dbmap.Delete(list...) + if err != nil { + panic(err) + } + + return count +} + +func _get(dbmap *DbMap, i interface{}, keys ...interface{}) interface{} { + obj, err := dbmap.Get(i, keys...) + if err != nil { + panic(err) + } + + return obj +} + +func selectInt(dbmap *DbMap, query string, args ...interface{}) int64 { + i64, err := SelectInt(dbmap, query, args...) + if err != nil { + panic(err) + } + + return i64 +} + +func selectNullInt(dbmap *DbMap, query string, args ...interface{}) sql.NullInt64 { + i64, err := SelectNullInt(dbmap, query, args...) + if err != nil { + panic(err) + } + + return i64 +} + +func selectFloat(dbmap *DbMap, query string, args ...interface{}) float64 { + f64, err := SelectFloat(dbmap, query, args...) + if err != nil { + panic(err) + } + + return f64 +} + +func selectNullFloat(dbmap *DbMap, query string, args ...interface{}) sql.NullFloat64 { + f64, err := SelectNullFloat(dbmap, query, args...) + if err != nil { + panic(err) + } + + return f64 +} + +func selectStr(dbmap *DbMap, query string, args ...interface{}) string { + s, err := SelectStr(dbmap, query, args...) + if err != nil { + panic(err) + } + + return s +} + +func selectNullStr(dbmap *DbMap, query string, args ...interface{}) sql.NullString { + s, err := SelectNullStr(dbmap, query, args...) + if err != nil { + panic(err) + } + + return s +} + +func _rawexec(dbmap *DbMap, query string, args ...interface{}) sql.Result { + res, err := dbmap.Exec(query, args...) + if err != nil { + panic(err) + } + return res +} + +func _rawselect(dbmap *DbMap, i interface{}, query string, args ...interface{}) []interface{} { + list, err := dbmap.Select(i, query, args...) + if err != nil { + panic(err) + } + return list +} diff --git a/Godeps/_workspace/src/github.com/go-gorp/gorp/test_all.sh b/Godeps/_workspace/src/github.com/go-gorp/gorp/test_all.sh new file mode 100644 index 000000000..f870b39a3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-gorp/gorp/test_all.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +# on macs, you may need to: +# export GOBUILDFLAG=-ldflags -linkmode=external + +set -e + +export GORP_TEST_DSN=gorptest/gorptest/gorptest +export GORP_TEST_DIALECT=mysql +go test $GOBUILDFLAG . + +export GORP_TEST_DSN=gorptest:gorptest@/gorptest +export GORP_TEST_DIALECT=gomysql +go test $GOBUILDFLAG . + +export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable" +export GORP_TEST_DIALECT=postgres +go test $GOBUILDFLAG . + +export GORP_TEST_DSN=/tmp/gorptest.bin +export GORP_TEST_DIALECT=sqlite +go test $GOBUILDFLAG . -- cgit v1.2.3-1-g7c22