All checks were successful
Build Docker and Deploy / Run goext test-suite (push) Successful in 1m40s
178 lines
3.8 KiB
Go
178 lines
3.8 KiB
Go
package wsw
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"sync"
|
|
|
|
"git.blackforestbytes.com/BlackForestBytes/goext/exerr"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// TODO move me - once finalized - to goext
|
|
|
|
type WebSocketWrapper[TMessage any] struct {
|
|
lock sync.Mutex
|
|
|
|
running bool
|
|
upgrader websocket.Upgrader
|
|
writer http.ResponseWriter
|
|
request *http.Request
|
|
socket *websocket.Conn
|
|
|
|
fallbackDecoder func([]byte) (TMessage, error)
|
|
|
|
MessageChan chan TMessage
|
|
CloseChan chan string
|
|
ErrorChan chan error
|
|
}
|
|
|
|
func NewWebSocketWrapper[TMessage any](w http.ResponseWriter, r *http.Request, allowedOrigins *[]string) *WebSocketWrapper[TMessage] {
|
|
var checkOrigin func(r *http.Request) bool = nil
|
|
if allowedOrigins != nil {
|
|
checkOrigin = func(r *http.Request) bool {
|
|
origin := r.Header["Origin"]
|
|
if len(origin) == 0 {
|
|
return true
|
|
}
|
|
|
|
u, err := url.Parse(origin[0])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
for _, origPattern := range *allowedOrigins {
|
|
if ok, err := path.Match(origPattern, u.Host); err == nil && ok {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
}
|
|
return &WebSocketWrapper[TMessage]{
|
|
lock: sync.Mutex{},
|
|
running: false,
|
|
upgrader: websocket.Upgrader{CheckOrigin: checkOrigin},
|
|
writer: w,
|
|
request: r,
|
|
}
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) Start() error {
|
|
wsw.MessageChan = make(chan TMessage)
|
|
wsw.CloseChan = make(chan string)
|
|
wsw.ErrorChan = make(chan error)
|
|
|
|
var err error
|
|
|
|
wsw.socket, err = wsw.upgrader.Upgrade(wsw.writer, wsw.request, nil)
|
|
if err != nil {
|
|
return exerr.Wrap(err, "").Build()
|
|
}
|
|
|
|
wsw.running = true
|
|
go func() {
|
|
for wsw.running {
|
|
mt, message, err := wsw.socket.ReadMessage()
|
|
if err != nil {
|
|
wsw.Close("Failed to read message: " + err.Error())
|
|
continue
|
|
}
|
|
|
|
if mt == websocket.TextMessage {
|
|
msg, err := wsw.decode(message)
|
|
if err != nil {
|
|
wsw.ErrorChan <- err
|
|
continue
|
|
}
|
|
wsw.MessageChan <- msg
|
|
} else if mt == websocket.BinaryMessage {
|
|
wsw.ErrorChan <- exerr.New(exerr.TypeWebsocket, "Binary messages are not supported").Build()
|
|
} else if mt == websocket.CloseMessage {
|
|
if len(message) > 0 {
|
|
wsw.Close("Closed by client: " + string(message))
|
|
} else {
|
|
wsw.Close("Closed by client")
|
|
}
|
|
} else if mt == websocket.PingMessage {
|
|
_ = wsw.socket.WriteMessage(websocket.PongMessage, []byte{})
|
|
}
|
|
}
|
|
wsw.running = false
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) decode(message []byte) (TMessage, error) {
|
|
var msg TMessage
|
|
err := json.Unmarshal(message, &msg)
|
|
if err == nil {
|
|
return msg, nil
|
|
}
|
|
if wsw.fallbackDecoder != nil {
|
|
return wsw.fallbackDecoder(message)
|
|
}
|
|
return *new(TMessage), err
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) Close(reasons ...string) {
|
|
wsw.lock.Lock()
|
|
defer wsw.lock.Unlock()
|
|
|
|
if !wsw.running {
|
|
return // already closed
|
|
}
|
|
|
|
reason := "Manual close"
|
|
if len(reasons) > 0 {
|
|
reason = strings.Join(reasons, " ")
|
|
}
|
|
|
|
wsw.CloseChan <- reason
|
|
|
|
wsw.running = false
|
|
_ = wsw.socket.Close()
|
|
|
|
close(wsw.MessageChan)
|
|
close(wsw.CloseChan)
|
|
close(wsw.ErrorChan)
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) Send(data any) {
|
|
if !wsw.running {
|
|
exerr.New(exerr.TypeWebsocket, "Cannot send to websocket -- not running").Print()
|
|
return
|
|
}
|
|
|
|
b, err := json.Marshal(data)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
err = wsw.socket.WriteMessage(websocket.TextMessage, b)
|
|
if err != nil {
|
|
exerr.New(exerr.TypeWebsocket, "Failed to send to websocket").Print()
|
|
wsw.Close("Failed to send to websocket: " + err.Error())
|
|
return
|
|
}
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) Running() bool {
|
|
wsw.lock.Lock()
|
|
defer wsw.lock.Unlock()
|
|
|
|
return wsw.running
|
|
}
|
|
|
|
func (wsw *WebSocketWrapper[TMessage]) SetFallbackDecoder(dec func([]byte) (TMessage, error)) {
|
|
wsw.lock.Lock()
|
|
defer wsw.lock.Unlock()
|
|
|
|
wsw.fallbackDecoder = dec
|
|
}
|