summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/lib
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lib')
-rw-r--r--vendor/github.com/lib/pq/conn.go32
-rw-r--r--vendor/github.com/lib/pq/conn_go18.go64
-rw-r--r--vendor/github.com/lib/pq/go18_test.go54
3 files changed, 109 insertions, 41 deletions
diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go
index 4b2fb4462..c005cf9c4 100644
--- a/vendor/github.com/lib/pq/conn.go
+++ b/vendor/github.com/lib/pq/conn.go
@@ -98,7 +98,7 @@ type conn struct {
namei int
scratch [512]byte
txnStatus transactionStatus
- txnClosed chan<- struct{}
+ txnFinish func()
// Save connection arguments to use during CancelRequest.
dialer Dialer
@@ -528,9 +528,8 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
}
func (cn *conn) closeTxn() {
- if cn.txnClosed != nil {
- close(cn.txnClosed)
- cn.txnClosed = nil
+ if finish := cn.txnFinish; finish != nil {
+ finish()
}
}
@@ -893,16 +892,9 @@ func (cn *conn) send(m *writeBuf) {
}
}
-func (cn *conn) sendStartupPacket(m *writeBuf) {
- // sanity check
- if m.buf[0] != 0 {
- panic("oops")
- }
-
+func (cn *conn) sendStartupPacket(m *writeBuf) error {
_, err := cn.c.Write((m.wrap())[1:])
- if err != nil {
- panic(err)
- }
+ return err
}
// Send a message of type typ to the server on the other end of cn. The
@@ -1024,7 +1016,9 @@ func (cn *conn) ssl(o values) {
w := cn.writeBuf(0)
w.int32(80877103)
- cn.sendStartupPacket(w)
+ if err := cn.sendStartupPacket(w); err != nil {
+ panic(err)
+ }
b := cn.scratch[:1]
_, err := io.ReadFull(cn.c, b)
@@ -1085,7 +1079,9 @@ func (cn *conn) startup(o values) {
w.string(v)
}
w.string("")
- cn.sendStartupPacket(w)
+ if err := cn.sendStartupPacket(w); err != nil {
+ panic(err)
+ }
for {
t, r := cn.recv()
@@ -1319,7 +1315,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
type rows struct {
cn *conn
- closed chan<- struct{}
+ finish func()
colNames []string
colTyps []oid.Oid
colFmts []format
@@ -1330,8 +1326,8 @@ type rows struct {
}
func (rs *rows) Close() error {
- if rs.closed != nil {
- defer close(rs.closed)
+ if finish := rs.finish; finish != nil {
+ defer finish()
}
// no need to look at cn.bad as Next() will
for {
diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go
index 43cc35f7b..fa3755d99 100644
--- a/vendor/github.com/lib/pq/conn_go18.go
+++ b/vendor/github.com/lib/pq/conn_go18.go
@@ -6,6 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
+ "io"
+ "io/ioutil"
)
// Implement the "QueryerContext" interface
@@ -14,12 +16,12 @@ func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.Na
for i, nv := range args {
list[i] = nv.Value
}
- closed := cn.watchCancel(ctx)
+ finish := cn.watchCancel(ctx)
r, err := cn.query(query, list)
if err != nil {
return nil, err
}
- r.closed = closed
+ r.finish = finish
return r, nil
}
@@ -30,8 +32,8 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
list[i] = nv.Value
}
- if closed := cn.watchCancel(ctx); closed != nil {
- defer close(closed)
+ if finish := cn.watchCancel(ctx); finish != nil {
+ defer finish()
}
return cn.Exec(query, list)
@@ -49,41 +51,57 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
if err != nil {
return nil, err
}
- cn.txnClosed = cn.watchCancel(ctx)
+ cn.txnFinish = cn.watchCancel(ctx)
return tx, nil
}
-func (cn *conn) watchCancel(ctx context.Context) chan<- struct{} {
+func (cn *conn) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
- closed := make(chan struct{})
+ finished := make(chan struct{})
go func() {
select {
case <-done:
- cn.cancel()
- case <-closed:
+ _ = cn.cancel()
+ finished <- struct{}{}
+ case <-finished:
}
}()
- return closed
+ return func() {
+ select {
+ case <-finished:
+ case finished <- struct{}{}:
+ }
+ }
}
return nil
}
-func (cn *conn) cancel() {
- var err error
- can := &conn{}
- can.c, err = dial(cn.dialer, cn.opts)
+func (cn *conn) cancel() error {
+ c, err := dial(cn.dialer, cn.opts)
if err != nil {
- return
+ return err
}
- can.ssl(cn.opts)
+ defer c.Close()
- defer can.errRecover(&err)
+ {
+ can := conn{
+ c: c,
+ }
+ can.ssl(cn.opts)
- w := can.writeBuf(0)
- w.int32(80877102) // cancel request code
- w.int32(cn.processID)
- w.int32(cn.secretKey)
+ w := can.writeBuf(0)
+ w.int32(80877102) // cancel request code
+ w.int32(cn.processID)
+ w.int32(cn.secretKey)
- can.sendStartupPacket(w)
- _ = can.c.Close()
+ if err := can.sendStartupPacket(w); err != nil {
+ return err
+ }
+ }
+
+ // Read until EOF to ensure that the server received the cancel.
+ {
+ _, err := io.Copy(ioutil.Discard, c)
+ return err
+ }
}
diff --git a/vendor/github.com/lib/pq/go18_test.go b/vendor/github.com/lib/pq/go18_test.go
index 5d17e4d92..cddbfb6a4 100644
--- a/vendor/github.com/lib/pq/go18_test.go
+++ b/vendor/github.com/lib/pq/go18_test.go
@@ -72,6 +72,8 @@ func TestMultipleSimpleQuery(t *testing.T) {
}
}
+const contextRaceIterations = 100
+
func TestContextCancelExec(t *testing.T) {
db := openTestConn(t)
defer db.Close()
@@ -94,6 +96,20 @@ func TestContextCancelExec(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ if _, err := db.ExecContext(ctx, "select 1"); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ if _, err := db.Exec("select 1"); err != nil {
+ t.Fatal(err)
+ }
+ }
}
func TestContextCancelQuery(t *testing.T) {
@@ -118,6 +134,25 @@ func TestContextCancelQuery(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ rows, err := db.QueryContext(ctx, "select 1")
+ cancel()
+ if err != nil {
+ t.Fatal(err)
+ } else if err := rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ if rows, err := db.Query("select 1"); err != nil {
+ t.Fatal(err)
+ } else if err := rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }
}
func TestContextCancelBegin(t *testing.T) {
@@ -153,4 +188,23 @@ func TestContextCancelBegin(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ tx, err := db.BeginTx(ctx, nil)
+ cancel()
+ if err != nil {
+ t.Fatal(err)
+ } else if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
+ t.Fatal(err)
+ }
+ }()
+
+ if tx, err := db.Begin(); err != nil {
+ t.Fatal(err)
+ } else if err := tx.Rollback(); err != nil {
+ t.Fatal(err)
+ }
+ }
}