// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "compress/flate" "errors" "io" "strings" "sync" ) var ( flateWriterPool = sync.Pool{} ) func decompressNoContextTakeover(r io.Reader) io.Reader { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) } func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { tw := &truncWriter{w: w} i := flateWriterPool.Get() var fw *flate.Writer var err error if i == nil { fw, err = flate.NewWriter(tw, 3) } else { fw = i.(*flate.Writer) fw.Reset(tw) } return &flateWrapper{fw: fw, tw: tw}, err } // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. type truncWriter struct { w io.WriteCloser n int p [4]byte } func (w *truncWriter) Write(p []byte) (int, error) { n := 0 // fill buffer first for simplicity. if w.n < len(w.p) { n = copy(w.p[w.n:], p) p = p[n:] w.n += n if len(p) == 0 { return n, nil } } m := len(p) if m > len(w.p) { m = len(w.p) } if nn, err := w.w.Write(w.p[:m]); err != nil { return n + nn, err } copy(w.p[:], w.p[m:]) copy(w.p[len(w.p)-m:], p[len(p)-m:]) nn, err := w.w.Write(p[:len(p)-m]) return n + nn, err } type flateWrapper struct { fw *flate.Writer tw *truncWriter } func (w *flateWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } return w.fw.Write(p) } func (w *flateWrapper) Close() error { if w.fw == nil { return errWriteClosed } err1 := w.fw.Flush() flateWriterPool.Put(w.fw) w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } err2 := w.tw.w.Close() if err1 != nil { return err1 } return err2 }