From ff21a5c75f090d5c27446eeaa6fb3ff6f82c5ab4 Mon Sep 17 00:00:00 2001 From: =Corey Hulen Date: Sun, 12 Jul 2015 14:56:44 -0800 Subject: Fixes mm-1420 adding postgres support --- Godeps/_workspace/src/github.com/lib/pq/encode.go | 538 ++++++++++++++++++++++ 1 file changed, 538 insertions(+) create mode 100644 Godeps/_workspace/src/github.com/lib/pq/encode.go (limited to 'Godeps/_workspace/src/github.com/lib/pq/encode.go') diff --git a/Godeps/_workspace/src/github.com/lib/pq/encode.go b/Godeps/_workspace/src/github.com/lib/pq/encode.go new file mode 100644 index 000000000..88422eb3d --- /dev/null +++ b/Godeps/_workspace/src/github.com/lib/pq/encode.go @@ -0,0 +1,538 @@ +package pq + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "strconv" + "strings" + "sync" + "time" + + "github.com/lib/pq/oid" +) + +func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { + switch v := x.(type) { + case []byte: + return v + default: + return encode(parameterStatus, x, oid.T_unknown) + } + panic("not reached") +} + +func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { + switch v := x.(type) { + case int64: + return strconv.AppendInt(nil, v, 10) + case float64: + return strconv.AppendFloat(nil, v, 'f', -1, 64) + case []byte: + if pgtypOid == oid.T_bytea { + return encodeBytea(parameterStatus.serverVersion, v) + } + + return v + case string: + if pgtypOid == oid.T_bytea { + return encodeBytea(parameterStatus.serverVersion, []byte(v)) + } + + return []byte(v) + case bool: + return strconv.AppendBool(nil, v) + case time.Time: + return formatTs(v) + + default: + errorf("encode: unknown type for %T", v) + } + + panic("not reached") +} + +func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { + if f == formatBinary { + return binaryDecode(parameterStatus, s, typ) + } else { + return textDecode(parameterStatus, s, typ) + } +} + +func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_bytea: + return s + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)) + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))) + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))) + + default: + errorf("don't know how to decode binary parameter of type %u", uint32(typ)) + } + + panic("not reached") +} + +func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_bytea: + return parseBytea(s) + case oid.T_timestamptz: + return parseTs(parameterStatus.currentLocation, string(s)) + case oid.T_timestamp, oid.T_date: + return parseTs(nil, string(s)) + case oid.T_time: + return mustParse("15:04:05", typ, s) + case oid.T_timetz: + return mustParse("15:04:05-07", typ, s) + case oid.T_bool: + return s[0] == 't' + case oid.T_int8, oid.T_int4, oid.T_int2: + i, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + errorf("%s", err) + } + return i + case oid.T_float4, oid.T_float8: + bits := 64 + if typ == oid.T_float4 { + bits = 32 + } + f, err := strconv.ParseFloat(string(s), bits) + if err != nil { + errorf("%s", err) + } + return f + } + + return s +} + +// appendEncodedText encodes item in text format as required by COPY +// and appends to buf +func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { + switch v := x.(type) { + case int64: + return strconv.AppendInt(buf, v, 10) + case float64: + return strconv.AppendFloat(buf, v, 'f', -1, 64) + case []byte: + encodedBytea := encodeBytea(parameterStatus.serverVersion, v) + return appendEscapedText(buf, string(encodedBytea)) + case string: + return appendEscapedText(buf, v) + case bool: + return strconv.AppendBool(buf, v) + case time.Time: + return append(buf, formatTs(v)...) + case nil: + return append(buf, "\\N"...) + default: + errorf("encode: unknown type for %T", v) + } + + panic("not reached") +} + +func appendEscapedText(buf []byte, text string) []byte { + escapeNeeded := false + startPos := 0 + var c byte + + // check if we need to escape + for i := 0; i < len(text); i++ { + c = text[i] + if c == '\\' || c == '\n' || c == '\r' || c == '\t' { + escapeNeeded = true + startPos = i + break + } + } + if !escapeNeeded { + return append(buf, text...) + } + + // copy till first char to escape, iterate the rest + result := append(buf, text[:startPos]...) + for i := startPos; i < len(text); i++ { + c = text[i] + switch c { + case '\\': + result = append(result, '\\', '\\') + case '\n': + result = append(result, '\\', 'n') + case '\r': + result = append(result, '\\', 'r') + case '\t': + result = append(result, '\\', 't') + default: + result = append(result, c) + } + } + return result +} + +func mustParse(f string, typ oid.Oid, s []byte) time.Time { + str := string(s) + + // check for a 30-minute-offset timezone + if (typ == oid.T_timestamptz || typ == oid.T_timetz) && + str[len(str)-3] == ':' { + f += ":00" + } + t, err := time.Parse(f, str) + if err != nil { + errorf("decode: %s", err) + } + return t +} + +func expect(str, char string, pos int) { + if c := str[pos : pos+1]; c != char { + errorf("expected '%v' at position %v; got '%v'", char, pos, c) + } +} + +func mustAtoi(str string) int { + result, err := strconv.Atoi(str) + if err != nil { + errorf("expected number; got '%v'", str) + } + return result +} + +// The location cache caches the time zones typically used by the client. +type locationCache struct { + cache map[int]*time.Location + lock sync.Mutex +} + +// All connections share the same list of timezones. Benchmarking shows that +// about 5% speed could be gained by putting the cache in the connection and +// losing the mutex, at the cost of a small amount of memory and a somewhat +// significant increase in code complexity. +var globalLocationCache *locationCache = newLocationCache() + +func newLocationCache() *locationCache { + return &locationCache{cache: make(map[int]*time.Location)} +} + +// Returns the cached timezone for the specified offset, creating and caching +// it if necessary. +func (c *locationCache) getLocation(offset int) *time.Location { + c.lock.Lock() + defer c.lock.Unlock() + + location, ok := c.cache[offset] + if !ok { + location = time.FixedZone("", offset) + c.cache[offset] = location + } + + return location +} + +var infinityTsEnabled = false +var infinityTsNegative time.Time +var infinityTsPositive time.Time + +const ( + infinityTsEnabledAlready = "pq: infinity timestamp enabled already" + infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" +) + +/* + * If EnableInfinityTs is not called, "-infinity" and "infinity" will return + * []byte("-infinity") and []byte("infinity") respectively, and potentially + * cause error "sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time", + * when scanning into a time.Time value. + * + * Once EnableInfinityTs has been called, all connections created using this + * driver will decode Postgres' "-infinity" and "infinity" for "timestamp", + * "timestamp with time zone" and "date" types to the predefined minimum and + * maximum times, respectively. When encoding time.Time values, any time which + * equals or preceeds the predefined minimum time will be encoded to + * "-infinity". Any values at or past the maximum time will similarly be + * encoded to "infinity". + * + * + * If EnableInfinityTs is called with negative >= positive, it will panic. + * Calling EnableInfinityTs after a connection has been established results in + * undefined behavior. If EnableInfinityTs is called more than once, it will + * panic. + */ +func EnableInfinityTs(negative time.Time, positive time.Time) { + if infinityTsEnabled { + panic(infinityTsEnabledAlready) + } + if !negative.Before(positive) { + panic(infinityTsNegativeMustBeSmaller) + } + infinityTsEnabled = true + infinityTsNegative = negative + infinityTsPositive = positive +} + +/* + * Testing might want to toggle infinityTsEnabled + */ +func disableInfinityTs() { + infinityTsEnabled = false +} + +// This is a time function specific to the Postgres default DateStyle +// setting ("ISO, MDY"), the only one we currently support. This +// accounts for the discrepancies between the parsing available with +// time.Parse and the Postgres date formatting quirks. +func parseTs(currentLocation *time.Location, str string) interface{} { + switch str { + case "-infinity": + if infinityTsEnabled { + return infinityTsNegative + } + return []byte(str) + case "infinity": + if infinityTsEnabled { + return infinityTsPositive + } + return []byte(str) + } + + monSep := strings.IndexRune(str, '-') + // this is Gregorian year, not ISO Year + // In Gregorian system, the year 1 BC is followed by AD 1 + year := mustAtoi(str[:monSep]) + daySep := monSep + 3 + month := mustAtoi(str[monSep+1 : daySep]) + expect(str, "-", daySep) + timeSep := daySep + 3 + day := mustAtoi(str[daySep+1 : timeSep]) + + var hour, minute, second int + if len(str) > monSep+len("01-01")+1 { + expect(str, " ", timeSep) + minSep := timeSep + 3 + expect(str, ":", minSep) + hour = mustAtoi(str[timeSep+1 : minSep]) + secSep := minSep + 3 + expect(str, ":", secSep) + minute = mustAtoi(str[minSep+1 : secSep]) + secEnd := secSep + 3 + second = mustAtoi(str[secSep+1 : secEnd]) + } + remainderIdx := monSep + len("01-01 00:00:00") + 1 + // Three optional (but ordered) sections follow: the + // fractional seconds, the time zone offset, and the BC + // designation. We set them up here and adjust the other + // offsets if the preceding sections exist. + + nanoSec := 0 + tzOff := 0 + + if remainderIdx < len(str) && str[remainderIdx:remainderIdx+1] == "." { + fracStart := remainderIdx + 1 + fracOff := strings.IndexAny(str[fracStart:], "-+ ") + if fracOff < 0 { + fracOff = len(str) - fracStart + } + fracSec := mustAtoi(str[fracStart : fracStart+fracOff]) + nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) + + remainderIdx += fracOff + 1 + } + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") { + // time zone separator is always '-' or '+' (UTC is +00) + var tzSign int + if c := str[tzStart : tzStart+1]; c == "-" { + tzSign = -1 + } else if c == "+" { + tzSign = +1 + } else { + errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + } + tzHours := mustAtoi(str[tzStart+1 : tzStart+3]) + remainderIdx += 3 + var tzMin, tzSec int + if tzStart+3 < len(str) && str[tzStart+3:tzStart+4] == ":" { + tzMin = mustAtoi(str[tzStart+4 : tzStart+6]) + remainderIdx += 3 + } + if tzStart+6 < len(str) && str[tzStart+6:tzStart+7] == ":" { + tzSec = mustAtoi(str[tzStart+7 : tzStart+9]) + remainderIdx += 3 + } + tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) + } + var isoYear int + if remainderIdx < len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + isoYear = 1 - year + remainderIdx += 3 + } else { + isoYear = year + } + if remainderIdx < len(str) { + errorf("expected end of input, got %v", str[remainderIdx:]) + } + t := time.Date(isoYear, time.Month(month), day, + hour, minute, second, nanoSec, + globalLocationCache.getLocation(tzOff)) + + if currentLocation != nil { + // Set the location of the returned Time based on the session's + // TimeZone value, but only if the local time zone database agrees with + // the remote database on the offset. + lt := t.In(currentLocation) + _, newOff := lt.Zone() + if newOff == tzOff { + t = lt + } + } + + return t +} + +// formatTs formats t into a format postgres understands. +func formatTs(t time.Time) (b []byte) { + if infinityTsEnabled { + // t <= -infinity : ! (t > -infinity) + if !t.After(infinityTsNegative) { + return []byte("-infinity") + } + // t >= infinity : ! (!t < infinity) + if !t.Before(infinityTsPositive) { + return []byte("infinity") + } + } + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b = []byte(t.Format(time.RFC3339Nano)) + + _, offset := t.Zone() + offset = offset % 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytea(s []byte) (result []byte) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + errorf("%s", err) + } + } else { + // bytea_output = escape + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + if err != nil { + errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result +} + +func encodeBytea(serverVersion int, v []byte) (result []byte) { + if serverVersion >= 90000 { + // Use the hex format if we know that the server supports it + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) + } else { + // .. or resort to "escape" + for _, b := range v { + if b == '\\' { + result = append(result, '\\', '\\') + } else if b < 0x20 || b > 0x7e { + result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) + } else { + result = append(result, b) + } + } + } + + return result +} + +// NullTime represents a time.Time that may be null. NullTime implements the +// sql.Scanner interface so it can be used as a scan destination, similar to +// sql.NullString. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + nt.Time, nt.Valid = value.(time.Time) + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} -- cgit v1.2.3-1-g7c22