package wsw import ( "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" ) type testMessage struct { Kind string `json:"kind"` Value int `json:"value"` } func newDummyReqRes() (*httptest.ResponseRecorder, *http.Request) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) return rec, req } func TestNewWebSocketWrapper_NilOrigins_NoCheckOrigin(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) if w == nil { t.Fatal("expected non-nil wrapper") } if w.upgrader.CheckOrigin != nil { t.Errorf("expected CheckOrigin to be nil when allowedOrigins is nil") } if w.running { t.Errorf("expected running to be false initially") } if w.writer != rec { t.Errorf("expected writer to be the same passed in") } if w.request != req { t.Errorf("expected request to be the same passed in") } } func TestNewWebSocketWrapper_WithOrigins_SetsCheckOrigin(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) if w == nil { t.Fatal("expected non-nil wrapper") } if w.upgrader.CheckOrigin == nil { t.Errorf("expected CheckOrigin to be set when allowedOrigins is non-nil") } } func makeReqWithOrigin(origin string) *http.Request { r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) if origin != "" { r.Header.Set("Origin", origin) } return r } func TestCheckOrigin_NoOriginHeader_AllowsRequest(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) r := makeReqWithOrigin("") if !w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to return true when Origin header is missing") } } func TestCheckOrigin_InvalidURL_Rejects(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) // Construct a request whose Origin cannot be parsed by url.Parse. r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) r.Header["Origin"] = []string{"http://[::1:bad"} if w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to return false for invalid Origin URL") } } func TestCheckOrigin_ExactMatch_Allowed(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) r := makeReqWithOrigin("https://example.com") if !w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to allow exact host match") } } func TestCheckOrigin_NoMatch_Rejected(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) r := makeReqWithOrigin("https://other.com") if w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to reject non-matching host") } } func TestCheckOrigin_WildcardMatch_Allowed(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"*.example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) r := makeReqWithOrigin("https://api.example.com") if !w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to allow wildcard host match") } } func TestCheckOrigin_WildcardMatch_DoesNotMatchTopLevel(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"*.example.com"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) // path.Match's "*" does not cross dots, so "example.com" must not match "*.example.com". r := makeReqWithOrigin("https://example.com") if w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to reject bare host against wildcard subdomain pattern") } } func TestCheckOrigin_EmptyAllowedList_Rejects(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) r := makeReqWithOrigin("https://example.com") if w.upgrader.CheckOrigin(r) { t.Errorf("expected CheckOrigin to reject when allowed list is empty and Origin is set") } } func TestCheckOrigin_PortIsPartOfHost(t *testing.T) { rec, req := newDummyReqRes() allowed := []string{"example.com:8080"} w := NewWebSocketWrapper[testMessage](rec, req, &allowed) rOK := makeReqWithOrigin("http://example.com:8080") if !w.upgrader.CheckOrigin(rOK) { t.Errorf("expected host:port to match allowed host:port") } // Without port -> different host -> reject. rNo := makeReqWithOrigin("http://example.com") if w.upgrader.CheckOrigin(rNo) { t.Errorf("expected host without port to not match allowed host:port") } } func TestRunning_BeforeStart_False(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) if w.Running() { t.Errorf("expected Running() to be false before Start") } } func TestDecode_ValidJSON(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) msg, err := w.decode([]byte(`{"kind":"hello","value":7}`)) if err != nil { t.Fatalf("unexpected error: %v", err) } if msg.Kind != "hello" || msg.Value != 7 { t.Errorf("unexpected decoded message: %+v", msg) } } func TestDecode_InvalidJSON_NoFallback_ReturnsError(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) _, err := w.decode([]byte("not json")) if err == nil { t.Fatalf("expected error when decoding invalid JSON without fallback") } } func TestDecode_InvalidJSON_WithFallback_UsesFallback(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) called := false w.SetFallbackDecoder(func(b []byte) (testMessage, error) { called = true return testMessage{Kind: "fallback:" + string(b), Value: 42}, nil }) msg, err := w.decode([]byte("not json")) if err != nil { t.Fatalf("unexpected error: %v", err) } if !called { t.Errorf("expected fallback decoder to be called") } if msg.Kind != "fallback:not json" || msg.Value != 42 { t.Errorf("unexpected fallback result: %+v", msg) } } func TestDecode_InvalidJSON_FallbackError_Propagates(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) sentinel := errors.New("nope") w.SetFallbackDecoder(func(b []byte) (testMessage, error) { return testMessage{}, sentinel }) _, err := w.decode([]byte("not json")) if !errors.Is(err, sentinel) { t.Errorf("expected sentinel error from fallback decoder, got %v", err) } } func TestDecode_ValidJSON_DoesNotCallFallback(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) called := false w.SetFallbackDecoder(func(b []byte) (testMessage, error) { called = true return testMessage{}, nil }) _, err := w.decode([]byte(`{"kind":"x","value":1}`)) if err != nil { t.Fatalf("unexpected error: %v", err) } if called { t.Errorf("did not expect fallback decoder to be called when JSON is valid") } } func TestSetFallbackDecoder_StoresFunction(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) if w.fallbackDecoder != nil { t.Fatalf("expected fallbackDecoder to be nil initially") } w.SetFallbackDecoder(func(b []byte) (testMessage, error) { return testMessage{}, nil }) if w.fallbackDecoder == nil { t.Errorf("expected fallbackDecoder to be set after SetFallbackDecoder") } } // --------------------------------------------------------------------------- // Integration tests (using httptest + gorilla/websocket loopback). // --------------------------------------------------------------------------- // startTestServer wires up a httptest server that, on a websocket upgrade // request, builds a WebSocketWrapper, starts it, and returns it via the // supplied channel for the test to interact with. The returned cleanup // function should be deferred by the caller. func startTestServer(t *testing.T, allowedOrigins *[]string) (*httptest.Server, <-chan *WebSocketWrapper[testMessage]) { t.Helper() wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ww := NewWebSocketWrapper[testMessage](w, r, allowedOrigins) if err := ww.Start(); err != nil { t.Errorf("Start failed: %v", err) wrapperCh <- nil return } wrapperCh <- ww })) t.Cleanup(srv.Close) return srv, wrapperCh } // httpToWS rewrites an http(s) URL into ws(s). func httpToWS(u string) string { if rest, ok := strings.CutPrefix(u, "https://"); ok { return "wss://" + rest } rest, _ := strings.CutPrefix(u, "http://") return "ws://" + rest } func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn { t.Helper() wsURL := httpToWS(server.URL) c, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("dial: %v", err) } return c } func TestIntegration_StartReceivesValidMessage(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } if !ww.Running() { t.Errorf("expected Running() to be true after Start") } payload := testMessage{Kind: "ping", Value: 9} b, _ := json.Marshal(payload) if err := c.WriteMessage(websocket.TextMessage, b); err != nil { t.Fatalf("client write: %v", err) } select { case got := <-ww.MessageChan: if got.Kind != "ping" || got.Value != 9 { t.Errorf("unexpected message: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for message") } // Drain CloseChan in a goroutine because Close() sends to it synchronously. closeReason := make(chan string, 1) go func() { if r, ok := <-ww.CloseChan; ok { closeReason <- r } else { closeReason <- "" } }() ww.Close("test-done") select { case r := <-closeReason: if r != "test-done" { t.Errorf("expected close reason 'test-done', got %q", r) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for close reason") } if ww.Running() { t.Errorf("expected Running() to be false after Close") } } func TestIntegration_InvalidJSON_SendsErrorOnErrorChan(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } if err := c.WriteMessage(websocket.TextMessage, []byte("garbage")); err != nil { t.Fatalf("client write: %v", err) } select { case err := <-ww.ErrorChan: if err == nil { t.Errorf("expected non-nil error on ErrorChan") } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for error on ErrorChan") } // Drain CloseChan and close cleanly. go func() { <-ww.CloseChan }() ww.Close() } func TestIntegration_InvalidJSON_FallbackDecoder(t *testing.T) { wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ww := NewWebSocketWrapper[testMessage](w, r, nil) ww.SetFallbackDecoder(func(b []byte) (testMessage, error) { return testMessage{Kind: "raw:" + string(b), Value: 1}, nil }) if err := ww.Start(); err != nil { t.Errorf("Start failed: %v", err) wrapperCh <- nil return } wrapperCh <- ww })) defer srv.Close() c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } if err := c.WriteMessage(websocket.TextMessage, []byte("not-json-but-ok")); err != nil { t.Fatalf("client write: %v", err) } select { case got := <-ww.MessageChan: if got.Kind != "raw:not-json-but-ok" { t.Errorf("unexpected fallback-decoded message: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for fallback-decoded message") } go func() { <-ww.CloseChan }() ww.Close() } func TestIntegration_BinaryMessage_SendsErrorOnErrorChan(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } if err := c.WriteMessage(websocket.BinaryMessage, []byte{0x01, 0x02, 0x03}); err != nil { t.Fatalf("client write: %v", err) } select { case err := <-ww.ErrorChan: if err == nil { t.Errorf("expected non-nil error for binary message") } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for binary-message error") } go func() { <-ww.CloseChan }() ww.Close() } func TestIntegration_Send_DeliversToClient(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } go func() { ww.Send(testMessage{Kind: "outbound", Value: 123}) }() mt, raw, err := c.ReadMessage() if err != nil { t.Fatalf("client read: %v", err) } if mt != websocket.TextMessage { t.Errorf("expected text message, got type %d", mt) } var got testMessage if err := json.Unmarshal(raw, &got); err != nil { t.Fatalf("unmarshal: %v", err) } if got.Kind != "outbound" || got.Value != 123 { t.Errorf("unexpected payload: %+v", got) } go func() { <-ww.CloseChan }() ww.Close() } func TestIntegration_Send_NotRunning_NoPanic(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) // Should not panic; should just print an internal error and return. defer func() { if r := recover(); r != nil { t.Errorf("Send panicked when not running: %v", r) } }() w.Send(testMessage{Kind: "x", Value: 1}) } func TestIntegration_Close_BeforeStart_NoPanic(t *testing.T) { rec, req := newDummyReqRes() w := NewWebSocketWrapper[testMessage](rec, req, nil) defer func() { if r := recover(); r != nil { t.Errorf("Close panicked when not running: %v", r) } }() w.Close("never started") } func TestIntegration_Close_IsIdempotent(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } // First close: drain CloseChan in goroutine, since send is synchronous. done := make(chan struct{}) go func() { <-ww.CloseChan close(done) }() ww.Close("first") select { case <-done: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for first close") } // Second close: must be a no-op (channels are already closed). defer func() { if r := recover(); r != nil { t.Errorf("second Close() panicked: %v", r) } }() ww.Close("second") if ww.Running() { t.Errorf("expected Running() to remain false after repeated Close") } } func TestIntegration_ClientClose_TriggersCloseChan(t *testing.T) { srv, wrapperCh := startTestServer(t, nil) c := dialWS(t, srv) ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } // Reader for CloseChan. gotClose := make(chan string, 1) go func() { if r, ok := <-ww.CloseChan; ok { gotClose <- r } else { gotClose <- "" } }() // Client closes the connection abruptly. _ = c.Close() select { case r := <-gotClose: if !strings.Contains(strings.ToLower(r), "fail") && !strings.Contains(strings.ToLower(r), "close") { t.Errorf("expected close reason mentioning failure/close, got %q", r) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for server-side close on client disconnect") } if ww.Running() { t.Errorf("expected Running() to be false after client close propagated") } } func TestIntegration_OriginCheck_RejectsDisallowedOrigin(t *testing.T) { allowed := []string{"allowed.example"} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ww := NewWebSocketWrapper[testMessage](w, r, &allowed) _ = ww.Start() // expected to fail with 403, returns error - we ignore it. })) defer srv.Close() wsURL := httpToWS(srv.URL) hdr := http.Header{} hdr.Set("Origin", "http://disallowed.example") c, resp, err := websocket.DefaultDialer.Dial(wsURL, hdr) if err == nil { _ = c.Close() t.Fatalf("expected dial to fail due to origin rejection") } if resp == nil { t.Fatalf("expected an HTTP response on origin rejection") } if resp.StatusCode != http.StatusForbidden { t.Errorf("expected status 403, got %d", resp.StatusCode) } } func TestIntegration_OriginCheck_AcceptsAllowedOrigin(t *testing.T) { allowed := []string{"allowed.example"} wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ww := NewWebSocketWrapper[testMessage](w, r, &allowed) if err := ww.Start(); err != nil { wrapperCh <- nil return } wrapperCh <- ww })) defer srv.Close() wsURL := httpToWS(srv.URL) hdr := http.Header{} hdr.Set("Origin", "http://allowed.example") c, _, err := websocket.DefaultDialer.Dial(wsURL, hdr) if err != nil { t.Fatalf("dial with allowed origin failed: %v", err) } defer c.Close() ww := <-wrapperCh if ww == nil { t.Fatal("wrapper not initialized") } if !ww.Running() { t.Errorf("expected Running() true after successful upgrade") } go func() { <-ww.CloseChan }() ww.Close() }