Commit fe16eae3 authored by Nick Thomas's avatar Nick Thomas

Handle environments/:id/terminal.ws

A GitLab environment may expose a terminal connection for out-of-band access.
Workhorse is responsible for providing a websocket connection to the terminal
if present.

It authenticates the user and retrieves connection details from GitLab using
the environments/:id/terminal.ws/authorize endpoint, and sets up a proxy to
the terminal provider, converting from the remote's subprotocol to a common
format.

Authentication is periodically re-done, and the connection will be broken if
it fails, or if the connection details change in any way.
parent f2d0435b
{
"ImportPath": "gitlab.com/gitlab-org/gitlab-workhorse",
"GoVersion": "go1.7",
"GoVersion": "go1.5",
"GodepVersion": "v74",
"Packages": [
"./..."
],
"Deps": [
{
"ImportPath": "github.com/beorn7/perks/quantile",
......@@ -29,6 +32,11 @@
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "8ee79997227bf9b34611aee7946ae64735e6fd93"
},
{
"ImportPath": "github.com/gorilla/websocket",
"Comment": "v1.0.0-39-ge8f0f8a",
"Rev": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a"
},
{
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/pbutil",
"Comment": "v1.0.0-2-gc12348c",
......
......@@ -18,6 +18,8 @@ push/pull and Git archive downloads.
when handling a Git LFS upload Workhorse first asks permission from
Rails, then it stores the request body in a tempfile, then it sends
a modified request containing the tempfile path to Rails.
- Workhorse can manage long-lived WebSocket connections for Rails.
Example: handling the terminal websocket for environments.
- Workhorse does not connect to Redis or Postgres, only to Rails.
- We assume that all requests that reach Workhorse pass through an
upstream proxy such as NGINX or Apache first.
......
# Terminal support
In some cases, GitLab can provide in-browser terminal access to an
environment (which is a running server or container, onto which a
project has been deployed) through a WebSocket. Workhorse manages
the WebSocket upgrade and long-lived connection to the terminal for
the environment, which frees up GitLab to process other requests.
This document outlines the architecture of these connections.
## Introduction to WebSockets
A websocket is an "upgraded" HTTP/1.1 request. Their purpose is to
permit bidirectional communication between a client and a server.
**Websockets are not HTTP**. Clients can send messages (known as
frames) to the server at any time, and vice-versa. Client messages
are not necessarily requests, and server messages are not necessarily
responses. WebSocket URLs have schemes like `ws://` (unencrypted) or
`wss://` (TLS-secured).
When requesting an upgrade to WebSocket, the browser sends a HTTP/1.1
request that looks like this:
```
GET /path.ws HTTP/1.1
Connection: upgrade
Upgrade: websocket
Sec-WebSocket-Protocol: terminal.gitlab.com
# More headers, including security measures
```
At this point, the connection is still HTTP, so this is a request and
the server can send a normal HTTP response, including `404 Not Found`,
`500 Internal Server Error`, etc.
If the server decides to permit the upgrade, it will send a HTTP
`101 Switching Protocols` response. From this point, the connection
is no longer HTTP. It is a WebSocket and frames, not HTTP requests,
will flow over it. The connection will persist until the client or
server closes the connection.
In addition to the subprotocol, individual websocket frames may
also specify a message type - examples include `BinaryMessage`,
`TextMessage`, `Ping`, `Pong` or `Close`. Only binary frames can
contain arbitrary data - other frames are expected to be valid
UTF-8 strings, in addition to any subprotocol expectations.
## Browser to Workhorse
GitLab serves a JavaScript terminal emulator to the browser on
a URL like `https://gitlab.com/group/project/environments/1/terminal`.
This opens a websocket connection to, e.g.,
`wss://gitlab.com/group/project/environments/1/terminal.ws`,
This endpoint doesn't exist in GitLab - only in Workhorse.
When receiving the connection, Workhorse first checks that the
client is authorized to access the requested terminal. It does
this by performing a "preauthentication" request to GitLab.
If the client has the appropriate permissions and the terminal
exists, GitLab responds with a successful response that includes
details of the terminal that the client should be connected to.
Otherwise, it returns an appropriate HTTP error response.
Errors are passed back to the client as HTTP responses, but if
GitLab returns valid terminal details to Workhorse, it will
connect to the specified terminal, upgrade the browser to a
WebSocket, and proxy between the two connections for as long
as the browser's credentials are valid. Workhorse will also
send regular `PingMessage` control frames to the browser, to
keep intervening proxies from terminating the connection
while the browser is present.
The browser must request an upgrade with a specific subprotocol:
### `terminal.gitlab.com`
This subprotocol considers `TextMessage` frames to be invalid.
Control frames, such as `PingMessage` or `CloseMessage`, have
their usual meanings.
`BinaryMessage` frames sent from the browser to the server are
arbitrary terminal input.
`BinaryMessage` frames sent from the server to the browser are
arbitrary terminal output.
These frames are expected to contain ANSI terminal control codes
and may be in any encoding.
### `base64.terminal.gitlab.com`
This subprotocol considers `BinaryMessage` frames to be invalid.
Control frames, such as `PingMessage` or `CloseMessage`, have
their usual meanings.
`TextMessage` frames sent from the browser to the server are
base64-encoded arbitrary terminal input (so the server must
base64-decode them before inputting them).
`TextMessage` frames sent from the server to the browser are
base64-encoded arbitrary terminal output (so the browser must
base64-decode them before outputting them).
In their base64-encoded form, these frames are expected to
contain ANSI terminal control codes, and may be in any encoding.
## Workhorse to GitLab
Before upgrading the browser, Workhorse sends a normal HTTP
request to GitLab on a URL like
`https://gitlab.com/group/project/environments/1/terminal.ws/authorize`.
This returns a JSON response containing details of where the
terminal can be found, and how to connect it. In particular,
the following details are returned in case of success:
* WebSocket URL to **connect** to, e.g.: `wss://example.com/terminals/1.ws?tty=1`
* WebSocket subprotocols to support, e.g.: `["channel.k8s.io"]`
* Headers to send, e.g.: `Authorization: Token xxyyz..`
* Certificate authority to verify `wss` connections with (optional)
Workhorse periodically re-checks this endpoint, and if it gets an
error response, or the details of the terminal change, it will
terminate the websocket session.
## Workhorse to Terminal
In GitLab, environments may have a deployment service (e.g.,
`KubernetesService`) associated with them. This service knows
where the terminals for an environment may be found, and these
details are returned to Workhorse by GitLab.
These URLs are *also* WebSocket URLs, and GitLab tells Workhorse
which subprotocols to speak over the connection, along with any
authentication details required by the remote end.
Before upgrading the browser's connection to a websocket,
Workhorse opens a HTTP client connection, according to the
details given to it by Workhorse, and attempts to upgrade
that connection to a websocket. If it fails, an error
response is sent to the browser; otherwise, the browser is
also upgraded.
Workhorse now has two websocket connections, albeit with
differing subprotocols. It decodes incoming frames from the
browser, re-encodes them to the terminal's subprotocol, and
sends them to the terminal. Similarly, it decodes incoming
frames from the terminal, re-encodes them to the browser's
subprotocol, and sends them to the browser.
When either connection closes or enters an error state,
Workhorse detects the error and closes the other connection,
terminating the terminal session. If the browser is the
connection that has disconnected, Workhorse will send an ANSI
`End of Transmission` control code (the `0x04` byte) to the
terminal, encoded according to the appropriate subprotocol.
Workhorse will automatically reply to any websocket ping frame
sent by the terminal, to avoid being disconnected.
Currently, Workhorse only supports the following subprotocols.
Supporting new deployment services will require new subprotocols
to be supported:
### `channel.k8s.io`
Used by Kubernetes, this subprotocol defines a simple multiplexed
channel.
Control frames have their usual meanings. `TextMessage` frames are
invalid. `BinaryMessage` frames represent I/O to a specific file
descriptor.
The first byte of each `BinaryMessage` frame represents the file
descriptor (fd) number, as a `uint8` (so the value `0x00` corresponds
to fd 0, `STDIN`, while `0x01` corresponds to fd 1, `STDOUT`).
The remaining bytes represent arbitrary data. For frames received
from the server, they are bytes that have been received from that
fd. For frames sent to the server, they are bytes that should be
written to that fd.
### `base64.channel.k8s.io`
Also used by Kubernetes, this subprotocol defines a similar multiplexed
channel to `channel.k8s.io`. The main differences are:
* `TextMessage` frames are valid, rather than `BinaryMessage` frames.
* The first byte of each `TextMessage` frame represents the file
descriptor as a numeric UTF-8 character, so the character `U+0030`,
or "0", is fd 0, STDIN).
* The remaining bytes represent base64-encoded arbitrary data.
......@@ -59,6 +59,8 @@ type Response struct {
Archive string `json:"archive"`
// Entry is a filename inside the archive point to file that needs to be extracted
Entry string `json:"entry"`
// Used to communicate terminal session details
Terminal *TerminalSettings
}
// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
......@@ -143,23 +145,60 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt
return authReq, nil
}
func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authReq, err := api.newRequest(r, nil, suffix)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err))
return
// Perform a pre-authorization check against the API for the given HTTP request
//
// If `outErr` is set, the other fields will be nil and it should be treated as
// a 500 error.
//
// If httpResponse is present, the caller is responsible for closing its body
//
// authResponse will only be present if the authorization check was successful
func (api *API) PreAuthorize(suffix string, r *http.Request) (httpResponse *http.Response, authResponse *Response, outErr error) {
authReq, err := api.newRequest(r, nil, suffix)
if err != nil {
return nil, nil, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err)
}
httpResponse, err = api.Client.Do(authReq)
if err != nil {
return nil, nil, fmt.Errorf("preAuthorizeHandler: do request: %v", err)
}
defer func() {
if outErr != nil {
httpResponse.Body.Close()
httpResponse = nil
}
}()
if httpResponse.StatusCode != http.StatusOK {
return httpResponse, nil, nil
}
if contentType := httpResponse.Header.Get("Content-Type"); contentType != ResponseContentType {
return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: API responded with wrong content type: %v", contentType)
}
authResponse = &Response{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(httpResponse.Body).Decode(authResponse); err != nil {
return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)
}
authResponse, err := api.Client.Do(authReq)
return httpResponse, authResponse, nil
}
func (api *API) PreAuthorizeHandler(next HandleFunc, suffix string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpResponse, authResponse, err := api.PreAuthorize(suffix, r)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: do request: %v", err))
helper.Fail500(w, r, err)
return
}
defer authResponse.Body.Close()
if authResponse.StatusCode != 200 {
for k, v := range authResponse.Header {
if httpResponse.StatusCode != http.StatusOK {
for k, v := range httpResponse.Header {
// Accomodate broken clients that do case-sensitive header lookup
if k == "Www-Authenticate" {
w.Header()["WWW-Authenticate"] = v
......@@ -167,36 +206,25 @@ func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
w.Header()[k] = v
}
}
w.WriteHeader(authResponse.StatusCode)
io.Copy(w, authResponse.Body)
return
}
if contentType := authResponse.Header.Get("Content-Type"); contentType != ResponseContentType {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: API responded with wrong content type: %v", contentType))
return
}
w.WriteHeader(httpResponse.StatusCode)
io.Copy(w, httpResponse.Body)
httpResponse.Body.Close()
a := &Response{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
authResponse.Body.Close()
// Close the body immediately, rather than waiting for the next handler
// to complete
httpResponse.Body.Close()
// Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
// header to the client even in case of success as per RFC4559.
for k, v := range authResponse.Header {
for k, v := range httpResponse.Header {
// Case-insensitive comparison as per RFC7230
if strings.EqualFold(k, "WWW-Authenticate") {
w.Header()[k] = v
}
}
h(w, r, a)
next(w, r, authResponse)
})
}
package api
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/url"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
type TerminalSettings struct {
// The terminal provider may require use of a particular subprotocol. If so,
// it must be specified here, and Workhorse must have a matching codec.
Subprotocols []string
// The websocket URL to connect to.
Url string
// Any headers (e.g., Authorization) to send with the websocket request
Header http.Header
// The CA roots to validate the remote endpoint with, for wss:// URLs. The
// system-provided CA pool will be used if this is blank. PEM-encoded data.
CAPem string
}
func (t *TerminalSettings) URL() (*url.URL, error) {
return url.Parse(t.Url)
}
func (t *TerminalSettings) Dialer() *websocket.Dialer {
dialer := &websocket.Dialer{
Subprotocols: t.Subprotocols,
}
if len(t.CAPem) > 0 {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM([]byte(t.CAPem))
dialer.TLSClientConfig = &tls.Config{RootCAs: pool}
}
return dialer
}
func (t *TerminalSettings) Clone() *TerminalSettings {
// Doesn't clone the strings, but that's OK as strings are immutable in go
cloned := *t
cloned.Header = helper.HeaderClone(t.Header)
return &cloned
}
func (t *TerminalSettings) Dial() (*websocket.Conn, *http.Response, error) {
return t.Dialer().Dial(t.Url, t.Header)
}
func (t *TerminalSettings) Validate() error {
if t == nil {
return fmt.Errorf("Terminal details not specified")
}
if len(t.Subprotocols) == 0 {
return fmt.Errorf("No subprotocol specified")
}
parsedURL, err := t.URL()
if err != nil {
return fmt.Errorf("Invalid URL")
}
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
return fmt.Errorf("Invalid websocket scheme: %q", parsedURL.Scheme)
}
return nil
}
func (t *TerminalSettings) IsEqual(other *TerminalSettings) bool {
if t == nil && other == nil {
return true
}
if t == nil || other == nil {
return false
}
if len(t.Subprotocols) != len(other.Subprotocols) {
return false
}
for i, subprotocol := range t.Subprotocols {
if other.Subprotocols[i] != subprotocol {
return false
}
}
if len(t.Header) != len(other.Header) {
return false
}
for header, values := range t.Header {
if len(values) != len(other.Header[header]) {
return false
}
for i, value := range values {
if other.Header[header][i] != value {
return false
}
}
}
return t.Url == other.Url && t.CAPem == other.CAPem
}
package api
import (
"net/http"
"testing"
)
func terminal(url string, subprotocols ...string) *TerminalSettings {
return &TerminalSettings{
Url: url,
Subprotocols: subprotocols,
}
}
func ca(term *TerminalSettings) *TerminalSettings {
term = term.Clone()
term.CAPem = "Valid CA data"
return term
}
func header(term *TerminalSettings, values ...string) *TerminalSettings {
if len(values) == 0 {
values = []string{"Dummy Value"}
}
term = term.Clone()
term.Header = http.Header{
"Header": values,
}
return term
}
func TestClone(t *testing.T) {
a := ca(header(terminal("ws:", "", "")))
b := a.Clone()
if a == b {
t.Fatalf("Address of cloned terminal didn't change")
}
if &a.Subprotocols == &b.Subprotocols {
t.Fatalf("Address of cloned subprotocols didn't change")
}
if &a.Header == &b.Header {
t.Fatalf("Address of cloned header didn't change")
}
}
func TestValidate(t *testing.T) {
for i, tc := range []struct {
terminal *TerminalSettings
valid bool
msg string
}{
{nil, false, "nil terminal"},
{terminal("", ""), false, "empty URL"},
{terminal("ws:"), false, "empty subprotocols"},
{terminal("ws:", "foo"), true, "any subprotocol"},
{terminal("ws:", "foo", "bar"), true, "multiple subprotocols"},
{terminal("ws:", ""), true, "websocket URL"},
{terminal("wss:", ""), true, "secure websocket URL"},
{terminal("http:", ""), false, "HTTP URL"},
{terminal("https:", ""), false, " HTTPS URL"},
{ca(terminal("ws:", "")), true, "any CA pem"},
{header(terminal("ws:", "")), true, "any headers"},
{ca(header(terminal("ws:", ""))), true, "PEM and headers"},
} {
if err := tc.terminal.Validate(); (err != nil) == tc.valid {
t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.terminal)
}
}
}
func TestDialer(t *testing.T) {
terminal := terminal("ws:", "foo")
dialer := terminal.Dialer()
if len(dialer.Subprotocols) != len(terminal.Subprotocols) {
t.Fatalf("Subprotocols don't match: %+v vs. %+v", terminal.Subprotocols, dialer.Subprotocols)
}
for i, subprotocol := range terminal.Subprotocols {
if dialer.Subprotocols[i] != subprotocol {
t.Fatalf("Subprotocols don't match: %+v vs. %+v", terminal.Subprotocols, dialer.Subprotocols)
}
}
if dialer.TLSClientConfig != nil {
t.Fatalf("Unexpected TLSClientConfig: %+v", dialer)
}
terminal = ca(terminal)
dialer = terminal.Dialer()
if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil {
t.Fatalf("Custom CA certificates not recognised!")
}
}
func TestIsEqual(t *testing.T) {
term := terminal("ws:", "foo")
term_header2 := header(term, "extra")
term_header3 := header(term)
term_header3.Header.Add("Extra", "extra")
term_ca2 := ca(term)
term_ca2.CAPem = "other value"
for i, tc := range []struct {
termA *TerminalSettings
termB *TerminalSettings
expected bool
}{
{nil, nil, true},
{term, nil, false},
{nil, term, false},
{term, term, true},
{term.Clone(), term.Clone(), true},
{term, terminal("foo:"), false},
{term, terminal(term.Url), false},
{header(term), header(term), true},
{term_header2, term_header2, true},
{term_header3, term_header3, true},
{header(term), term_header2, false},
{header(term), term_header3, false},
{header(term), term, false},
{term, header(term), false},
{ca(term), ca(term), true},
{ca(term), term, false},
{term, ca(term), false},
{ca(header(term)), ca(header(term)), true},
{term_ca2, ca(term), false},
} {
if actual := tc.termA.IsEqual(tc.termB); tc.expected != actual {
t.Fatalf(
"test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v",
i, tc.termA, tc.termB, tc.expected, actual,
)
}
}
}
package helper
import (
"bufio"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
......@@ -43,26 +45,49 @@ func registerPrometheusMetrics() {
prometheus.MustRegister(requestsTotal)
}
type LoggingResponseWriter struct {
type LoggingResponseWriter interface {
http.ResponseWriter
Log(r *http.Request)
}
type loggingResponseWriter struct {
rw http.ResponseWriter
status int
written int64
started time.Time
}
type hijackingResponseWriter struct {
loggingResponseWriter
}
func NewLoggingResponseWriter(rw http.ResponseWriter) LoggingResponseWriter {
sessionsActive.Inc()
return LoggingResponseWriter{
out := loggingResponseWriter{
rw: rw,
started: time.Now(),
}
if _, ok := rw.(http.Hijacker); ok {
return &hijackingResponseWriter{out}
}
return &out
}
func (l *hijackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// The only way to gethere is through NewLoggingResponseWriter(), which
// checks that this cast will be valid.
hijacker := l.rw.(http.Hijacker)
return hijacker.Hijack()
}
func (l *LoggingResponseWriter) Header() http.Header {
func (l *loggingResponseWriter) Header() http.Header {
return l.rw.Header()
}
func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 {
l.WriteHeader(http.StatusOK)
}
......@@ -71,7 +96,7 @@ func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
return
}
func (l *LoggingResponseWriter) WriteHeader(status int) {
func (l *loggingResponseWriter) WriteHeader(status int) {
if l.status != 0 {
return
}
......@@ -80,7 +105,7 @@ func (l *LoggingResponseWriter) WriteHeader(status int) {
l.rw.WriteHeader(status)
}
func (l *LoggingResponseWriter) Log(r *http.Request) {
func (l *loggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started)
responseLogger.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started,
......
package terminal
import (
"errors"
"net/http"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
type AuthCheckerFunc func() *api.TerminalSettings
// Regularly checks that authorization is still valid for a terminal, outputting
// to the stopper when it isn't
type AuthChecker struct {
Checker AuthCheckerFunc
Template *api.TerminalSettings
StopCh chan error
Done chan struct{}
Count int64
}
var ErrAuthChanged = errors.New("Connection closed: authentication changed or endpoint unavailable.")
func NewAuthChecker(f AuthCheckerFunc, template *api.TerminalSettings, stopCh chan error) *AuthChecker {
return &AuthChecker{
Checker: f,
Template: template,
StopCh: stopCh,
Done: make(chan struct{}),
}
}
func (c *AuthChecker) Loop(interval time.Duration) {
for {
select {
case <-time.After(interval):
settings := c.Checker()
if !c.Template.IsEqual(settings) {
c.StopCh <- ErrAuthChanged
return
}
c.Count = c.Count + 1
case <-c.Done:
return
}
}
}
func (c *AuthChecker) Close() error {
close(c.Done)
return nil
}
// Generates a CheckerFunc from an *api.API + request needing authorization
func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc {
return func() *api.TerminalSettings {
httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r)
if err != nil {
return nil
}
defer httpResponse.Body.Close()
if httpResponse.StatusCode != http.StatusOK || authResponse == nil {
return nil
}
return authResponse.Terminal
}
}
package terminal
import (
"testing"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
func checkerSeries(values ...*api.TerminalSettings) AuthCheckerFunc {
return func() *api.TerminalSettings {
if len(values) == 0 {
return nil
}
out := values[0]
values = values[1:]
return out
}
}
func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
template := &api.TerminalSettings{Url: "ws://example.com"}
stopCh := make(chan error)
series := checkerSeries(template, template, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 3 {
t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
}
}
func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
template := &api.TerminalSettings{Url: "ws://example.com"}
changed := template.Clone()
changed.Url = "wss://example.com"
stopCh := make(chan error)
series := checkerSeries(template, changed, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 1 {
t.Fatalf("Expected 1 successful check, got %v", ac.Count)
}
}
package terminal
import (
"fmt"
"net"
"time"
"github.com/gorilla/websocket"
)
// ANSI "end of terminal" code
var eot = []byte{0x04}
// An abstraction of gorilla's *websocket.Conn
type Connection interface {
UnderlyingConn() net.Conn
ReadMessage() (int, []byte, error)
WriteMessage(int, []byte) error
WriteControl(int, []byte, time.Time) error
}
type Proxy struct {
StopCh chan error
}
// stoppers is the number of goroutines that may attempt to call Stop()
func NewProxy(stoppers int) *Proxy {
return &Proxy{
StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper
}
}
func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error {
// This signals the upstream terminal to kill the exec'd process
defer upstream.WriteMessage(websocket.BinaryMessage, eot)
go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr)
go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr)
return <-p.StopCh
}
func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) {
for {
messageType, data, err := from.ReadMessage()
if err != nil {
p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err)
break
}
if err := to.WriteMessage(messageType, data); err != nil {
p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err)
break
}
}
}
package terminal
import (
"log"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
var (
// See doc/terminal.md for documentation of this subprotocol
subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"}
upgrader = &websocket.Upgrader{Subprotocols: subprotocols}
ReauthenticationInterval = 5 * time.Minute
BrowserPingInterval = 30 * time.Second
)
func Handler(myAPI *api.API) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if err := a.Terminal.Validate(); err != nil {
helper.Fail500(w, r, err)
return
}
proxy := NewProxy(1) // one stopper: auth checker
checker := NewAuthChecker(
authCheckFunc(myAPI, r, "authorize"),
a.Terminal,
proxy.StopCh,
)
defer checker.Close()
go checker.Loop(ReauthenticationInterval)
ProxyTerminal(w, r, a.Terminal, proxy)
}, "authorize")
}
func ProxyTerminal(w http.ResponseWriter, r *http.Request, terminal *api.TerminalSettings, proxy *Proxy) {
server, err := connectToServer(terminal, r)
if err != nil {
helper.Fail500(w, r, err)
log.Printf("Terminal: connecting to server failed: %s", err)
return
}
defer server.UnderlyingConn().Close()
serverAddr := server.UnderlyingConn().RemoteAddr().String()
client, err := upgradeClient(w, r)
if err != nil {
log.Printf("Terminal: upgrading client to websocket failed: %s", err)
return
}
// Regularly send ping messages to the browser to keep the websocket from
// being timed out by intervening proxies.
go pingLoop(client)
defer client.UnderlyingConn().Close()
clientAddr := getClientAddr(r) // We can't know the port with confidence
log.Printf("Terminal: started proxying from %s to %s", clientAddr, serverAddr)
defer log.Printf("Terminal: finished proxying from %s to %s", clientAddr, serverAddr)
if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil {
log.Printf("Terminal: error proxying from %s to %s: %s", clientAddr, serverAddr, err)
}
}
// In the future, we might want to look at X-Client-Ip or X-Forwarded-For
func getClientAddr(r *http.Request) string {
return r.RemoteAddr
}
func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}
return Wrap(conn, conn.Subprotocol()), nil
}
func pingLoop(conn Connection) {
for {
time.Sleep(BrowserPingInterval)
deadline := time.Now().Add(5 * time.Second)
if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
// Either the connection was already closed so no further pings are
// needed, or this connection is now dead and no further pings can
// be sent.
break
}
}
}
func connectToServer(terminal *api.TerminalSettings, r *http.Request) (Connection, error) {
terminal = terminal.Clone()
// Pass along X-Forwarded-For, appending request.RemoteAddr, to the server
// we're connecting to.
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
if chains, ok := r.Header["X-Forwarded-For"]; ok {
terminal.Header.Set("X-Forwarded-For", strings.Join(chains, ", ")+", "+ip)
} else {
terminal.Header.Set("X-Forwarded-For", ip)
}
}
conn, _, err := terminal.Dial()
if err != nil {
return nil, err
}
return Wrap(conn, conn.Subprotocol()), nil
}
package terminal
import (
"encoding/base64"
"net"
"time"
"github.com/gorilla/websocket"
)
func Wrap(conn Connection, subprotocol string) Connection {
switch subprotocol {
case "channel.k8s.io":
return &kubeWrapper{base64: false, conn: conn}
case "base64.channel.k8s.io":
return &kubeWrapper{base64: true, conn: conn}
case "terminal.gitlab.com":
return &gitlabWrapper{base64: false, conn: conn}
case "base64.terminal.gitlab.com":
return &gitlabWrapper{base64: true, conn: conn}
}
return conn
}
type kubeWrapper struct {
base64 bool
conn Connection
}
type gitlabWrapper struct {
base64 bool
conn Connection
}
func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
mt, data, err := w.conn.ReadMessage()
if err != nil {
return mt, data, err
}
if isData(mt) {
mt = websocket.BinaryMessage
if w.base64 {
data, err = decodeBase64(data)
}
}
return mt, data, err
}
func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
if isData(mt) {
if w.base64 {
mt = websocket.TextMessage
data = encodeBase64(data)
} else {
mt = websocket.BinaryMessage
}
}
return w.conn.WriteMessage(mt, data)
}
func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
return w.conn.WriteControl(mt, data, deadline)
}
func (w *gitlabWrapper) UnderlyingConn() net.Conn {
return w.conn.UnderlyingConn()
}
// Coalesces all wsstreams into a single stream. In practice, we should only
// receive data on stream 1.
func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
mt, data, err := w.conn.ReadMessage()
if err != nil {
return mt, data, err
}
if isData(mt) {
mt = websocket.BinaryMessage
// Remove the WSStream channel number, decode to raw
if len(data) > 0 {
data = data[1:]
if w.base64 {
data, err = decodeBase64(data)
}
}
}
return mt, data, err
}
// Always sends to wsstream 0
func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
if isData(mt) {
if w.base64 {
mt = websocket.TextMessage
data = append([]byte{'0'}, encodeBase64(data)...)
} else {
mt = websocket.BinaryMessage
data = append([]byte{0}, data...)
}
}
return w.conn.WriteMessage(mt, data)
}
func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
return w.conn.WriteControl(mt, data, deadline)
}
func (w *kubeWrapper) UnderlyingConn() net.Conn {
return w.conn.UnderlyingConn()
}
func isData(mt int) bool {
return mt == websocket.BinaryMessage || mt == websocket.TextMessage
}
func encodeBase64(data []byte) []byte {
buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
base64.StdEncoding.Encode(buf, data)
return buf
}
func decodeBase64(data []byte) ([]byte, error) {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
n, err := base64.StdEncoding.Decode(buf, data)
return buf[:n], err
}
package terminal
import (
"bytes"
"errors"
"net"
"testing"
"time"
"github.com/gorilla/websocket"
)
type testcase struct {
input *fakeConn
expected *fakeConn
}
type fakeConn struct {
// WebSocket message type
mt int
data []byte
err error
}
func (f *fakeConn) ReadMessage() (int, []byte, error) {
return f.mt, f.data, f.err
}
func (f *fakeConn) WriteMessage(mt int, data []byte) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) UnderlyingConn() net.Conn {
return nil
}
func fake(mt int, data []byte, err error) *fakeConn {
return &fakeConn{mt: mt, data: []byte(data), err: err}
}
var (
msg = []byte("foo bar")
msgBase64 = []byte("Zm9vIGJhcg==")
kubeMsg = append([]byte{0}, msg...)
kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
fakeErr = errors.New("fake error")
text = websocket.TextMessage
binary = websocket.BinaryMessage
other = 999
fakeOther = fake(other, []byte("foo"), nil)
)
func assertEqual(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
if expected.mt != actual.mt {
t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
t.Fatalf(msg, args...)
}
if bytes.Compare(expected.data, actual.data) != 0 {
t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
t.Fatalf(msg, args...)
}
if expected.err != actual.err {
t.Logf("error expected to be %v but was %v", expected.err, actual.err)
t.Fatalf(msg, args...)
}
}
func TestReadMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, kubeMsg, fakeErr), fake(binary, kubeMsg, fakeErr)},
{fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
{fake(text, kubeMsg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(text, kubeMsgBase64, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
{fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(text, msgBase64, fakeErr), fake(text, msgBase64, fakeErr)},
{fake(text, msgBase64, nil), fake(binary, msg, nil)},
{fake(binary, msgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
conn := Wrap(tc.input, subprotocol)
mt, data, err := conn.ReadMessage()
actual := fake(mt, data, err)
assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}
func TestWriteMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, msg, fakeErr), fake(binary, kubeMsg, fakeErr)},
{fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
{fake(text, msg, nil), fake(binary, kubeMsg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(binary, msg, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
{fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
{fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(text, msgBase64, fakeErr)},
{fake(binary, msg, nil), fake(text, msgBase64, nil)},
{fake(text, msg, nil), fake(text, msgBase64, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
actual := fake(0, nil, tc.input.err)
conn := Wrap(actual, subprotocol)
actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}
......@@ -4,21 +4,28 @@ import (
"net/http"
"regexp"
"github.com/gorilla/websocket"
apipkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/artifacts"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/git"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs"
proxypkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal"
)
type route struct {
method string
regex *regexp.Regexp
handler http.Handler
type matcherFunc func(*http.Request) bool
type routeEntry struct {
method string
regex *regexp.Regexp
handler http.Handler
matchers []matcherFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
......@@ -30,6 +37,52 @@ const apiPattern = `^/api/`
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
func compileRegexp(regexpStr string) *regexp.Regexp {
if len(regexpStr) == 0 {
return nil
}
return regexp.MustCompile(regexpStr)
}
func route(method, regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
return routeEntry{
method: method,
regex: compileRegexp(regexpStr),
handler: denyWebsocket(handler),
matchers: matchers,
}
}
func wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
return routeEntry{
method: "GET",
regex: compileRegexp(regexpStr),
handler: handler,
matchers: append(matchers, websocket.IsWebSocketUpgrade),
}
}
func (ro *routeEntry) isMatch(cleanedPath string, req *http.Request) bool {
if ro.method != "" && req.Method != ro.method {
return false
}
if ro.regex != nil && !ro.regex.MatchString(cleanedPath) {
return false
}
ok := true
for _, matcher := range ro.matchers {
ok = matcher(req)
if !ok {
break
}
}
return ok
}
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
......@@ -58,47 +111,61 @@ func (u *Upstream) configureRoutes() {
)
ciAPIProxyQueue := queueing.QueueRequests(proxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
u.Routes = []route{
u.Routes = []routeEntry{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(api)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(api, proxy)},
route("GET", gitProjectPattern+`info/refs\z`, git.GetInfoRefs(api)),
route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.PostRPC(api))),
route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.PostRPC(api))),
route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, proxy)),
// CI Artifacts
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))},
route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))),
// Terminal websocket
wsRoute(projectPattern+`environments/[0-9]+/terminal.ws\z`, terminal.Handler(api)),
// Limit capacity given to builds/register.json
route{"", regexp.MustCompile(ciAPIPattern + `v1/builds/register.json\z`), ciAPIProxyQueue},
route("", ciAPIPattern+`v1/builds/register.json\z`, ciAPIProxyQueue),
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
route("", apiPattern, proxy),
route("", ciAPIPattern, proxy),
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix, staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
proxy,
),
route(
"", `^/assets/`,
static.ServeExisting(
u.URLPrefix,
staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode, proxy),
),
},
),
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route{"", regexp.MustCompile(`^/uploads/`), static.ErrorPagesUnless(u.DevelopmentMode, proxy)},
route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
// Serve static files or forward the requests
route{"", nil,
static.ServeExisting(u.URLPrefix, staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPagesUnless(u.DevelopmentMode,
proxy,
),
),
route(
"", "",
static.ServeExisting(
u.URLPrefix,
staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
),
},
),
}
}
func denyWebsocket(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if websocket.IsWebSocketUpgrade(r) {
helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest)
return
}
next.ServeHTTP(w, r)
})
}
......@@ -36,7 +36,7 @@ type Config struct {
type Upstream struct {
Config
URLPrefix urlprefix.Prefix
Routes []route
Routes []routeEntry
RoundTripper *badgateway.RoundTripper
}
......@@ -65,17 +65,17 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := helper.NewLoggingResponseWriter(ow)
defer w.Log(r)
helper.DisableResponseBuffering(&w)
helper.DisableResponseBuffering(w)
// Drop WebSocket connection and CONNECT method
// Drop RequestURI == "*" (FIXME: why?)
if r.RequestURI == "*" {
helper.HTTPError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
......@@ -83,29 +83,25 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
URIPath := urlprefix.CleanURIPath(r.URL.Path)
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.Routes {
if ro.method != "" && r.Method != ro.method {
continue
}
if ro.regex == nil || ro.regex.MatchString(prefix.Strip(URIPath)) {
foundService = true
// Look for a matching route
var route *routeEntry
for _, ro := range u.Routes {
if ro.isMatch(prefix.Strip(URIPath), r) {
route = &ro
break
}
}
if !foundService {
if route == nil {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(&w, r, "Forbidden", http.StatusForbidden)
helper.HTTPError(w, r, "Forbidden", http.StatusForbidden)
return
}
ro.handler.ServeHTTP(&w, r)
route.handler.ServeHTTP(w, r)
}
package main
import (
"bytes"
"encoding/pem"
"fmt"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
var terminalPath = fmt.Sprintf("%s/environments/1/terminal.ws", testProject)
type connWithReq struct {
conn *websocket.Conn
req *http.Request
}
func TestTerminalHappyPath(t *testing.T) {
serverConns, clientURL, close := wireupTerminal(nil, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
server := (<-serverConns).conn
defer server.Close()
message := "test message"
// channel.k8s.io: server writes to channel 1, STDOUT
if err := say(server, "\x01"+message); err != nil {
t.Fatal(err)
}
assertReadMessage(t, client, websocket.BinaryMessage, message)
if err := say(client, message); err != nil {
t.Fatal(err)
}
// channel.k8s.io: client writes get put on channel 0, STDIN
assertReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
// Closing the client should send an EOT signal to the server's STDIN
client.Close()
assertReadMessage(t, server, websocket.BinaryMessage, "\x00\x04")
}
func TestTerminalBadTLS(t *testing.T) {
_, clientURL, close := wireupTerminal(badCA, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != websocket.ErrBadHandshake {
t.Fatalf("Expected connection to fail ErrBadHandshake, got: %v", err)
}
if err == nil {
log.Println("TLS negotiation should have failed!")
defer client.Close()
}
}
func TestTerminalProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header)
hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupTerminal(setHeader(hdr), "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
defer client.Close()
sc := <-serverConns
defer sc.conn.Close()
if sc.req.Header.Get("Random-Header") != "Value" {
t.Fatal("Header specified by upstream not sent to remote")
}
}
func TestTerminalProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupTerminal(nil, "channel.k8s.io")
defer close()
hdr := make(http.Header)
hdr.Set("X-Forwarded-For", "127.0.0.2")
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
sc := <-serverConns
defer sc.conn.Close()
if xff := sc.req.Header.Get("X-Forwarded-For"); xff != "127.0.0.2, "+clientIP {
t.Fatalf("X-Forwarded-For from client not sent to remote: %+v", xff)
}
}
func wireupTerminal(modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := terminalOkBody(remote, nil, subprotocols...)
if modifier != nil {
modifier(authResponse)
}
upstream := testAuthServer(nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, terminalPath), func() {
workhorse.Close()
upstream.Close()
remote.Close()
}
}
func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) {
upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
connCh := make(chan connWithReq, 1)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("WEBSOCKET", r.Method, r.URL, r.Header)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("WEBSOCKET", r.Method, r.URL, "Upgrade failed", err)
return
}
connCh <- connWithReq{conn, r}
// The connection has been hijacked so it's OK to end here
}))
return connCh, server
}
func terminalOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
out := &api.Response{
Terminal: &api.TerminalSettings{
Url: websocketURL(remote.URL),
Header: header,
Subprotocols: subprotocols,
},
}
if len(remote.TLS.Certificates) > 0 {
data := bytes.NewBuffer(nil)
pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]})
out.Terminal.CAPem = data.String()
}
return out
}
func badCA(authResponse *api.Response) {
authResponse.Terminal.CAPem = "Bad CA"
}
func setHeader(hdr http.Header) func(*api.Response) {
return func(authResponse *api.Response) {
authResponse.Terminal.Header = hdr
}
}
func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) {
dialer := &websocket.Dialer{
Subprotocols: subprotocols,
}
return dialer.Dial(url, header)
}
func websocketURL(httpURL string, suffix ...string) string {
url, err := url.Parse(httpURL)
if err != nil {
panic(err)
}
switch url.Scheme {
case "http":
url.Scheme = "ws"
case "https":
url.Scheme = "wss"
default:
panic("Unknown scheme: " + url.Scheme)
}
url.Path = path.Join(url.Path, strings.Join(suffix, "/"))
return url.String()
}
func say(conn *websocket.Conn, message string) error {
return conn.WriteMessage(websocket.TextMessage, []byte(message))
}
func assertReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
messageType, data, err := conn.ReadMessage()
if err != nil {
t.Fatal(err)
}
if messageType != expectedMessageType {
t.Fatalf("Expected message, %d, got %d", expectedMessageType, messageType)
}
if string(data) != expectedData {
t.Fatalf("Message was mangled in transit. Expected %q, got %q", expectedData, string(data))
}
}
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml
\ No newline at end of file
language: go
sudo: false
matrix:
include:
- go: 1.4
- go: 1.5
- go: 1.6
- go: 1.7
- go: tip
allow_failures:
- go: tip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go vet $(go list ./... | grep -v /vendor/)
- go test -v -race ./...
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Joachim Bauch <mail@joachim-bauch.de>
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Gorilla WebSocket
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket)
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
### Documentation
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status
The Gorilla WebSocket package provides a complete and tested implementation of
the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The
package API is stable.
### Installation
go get github.com/gorilla/websocket
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="http://autobahn.ws/testsuite/">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.
This diff is collapsed.
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
)
func decompressNoContextTakeover(r io.Reader) io.Reader {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
}
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
tw := &truncWriter{w: w}
fw, err := flate.NewWriter(tw, 3)
return &flateWrapper{fw: fw, tw: tw}, err
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWrapper struct {
fw *flate.Writer
tw *truncWriter
}
func (w *flateWrapper) Write(p []byte) (int, error) {
return w.fw.Write(p)
}
func (w *flateWrapper) Close() error {
err1 := w.fw.Flush()
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
This diff is collapsed.
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements the WebSocket protocol defined in RFC 6455.
//
// Overview
//
// The Conn type represents a WebSocket connection. A server application uses
// the Upgrade function from an Upgrader object with a HTTP request handler
// to get a pointer to a Conn:
//
// var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024,
// WriteBufferSize: 1024,
// }
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := upgrader.Upgrade(w, r, nil)
// if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// Call the connection's WriteMessage and ReadMessage methods to send and
// receive messages as a slice of bytes. This snippet of code shows how to echo
// messages using these methods:
//
// for {
// messageType, p, err := conn.ReadMessage()
// if err != nil {
// return
// }
// if err = conn.WriteMessage(messageType, p); err != nil {
// return err
// }
// }
//
// In above snippet of code, p is a []byte and messageType is an int with value
// websocket.BinaryMessage or websocket.TextMessage.
//
// An application can also send and receive messages using the io.WriteCloser
// and io.Reader interfaces. To send a message, call the connection NextWriter
// method to get an io.WriteCloser, write the message to the writer and close
// the writer when done. To receive a message, call the connection NextReader
// method to get an io.Reader and read until io.EOF is returned. This snippet
// shows how to echo messages using the NextWriter and NextReader methods:
//
// for {
// messageType, r, err := conn.NextReader()
// if err != nil {
// return
// }
// w, err := conn.NextWriter(messageType)
// if err != nil {
// return err
// }
// if _, err := io.Copy(w, r); err != nil {
// return err
// }
// if err := w.Close(); err != nil {
// return err
// }
// }
//
// Data Messages
//
// The WebSocket protocol distinguishes between text and binary data messages.
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
// binary messages is left to the application.
//
// This package uses the TextMessage and BinaryMessage integer constants to
// identify the two data message types. The ReadMessage and NextReader methods
// return the type of the received message. The messageType argument to the
// WriteMessage and NextWriter methods specifies the type of a sent message.
//
// It is the application's responsibility to ensure that text messages are
// valid UTF-8 encoded text.
//
// Control Messages
//
// The WebSocket protocol defines three types of control messages: close, ping
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received close messages by sending a close message to the
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
// message Read method.
//
// Connections handle received ping and pong messages by invoking callback
// functions set with SetPingHandler and SetPongHandler methods. The callback
// functions are called from the NextReader, ReadMessage and the message Read
// methods.
//
// The default ping handler sends a pong to the peer. The application's reading
// goroutine can block for a short time while the handler writes the pong data
// to the connection.
//
// The application must read the connection to process ping, pong and close
// messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is:
//
// func readLoop(c *websocket.Conn) {
// for {
// if _, _, err := c.NextReader(); err != nil {
// c.Close()
// break
// }
// }
// }
//
// Concurrency
//
// Connections support one concurrent reader and one concurrent writer.
//
// Applications are responsible for ensuring that no more than one goroutine
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
// WriteJSON) concurrently and that no more than one goroutine calls the read
// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler,
// SetPingHandler) concurrently.
//
// The Close and WriteControl methods can be called concurrently with all other
// methods.
//
// Origin Considerations
//
// Web browsers allow Javascript applications to open a WebSocket connection to
// any host. It's up to the server to enforce an origin policy using the Origin
// request header sent by the browser.
//
// The Upgrader calls the function specified in the CheckOrigin field to check
// the origin. If the CheckOrigin function returns false, then the Upgrade
// method fails the WebSocket handshake with HTTP status 403.
//
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and not equal to the
// Host request header.
//
// An application can allow connections from any origin by specifying a
// function that always returns true:
//
// var upgrader = websocket.Upgrader{
// CheckOrigin: func(r *http.Request) bool { return true },
// }
//
// The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling
// Upgrade.
//
// Compression [Experimental]
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support. If compression was successfully negotiated with the connection's
// peer, any message received in compressed form will be automatically
// decompressed. All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(true)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"encoding/json"
"io"
)
// WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
// WriteJSON writes the JSON encoding of v to the connection.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package websocket
import (
"math/rand"
"unsafe"
)
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a default value of 4096 is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(w, r, status, err)
} else {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return u.Host == r.Host
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed")
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
var (
netConn net.Conn
br *bufio.Reader
err error
)
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// This function is deprecated, use websocket.Upgrader instead.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j += 1
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j += 1
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment