nmdc-webfrontend/vendor/github.com/googollee/go-engine.io/server_conn.go

389 lines
8.6 KiB
Go

package engineio
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/googollee/go-engine.io/message"
"github.com/googollee/go-engine.io/parser"
"github.com/googollee/go-engine.io/transport"
)
type MessageType message.MessageType
const (
MessageBinary MessageType = MessageType(message.MessageBinary)
MessageText MessageType = MessageType(message.MessageText)
)
// Conn is the connection object of engine.io.
type Conn interface {
// Id returns the session id of connection.
Id() string
// Request returns the first http request when established connection.
Request() *http.Request
// Close closes the connection.
Close() error
// NextReader returns the next message type, reader. If no message received, it will block.
NextReader() (MessageType, io.ReadCloser, error)
// NextWriter returns the next message writer with given message type.
NextWriter(messageType MessageType) (io.WriteCloser, error)
}
type transportCreaters map[string]transport.Creater
func (c transportCreaters) Get(name string) transport.Creater {
return c[name]
}
type serverCallback interface {
configure() config
transports() transportCreaters
onClose(sid string)
}
type state int
const (
stateUnknow state = iota
stateNormal
stateUpgrading
stateClosing
stateClosed
)
type serverConn struct {
id string
request *http.Request
callback serverCallback
writerLocker sync.Mutex
transportLocker sync.RWMutex
currentName string
current transport.Server
upgradingName string
upgrading transport.Server
state state
stateLocker sync.RWMutex
readerChan chan *connReader
pingTimeout time.Duration
pingInterval time.Duration
pingChan chan bool
pingLocker sync.Mutex
}
var InvalidError = errors.New("invalid transport")
func newServerConn(id string, w http.ResponseWriter, r *http.Request, callback serverCallback) (*serverConn, error) {
transportName := r.URL.Query().Get("transport")
creater := callback.transports().Get(transportName)
if creater.Name == "" {
return nil, InvalidError
}
ret := &serverConn{
id: id,
request: r,
callback: callback,
state: stateNormal,
readerChan: make(chan *connReader),
pingTimeout: callback.configure().PingTimeout,
pingInterval: callback.configure().PingInterval,
pingChan: make(chan bool),
}
transport, err := creater.Server(w, r, ret)
if err != nil {
return nil, err
}
ret.setCurrent(transportName, transport)
if err := ret.onOpen(); err != nil {
return nil, err
}
go ret.pingLoop()
return ret, nil
}
func (c *serverConn) Id() string {
return c.id
}
func (c *serverConn) Request() *http.Request {
return c.request
}
func (c *serverConn) NextReader() (MessageType, io.ReadCloser, error) {
if c.getState() == stateClosed {
return MessageBinary, nil, io.EOF
}
ret := <-c.readerChan
if ret == nil {
return MessageBinary, nil, io.EOF
}
return MessageType(ret.MessageType()), ret, nil
}
func (c *serverConn) NextWriter(t MessageType) (io.WriteCloser, error) {
switch c.getState() {
case stateUpgrading:
for i := 0; i < 30; i++ {
time.Sleep(50 * time.Millisecond)
if c.getState() != stateUpgrading {
break
}
}
if c.getState() == stateUpgrading {
return nil, fmt.Errorf("upgrading")
}
case stateNormal:
default:
return nil, io.EOF
}
c.writerLocker.Lock()
ret, err := c.getCurrent().NextWriter(message.MessageType(t), parser.MESSAGE)
if err != nil {
c.writerLocker.Unlock()
return ret, err
}
writer := newConnWriter(ret, &c.writerLocker)
return writer, err
}
func (c *serverConn) Close() error {
if c.getState() != stateNormal && c.getState() != stateUpgrading {
return nil
}
if c.upgrading != nil {
c.upgrading.Close()
}
c.writerLocker.Lock()
if w, err := c.getCurrent().NextWriter(message.MessageText, parser.CLOSE); err == nil {
writer := newConnWriter(w, &c.writerLocker)
writer.Close()
} else {
c.writerLocker.Unlock()
}
if err := c.getCurrent().Close(); err != nil {
return err
}
c.setState(stateClosing)
return nil
}
func (c *serverConn) ServeHTTP(w http.ResponseWriter, r *http.Request) {
transportName := r.URL.Query().Get("transport")
if c.currentName != transportName {
creater := c.callback.transports().Get(transportName)
if creater.Name == "" {
http.Error(w, fmt.Sprintf("invalid transport %s", transportName), http.StatusBadRequest)
return
}
u, err := creater.Server(w, r, c)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
c.setUpgrading(creater.Name, u)
return
}
c.current.ServeHTTP(w, r)
}
func (c *serverConn) OnPacket(r *parser.PacketDecoder) {
if s := c.getState(); s != stateNormal && s != stateUpgrading {
return
}
switch r.Type() {
case parser.OPEN:
case parser.CLOSE:
c.getCurrent().Close()
case parser.PING:
c.writerLocker.Lock()
t := c.getCurrent()
u := c.getUpgrade()
newWriter := t.NextWriter
if u != nil {
if w, _ := t.NextWriter(message.MessageText, parser.NOOP); w != nil {
w.Close()
}
newWriter = u.NextWriter
}
if w, _ := newWriter(message.MessageText, parser.PONG); w != nil {
io.Copy(w, r)
w.Close()
}
c.writerLocker.Unlock()
fallthrough
case parser.PONG:
c.pingLocker.Lock()
defer c.pingLocker.Unlock()
if s := c.getState(); s != stateNormal && s != stateUpgrading {
return
}
c.pingChan <- true
case parser.MESSAGE:
closeChan := make(chan struct{})
c.readerChan <- newConnReader(r, closeChan)
<-closeChan
close(closeChan)
r.Close()
case parser.UPGRADE:
c.upgraded()
case parser.NOOP:
}
}
func (c *serverConn) OnClose(server transport.Server) {
if t := c.getUpgrade(); server == t {
c.setUpgrading("", nil)
t.Close()
return
}
t := c.getCurrent()
if server != t {
return
}
t.Close()
if t := c.getUpgrade(); t != nil {
t.Close()
c.setUpgrading("", nil)
}
c.setState(stateClosed)
close(c.readerChan)
c.pingLocker.Lock()
close(c.pingChan)
c.pingLocker.Unlock()
c.callback.onClose(c.id)
}
func (s *serverConn) onOpen() error {
upgrades := []string{}
for name := range s.callback.transports() {
if name == s.currentName {
continue
}
upgrades = append(upgrades, name)
}
type connectionInfo struct {
Sid string `json:"sid"`
Upgrades []string `json:"upgrades"`
PingInterval time.Duration `json:"pingInterval"`
PingTimeout time.Duration `json:"pingTimeout"`
}
resp := connectionInfo{
Sid: s.Id(),
Upgrades: upgrades,
PingInterval: s.callback.configure().PingInterval / time.Millisecond,
PingTimeout: s.callback.configure().PingTimeout / time.Millisecond,
}
w, err := s.getCurrent().NextWriter(message.MessageText, parser.OPEN)
if err != nil {
return err
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(resp); err != nil {
return err
}
if err := w.Close(); err != nil {
return err
}
return nil
}
func (c *serverConn) getCurrent() transport.Server {
c.transportLocker.RLock()
defer c.transportLocker.RUnlock()
return c.current
}
func (c *serverConn) getUpgrade() transport.Server {
c.transportLocker.RLock()
defer c.transportLocker.RUnlock()
return c.upgrading
}
func (c *serverConn) setCurrent(name string, s transport.Server) {
c.transportLocker.Lock()
defer c.transportLocker.Unlock()
c.currentName = name
c.current = s
}
func (c *serverConn) setUpgrading(name string, s transport.Server) {
c.transportLocker.Lock()
defer c.transportLocker.Unlock()
c.upgradingName = name
c.upgrading = s
c.setState(stateUpgrading)
}
func (c *serverConn) upgraded() {
c.transportLocker.Lock()
current := c.current
c.current = c.upgrading
c.currentName = c.upgradingName
c.upgrading = nil
c.upgradingName = ""
c.transportLocker.Unlock()
current.Close()
c.setState(stateNormal)
}
func (c *serverConn) getState() state {
c.stateLocker.RLock()
defer c.stateLocker.RUnlock()
return c.state
}
func (c *serverConn) setState(state state) {
c.stateLocker.Lock()
defer c.stateLocker.Unlock()
c.state = state
}
func (c *serverConn) pingLoop() {
lastPing := time.Now()
lastTry := lastPing
for {
now := time.Now()
pingDiff := now.Sub(lastPing)
tryDiff := now.Sub(lastTry)
select {
case ok := <-c.pingChan:
if !ok {
return
}
lastPing = time.Now()
lastTry = lastPing
case <-time.After(c.pingInterval - tryDiff):
c.writerLocker.Lock()
if w, _ := c.getCurrent().NextWriter(message.MessageText, parser.PING); w != nil {
writer := newConnWriter(w, &c.writerLocker)
writer.Close()
} else {
c.writerLocker.Unlock()
}
lastTry = time.Now()
case <-time.After(c.pingTimeout - pingDiff):
c.Close()
return
}
}
}