summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-sql-driver/mysql/infile.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/infile.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/infile.go56
1 files changed, 38 insertions, 18 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go
index 121a04c71..547357cfa 100644
--- a/vendor/github.com/go-sql-driver/mysql/infile.go
+++ b/vendor/github.com/go-sql-driver/mysql/infile.go
@@ -13,11 +13,14 @@ import (
"io"
"os"
"strings"
+ "sync"
)
var (
- fileRegister map[string]bool
- readerRegister map[string]func() io.Reader
+ fileRegister map[string]bool
+ fileRegisterLock sync.RWMutex
+ readerRegister map[string]func() io.Reader
+ readerRegisterLock sync.RWMutex
)
// RegisterLocalFile adds the given file to the file whitelist,
@@ -32,17 +35,21 @@ var (
// ...
//
func RegisterLocalFile(filePath string) {
+ fileRegisterLock.Lock()
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
}
fileRegister[strings.Trim(filePath, `"`)] = true
+ fileRegisterLock.Unlock()
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
+ fileRegisterLock.Lock()
delete(fileRegister, strings.Trim(filePath, `"`))
+ fileRegisterLock.Unlock()
}
// RegisterReaderHandler registers a handler function which is used
@@ -61,18 +68,22 @@ func DeregisterLocalFile(filePath string) {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
+ readerRegisterLock.Lock()
// lazy map init
if readerRegister == nil {
readerRegister = make(map[string]func() io.Reader)
}
readerRegister[name] = handler
+ readerRegisterLock.Unlock()
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
+ readerRegisterLock.Lock()
delete(readerRegister, name)
+ readerRegisterLock.Unlock()
}
func deferredClose(err *error, closer io.Closer) {
@@ -85,14 +96,22 @@ func deferredClose(err *error, closer io.Closer) {
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
+ packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
+ if mc.maxWriteSize < packetSize {
+ packetSize = mc.maxWriteSize
+ }
+
+ if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
+ // The server might return an an absolute path. See issue #355.
+ name = name[idx+8:]
+
+ readerRegisterLock.RLock()
+ handler, inMap := readerRegister[name]
+ readerRegisterLock.RUnlock()
- if strings.HasPrefix(name, "Reader::") { // io.Reader
- name = name[8:]
- if handler, inMap := readerRegister[name]; inMap {
+ if inMap {
rdr = handler()
if rdr != nil {
- data = make([]byte, 4+mc.maxWriteSize)
-
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
@@ -104,7 +123,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
}
} else { // File
name = strings.Trim(name, `"`)
- if mc.cfg.allowAllFiles || fileRegister[name] {
+ fileRegisterLock.RLock()
+ fr := fileRegister[name]
+ fileRegisterLock.RUnlock()
+ if mc.cfg.AllowAllFiles || fr {
var file *os.File
var fi os.FileInfo
@@ -114,22 +136,19 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
- if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
- data = make([]byte, 4+fileSize)
- } else if fileSize <= mc.maxPacketAllowed {
- data = make([]byte, 4+mc.maxWriteSize)
- } else {
- err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
+ if fileSize := int(fi.Size()); fileSize < packetSize {
+ packetSize = fileSize
}
}
}
} else {
- err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
+ err = fmt.Errorf("local file '%s' is not registered", name)
}
}
// send content packets
if err == nil {
+ data := make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
@@ -154,9 +173,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
// read OK packet
if err == nil {
- return mc.readResultOK()
- } else {
- mc.readPacket()
+ _, err = mc.readResultOK()
+ return err
}
+
+ mc.readPacket()
return err
}