@@ -236,6 +236,13 @@ func isHopByHopHeader(name string) bool {
236236}
237237
238238func websocketShimResponseHandlerOpen (resp * http.Response , ws * wsSessionHelper ) error {
239+ if resp .StatusCode != http .StatusOK {
240+ respBody , err := io .ReadAll (resp .Body )
241+ if err != nil {
242+ return fmt .Errorf ("%v/open: http status code %v, error reading response body" , shimPath , resp .StatusCode )
243+ }
244+ return fmt .Errorf ("%v/open: http status code %v, response: %v" , shimPath , resp .StatusCode , string (respBody ))
245+ }
239246 p , err := io .ReadAll (resp .Body )
240247 if err != nil {
241248 return fmt .Errorf ("%v/open: failed to read response from agent: %v" , shimPath , err )
@@ -247,17 +254,6 @@ func websocketShimResponseHandlerOpen(resp *http.Response, ws *wsSessionHelper)
247254 return nil
248255}
249256
250- func websocketShimResponseHandlerData (resp * http.Response , ws * wsSessionHelper ) error {
251- if resp .StatusCode != http .StatusOK {
252- respBody , err := io .ReadAll (resp .Body )
253- if err != nil {
254- return fmt .Errorf ("%v/data: http status code is %v, error reading response body" , shimPath , resp .StatusCode )
255- }
256- return fmt .Errorf ("%v/data: http status code %v, response: %v" , shimPath , resp .StatusCode , string (respBody ))
257- }
258- return nil
259- }
260-
261257func websocketShimResponseHandlerPoll (resp * http.Response , ws * wsSessionHelper ) error {
262258 if resp .StatusCode != http .StatusOK {
263259 respBody , err := io .ReadAll (resp .Body )
@@ -274,13 +270,6 @@ func websocketShimResponseHandlerPoll(resp *http.Response, ws *wsSessionHelper)
274270 return nil
275271}
276272
277- func websocketShimResponseHandlerClose (resp * http.Response , ws * wsSessionHelper ) error {
278- if resp .StatusCode != http .StatusOK {
279- return fmt .Errorf ("%v/close: http status code %v" , shimPath , resp .StatusCode )
280- }
281- return nil
282- }
283-
284273// if the request is a normal HTTP request and not related to the websocket shim, nil value
285274// can be passed into the wsSessionHelper parameter
286275func (p * proxy ) handleFrontendRequest (w http.ResponseWriter , r * http.Request , ws * wsSessionHelper ) error {
@@ -318,15 +307,22 @@ func (p *proxy) handleFrontendRequest(w http.ResponseWriter, r *http.Request, ws
318307 return fmt .Errorf ("timeout waiting for the response to %q" , id )
319308 case resp := <- pending .respChan :
320309 // websocket shim endpoint handling
321- switch resp .Request .URL .Path {
322- case shimPath + "/open" :
323- return websocketShimResponseHandlerOpen (resp , ws )
324- case shimPath + "/data" :
325- return websocketShimResponseHandlerData (resp , ws )
326- case shimPath + "/poll" :
327- return websocketShimResponseHandlerPoll (resp , ws )
328- case shimPath + "/close" :
329- return websocketShimResponseHandlerClose (resp , ws )
310+ if ws != nil {
311+ // websocket shim endpoint handling
312+ if resp .StatusCode != http .StatusOK {
313+ respBody , err := io .ReadAll (resp .Body )
314+ if err != nil {
315+ return fmt .Errorf ("%v: http status code is %v, error reading response body" , r .URL .Path , resp .StatusCode )
316+ }
317+ return fmt .Errorf ("%v: http status code %v, response: %v" , r .URL .Path , resp .StatusCode , string (respBody ))
318+ }
319+ switch r .URL .Path {
320+ case shimPath + "/open" :
321+ return websocketShimResponseHandlerOpen (resp , ws )
322+ case shimPath + "/poll" :
323+ return websocketShimResponseHandlerPoll (resp , ws )
324+ }
325+ return nil
330326 }
331327
332328 // Copy all of the non-hop-by-hop headers to the proxied response
@@ -411,7 +407,7 @@ func (p *proxy) handleWebsocketRequest(w http.ResponseWriter, r *http.Request) e
411407 }
412408 req , err := http .NewRequest (http .MethodPost , fmt .Sprintf ("http://:%v%v/close" , port , shimPath ), bytes .NewBuffer (buf ))
413409 if err != nil {
414- log .Printf ("Failed to create a new open request: %v" , err )
410+ log .Printf ("Failed to create a new close request: %v" , err )
415411 return
416412 }
417413 req .Header .Set ("X-Websocket-Shim-Version" , "1" )
@@ -571,7 +567,7 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
571567
572568func main () {
573569 flag .IntVar (& port , "port" , 0 , "Port on which to listen" )
574- flag .IntVar (& setReadLimit , "ws-read-limit" , 0 , "websocket read limit from client in bytes" )
570+ flag .IntVar (& setReadLimit , "ws-read-limit" , - 1 , "websocket read limit from client in bytes" )
575571 flag .IntVar (& bufSize , "ws-buffer-size" , 1024 * 4 , "websocket buffer size for writes" )
576572 flag .StringVar (& shimPath , "shim-path" , "" , "Path under which to handle websocket shim requests" )
577573
0 commit comments