summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/lib/pq/conn_go18.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lib/pq/conn_go18.go')
-rw-r--r--vendor/github.com/lib/pq/conn_go18.go64
1 files changed, 41 insertions, 23 deletions
diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go
index 43cc35f7b..fa3755d99 100644
--- a/vendor/github.com/lib/pq/conn_go18.go
+++ b/vendor/github.com/lib/pq/conn_go18.go
@@ -6,6 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
+ "io"
+ "io/ioutil"
)
// Implement the "QueryerContext" interface
@@ -14,12 +16,12 @@ func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.Na
for i, nv := range args {
list[i] = nv.Value
}
- closed := cn.watchCancel(ctx)
+ finish := cn.watchCancel(ctx)
r, err := cn.query(query, list)
if err != nil {
return nil, err
}
- r.closed = closed
+ r.finish = finish
return r, nil
}
@@ -30,8 +32,8 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
list[i] = nv.Value
}
- if closed := cn.watchCancel(ctx); closed != nil {
- defer close(closed)
+ if finish := cn.watchCancel(ctx); finish != nil {
+ defer finish()
}
return cn.Exec(query, list)
@@ -49,41 +51,57 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
if err != nil {
return nil, err
}
- cn.txnClosed = cn.watchCancel(ctx)
+ cn.txnFinish = cn.watchCancel(ctx)
return tx, nil
}
-func (cn *conn) watchCancel(ctx context.Context) chan<- struct{} {
+func (cn *conn) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
- closed := make(chan struct{})
+ finished := make(chan struct{})
go func() {
select {
case <-done:
- cn.cancel()
- case <-closed:
+ _ = cn.cancel()
+ finished <- struct{}{}
+ case <-finished:
}
}()
- return closed
+ return func() {
+ select {
+ case <-finished:
+ case finished <- struct{}{}:
+ }
+ }
}
return nil
}
-func (cn *conn) cancel() {
- var err error
- can := &conn{}
- can.c, err = dial(cn.dialer, cn.opts)
+func (cn *conn) cancel() error {
+ c, err := dial(cn.dialer, cn.opts)
if err != nil {
- return
+ return err
}
- can.ssl(cn.opts)
+ defer c.Close()
- defer can.errRecover(&err)
+ {
+ can := conn{
+ c: c,
+ }
+ can.ssl(cn.opts)
- w := can.writeBuf(0)
- w.int32(80877102) // cancel request code
- w.int32(cn.processID)
- w.int32(cn.secretKey)
+ w := can.writeBuf(0)
+ w.int32(80877102) // cancel request code
+ w.int32(cn.processID)
+ w.int32(cn.secretKey)
- can.sendStartupPacket(w)
- _ = can.c.Close()
+ if err := can.sendStartupPacket(w); err != nil {
+ return err
+ }
+ }
+
+ // Read until EOF to ensure that the server received the cancel.
+ {
+ _, err := io.Copy(ioutil.Discard, c)
+ return err
+ }
}