diff options
Diffstat (limited to 'vendor/github.com/lib/pq/conn_go18.go')
-rw-r--r-- | vendor/github.com/lib/pq/conn_go18.go | 64 |
1 files changed, 41 insertions, 23 deletions
diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index 43cc35f7b..fa3755d99 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -6,6 +6,8 @@ import ( "context" "database/sql/driver" "errors" + "io" + "io/ioutil" ) // Implement the "QueryerContext" interface @@ -14,12 +16,12 @@ func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.Na for i, nv := range args { list[i] = nv.Value } - closed := cn.watchCancel(ctx) + finish := cn.watchCancel(ctx) r, err := cn.query(query, list) if err != nil { return nil, err } - r.closed = closed + r.finish = finish return r, nil } @@ -30,8 +32,8 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam list[i] = nv.Value } - if closed := cn.watchCancel(ctx); closed != nil { - defer close(closed) + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() } return cn.Exec(query, list) @@ -49,41 +51,57 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, if err != nil { return nil, err } - cn.txnClosed = cn.watchCancel(ctx) + cn.txnFinish = cn.watchCancel(ctx) return tx, nil } -func (cn *conn) watchCancel(ctx context.Context) chan<- struct{} { +func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - closed := make(chan struct{}) + finished := make(chan struct{}) go func() { select { case <-done: - cn.cancel() - case <-closed: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: } }() - return closed + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } } return nil } -func (cn *conn) cancel() { - var err error - can := &conn{} - can.c, err = dial(cn.dialer, cn.opts) +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) if err != nil { - return + return err } - can.ssl(cn.opts) + defer c.Close() - defer can.errRecover(&err) + { + can := conn{ + c: c, + } + can.ssl(cn.opts) - w := can.writeBuf(0) - w.int32(80877102) // cancel request code - w.int32(cn.processID) - w.int32(cn.secretKey) + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) - can.sendStartupPacket(w) - _ = can.c.Close() + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } } |