Skip to content

Commit 5b33b58

Browse files
cfsmp3claude
andauthored
security: add authentication and origin validation to WebSocket (#414)
* security: add authentication and origin validation to WebSocket Add session authentication and origin header validation to the WebSocket endpoint to prevent unauthorized access. - Add checkWebSocketOrigin() for origin header validation - Add AuthenticatedWebSocketHandler() requiring valid session - Update main.go to use authenticated handler - Support ALLOWED_ORIGIN and FRONTEND_ORIGIN_DEV env vars - Allow localhost in development mode - Log rejected connections for security monitoring Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix origin comparison security issue and improve ENV handling Addresses review feedback: 1. Security fix: Replace insecure substring check with proper hostname comparison. Previously `strings.Contains(origin, host)` could be bypassed by an attacker using "malicious-example.com" to match "example.com". Now parses the origin URL and compares hostnames exactly. 2. Add getEnv() helper that returns "development" by default, making the environment check clearer and more maintainable. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent dbeffcf commit 5b33b58

2 files changed

Lines changed: 51 additions & 2 deletions

File tree

backend/controllers/websocket.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@ import (
88

99
"ccsync_backend/utils"
1010

11+
"github.com/gorilla/sessions"
1112
"github.com/gorilla/websocket"
1213
)
1314

15+
// getEnv returns the environment mode, defaulting to "development"
16+
func getEnv() string {
17+
env := os.Getenv("ENV")
18+
if env == "" {
19+
return "development"
20+
}
21+
return env
22+
}
23+
1424
type JobStatus struct {
1525
Job string `json:"job"`
1626
Status string `json:"status"`
@@ -21,7 +31,7 @@ func checkWebSocketOrigin(r *http.Request) bool {
2131
origin := r.Header.Get("Origin")
2232

2333
// In development mode, be more permissive
24-
if os.Getenv("ENV") != "production" {
34+
if getEnv() != "production" {
2535
if origin == "" ||
2636
strings.HasPrefix(origin, "http://localhost") ||
2737
strings.HasPrefix(origin, "http://127.0.0.1") {
@@ -73,6 +83,45 @@ var upgrader = websocket.Upgrader{
7383
var clients = make(map[*websocket.Conn]bool)
7484
var broadcast = make(chan JobStatus)
7585

86+
// AuthenticatedWebSocketHandler creates a WebSocket handler that requires authentication
87+
func AuthenticatedWebSocketHandler(store *sessions.CookieStore) http.HandlerFunc {
88+
return func(w http.ResponseWriter, r *http.Request) {
89+
// Validate session before upgrading to WebSocket
90+
session, err := store.Get(r, "session-name")
91+
if err != nil {
92+
utils.Logger.Warnf("WebSocket auth failed: could not get session: %v", err)
93+
http.Error(w, "Authentication required", http.StatusUnauthorized)
94+
return
95+
}
96+
97+
userInfo, ok := session.Values["user"].(map[string]interface{})
98+
if !ok || userInfo == nil {
99+
utils.Logger.Warnf("WebSocket auth failed: no user in session")
100+
http.Error(w, "Authentication required", http.StatusUnauthorized)
101+
return
102+
}
103+
104+
// User is authenticated, proceed with WebSocket upgrade
105+
ws, err := upgrader.Upgrade(w, r, nil)
106+
if err != nil {
107+
utils.Logger.Error("WebSocket Upgrade Error:", err)
108+
return
109+
}
110+
defer ws.Close()
111+
112+
clients[ws] = true
113+
for {
114+
_, _, err := ws.ReadMessage()
115+
if err != nil {
116+
delete(clients, ws)
117+
break
118+
}
119+
}
120+
}
121+
}
122+
123+
// WebSocketHandler is kept for backward compatibility but should not be used
124+
// Use AuthenticatedWebSocketHandler instead
76125
func WebSocketHandler(w http.ResponseWriter, r *http.Request) {
77126
ws, err := upgrader.Upgrade(w, r, nil)
78127
if err != nil {

backend/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func main() {
127127

128128
mux.HandleFunc("/health", controllers.HealthCheckHandler)
129129

130-
mux.HandleFunc("/ws", controllers.WebSocketHandler)
130+
mux.HandleFunc("/ws", controllers.AuthenticatedWebSocketHandler(store))
131131

132132
// API documentation endpoint
133133
mux.HandleFunc("/api/docs/", httpSwagger.WrapHandler)

0 commit comments

Comments
 (0)