[🤖] Add Unit-Tests
Build Docker and Deploy / Run goext test-suite (push) Successful in 1m34s

This commit is contained in:
2026-04-27 10:46:08 +02:00
parent dad0e3240d
commit 02d6894ec6
116 changed files with 18795 additions and 1 deletions
+652
View File
@@ -0,0 +1,652 @@
package wsw
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type testMessage struct {
Kind string `json:"kind"`
Value int `json:"value"`
}
func newDummyReqRes() (*httptest.ResponseRecorder, *http.Request) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
return rec, req
}
func TestNewWebSocketWrapper_NilOrigins_NoCheckOrigin(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
if w == nil {
t.Fatal("expected non-nil wrapper")
}
if w.upgrader.CheckOrigin != nil {
t.Errorf("expected CheckOrigin to be nil when allowedOrigins is nil")
}
if w.running {
t.Errorf("expected running to be false initially")
}
if w.writer != rec {
t.Errorf("expected writer to be the same passed in")
}
if w.request != req {
t.Errorf("expected request to be the same passed in")
}
}
func TestNewWebSocketWrapper_WithOrigins_SetsCheckOrigin(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
if w == nil {
t.Fatal("expected non-nil wrapper")
}
if w.upgrader.CheckOrigin == nil {
t.Errorf("expected CheckOrigin to be set when allowedOrigins is non-nil")
}
}
func makeReqWithOrigin(origin string) *http.Request {
r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
if origin != "" {
r.Header.Set("Origin", origin)
}
return r
}
func TestCheckOrigin_NoOriginHeader_AllowsRequest(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
r := makeReqWithOrigin("")
if !w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to return true when Origin header is missing")
}
}
func TestCheckOrigin_InvalidURL_Rejects(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
// Construct a request whose Origin cannot be parsed by url.Parse.
r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
r.Header["Origin"] = []string{"http://[::1:bad"}
if w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to return false for invalid Origin URL")
}
}
func TestCheckOrigin_ExactMatch_Allowed(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
r := makeReqWithOrigin("https://example.com")
if !w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to allow exact host match")
}
}
func TestCheckOrigin_NoMatch_Rejected(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
r := makeReqWithOrigin("https://other.com")
if w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to reject non-matching host")
}
}
func TestCheckOrigin_WildcardMatch_Allowed(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"*.example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
r := makeReqWithOrigin("https://api.example.com")
if !w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to allow wildcard host match")
}
}
func TestCheckOrigin_WildcardMatch_DoesNotMatchTopLevel(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"*.example.com"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
// path.Match's "*" does not cross dots, so "example.com" must not match "*.example.com".
r := makeReqWithOrigin("https://example.com")
if w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to reject bare host against wildcard subdomain pattern")
}
}
func TestCheckOrigin_EmptyAllowedList_Rejects(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
r := makeReqWithOrigin("https://example.com")
if w.upgrader.CheckOrigin(r) {
t.Errorf("expected CheckOrigin to reject when allowed list is empty and Origin is set")
}
}
func TestCheckOrigin_PortIsPartOfHost(t *testing.T) {
rec, req := newDummyReqRes()
allowed := []string{"example.com:8080"}
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
rOK := makeReqWithOrigin("http://example.com:8080")
if !w.upgrader.CheckOrigin(rOK) {
t.Errorf("expected host:port to match allowed host:port")
}
// Without port -> different host -> reject.
rNo := makeReqWithOrigin("http://example.com")
if w.upgrader.CheckOrigin(rNo) {
t.Errorf("expected host without port to not match allowed host:port")
}
}
func TestRunning_BeforeStart_False(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
if w.Running() {
t.Errorf("expected Running() to be false before Start")
}
}
func TestDecode_ValidJSON(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
msg, err := w.decode([]byte(`{"kind":"hello","value":7}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if msg.Kind != "hello" || msg.Value != 7 {
t.Errorf("unexpected decoded message: %+v", msg)
}
}
func TestDecode_InvalidJSON_NoFallback_ReturnsError(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
_, err := w.decode([]byte("not json"))
if err == nil {
t.Fatalf("expected error when decoding invalid JSON without fallback")
}
}
func TestDecode_InvalidJSON_WithFallback_UsesFallback(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
called := false
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
called = true
return testMessage{Kind: "fallback:" + string(b), Value: 42}, nil
})
msg, err := w.decode([]byte("not json"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Errorf("expected fallback decoder to be called")
}
if msg.Kind != "fallback:not json" || msg.Value != 42 {
t.Errorf("unexpected fallback result: %+v", msg)
}
}
func TestDecode_InvalidJSON_FallbackError_Propagates(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
sentinel := errors.New("nope")
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
return testMessage{}, sentinel
})
_, err := w.decode([]byte("not json"))
if !errors.Is(err, sentinel) {
t.Errorf("expected sentinel error from fallback decoder, got %v", err)
}
}
func TestDecode_ValidJSON_DoesNotCallFallback(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
called := false
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
called = true
return testMessage{}, nil
})
_, err := w.decode([]byte(`{"kind":"x","value":1}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if called {
t.Errorf("did not expect fallback decoder to be called when JSON is valid")
}
}
func TestSetFallbackDecoder_StoresFunction(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
if w.fallbackDecoder != nil {
t.Fatalf("expected fallbackDecoder to be nil initially")
}
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
return testMessage{}, nil
})
if w.fallbackDecoder == nil {
t.Errorf("expected fallbackDecoder to be set after SetFallbackDecoder")
}
}
// ---------------------------------------------------------------------------
// Integration tests (using httptest + gorilla/websocket loopback).
// ---------------------------------------------------------------------------
// startTestServer wires up a httptest server that, on a websocket upgrade
// request, builds a WebSocketWrapper, starts it, and returns it via the
// supplied channel for the test to interact with. The returned cleanup
// function should be deferred by the caller.
func startTestServer(t *testing.T, allowedOrigins *[]string) (*httptest.Server, <-chan *WebSocketWrapper[testMessage]) {
t.Helper()
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := NewWebSocketWrapper[testMessage](w, r, allowedOrigins)
if err := ww.Start(); err != nil {
t.Errorf("Start failed: %v", err)
wrapperCh <- nil
return
}
wrapperCh <- ww
}))
t.Cleanup(srv.Close)
return srv, wrapperCh
}
// httpToWS rewrites an http(s) URL into ws(s).
func httpToWS(u string) string {
if rest, ok := strings.CutPrefix(u, "https://"); ok {
return "wss://" + rest
}
rest, _ := strings.CutPrefix(u, "http://")
return "ws://" + rest
}
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
t.Helper()
wsURL := httpToWS(server.URL)
c, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
return c
}
func TestIntegration_StartReceivesValidMessage(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
if !ww.Running() {
t.Errorf("expected Running() to be true after Start")
}
payload := testMessage{Kind: "ping", Value: 9}
b, _ := json.Marshal(payload)
if err := c.WriteMessage(websocket.TextMessage, b); err != nil {
t.Fatalf("client write: %v", err)
}
select {
case got := <-ww.MessageChan:
if got.Kind != "ping" || got.Value != 9 {
t.Errorf("unexpected message: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for message")
}
// Drain CloseChan in a goroutine because Close() sends to it synchronously.
closeReason := make(chan string, 1)
go func() {
if r, ok := <-ww.CloseChan; ok {
closeReason <- r
} else {
closeReason <- ""
}
}()
ww.Close("test-done")
select {
case r := <-closeReason:
if r != "test-done" {
t.Errorf("expected close reason 'test-done', got %q", r)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for close reason")
}
if ww.Running() {
t.Errorf("expected Running() to be false after Close")
}
}
func TestIntegration_InvalidJSON_SendsErrorOnErrorChan(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
if err := c.WriteMessage(websocket.TextMessage, []byte("garbage")); err != nil {
t.Fatalf("client write: %v", err)
}
select {
case err := <-ww.ErrorChan:
if err == nil {
t.Errorf("expected non-nil error on ErrorChan")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for error on ErrorChan")
}
// Drain CloseChan and close cleanly.
go func() { <-ww.CloseChan }()
ww.Close()
}
func TestIntegration_InvalidJSON_FallbackDecoder(t *testing.T) {
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := NewWebSocketWrapper[testMessage](w, r, nil)
ww.SetFallbackDecoder(func(b []byte) (testMessage, error) {
return testMessage{Kind: "raw:" + string(b), Value: 1}, nil
})
if err := ww.Start(); err != nil {
t.Errorf("Start failed: %v", err)
wrapperCh <- nil
return
}
wrapperCh <- ww
}))
defer srv.Close()
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
if err := c.WriteMessage(websocket.TextMessage, []byte("not-json-but-ok")); err != nil {
t.Fatalf("client write: %v", err)
}
select {
case got := <-ww.MessageChan:
if got.Kind != "raw:not-json-but-ok" {
t.Errorf("unexpected fallback-decoded message: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for fallback-decoded message")
}
go func() { <-ww.CloseChan }()
ww.Close()
}
func TestIntegration_BinaryMessage_SendsErrorOnErrorChan(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
if err := c.WriteMessage(websocket.BinaryMessage, []byte{0x01, 0x02, 0x03}); err != nil {
t.Fatalf("client write: %v", err)
}
select {
case err := <-ww.ErrorChan:
if err == nil {
t.Errorf("expected non-nil error for binary message")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for binary-message error")
}
go func() { <-ww.CloseChan }()
ww.Close()
}
func TestIntegration_Send_DeliversToClient(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
go func() {
ww.Send(testMessage{Kind: "outbound", Value: 123})
}()
mt, raw, err := c.ReadMessage()
if err != nil {
t.Fatalf("client read: %v", err)
}
if mt != websocket.TextMessage {
t.Errorf("expected text message, got type %d", mt)
}
var got testMessage
if err := json.Unmarshal(raw, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Kind != "outbound" || got.Value != 123 {
t.Errorf("unexpected payload: %+v", got)
}
go func() { <-ww.CloseChan }()
ww.Close()
}
func TestIntegration_Send_NotRunning_NoPanic(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
// Should not panic; should just print an internal error and return.
defer func() {
if r := recover(); r != nil {
t.Errorf("Send panicked when not running: %v", r)
}
}()
w.Send(testMessage{Kind: "x", Value: 1})
}
func TestIntegration_Close_BeforeStart_NoPanic(t *testing.T) {
rec, req := newDummyReqRes()
w := NewWebSocketWrapper[testMessage](rec, req, nil)
defer func() {
if r := recover(); r != nil {
t.Errorf("Close panicked when not running: %v", r)
}
}()
w.Close("never started")
}
func TestIntegration_Close_IsIdempotent(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
// First close: drain CloseChan in goroutine, since send is synchronous.
done := make(chan struct{})
go func() {
<-ww.CloseChan
close(done)
}()
ww.Close("first")
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first close")
}
// Second close: must be a no-op (channels are already closed).
defer func() {
if r := recover(); r != nil {
t.Errorf("second Close() panicked: %v", r)
}
}()
ww.Close("second")
if ww.Running() {
t.Errorf("expected Running() to remain false after repeated Close")
}
}
func TestIntegration_ClientClose_TriggersCloseChan(t *testing.T) {
srv, wrapperCh := startTestServer(t, nil)
c := dialWS(t, srv)
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
// Reader for CloseChan.
gotClose := make(chan string, 1)
go func() {
if r, ok := <-ww.CloseChan; ok {
gotClose <- r
} else {
gotClose <- ""
}
}()
// Client closes the connection abruptly.
_ = c.Close()
select {
case r := <-gotClose:
if !strings.Contains(strings.ToLower(r), "fail") && !strings.Contains(strings.ToLower(r), "close") {
t.Errorf("expected close reason mentioning failure/close, got %q", r)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for server-side close on client disconnect")
}
if ww.Running() {
t.Errorf("expected Running() to be false after client close propagated")
}
}
func TestIntegration_OriginCheck_RejectsDisallowedOrigin(t *testing.T) {
allowed := []string{"allowed.example"}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := NewWebSocketWrapper[testMessage](w, r, &allowed)
_ = ww.Start() // expected to fail with 403, returns error - we ignore it.
}))
defer srv.Close()
wsURL := httpToWS(srv.URL)
hdr := http.Header{}
hdr.Set("Origin", "http://disallowed.example")
c, resp, err := websocket.DefaultDialer.Dial(wsURL, hdr)
if err == nil {
_ = c.Close()
t.Fatalf("expected dial to fail due to origin rejection")
}
if resp == nil {
t.Fatalf("expected an HTTP response on origin rejection")
}
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected status 403, got %d", resp.StatusCode)
}
}
func TestIntegration_OriginCheck_AcceptsAllowedOrigin(t *testing.T) {
allowed := []string{"allowed.example"}
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := NewWebSocketWrapper[testMessage](w, r, &allowed)
if err := ww.Start(); err != nil {
wrapperCh <- nil
return
}
wrapperCh <- ww
}))
defer srv.Close()
wsURL := httpToWS(srv.URL)
hdr := http.Header{}
hdr.Set("Origin", "http://allowed.example")
c, _, err := websocket.DefaultDialer.Dial(wsURL, hdr)
if err != nil {
t.Fatalf("dial with allowed origin failed: %v", err)
}
defer c.Close()
ww := <-wrapperCh
if ww == nil {
t.Fatal("wrapper not initialized")
}
if !ww.Running() {
t.Errorf("expected Running() true after successful upgrade")
}
go func() { <-ww.CloseChan }()
ww.Close()
}