Commit de68567e authored by Adam Niedzielski's avatar Adam Niedzielski

Terminate terminal session after configurable time limit

parent 1ee9219a
...@@ -26,6 +26,10 @@ type TerminalSettings struct { ...@@ -26,6 +26,10 @@ type TerminalSettings struct {
// The CA roots to validate the remote endpoint with, for wss:// URLs. The // The CA roots to validate the remote endpoint with, for wss:// URLs. The
// system-provided CA pool will be used if this is blank. PEM-encoded data. // system-provided CA pool will be used if this is blank. PEM-encoded data.
CAPem string CAPem string
// The value is specified in seconds. It is converted to time.Duration
// later.
MaxSessionTime int
} }
func (t *TerminalSettings) URL() (*url.URL, error) { func (t *TerminalSettings) URL() (*url.URL, error) {
...@@ -112,5 +116,7 @@ func (t *TerminalSettings) IsEqual(other *TerminalSettings) bool { ...@@ -112,5 +116,7 @@ func (t *TerminalSettings) IsEqual(other *TerminalSettings) bool {
} }
} }
return t.Url == other.Url && t.CAPem == other.CAPem return t.Url == other.Url &&
t.CAPem == other.CAPem &&
t.MaxSessionTime == other.MaxSessionTime
} }
...@@ -7,8 +7,9 @@ import ( ...@@ -7,8 +7,9 @@ import (
func terminal(url string, subprotocols ...string) *TerminalSettings { func terminal(url string, subprotocols ...string) *TerminalSettings {
return &TerminalSettings{ return &TerminalSettings{
Url: url, Url: url,
Subprotocols: subprotocols, Subprotocols: subprotocols,
MaxSessionTime: 0,
} }
} }
...@@ -19,6 +20,13 @@ func ca(term *TerminalSettings) *TerminalSettings { ...@@ -19,6 +20,13 @@ func ca(term *TerminalSettings) *TerminalSettings {
return term return term
} }
func timeout(term *TerminalSettings) *TerminalSettings {
term = term.Clone()
term.MaxSessionTime = 600
return term
}
func header(term *TerminalSettings, values ...string) *TerminalSettings { func header(term *TerminalSettings, values ...string) *TerminalSettings {
if len(values) == 0 { if len(values) == 0 {
values = []string{"Dummy Value"} values = []string{"Dummy Value"}
...@@ -134,6 +142,7 @@ func TestIsEqual(t *testing.T) { ...@@ -134,6 +142,7 @@ func TestIsEqual(t *testing.T) {
{term, ca(term), false}, {term, ca(term), false},
{ca(header(term)), ca(header(term)), true}, {ca(header(term)), ca(header(term)), true},
{term_ca2, ca(term), false}, {term_ca2, ca(term), false},
{term, timeout(term), false},
} { } {
if actual := tc.termA.IsEqual(tc.termB); tc.expected != actual { if actual := tc.termA.IsEqual(tc.termB); tc.expected != actual {
t.Fatalf( t.Fatalf(
......
package terminal package terminal
import ( import (
"errors"
"fmt"
"log" "log"
"net/http" "net/http"
"time" "time"
...@@ -26,7 +28,7 @@ func Handler(myAPI *api.API) http.Handler { ...@@ -26,7 +28,7 @@ func Handler(myAPI *api.API) http.Handler {
return return
} }
proxy := NewProxy(1) // one stopper: auth checker proxy := NewProxy(2) // two stoppers: auth checker, max time
checker := NewAuthChecker( checker := NewAuthChecker(
authCheckFunc(myAPI, r, "authorize"), authCheckFunc(myAPI, r, "authorize"),
a.Terminal, a.Terminal,
...@@ -34,6 +36,7 @@ func Handler(myAPI *api.API) http.Handler { ...@@ -34,6 +36,7 @@ func Handler(myAPI *api.API) http.Handler {
) )
defer checker.Close() defer checker.Close()
go checker.Loop(ReauthenticationInterval) go checker.Loop(ReauthenticationInterval)
go closeAfterMaxTime(proxy, a.Terminal.MaxSessionTime)
ProxyTerminal(w, r, a.Terminal, proxy) ProxyTerminal(w, r, a.Terminal, proxy)
}, "authorize") }, "authorize")
...@@ -109,3 +112,17 @@ func connectToServer(terminal *api.TerminalSettings, r *http.Request) (Connectio ...@@ -109,3 +112,17 @@ func connectToServer(terminal *api.TerminalSettings, r *http.Request) (Connectio
return Wrap(conn, conn.Subprotocol()), nil return Wrap(conn, conn.Subprotocol()), nil
} }
func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) {
if maxSessionTime == 0 {
return
}
<-time.After(time.Duration(maxSessionTime) * time.Second)
proxy.StopCh <- errors.New(
fmt.Sprintf(
"Connection closed: session time greater than maximum time allowed - %v seconds",
maxSessionTime,
),
)
}
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"path" "path"
"strings" "strings"
"testing" "testing"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -71,6 +72,26 @@ func TestTerminalBadTLS(t *testing.T) { ...@@ -71,6 +72,26 @@ func TestTerminalBadTLS(t *testing.T) {
} }
} }
func TestTerminalSessionTimeout(t *testing.T) {
serverConns, clientURL, close := wireupTerminal(timeout, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
sc := <-serverConns
defer sc.conn.Close()
client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
_, _, err = client.ReadMessage()
if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
t.Fatalf("Client connection was not closed, got %v", err)
}
}
func TestTerminalProxyForwardsHeadersFromUpstream(t *testing.T) { func TestTerminalProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header) hdr := make(http.Header)
hdr.Set("Random-Header", "Value") hdr.Set("Random-Header", "Value")
...@@ -151,9 +172,10 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S ...@@ -151,9 +172,10 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S
func terminalOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response { func terminalOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
out := &api.Response{ out := &api.Response{
Terminal: &api.TerminalSettings{ Terminal: &api.TerminalSettings{
Url: websocketURL(remote.URL), Url: websocketURL(remote.URL),
Header: header, Header: header,
Subprotocols: subprotocols, Subprotocols: subprotocols,
MaxSessionTime: 0,
}, },
} }
...@@ -170,6 +192,10 @@ func badCA(authResponse *api.Response) { ...@@ -170,6 +192,10 @@ func badCA(authResponse *api.Response) {
authResponse.Terminal.CAPem = "Bad CA" authResponse.Terminal.CAPem = "Bad CA"
} }
func timeout(authResponse *api.Response) {
authResponse.Terminal.MaxSessionTime = 1
}
func setHeader(hdr http.Header) func(*api.Response) { func setHeader(hdr http.Header) func(*api.Response) {
return func(authResponse *api.Response) { return func(authResponse *api.Response) {
authResponse.Terminal.Header = hdr authResponse.Terminal.Header = hdr
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment