diff options
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/infile.go')
-rw-r--r-- | vendor/github.com/go-sql-driver/mysql/infile.go | 56 |
1 files changed, 38 insertions, 18 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 121a04c71..547357cfa 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -13,11 +13,14 @@ import ( "io" "os" "strings" + "sync" ) var ( - fileRegister map[string]bool - readerRegister map[string]func() io.Reader + fileRegister map[string]bool + fileRegisterLock sync.RWMutex + readerRegister map[string]func() io.Reader + readerRegisterLock sync.RWMutex ) // RegisterLocalFile adds the given file to the file whitelist, @@ -32,17 +35,21 @@ var ( // ... // func RegisterLocalFile(filePath string) { + fileRegisterLock.Lock() // lazy map init if fileRegister == nil { fileRegister = make(map[string]bool) } fileRegister[strings.Trim(filePath, `"`)] = true + fileRegisterLock.Unlock() } // DeregisterLocalFile removes the given filepath from the whitelist. func DeregisterLocalFile(filePath string) { + fileRegisterLock.Lock() delete(fileRegister, strings.Trim(filePath, `"`)) + fileRegisterLock.Unlock() } // RegisterReaderHandler registers a handler function which is used @@ -61,18 +68,22 @@ func DeregisterLocalFile(filePath string) { // ... // func RegisterReaderHandler(name string, handler func() io.Reader) { + readerRegisterLock.Lock() // lazy map init if readerRegister == nil { readerRegister = make(map[string]func() io.Reader) } readerRegister[name] = handler + readerRegisterLock.Unlock() } // DeregisterReaderHandler removes the ReaderHandler function with // the given name from the registry. func DeregisterReaderHandler(name string) { + readerRegisterLock.Lock() delete(readerRegister, name) + readerRegisterLock.Unlock() } func deferredClose(err *error, closer io.Closer) { @@ -85,14 +96,22 @@ func deferredClose(err *error, closer io.Closer) { func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte + packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } + + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader + // The server might return an an absolute path. See issue #355. + name = name[idx+8:] + + readerRegisterLock.RLock() + handler, inMap := readerRegister[name] + readerRegisterLock.RUnlock() - if strings.HasPrefix(name, "Reader::") { // io.Reader - name = name[8:] - if handler, inMap := readerRegister[name]; inMap { + if inMap { rdr = handler() if rdr != nil { - data = make([]byte, 4+mc.maxWriteSize) - if cl, ok := rdr.(io.Closer); ok { defer deferredClose(&err, cl) } @@ -104,7 +123,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } } else { // File name = strings.Trim(name, `"`) - if mc.cfg.allowAllFiles || fileRegister[name] { + fileRegisterLock.RLock() + fr := fileRegister[name] + fileRegisterLock.RUnlock() + if mc.cfg.AllowAllFiles || fr { var file *os.File var fi os.FileInfo @@ -114,22 +136,19 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // get file size if fi, err = file.Stat(); err == nil { rdr = file - if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize { - data = make([]byte, 4+fileSize) - } else if fileSize <= mc.maxPacketAllowed { - data = make([]byte, 4+mc.maxWriteSize) - } else { - err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed) + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize } } } } else { - err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name) + err = fmt.Errorf("local file '%s' is not registered", name) } } // send content packets if err == nil { + data := make([]byte, 4+packetSize) var n int for err == nil { n, err = rdr.Read(data[4:]) @@ -154,9 +173,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { - return mc.readResultOK() - } else { - mc.readPacket() + _, err = mc.readResultOK() + return err } + + mc.readPacket() return err } |