nmdc-webfrontend/vendor/github.com/googollee/go-engine.io/websocket/websocket_test.go

459 lines
9.0 KiB
Go

package websocket
import (
"encoding/hex"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
"github.com/googollee/go-engine.io/transport"
"github.com/gorilla/websocket"
"github.com/googollee/go-engine.io/message"
"github.com/googollee/go-engine.io/parser"
. "github.com/smartystreets/goconvey/convey"
)
func TestWebsocket(t *testing.T) {
Convey("Creater", t, func() {
So(Creater.Name, ShouldEqual, "websocket")
So(Creater.Server, ShouldEqual, NewServer)
So(Creater.Client, ShouldEqual, NewClient)
})
Convey("Normal work, server part", t, func() {
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f := newFakeCallback()
s, err := NewServer(w, r, f)
if err != nil {
t.Fatal(err)
}
defer s.Close()
{
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
recoder := httptest.NewRecorder()
s.ServeHTTP(recoder, req)
if recoder.Code != http.StatusBadRequest {
t.Fatal(recoder.Code, "!=", http.StatusBadRequest)
}
}
{
w, err := s.NextWriter(message.MessageText, parser.OPEN)
if err != nil {
t.Fatal(err)
}
err = w.Close()
if err != nil {
t.Fatal(err)
}
}
{
<-f.onPacket
if f.messageType != message.MessageBinary {
t.Fatal(f.messageType, "!=", message.MessageBinary)
}
if f.packetType != parser.MESSAGE {
t.Fatal(f.packetType, "!=", parser.MESSAGE)
}
if f.err != nil {
t.Fatal(err)
}
if body := string(f.body); body != "测试" {
t.Fatal(body, "!=", "测试")
}
}
<-sync
sync <- 1
<-sync
sync <- 1
{
w, err := s.NextWriter(message.MessageBinary, parser.NOOP)
if err != nil {
t.Fatal(err)
}
err = w.Close()
if err != nil {
t.Fatal(err)
}
}
<-sync
sync <- 1
{
<-f.onPacket
if f.messageType != message.MessageText {
t.Fatal(f.messageType, "!=", message.MessageText)
}
if f.packetType != parser.MESSAGE {
t.Fatal(f.packetType, "!=", parser.MESSAGE)
}
if f.err != nil {
t.Fatal(err)
}
if body := hex.EncodeToString(f.body); body != "e697a5e69cace8aa9e" {
t.Fatal(body, "!=", "e697a5e69cace8aa9e")
}
}
<-sync
sync <- 1
}))
defer server.Close()
u, _ := url.Parse(server.URL)
u.Scheme = "ws"
req, err := http.NewRequest("GET", u.String(), nil)
So(err, ShouldBeNil)
c, _ := NewClient(req)
defer c.Close()
{
w, _ := c.NextWriter(message.MessageBinary, parser.MESSAGE)
w.Write([]byte("测试"))
w.Close()
}
sync <- 1
<-sync
{
decoder, _ := c.NextReader()
defer decoder.Close()
ioutil.ReadAll(decoder)
}
sync <- 1
<-sync
{
decoder, _ := c.NextReader()
defer decoder.Close()
ioutil.ReadAll(decoder)
}
sync <- 1
<-sync
{
w, _ := c.NextWriter(message.MessageText, parser.MESSAGE)
w.Write([]byte("日本語"))
w.Close()
}
sync <- 1
<-sync
})
Convey("Normal work, client part", t, func() {
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f := newFakeCallback()
if v := r.URL.Query().Get("key"); v != "value" {
t.Fatal(v, "!=", "value")
}
s, _ := NewServer(w, r, f)
defer s.Close()
{
w, _ := s.NextWriter(message.MessageText, parser.OPEN)
w.Close()
}
{
<-f.onPacket
}
<-sync
sync <- 1
<-sync
sync <- 1
{
w, _ := s.NextWriter(message.MessageBinary, parser.NOOP)
w.Close()
}
<-sync
sync <- 1
{
<-f.onPacket
}
<-sync
sync <- 1
}))
defer server.Close()
u, err := url.Parse(server.URL)
So(err, ShouldBeNil)
u.Scheme = "ws"
req, err := http.NewRequest("GET", u.String()+"/?key=value", nil)
So(err, ShouldBeNil)
c, err := NewClient(req)
So(err, ShouldBeNil)
defer c.Close()
So(c.Response(), ShouldNotBeNil)
So(c.Response().StatusCode, ShouldEqual, http.StatusSwitchingProtocols)
{
w, err := c.NextWriter(message.MessageBinary, parser.MESSAGE)
So(err, ShouldBeNil)
_, err = w.Write([]byte("测试"))
So(err, ShouldBeNil)
err = w.Close()
So(err, ShouldBeNil)
}
sync <- 1
<-sync
{
decoder, err := c.NextReader()
So(err, ShouldBeNil)
defer decoder.Close()
So(decoder.MessageType(), ShouldEqual, message.MessageText)
So(decoder.Type(), ShouldEqual, parser.OPEN)
b, err := ioutil.ReadAll(decoder)
So(err, ShouldBeNil)
So(string(b), ShouldEqual, "")
}
sync <- 1
<-sync
{
decoder, err := c.NextReader()
So(err, ShouldBeNil)
defer decoder.Close()
So(decoder.MessageType(), ShouldEqual, message.MessageBinary)
So(decoder.Type(), ShouldEqual, parser.NOOP)
b, err := ioutil.ReadAll(decoder)
So(err, ShouldBeNil)
So(string(b), ShouldEqual, "")
}
sync <- 1
<-sync
{
w, err := c.NextWriter(message.MessageText, parser.MESSAGE)
So(err, ShouldBeNil)
_, err = w.Write([]byte("日本語"))
So(err, ShouldBeNil)
err = w.Close()
So(err, ShouldBeNil)
}
sync <- 1
<-sync
})
Convey("Packet content", t, func() {
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f := newFakeCallback()
s, _ := NewServer(w, r, f)
defer s.Close()
{
w, _ := s.NextWriter(message.MessageText, parser.MESSAGE)
w.Write([]byte("日本語"))
w.Close()
}
sync <- 1
<-sync
}))
defer server.Close()
u, err := url.Parse(server.URL)
So(err, ShouldBeNil)
u.Scheme = "ws"
req, err := http.NewRequest("GET", u.String(), nil)
So(err, ShouldBeNil)
c, err := NewClient(req)
So(err, ShouldBeNil)
defer c.Close()
{
client := c.(*client)
t, r, err := client.conn.NextReader()
So(err, ShouldBeNil)
So(t, ShouldEqual, websocket.TextMessage)
b, err := ioutil.ReadAll(r)
So(err, ShouldBeNil)
So(string(b), ShouldEqual, "4日本語")
So(hex.EncodeToString(b), ShouldEqual, "34e697a5e69cace8aa9e")
}
<-sync
sync <- 1
})
Convey("Close", t, func() {
f := newFakeCallback()
var s transport.Server
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s, _ = NewServer(w, r, f)
s.Close()
s.Close()
s.Close()
sync <- 1
}))
defer server.Close()
u, err := url.Parse(server.URL)
So(err, ShouldBeNil)
u.Scheme = "ws"
req, err := http.NewRequest("GET", u.String(), nil)
So(err, ShouldBeNil)
c, err := NewClient(req)
So(err, ShouldBeNil)
defer c.Close()
<-sync
waitForClose(f)
So(f.ClosedCount(), ShouldEqual, 1)
So(f.closeServer, ShouldEqual, s)
})
Convey("Closing by disconnected", t, func() {
f := newFakeCallback()
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s, _ := NewServer(w, r, f)
server := s.(*Server)
server.conn.Close()
sync <- 1
}))
defer server.Close()
u, err := url.Parse(server.URL)
So(err, ShouldBeNil)
u.Scheme = "ws"
req, err := http.NewRequest("GET", u.String(), nil)
So(err, ShouldBeNil)
c, err := NewClient(req)
So(err, ShouldBeNil)
defer c.Close()
<-sync
waitForClose(f)
So(f.ClosedCount(), ShouldEqual, 1)
})
Convey("Closing writer after closed", t, func() {
f := newFakeCallback()
sync := make(chan int)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s, err := NewServer(w, r, f)
if err != nil {
t.Fatal(err)
}
writer, err := s.NextWriter(message.MessageText, parser.MESSAGE)
if err != nil {
t.Fatal(err)
}
err = s.Close()
if err != nil {
t.Fatal(err)
}
err = writer.Close()
if err == nil {
t.Fatal("err should not be nil")
}
sync <- 1
}))
defer server.Close()
u, _ := url.Parse(server.URL)
u.Scheme = "ws"
req, _ := http.NewRequest("GET", u.String(), nil)
c, _ := NewClient(req)
defer c.Close()
<-sync
})
}
func waitForClose(f *fakeCallback) {
timeout := time.After(5 * time.Second)
var closed bool
select {
case <-f.closedChan:
closed = true
case <-timeout:
}
So(closed, ShouldBeTrue)
}
type fakeCallback struct {
onPacket chan bool
messageType message.MessageType
packetType parser.PacketType
body []byte
err error
closedCount int
countLocker sync.Mutex
closeServer transport.Server
closedChan chan struct{}
}
func newFakeCallback() *fakeCallback {
return &fakeCallback{
onPacket: make(chan bool),
closedChan: make(chan struct{}),
}
}
func (f *fakeCallback) OnPacket(r *parser.PacketDecoder) {
f.packetType = r.Type()
f.messageType = r.MessageType()
f.body, f.err = ioutil.ReadAll(r)
f.onPacket <- true
}
func (f *fakeCallback) OnClose(s transport.Server) {
f.countLocker.Lock()
defer f.countLocker.Unlock()
f.closedCount++
f.closeServer = s
if f.closedCount == 1 {
close(f.closedChan)
}
}
func (f *fakeCallback) ClosedCount() int {
f.countLocker.Lock()
defer f.countLocker.Unlock()
return f.closedCount
}