498 lines
13 KiB
Go
498 lines
13 KiB
Go
|
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net"
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
"testing/iotest"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var _ net.Error = errWriteTimeout
|
||
|
|
||
|
type fakeNetConn struct {
|
||
|
io.Reader
|
||
|
io.Writer
|
||
|
}
|
||
|
|
||
|
func (c fakeNetConn) Close() error { return nil }
|
||
|
func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
|
||
|
func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
|
||
|
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
|
||
|
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||
|
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||
|
|
||
|
type fakeAddr int
|
||
|
|
||
|
var (
|
||
|
localAddr = fakeAddr(1)
|
||
|
remoteAddr = fakeAddr(2)
|
||
|
)
|
||
|
|
||
|
func (a fakeAddr) Network() string {
|
||
|
return "net"
|
||
|
}
|
||
|
|
||
|
func (a fakeAddr) String() string {
|
||
|
return "str"
|
||
|
}
|
||
|
|
||
|
func TestFraming(t *testing.T) {
|
||
|
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
||
|
var readChunkers = []struct {
|
||
|
name string
|
||
|
f func(io.Reader) io.Reader
|
||
|
}{
|
||
|
{"half", iotest.HalfReader},
|
||
|
{"one", iotest.OneByteReader},
|
||
|
{"asis", func(r io.Reader) io.Reader { return r }},
|
||
|
}
|
||
|
writeBuf := make([]byte, 65537)
|
||
|
for i := range writeBuf {
|
||
|
writeBuf[i] = byte(i)
|
||
|
}
|
||
|
var writers = []struct {
|
||
|
name string
|
||
|
f func(w io.Writer, n int) (int, error)
|
||
|
}{
|
||
|
{"iocopy", func(w io.Writer, n int) (int, error) {
|
||
|
nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
|
||
|
return int(nn), err
|
||
|
}},
|
||
|
{"write", func(w io.Writer, n int) (int, error) {
|
||
|
return w.Write(writeBuf[:n])
|
||
|
}},
|
||
|
{"string", func(w io.Writer, n int) (int, error) {
|
||
|
return io.WriteString(w, string(writeBuf[:n]))
|
||
|
}},
|
||
|
}
|
||
|
|
||
|
for _, compress := range []bool{false, true} {
|
||
|
for _, isServer := range []bool{true, false} {
|
||
|
for _, chunker := range readChunkers {
|
||
|
|
||
|
var connBuf bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
||
|
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
|
||
|
if compress {
|
||
|
wc.newCompressionWriter = compressNoContextTakeover
|
||
|
rc.newDecompressionReader = decompressNoContextTakeover
|
||
|
}
|
||
|
for _, n := range frameSizes {
|
||
|
for _, writer := range writers {
|
||
|
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
|
||
|
|
||
|
w, err := wc.NextWriter(TextMessage)
|
||
|
if err != nil {
|
||
|
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
nn, err := writer.f(w, n)
|
||
|
if err != nil || nn != n {
|
||
|
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
|
||
|
continue
|
||
|
}
|
||
|
err = w.Close()
|
||
|
if err != nil {
|
||
|
t.Errorf("%s: w.Close() returned %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
opCode, r, err := rc.NextReader()
|
||
|
if err != nil || opCode != TextMessage {
|
||
|
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
|
||
|
continue
|
||
|
}
|
||
|
rbuf, err := ioutil.ReadAll(r)
|
||
|
if err != nil {
|
||
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if len(rbuf) != n {
|
||
|
t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
for i, b := range rbuf {
|
||
|
if byte(i) != b {
|
||
|
t.Errorf("%s: bad byte at offset %d", name, i)
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestControl(t *testing.T) {
|
||
|
const message = "this is a ping/pong messsage"
|
||
|
for _, isServer := range []bool{true, false} {
|
||
|
for _, isWriteControl := range []bool{true, false} {
|
||
|
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
|
||
|
var connBuf bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
||
|
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
|
||
|
if isWriteControl {
|
||
|
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
||
|
} else {
|
||
|
w, err := wc.NextWriter(PongMessage)
|
||
|
if err != nil {
|
||
|
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
if _, err := w.Write([]byte(message)); err != nil {
|
||
|
t.Errorf("%s: w.Write() returned %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
if err := w.Close(); err != nil {
|
||
|
t.Errorf("%s: w.Close() returned %v", name, err)
|
||
|
continue
|
||
|
}
|
||
|
var actualMessage string
|
||
|
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
||
|
rc.NextReader()
|
||
|
if actualMessage != message {
|
||
|
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||
|
const bufSize = 512
|
||
|
|
||
|
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||
|
|
||
|
var b1, b2 bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||
|
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||
|
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
||
|
w.Close()
|
||
|
|
||
|
op, r, err := rc.NextReader()
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||
|
}
|
||
|
_, err = io.Copy(ioutil.Discard, r)
|
||
|
if !reflect.DeepEqual(err, expectedErr) {
|
||
|
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||
|
}
|
||
|
_, _, err = rc.NextReader()
|
||
|
if !reflect.DeepEqual(err, expectedErr) {
|
||
|
t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEOFWithinFrame(t *testing.T) {
|
||
|
const bufSize = 64
|
||
|
|
||
|
for n := 0; ; n++ {
|
||
|
var b bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
|
||
|
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
|
||
|
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
w.Write(make([]byte, bufSize))
|
||
|
w.Close()
|
||
|
|
||
|
if n >= b.Len() {
|
||
|
break
|
||
|
}
|
||
|
b.Truncate(n)
|
||
|
|
||
|
op, r, err := rc.NextReader()
|
||
|
if err == errUnexpectedEOF {
|
||
|
continue
|
||
|
}
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||
|
}
|
||
|
_, err = io.Copy(ioutil.Discard, r)
|
||
|
if err != errUnexpectedEOF {
|
||
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||
|
}
|
||
|
_, _, err = rc.NextReader()
|
||
|
if err != errUnexpectedEOF {
|
||
|
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||
|
const bufSize = 512
|
||
|
|
||
|
var b1, b2 bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||
|
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||
|
|
||
|
op, r, err := rc.NextReader()
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||
|
}
|
||
|
_, err = io.Copy(ioutil.Discard, r)
|
||
|
if err != errUnexpectedEOF {
|
||
|
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||
|
}
|
||
|
_, _, err = rc.NextReader()
|
||
|
if err != errUnexpectedEOF {
|
||
|
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
io.WriteString(w, "hello")
|
||
|
if err := w.Close(); err != nil {
|
||
|
t.Fatalf("unxpected error closing message writer, %v", err)
|
||
|
}
|
||
|
|
||
|
if _, err := io.WriteString(w, "world"); err == nil {
|
||
|
t.Fatalf("no error writing after close")
|
||
|
}
|
||
|
|
||
|
w, _ = wc.NextWriter(BinaryMessage)
|
||
|
io.WriteString(w, "hello")
|
||
|
|
||
|
// close w by getting next writer
|
||
|
_, err := wc.NextWriter(BinaryMessage)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unexpected error getting next writer, %v", err)
|
||
|
}
|
||
|
|
||
|
if _, err := io.WriteString(w, "world"); err == nil {
|
||
|
t.Fatalf("no error writing after close")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReadLimit(t *testing.T) {
|
||
|
|
||
|
const readLimit = 512
|
||
|
message := make([]byte, readLimit+1)
|
||
|
|
||
|
var b1, b2 bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
|
||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||
|
rc.SetReadLimit(readLimit)
|
||
|
|
||
|
// Send message at the limit with interleaved pong.
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
w.Write(message[:readLimit-1])
|
||
|
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
||
|
w.Write(message[:1])
|
||
|
w.Close()
|
||
|
|
||
|
// Send message larger than the limit.
|
||
|
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
||
|
|
||
|
op, _, err := rc.NextReader()
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("1: NextReader() returned %d, %v", op, err)
|
||
|
}
|
||
|
op, r, err := rc.NextReader()
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||
|
}
|
||
|
_, err = io.Copy(ioutil.Discard, r)
|
||
|
if err != ErrReadLimit {
|
||
|
t.Fatalf("io.Copy() returned %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAddrs(t *testing.T) {
|
||
|
c := newConn(&fakeNetConn{}, true, 1024, 1024)
|
||
|
if c.LocalAddr() != localAddr {
|
||
|
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||
|
}
|
||
|
if c.RemoteAddr() != remoteAddr {
|
||
|
t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestUnderlyingConn(t *testing.T) {
|
||
|
var b1, b2 bytes.Buffer
|
||
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||
|
c := newConn(fc, true, 1024, 1024)
|
||
|
ul := c.UnderlyingConn()
|
||
|
if ul != fc {
|
||
|
t.Fatalf("Underlying conn is not what it should be.")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestBufioReadBytes(t *testing.T) {
|
||
|
|
||
|
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
||
|
|
||
|
m := make([]byte, 512)
|
||
|
m[len(m)-1] = '\n'
|
||
|
|
||
|
var b1, b2 bytes.Buffer
|
||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
|
||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
|
||
|
|
||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||
|
w.Write(m)
|
||
|
w.Close()
|
||
|
|
||
|
op, r, err := rc.NextReader()
|
||
|
if op != BinaryMessage || err != nil {
|
||
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||
|
}
|
||
|
|
||
|
br := bufio.NewReader(r)
|
||
|
p, err := br.ReadBytes('\n')
|
||
|
if err != nil {
|
||
|
t.Fatalf("ReadBytes() returned %v", err)
|
||
|
}
|
||
|
if len(p) != len(m) {
|
||
|
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var closeErrorTests = []struct {
|
||
|
err error
|
||
|
codes []int
|
||
|
ok bool
|
||
|
}{
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
|
||
|
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||
|
}
|
||
|
|
||
|
func TestCloseError(t *testing.T) {
|
||
|
for _, tt := range closeErrorTests {
|
||
|
ok := IsCloseError(tt.err, tt.codes...)
|
||
|
if ok != tt.ok {
|
||
|
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var unexpectedCloseErrorTests = []struct {
|
||
|
err error
|
||
|
codes []int
|
||
|
ok bool
|
||
|
}{
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
|
||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
|
||
|
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||
|
}
|
||
|
|
||
|
func TestUnexpectedCloseErrors(t *testing.T) {
|
||
|
for _, tt := range unexpectedCloseErrorTests {
|
||
|
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
||
|
if ok != tt.ok {
|
||
|
t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type blockingWriter struct {
|
||
|
c1, c2 chan struct{}
|
||
|
}
|
||
|
|
||
|
func (w blockingWriter) Write(p []byte) (int, error) {
|
||
|
// Allow main to continue
|
||
|
close(w.c1)
|
||
|
// Wait for panic in main
|
||
|
<-w.c2
|
||
|
return len(p), nil
|
||
|
}
|
||
|
|
||
|
func TestConcurrentWritePanic(t *testing.T) {
|
||
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||
|
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||
|
go func() {
|
||
|
c.WriteMessage(TextMessage, []byte{})
|
||
|
}()
|
||
|
|
||
|
// wait for goroutine to block in write.
|
||
|
<-w.c1
|
||
|
|
||
|
defer func() {
|
||
|
close(w.c2)
|
||
|
if v := recover(); v != nil {
|
||
|
return
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
c.WriteMessage(TextMessage, []byte{})
|
||
|
t.Fatal("should not get here")
|
||
|
}
|
||
|
|
||
|
type failingReader struct{}
|
||
|
|
||
|
func (r failingReader) Read(p []byte) (int, error) {
|
||
|
return 0, io.EOF
|
||
|
}
|
||
|
|
||
|
func TestFailedConnectionReadPanic(t *testing.T) {
|
||
|
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
|
||
|
|
||
|
defer func() {
|
||
|
if v := recover(); v != nil {
|
||
|
return
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
for i := 0; i < 20000; i++ {
|
||
|
c.ReadMessage()
|
||
|
}
|
||
|
t.Fatal("should not get here")
|
||
|
}
|
||
|
|
||
|
func TestBufioReuse(t *testing.T) {
|
||
|
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
|
||
|
c := newConnBRW(nil, false, 0, 0, brw)
|
||
|
|
||
|
if c.br != brw.Reader {
|
||
|
t.Error("connection did not reuse bufio.Reader")
|
||
|
}
|
||
|
|
||
|
var wh writeHook
|
||
|
brw.Writer.Reset(&wh)
|
||
|
brw.WriteByte(0)
|
||
|
brw.Flush()
|
||
|
if &c.writeBuf[0] != &wh.p[0] {
|
||
|
t.Error("connection did not reuse bufio.Writer")
|
||
|
}
|
||
|
|
||
|
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
|
||
|
c = newConnBRW(nil, false, 0, 0, brw)
|
||
|
|
||
|
if c.br == brw.Reader {
|
||
|
t.Error("connection used bufio.Reader with small size")
|
||
|
}
|
||
|
|
||
|
brw.Writer.Reset(&wh)
|
||
|
brw.WriteByte(0)
|
||
|
brw.Flush()
|
||
|
if &c.writeBuf[0] != &wh.p[0] {
|
||
|
t.Error("connection used bufio.Writer with small size")
|
||
|
}
|
||
|
|
||
|
}
|