package polling
import (
"bytes"
"html/template"
"io"
"net/http"
"sync"
"github.com/googollee/go-engine.io/message"
"github.com/googollee/go-engine.io/parser"
"github.com/googollee/go-engine.io/transport"
)
type state int
const (
stateUnknow state = iota
stateNormal
stateClosing
stateClosed
)
type Polling struct {
sendChan chan bool
encoder *parser.PayloadEncoder
callback transport.Callback
getLocker *Locker
postLocker *Locker
state state
stateLocker sync.Mutex
}
func NewServer(w http.ResponseWriter, r *http.Request, callback transport.Callback) (transport.Server, error) {
newEncoder := parser.NewBinaryPayloadEncoder
if r.URL.Query()["b64"] != nil {
newEncoder = parser.NewStringPayloadEncoder
}
ret := &Polling{
sendChan: MakeSendChan(),
encoder: newEncoder(),
callback: callback,
getLocker: NewLocker(),
postLocker: NewLocker(),
state: stateNormal,
}
return ret, nil
}
func (p *Polling) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
p.get(w, r)
case "POST":
p.post(w, r)
}
}
func (p *Polling) Close() error {
if p.getState() != stateNormal {
return nil
}
close(p.sendChan)
p.setState(stateClosing)
if p.getLocker.TryLock() {
if p.postLocker.TryLock() {
p.callback.OnClose(p)
p.setState(stateClosed)
p.postLocker.Unlock()
}
p.getLocker.Unlock()
}
return nil
}
func (p *Polling) NextWriter(msgType message.MessageType, packetType parser.PacketType) (io.WriteCloser, error) {
if p.getState() != stateNormal {
return nil, io.EOF
}
var ret io.WriteCloser
var err error
switch msgType {
case message.MessageText:
ret, err = p.encoder.NextString(packetType)
case message.MessageBinary:
ret, err = p.encoder.NextBinary(packetType)
}
if err != nil {
return nil, err
}
return NewWriter(ret, p), nil
}
func (p *Polling) get(w http.ResponseWriter, r *http.Request) {
if !p.getLocker.TryLock() {
http.Error(w, "overlay get", http.StatusBadRequest)
return
}
if p.getState() != stateNormal {
http.Error(w, "closed", http.StatusBadRequest)
return
}
defer func() {
if p.getState() == stateClosing {
if p.postLocker.TryLock() {
p.setState(stateClosed)
p.callback.OnClose(p)
p.postLocker.Unlock()
}
}
p.getLocker.Unlock()
}()
<-p.sendChan
if j := r.URL.Query().Get("j"); j != "" {
// JSONP Polling
w.Header().Set("Content-Type", "text/javascript; charset=UTF-8")
tmp := bytes.Buffer{}
p.encoder.EncodeTo(&tmp)
pl := template.JSEscapeString(tmp.String())
w.Write([]byte("___eio[" + j + "](\""))
w.Write([]byte(pl))
w.Write([]byte("\");"))
} else {
// XHR Polling
if p.encoder.IsString() {
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
} else {
w.Header().Set("Content-Type", "application/octet-stream")
}
p.encoder.EncodeTo(w)
}
}
func (p *Polling) post(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
if !p.postLocker.TryLock() {
http.Error(w, "overlay post", http.StatusBadRequest)
return
}
if p.getState() != stateNormal {
http.Error(w, "closed", http.StatusBadRequest)
return
}
defer func() {
if p.getState() == stateClosing {
if p.getLocker.TryLock() {
p.setState(stateClosed)
p.callback.OnClose(p)
p.getLocker.Unlock()
}
}
p.postLocker.Unlock()
}()
var decoder *parser.PayloadDecoder
if j := r.URL.Query().Get("j"); j != "" {
// JSONP Polling
d := r.FormValue("d")
decoder = parser.NewPayloadDecoder(bytes.NewBufferString(d))
} else {
// XHR Polling
decoder = parser.NewPayloadDecoder(r.Body)
}
for {
d, err := decoder.Next()
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
p.callback.OnPacket(d)
d.Close()
}
w.Write([]byte("ok"))
}
func (p *Polling) setState(s state) {
p.stateLocker.Lock()
defer p.stateLocker.Unlock()
p.state = s
}
func (p *Polling) getState() state {
p.stateLocker.Lock()
defer p.stateLocker.Unlock()
return p.state
}