package wsw import ( "encoding/json" "net/http" "net/url" "path" "strings" "sync" "git.blackforestbytes.com/BlackForestBytes/goext/exerr" "github.com/gorilla/websocket" ) // TODO move me - once finalized - to goext type WebSocketWrapper[TMessage any] struct { lock sync.Mutex running bool upgrader websocket.Upgrader writer http.ResponseWriter request *http.Request socket *websocket.Conn fallbackDecoder func([]byte) (TMessage, error) MessageChan chan TMessage CloseChan chan string ErrorChan chan error } func NewWebSocketWrapper[TMessage any](w http.ResponseWriter, r *http.Request, allowedOrigins *[]string) *WebSocketWrapper[TMessage] { var checkOrigin func(r *http.Request) bool = nil if allowedOrigins != nil { checkOrigin = func(r *http.Request) bool { origin := r.Header["Origin"] if len(origin) == 0 { return true } u, err := url.Parse(origin[0]) if err != nil { return false } for _, origPattern := range *allowedOrigins { if ok, err := path.Match(origPattern, u.Host); err == nil && ok { return true } } return false } } return &WebSocketWrapper[TMessage]{ lock: sync.Mutex{}, running: false, upgrader: websocket.Upgrader{CheckOrigin: checkOrigin}, writer: w, request: r, } } func (wsw *WebSocketWrapper[TMessage]) Start() error { wsw.MessageChan = make(chan TMessage) wsw.CloseChan = make(chan string) wsw.ErrorChan = make(chan error) var err error wsw.socket, err = wsw.upgrader.Upgrade(wsw.writer, wsw.request, nil) if err != nil { return exerr.Wrap(err, "").Build() } wsw.running = true go func() { for wsw.running { mt, message, err := wsw.socket.ReadMessage() if err != nil { wsw.Close("Failed to read message: " + err.Error()) continue } if mt == websocket.TextMessage { msg, err := wsw.decode(message) if err != nil { wsw.ErrorChan <- err continue } wsw.MessageChan <- msg } else if mt == websocket.BinaryMessage { wsw.ErrorChan <- exerr.New(exerr.TypeWebsocket, "Binary messages are not supported").Build() } else if mt == websocket.CloseMessage { if len(message) > 0 { wsw.Close("Closed by client: " + string(message)) } else { wsw.Close("Closed by client") } } else if mt == websocket.PingMessage { _ = wsw.socket.WriteMessage(websocket.PongMessage, []byte{}) } } wsw.running = false }() return nil } func (wsw *WebSocketWrapper[TMessage]) decode(message []byte) (TMessage, error) { var msg TMessage err := json.Unmarshal(message, &msg) if err == nil { return msg, nil } if wsw.fallbackDecoder != nil { return wsw.fallbackDecoder(message) } return *new(TMessage), err } func (wsw *WebSocketWrapper[TMessage]) Close(reasons ...string) { wsw.lock.Lock() defer wsw.lock.Unlock() if !wsw.running { return // already closed } reason := "Manual close" if len(reasons) > 0 { reason = strings.Join(reasons, " ") } wsw.CloseChan <- reason wsw.running = false _ = wsw.socket.Close() close(wsw.MessageChan) close(wsw.CloseChan) close(wsw.ErrorChan) } func (wsw *WebSocketWrapper[TMessage]) Send(data any) { if !wsw.running { exerr.New(exerr.TypeWebsocket, "Cannot send to websocket -- not running").Print() return } b, err := json.Marshal(data) if err != nil { panic(err) } err = wsw.socket.WriteMessage(websocket.TextMessage, b) if err != nil { exerr.New(exerr.TypeWebsocket, "Failed to send to websocket").Print() wsw.Close("Failed to send to websocket: " + err.Error()) return } } func (wsw *WebSocketWrapper[TMessage]) Running() bool { wsw.lock.Lock() defer wsw.lock.Unlock() return wsw.running } func (wsw *WebSocketWrapper[TMessage]) SetFallbackDecoder(dec func([]byte) (TMessage, error)) { wsw.lock.Lock() defer wsw.lock.Unlock() wsw.fallbackDecoder = dec }