summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/lib/pq/go18_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lib/pq/go18_test.go')
-rw-r--r--vendor/github.com/lib/pq/go18_test.go54
1 files changed, 54 insertions, 0 deletions
diff --git a/vendor/github.com/lib/pq/go18_test.go b/vendor/github.com/lib/pq/go18_test.go
index 5d17e4d92..cddbfb6a4 100644
--- a/vendor/github.com/lib/pq/go18_test.go
+++ b/vendor/github.com/lib/pq/go18_test.go
@@ -72,6 +72,8 @@ func TestMultipleSimpleQuery(t *testing.T) {
}
}
+const contextRaceIterations = 100
+
func TestContextCancelExec(t *testing.T) {
db := openTestConn(t)
defer db.Close()
@@ -94,6 +96,20 @@ func TestContextCancelExec(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ if _, err := db.ExecContext(ctx, "select 1"); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ if _, err := db.Exec("select 1"); err != nil {
+ t.Fatal(err)
+ }
+ }
}
func TestContextCancelQuery(t *testing.T) {
@@ -118,6 +134,25 @@ func TestContextCancelQuery(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ rows, err := db.QueryContext(ctx, "select 1")
+ cancel()
+ if err != nil {
+ t.Fatal(err)
+ } else if err := rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ if rows, err := db.Query("select 1"); err != nil {
+ t.Fatal(err)
+ } else if err := rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }
}
func TestContextCancelBegin(t *testing.T) {
@@ -153,4 +188,23 @@ func TestContextCancelBegin(t *testing.T) {
} else if err.Error() != "context canceled" {
t.Fatalf("unexpected error: %s", err)
}
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ tx, err := db.BeginTx(ctx, nil)
+ cancel()
+ if err != nil {
+ t.Fatal(err)
+ } else if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
+ t.Fatal(err)
+ }
+ }()
+
+ if tx, err := db.Begin(); err != nil {
+ t.Fatal(err)
+ } else if err := tx.Rollback(); err != nil {
+ t.Fatal(err)
+ }
+ }
}