diff --git a/vm/proxy/proxy.go b/vm/proxy/proxy.go index 8f11b2a..c6ebe5b 100644 --- a/vm/proxy/proxy.go +++ b/vm/proxy/proxy.go @@ -211,8 +211,10 @@ func (p *Proxy) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } default: - //http: Request.RequestURI can't be set in client requests. - //http://golang.org/src/pkg/net/http/client.go + if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + p.proxyWebsocket(wr, req) + return + } req.RequestURI = "" req.URL.Scheme = "http" @@ -266,6 +268,69 @@ func (p *Proxy) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } } +func (p *Proxy) proxyWebsocket(wr http.ResponseWriter, req *http.Request) { + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + + clientConn, err := upgrader.Upgrade(wr, req, nil) + if err != nil { + log.Println("websocket upgrade (client) failed:", err) + return + } + + upstreamURL := fmt.Sprintf("ws://%s%s", MesheryServerHost, req.URL.RequestURI()) + + header := http.Header{} + copyHeader(header, req.Header) + + if p.token != "" { + header.Add("Cookie", fmt.Sprintf("token=%s", p.token)) + header.Add("Cookie", "meshery-provider=Layer5") + } + + backendConn, _, err := websocket.DefaultDialer.Dial(upstreamURL, header) + if err != nil { + log.Println("websocket dial (backend) failed:", err) + clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseTryAgainLater, err.Error())) + clientConn.Close() + return + } + + errc := make(chan error, 2) + + go func() { + for { + mt, message, err := clientConn.ReadMessage() + if err != nil { + errc <- err + return + } + if err := backendConn.WriteMessage(mt, message); err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + mt, message, err := backendConn.ReadMessage() + if err != nil { + errc <- err + return + } + if err := clientConn.WriteMessage(mt, message); err != nil { + errc <- err + return + } + } + }() + + <-errc + + clientConn.Close() + backendConn.Close() +} + func main() { var addr = flag.String("addr", "127.0.0.1:8080", "The addr of the application.") flag.Parse()