nmdc-webfrontend/vendor/github.com/googollee/go-engine.io/server.go

189 lines
4.8 KiB
Go

package engineio
import (
"bytes"
"crypto/md5"
"encoding/base64"
"fmt"
"net/http"
"sync/atomic"
"time"
"github.com/googollee/go-engine.io/polling"
"github.com/googollee/go-engine.io/websocket"
)
type config struct {
PingTimeout time.Duration
PingInterval time.Duration
MaxConnection int
AllowRequest func(*http.Request) error
AllowUpgrades bool
Cookie string
NewId func(r *http.Request) string
}
// Server is the server of engine.io.
type Server struct {
config config
socketChan chan Conn
serverSessions Sessions
creaters transportCreaters
currentConnection int32
}
// NewServer returns the server suppported given transports. If transports is nil, server will use ["polling", "websocket"] as default.
func NewServer(transports []string) (*Server, error) {
if transports == nil {
transports = []string{"polling", "websocket"}
}
creaters := make(transportCreaters)
for _, t := range transports {
switch t {
case "polling":
creaters[t] = polling.Creater
case "websocket":
creaters[t] = websocket.Creater
default:
return nil, InvalidError
}
}
return &Server{
config: config{
PingTimeout: 60000 * time.Millisecond,
PingInterval: 25000 * time.Millisecond,
MaxConnection: 1000,
AllowRequest: func(*http.Request) error { return nil },
AllowUpgrades: true,
Cookie: "io",
NewId: newId,
},
socketChan: make(chan Conn),
serverSessions: newServerSessions(),
creaters: creaters,
}, nil
}
// SetPingTimeout sets the timeout of ping. When time out, server will close connection. Default is 60s.
func (s *Server) SetPingTimeout(t time.Duration) {
s.config.PingTimeout = t
}
// SetPingInterval sets the interval of ping. Default is 25s.
func (s *Server) SetPingInterval(t time.Duration) {
s.config.PingInterval = t
}
// SetMaxConnection sets the max connetion. Default is 1000.
func (s *Server) SetMaxConnection(n int) {
s.config.MaxConnection = n
}
// GetMaxConnection returns the current max connection
func (s *Server) GetMaxConnection() int {
return s.config.MaxConnection
}
// Count returns a count of current number of active connections in session
func (s *Server) Count() int {
return int(atomic.LoadInt32(&s.currentConnection))
}
// SetAllowRequest sets the middleware function when establish connection. If it return non-nil, connection won't be established. Default will allow all request.
func (s *Server) SetAllowRequest(f func(*http.Request) error) {
s.config.AllowRequest = f
}
// SetAllowUpgrades sets whether server allows transport upgrade. Default is true.
func (s *Server) SetAllowUpgrades(allow bool) {
s.config.AllowUpgrades = allow
}
// SetCookie sets the name of cookie which used by engine.io. Default is "io".
func (s *Server) SetCookie(prefix string) {
s.config.Cookie = prefix
}
// SetNewId sets the callback func to generate new connection id. By default, id is generated from remote addr + current time stamp
func (s *Server) SetNewId(f func(*http.Request) string) {
s.config.NewId = f
}
// SetSessionManager sets the sessions as server's session manager. Default sessions is single process manager. You can custom it as load balance.
func (s *Server) SetSessionManager(sessions Sessions) {
s.serverSessions = sessions
}
// ServeHTTP handles http request.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
sid := r.URL.Query().Get("sid")
conn := s.serverSessions.Get(sid)
if conn == nil {
if sid != "" {
http.Error(w, "invalid sid", http.StatusBadRequest)
return
}
if err := s.config.AllowRequest(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
n := atomic.AddInt32(&s.currentConnection, 1)
if int(n) > s.config.MaxConnection {
atomic.AddInt32(&s.currentConnection, -1)
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
sid = s.config.NewId(r)
var err error
conn, err = newServerConn(sid, w, r, s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.serverSessions.Set(sid, conn)
s.socketChan <- conn
}
http.SetCookie(w, &http.Cookie{
Name: s.config.Cookie,
Value: sid,
})
conn.(*serverConn).ServeHTTP(w, r)
}
// Accept returns Conn when client connect to server.
func (s *Server) Accept() (Conn, error) {
return <-s.socketChan, nil
}
func (s *Server) configure() config {
return s.config
}
func (s *Server) transports() transportCreaters {
return s.creaters
}
func (s *Server) onClose(id string) {
s.serverSessions.Remove(id)
atomic.AddInt32(&s.currentConnection, -1)
}
func newId(r *http.Request) string {
hash := fmt.Sprintf("%s %s", r.RemoteAddr, time.Now())
buf := bytes.NewBuffer(nil)
sum := md5.Sum([]byte(hash))
encoder := base64.NewEncoder(base64.URLEncoding, buf)
encoder.Write(sum[:])
encoder.Close()
return buf.String()[:20]
}