This commit is contained in:
@@ -0,0 +1,652 @@
|
||||
package wsw
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type testMessage struct {
|
||||
Kind string `json:"kind"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
func newDummyReqRes() (*httptest.ResponseRecorder, *http.Request) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
|
||||
return rec, req
|
||||
}
|
||||
|
||||
func TestNewWebSocketWrapper_NilOrigins_NoCheckOrigin(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
if w == nil {
|
||||
t.Fatal("expected non-nil wrapper")
|
||||
}
|
||||
if w.upgrader.CheckOrigin != nil {
|
||||
t.Errorf("expected CheckOrigin to be nil when allowedOrigins is nil")
|
||||
}
|
||||
if w.running {
|
||||
t.Errorf("expected running to be false initially")
|
||||
}
|
||||
if w.writer != rec {
|
||||
t.Errorf("expected writer to be the same passed in")
|
||||
}
|
||||
if w.request != req {
|
||||
t.Errorf("expected request to be the same passed in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSocketWrapper_WithOrigins_SetsCheckOrigin(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
if w == nil {
|
||||
t.Fatal("expected non-nil wrapper")
|
||||
}
|
||||
if w.upgrader.CheckOrigin == nil {
|
||||
t.Errorf("expected CheckOrigin to be set when allowedOrigins is non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
func makeReqWithOrigin(origin string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
|
||||
if origin != "" {
|
||||
r.Header.Set("Origin", origin)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCheckOrigin_NoOriginHeader_AllowsRequest(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
r := makeReqWithOrigin("")
|
||||
if !w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to return true when Origin header is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_InvalidURL_Rejects(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
// Construct a request whose Origin cannot be parsed by url.Parse.
|
||||
r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
|
||||
r.Header["Origin"] = []string{"http://[::1:bad"}
|
||||
if w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to return false for invalid Origin URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_ExactMatch_Allowed(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
r := makeReqWithOrigin("https://example.com")
|
||||
if !w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to allow exact host match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_NoMatch_Rejected(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
r := makeReqWithOrigin("https://other.com")
|
||||
if w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to reject non-matching host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_WildcardMatch_Allowed(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"*.example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
r := makeReqWithOrigin("https://api.example.com")
|
||||
if !w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to allow wildcard host match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_WildcardMatch_DoesNotMatchTopLevel(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"*.example.com"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
// path.Match's "*" does not cross dots, so "example.com" must not match "*.example.com".
|
||||
r := makeReqWithOrigin("https://example.com")
|
||||
if w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to reject bare host against wildcard subdomain pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_EmptyAllowedList_Rejects(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
r := makeReqWithOrigin("https://example.com")
|
||||
if w.upgrader.CheckOrigin(r) {
|
||||
t.Errorf("expected CheckOrigin to reject when allowed list is empty and Origin is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOrigin_PortIsPartOfHost(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
allowed := []string{"example.com:8080"}
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, &allowed)
|
||||
|
||||
rOK := makeReqWithOrigin("http://example.com:8080")
|
||||
if !w.upgrader.CheckOrigin(rOK) {
|
||||
t.Errorf("expected host:port to match allowed host:port")
|
||||
}
|
||||
|
||||
// Without port -> different host -> reject.
|
||||
rNo := makeReqWithOrigin("http://example.com")
|
||||
if w.upgrader.CheckOrigin(rNo) {
|
||||
t.Errorf("expected host without port to not match allowed host:port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunning_BeforeStart_False(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
if w.Running() {
|
||||
t.Errorf("expected Running() to be false before Start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_ValidJSON(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
msg, err := w.decode([]byte(`{"kind":"hello","value":7}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if msg.Kind != "hello" || msg.Value != 7 {
|
||||
t.Errorf("unexpected decoded message: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_InvalidJSON_NoFallback_ReturnsError(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
_, err := w.decode([]byte("not json"))
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when decoding invalid JSON without fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_InvalidJSON_WithFallback_UsesFallback(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
called := false
|
||||
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
|
||||
called = true
|
||||
return testMessage{Kind: "fallback:" + string(b), Value: 42}, nil
|
||||
})
|
||||
|
||||
msg, err := w.decode([]byte("not json"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Errorf("expected fallback decoder to be called")
|
||||
}
|
||||
if msg.Kind != "fallback:not json" || msg.Value != 42 {
|
||||
t.Errorf("unexpected fallback result: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_InvalidJSON_FallbackError_Propagates(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
sentinel := errors.New("nope")
|
||||
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
|
||||
return testMessage{}, sentinel
|
||||
})
|
||||
|
||||
_, err := w.decode([]byte("not json"))
|
||||
if !errors.Is(err, sentinel) {
|
||||
t.Errorf("expected sentinel error from fallback decoder, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_ValidJSON_DoesNotCallFallback(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
called := false
|
||||
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
|
||||
called = true
|
||||
return testMessage{}, nil
|
||||
})
|
||||
|
||||
_, err := w.decode([]byte(`{"kind":"x","value":1}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if called {
|
||||
t.Errorf("did not expect fallback decoder to be called when JSON is valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFallbackDecoder_StoresFunction(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
if w.fallbackDecoder != nil {
|
||||
t.Fatalf("expected fallbackDecoder to be nil initially")
|
||||
}
|
||||
w.SetFallbackDecoder(func(b []byte) (testMessage, error) {
|
||||
return testMessage{}, nil
|
||||
})
|
||||
if w.fallbackDecoder == nil {
|
||||
t.Errorf("expected fallbackDecoder to be set after SetFallbackDecoder")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Integration tests (using httptest + gorilla/websocket loopback).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// startTestServer wires up a httptest server that, on a websocket upgrade
|
||||
// request, builds a WebSocketWrapper, starts it, and returns it via the
|
||||
// supplied channel for the test to interact with. The returned cleanup
|
||||
// function should be deferred by the caller.
|
||||
func startTestServer(t *testing.T, allowedOrigins *[]string) (*httptest.Server, <-chan *WebSocketWrapper[testMessage]) {
|
||||
t.Helper()
|
||||
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := NewWebSocketWrapper[testMessage](w, r, allowedOrigins)
|
||||
if err := ww.Start(); err != nil {
|
||||
t.Errorf("Start failed: %v", err)
|
||||
wrapperCh <- nil
|
||||
return
|
||||
}
|
||||
wrapperCh <- ww
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, wrapperCh
|
||||
}
|
||||
|
||||
// httpToWS rewrites an http(s) URL into ws(s).
|
||||
func httpToWS(u string) string {
|
||||
if rest, ok := strings.CutPrefix(u, "https://"); ok {
|
||||
return "wss://" + rest
|
||||
}
|
||||
rest, _ := strings.CutPrefix(u, "http://")
|
||||
return "ws://" + rest
|
||||
}
|
||||
|
||||
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
|
||||
t.Helper()
|
||||
wsURL := httpToWS(server.URL)
|
||||
c, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func TestIntegration_StartReceivesValidMessage(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
if !ww.Running() {
|
||||
t.Errorf("expected Running() to be true after Start")
|
||||
}
|
||||
|
||||
payload := testMessage{Kind: "ping", Value: 9}
|
||||
b, _ := json.Marshal(payload)
|
||||
if err := c.WriteMessage(websocket.TextMessage, b); err != nil {
|
||||
t.Fatalf("client write: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-ww.MessageChan:
|
||||
if got.Kind != "ping" || got.Value != 9 {
|
||||
t.Errorf("unexpected message: %+v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
|
||||
// Drain CloseChan in a goroutine because Close() sends to it synchronously.
|
||||
closeReason := make(chan string, 1)
|
||||
go func() {
|
||||
if r, ok := <-ww.CloseChan; ok {
|
||||
closeReason <- r
|
||||
} else {
|
||||
closeReason <- ""
|
||||
}
|
||||
}()
|
||||
|
||||
ww.Close("test-done")
|
||||
|
||||
select {
|
||||
case r := <-closeReason:
|
||||
if r != "test-done" {
|
||||
t.Errorf("expected close reason 'test-done', got %q", r)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for close reason")
|
||||
}
|
||||
|
||||
if ww.Running() {
|
||||
t.Errorf("expected Running() to be false after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_InvalidJSON_SendsErrorOnErrorChan(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
if err := c.WriteMessage(websocket.TextMessage, []byte("garbage")); err != nil {
|
||||
t.Fatalf("client write: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-ww.ErrorChan:
|
||||
if err == nil {
|
||||
t.Errorf("expected non-nil error on ErrorChan")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for error on ErrorChan")
|
||||
}
|
||||
|
||||
// Drain CloseChan and close cleanly.
|
||||
go func() { <-ww.CloseChan }()
|
||||
ww.Close()
|
||||
}
|
||||
|
||||
func TestIntegration_InvalidJSON_FallbackDecoder(t *testing.T) {
|
||||
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := NewWebSocketWrapper[testMessage](w, r, nil)
|
||||
ww.SetFallbackDecoder(func(b []byte) (testMessage, error) {
|
||||
return testMessage{Kind: "raw:" + string(b), Value: 1}, nil
|
||||
})
|
||||
if err := ww.Start(); err != nil {
|
||||
t.Errorf("Start failed: %v", err)
|
||||
wrapperCh <- nil
|
||||
return
|
||||
}
|
||||
wrapperCh <- ww
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
if err := c.WriteMessage(websocket.TextMessage, []byte("not-json-but-ok")); err != nil {
|
||||
t.Fatalf("client write: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-ww.MessageChan:
|
||||
if got.Kind != "raw:not-json-but-ok" {
|
||||
t.Errorf("unexpected fallback-decoded message: %+v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for fallback-decoded message")
|
||||
}
|
||||
|
||||
go func() { <-ww.CloseChan }()
|
||||
ww.Close()
|
||||
}
|
||||
|
||||
func TestIntegration_BinaryMessage_SendsErrorOnErrorChan(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
if err := c.WriteMessage(websocket.BinaryMessage, []byte{0x01, 0x02, 0x03}); err != nil {
|
||||
t.Fatalf("client write: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-ww.ErrorChan:
|
||||
if err == nil {
|
||||
t.Errorf("expected non-nil error for binary message")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for binary-message error")
|
||||
}
|
||||
|
||||
go func() { <-ww.CloseChan }()
|
||||
ww.Close()
|
||||
}
|
||||
|
||||
func TestIntegration_Send_DeliversToClient(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
go func() {
|
||||
ww.Send(testMessage{Kind: "outbound", Value: 123})
|
||||
}()
|
||||
|
||||
mt, raw, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("client read: %v", err)
|
||||
}
|
||||
if mt != websocket.TextMessage {
|
||||
t.Errorf("expected text message, got type %d", mt)
|
||||
}
|
||||
|
||||
var got testMessage
|
||||
if err := json.Unmarshal(raw, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.Kind != "outbound" || got.Value != 123 {
|
||||
t.Errorf("unexpected payload: %+v", got)
|
||||
}
|
||||
|
||||
go func() { <-ww.CloseChan }()
|
||||
ww.Close()
|
||||
}
|
||||
|
||||
func TestIntegration_Send_NotRunning_NoPanic(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
// Should not panic; should just print an internal error and return.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Send panicked when not running: %v", r)
|
||||
}
|
||||
}()
|
||||
w.Send(testMessage{Kind: "x", Value: 1})
|
||||
}
|
||||
|
||||
func TestIntegration_Close_BeforeStart_NoPanic(t *testing.T) {
|
||||
rec, req := newDummyReqRes()
|
||||
w := NewWebSocketWrapper[testMessage](rec, req, nil)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Close panicked when not running: %v", r)
|
||||
}
|
||||
}()
|
||||
w.Close("never started")
|
||||
}
|
||||
|
||||
func TestIntegration_Close_IsIdempotent(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
// First close: drain CloseChan in goroutine, since send is synchronous.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-ww.CloseChan
|
||||
close(done)
|
||||
}()
|
||||
ww.Close("first")
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first close")
|
||||
}
|
||||
|
||||
// Second close: must be a no-op (channels are already closed).
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("second Close() panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
ww.Close("second")
|
||||
|
||||
if ww.Running() {
|
||||
t.Errorf("expected Running() to remain false after repeated Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ClientClose_TriggersCloseChan(t *testing.T) {
|
||||
srv, wrapperCh := startTestServer(t, nil)
|
||||
c := dialWS(t, srv)
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
|
||||
// Reader for CloseChan.
|
||||
gotClose := make(chan string, 1)
|
||||
go func() {
|
||||
if r, ok := <-ww.CloseChan; ok {
|
||||
gotClose <- r
|
||||
} else {
|
||||
gotClose <- ""
|
||||
}
|
||||
}()
|
||||
|
||||
// Client closes the connection abruptly.
|
||||
_ = c.Close()
|
||||
|
||||
select {
|
||||
case r := <-gotClose:
|
||||
if !strings.Contains(strings.ToLower(r), "fail") && !strings.Contains(strings.ToLower(r), "close") {
|
||||
t.Errorf("expected close reason mentioning failure/close, got %q", r)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for server-side close on client disconnect")
|
||||
}
|
||||
|
||||
if ww.Running() {
|
||||
t.Errorf("expected Running() to be false after client close propagated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_OriginCheck_RejectsDisallowedOrigin(t *testing.T) {
|
||||
allowed := []string{"allowed.example"}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := NewWebSocketWrapper[testMessage](w, r, &allowed)
|
||||
_ = ww.Start() // expected to fail with 403, returns error - we ignore it.
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := httpToWS(srv.URL)
|
||||
hdr := http.Header{}
|
||||
hdr.Set("Origin", "http://disallowed.example")
|
||||
c, resp, err := websocket.DefaultDialer.Dial(wsURL, hdr)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
t.Fatalf("expected dial to fail due to origin rejection")
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("expected an HTTP response on origin rejection")
|
||||
}
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status 403, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_OriginCheck_AcceptsAllowedOrigin(t *testing.T) {
|
||||
allowed := []string{"allowed.example"}
|
||||
wrapperCh := make(chan *WebSocketWrapper[testMessage], 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := NewWebSocketWrapper[testMessage](w, r, &allowed)
|
||||
if err := ww.Start(); err != nil {
|
||||
wrapperCh <- nil
|
||||
return
|
||||
}
|
||||
wrapperCh <- ww
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := httpToWS(srv.URL)
|
||||
hdr := http.Header{}
|
||||
hdr.Set("Origin", "http://allowed.example")
|
||||
c, _, err := websocket.DefaultDialer.Dial(wsURL, hdr)
|
||||
if err != nil {
|
||||
t.Fatalf("dial with allowed origin failed: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
ww := <-wrapperCh
|
||||
if ww == nil {
|
||||
t.Fatal("wrapper not initialized")
|
||||
}
|
||||
if !ww.Running() {
|
||||
t.Errorf("expected Running() true after successful upgrade")
|
||||
}
|
||||
|
||||
go func() { <-ww.CloseChan }()
|
||||
ww.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user