Commit c0e60a07 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'id-go-recovery-codes' into 'master'

Provide go implementation for 2fa_recovery_codes command

See merge request gitlab-org/gitlab-shell!285
parents 81bed658 98dbdfb7
...@@ -7,28 +7,28 @@ import ( ...@@ -7,28 +7,28 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
) )
var ( var (
binDir string binDir string
rootDir string rootDir string
reporter *reporting.Reporter readWriter *readwriter.ReadWriter
) )
func init() { func init() {
binDir = filepath.Dir(os.Args[0]) binDir = filepath.Dir(os.Args[0])
rootDir = filepath.Dir(binDir) rootDir = filepath.Dir(binDir)
reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr} readWriter = &readwriter.ReadWriter{Out: os.Stdout, In: os.Stdin, ErrOut: os.Stderr}
} }
// rubyExec will never return. It either replaces the current process with a // rubyExec will never return. It either replaces the current process with a
// Ruby interpreter, or outputs an error and kills the process. // Ruby interpreter, or outputs an error and kills the process.
func execRuby() { func execRuby() {
cmd := &fallback.Command{} cmd := &fallback.Command{}
if err := cmd.Execute(reporter); err != nil { if err := cmd.Execute(readWriter); err != nil {
fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err) fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
...@@ -38,7 +38,7 @@ func main() { ...@@ -38,7 +38,7 @@ func main() {
// warning as this isn't something we can sustain indefinitely // warning as this isn't something we can sustain indefinitely
config, err := config.NewFromDir(rootDir) config, err := config.NewFromDir(rootDir)
if err != nil { if err != nil {
fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby") fmt.Fprintln(readWriter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby")
execRuby() execRuby()
} }
...@@ -46,14 +46,14 @@ func main() { ...@@ -46,14 +46,14 @@ func main() {
if err != nil { if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on // For now this could happen if `SSH_CONNECTION` is not set on
// the environment // the environment
fmt.Fprintf(reporter.ErrOut, "%v\n", err) fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1) os.Exit(1)
} }
// The command will write to STDOUT on execution or replace the current // The command will write to STDOUT on execution or replace the current
// process in case of the `fallback.Command` // process in case of the `fallback.Command`
if err = cmd.Execute(reporter); err != nil { if err = cmd.Execute(readWriter); err != nil {
fmt.Fprintf(reporter.ErrOut, "%v\n", err) fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
...@@ -4,12 +4,13 @@ import ( ...@@ -4,12 +4,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
) )
type Command interface { type Command interface {
Execute(*reporting.Reporter) error Execute(*readwriter.ReadWriter) error
} }
func New(arguments []string, config *config.Config) (Command, error) { func New(arguments []string, config *config.Config) (Command, error) {
...@@ -30,6 +31,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command ...@@ -30,6 +31,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command
switch args.CommandType { switch args.CommandType {
case commandargs.Discover: case commandargs.Discover:
return &discover.Command{Config: config, Args: args} return &discover.Command{Config: config, Args: args}
case commandargs.TwoFactorRecover:
return &twofactorrecover.Command{Config: config, Args: args}
} }
return nil return nil
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper"
) )
...@@ -44,6 +45,19 @@ func TestNew(t *testing.T) { ...@@ -44,6 +45,19 @@ func TestNew(t *testing.T) {
}, },
expectedType: &fallback.Command{}, expectedType: &fallback.Command{},
}, },
{
desc: "it returns a TwoFactorRecover command if the feature is enabled",
arguments: []string{},
config: &config.Config{
GitlabUrl: "http+unix://gitlab.socket",
Migration: config.MigrationConfig{Enabled: true, Features: []string{"2fa_recovery_codes"}},
},
environment: map[string]string{
"SSH_CONNECTION": "1",
"SSH_ORIGINAL_COMMAND": "2fa_recovery_codes",
},
expectedType: &twofactorrecover.Command{},
},
} }
for _, tc := range testCases { for _, tc := range testCases {
......
...@@ -10,6 +10,7 @@ type CommandType string ...@@ -10,6 +10,7 @@ type CommandType string
const ( const (
Discover CommandType = "discover" Discover CommandType = "discover"
TwoFactorRecover CommandType = "2fa_recovery_codes"
) )
var ( var (
...@@ -79,4 +80,8 @@ func (c *CommandArgs) parseCommand(commandString string) { ...@@ -79,4 +80,8 @@ func (c *CommandArgs) parseCommand(commandString string) {
if commandString == "" { if commandString == "" {
c.CommandType = Discover c.CommandType = Discover
} }
if CommandType(commandString) == TwoFactorRecover {
c.CommandType = TwoFactorRecover
}
} }
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
) )
...@@ -14,16 +14,16 @@ type Command struct { ...@@ -14,16 +14,16 @@ type Command struct {
Args *commandargs.CommandArgs Args *commandargs.CommandArgs
} }
func (c *Command) Execute(reporter *reporting.Reporter) error { func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
response, err := c.getUserInfo() response, err := c.getUserInfo()
if err != nil { if err != nil {
return fmt.Errorf("Failed to get username: %v", err) return fmt.Errorf("Failed to get username: %v", err)
} }
if response.IsAnonymous() { if response.IsAnonymous() {
fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n") fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n")
} else { } else {
fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username) fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
} }
return nil return nil
...@@ -35,13 +35,5 @@ func (c *Command) getUserInfo() (*discover.Response, error) { ...@@ -35,13 +35,5 @@ func (c *Command) getUserInfo() (*discover.Response, error) {
return nil, err return nil, err
} }
if c.Args.GitlabKeyId != "" { return client.GetByCommandArgs(c.Args)
return client.GetByKeyId(c.Args.GitlabKeyId)
} else if c.Args.GitlabUsername != "" {
return client.GetByUsername(c.Args.GitlabUsername)
} else {
// There was no 'who' information, this matches the ruby error
// message.
return nil, fmt.Errorf("who='' is invalid")
}
} }
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
) )
...@@ -82,7 +82,7 @@ func TestExecute(t *testing.T) { ...@@ -82,7 +82,7 @@ func TestExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments} cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
err := cmd.Execute(&reporting.Reporter{Out: buffer}) err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String()) assert.Equal(t, tc.expectedOutput, buffer.String())
...@@ -122,7 +122,7 @@ func TestFailingExecute(t *testing.T) { ...@@ -122,7 +122,7 @@ func TestFailingExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments} cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
err := cmd.Execute(&reporting.Reporter{Out: buffer}) err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.Empty(t, buffer.String()) assert.Empty(t, buffer.String())
assert.EqualError(t, err, tc.expectedError) assert.EqualError(t, err, tc.expectedError)
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"path/filepath" "path/filepath"
"syscall" "syscall"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
) )
type Command struct{} type Command struct{}
...@@ -14,7 +14,7 @@ var ( ...@@ -14,7 +14,7 @@ var (
binDir = filepath.Dir(os.Args[0]) binDir = filepath.Dir(os.Args[0])
) )
func (c *Command) Execute(_ *reporting.Reporter) error { func (c *Command) Execute(_ *readwriter.ReadWriter) error {
rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby") rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby")
execErr := syscall.Exec(rubyCmd, os.Args, os.Environ()) execErr := syscall.Exec(rubyCmd, os.Args, os.Environ())
return execErr return execErr
......
package reporting package readwriter
import "io" import "io"
type Reporter struct { type ReadWriter struct {
Out io.Writer Out io.Writer
In io.Reader
ErrOut io.Writer ErrOut io.Writer
} }
package twofactorrecover
import (
"fmt"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
)
type Command struct {
Config *config.Config
Args *commandargs.CommandArgs
}
func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
if c.canContinue(readWriter) {
c.displayRecoveryCodes(readWriter)
} else {
fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
return nil
}
func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool {
question :=
"Are you sure you want to generate new two-factor recovery codes?\n" +
"Any existing recovery codes you saved will be invalidated. (yes/no)"
fmt.Fprintln(readWriter.Out, question)
var answer string
fmt.Fscanln(readWriter.In, &answer)
return answer == "yes"
}
func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) {
codes, err := c.getRecoveryCodes()
if err == nil {
messageWithCodes :=
"\nYour two-factor authentication recovery codes are:\n\n" +
strings.Join(codes, "\n") +
"\n\nDuring sign in, use one of the codes above when prompted for\n" +
"your two-factor code. Then, visit your Profile Settings and add\n" +
"a new device so you do not lose access to your account again.\n"
fmt.Fprint(readWriter.Out, messageWithCodes)
} else {
fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err)
}
}
func (c *Command) getRecoveryCodes() ([]string, error) {
client, err := twofactorrecover.NewClient(c.Config)
if err != nil {
return nil, err
}
return client.GetRecoveryCodes(c.Args)
}
package twofactorrecover
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
)
func setup(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
Handler: func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
require.NoError(t, err)
var requestBody *twofactorrecover.RequestBody
json.Unmarshal(b, &requestBody)
switch requestBody.KeyId {
case "1":
body := map[string]interface{}{
"success": true,
"recovery_codes": [2]string{"recovery", "codes"},
}
json.NewEncoder(w).Encode(body)
case "forbidden":
body := map[string]interface{}{
"success": false,
"message": "Forbidden!",
}
json.NewEncoder(w).Encode(body)
case "broken":
w.WriteHeader(http.StatusInternalServerError)
}
},
},
}
}
const (
question = "Are you sure you want to generate new two-factor recovery codes?\n" +
"Any existing recovery codes you saved will be invalidated. (yes/no)\n\n"
errorHeader = "An error occurred while trying to generate new recovery codes.\n"
)
func TestExecute(t *testing.T) {
setup(t)
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
testCases := []struct {
desc string
arguments *commandargs.CommandArgs
answer string
expectedOutput string
}{
{
desc: "With a known key id",
arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
answer: "yes\n",
expectedOutput: question +
"Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" +
"During sign in, use one of the codes above when prompted for\n" +
"your two-factor code. Then, visit your Profile Settings and add\n" +
"a new device so you do not lose access to your account again.\n",
},
{
desc: "With bad response",
arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
answer: "yes\n",
expectedOutput: question + errorHeader + "Parsing failed\n",
},
{
desc: "With API returns an error",
arguments: &commandargs.CommandArgs{GitlabKeyId: "forbidden"},
answer: "yes\n",
expectedOutput: question + errorHeader + "Forbidden!\n",
},
{
desc: "With API fails",
arguments: &commandargs.CommandArgs{GitlabKeyId: "broken"},
answer: "yes\n",
expectedOutput: question + errorHeader + "Internal API error (500)\n",
},
{
desc: "With missing arguments",
arguments: &commandargs.CommandArgs{},
answer: "yes\n",
expectedOutput: question + errorHeader + "who='' is invalid\n",
},
{
desc: "With negative answer",
arguments: &commandargs.CommandArgs{},
answer: "no\n",
expectedOutput: question +
"New recovery codes have *not* been generated. Existing codes will remain valid.\n",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
output := &bytes.Buffer{}
input := bytes.NewBufferString(tc.answer)
cmd := &Command{Config: testConfig, Args: tc.arguments}
err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String())
})
}
}
...@@ -17,8 +17,7 @@ const ( ...@@ -17,8 +17,7 @@ const (
type GitlabClient interface { type GitlabClient interface {
Get(path string) (*http.Response, error) Get(path string) (*http.Response, error)
// TODO: implement posts Post(path string, data interface{}) (*http.Response, error)
// Post(path string) (http.Response, error)
} }
type ErrorResponse struct { type ErrorResponse struct {
......
...@@ -19,9 +19,24 @@ func TestClients(t *testing.T) { ...@@ -19,9 +19,24 @@ func TestClients(t *testing.T) {
{ {
Path: "/api/v4/internal/hello", Path: "/api/v4/internal/hello",
Handler: func(w http.ResponseWriter, r *http.Request) { Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
fmt.Fprint(w, "Hello") fmt.Fprint(w, "Hello")
}, },
}, },
{
Path: "/api/v4/internal/post_endpoint",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
b, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
require.NoError(t, err)
fmt.Fprint(w, "Echo: "+string(b))
},
},
{ {
Path: "/api/v4/internal/auth", Path: "/api/v4/internal/auth",
Handler: func(w http.ResponseWriter, r *http.Request) { Handler: func(w http.ResponseWriter, r *http.Request) {
...@@ -68,6 +83,7 @@ func TestClients(t *testing.T) { ...@@ -68,6 +83,7 @@ func TestClients(t *testing.T) {
testBrokenRequest(t, tc.client) testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client) testSuccessfulGet(t, tc.client)
testSuccessfulPost(t, tc.client)
testMissing(t, tc.client) testMissing(t, tc.client)
testErrorMessage(t, tc.client) testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client) testAuthenticationHeader(t, tc.client)
...@@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) { ...@@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
}) })
} }
func testSuccessfulPost(t *testing.T, client GitlabClient) {
t.Run("Successful Post", func(t *testing.T) {
data := map[string]string{"key": "value"}
response, err := client.Post("/post_endpoint", data)
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
})
}
func testMissing(t *testing.T, client GitlabClient) { func testMissing(t *testing.T, client GitlabClient) {
t.Run("Missing error", func(t *testing.T) { t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing") response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)") assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response) assert.Nil(t, response)
}) })
t.Run("Missing error for POST", func(t *testing.T) {
response, err := client.Post("/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
})
} }
func testErrorMessage(t *testing.T, client GitlabClient) { func testErrorMessage(t *testing.T, client GitlabClient) {
t.Run("Error with message", func(t *testing.T) { t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error") response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that") assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response) assert.Nil(t, response)
}) })
t.Run("Error with message for POST", func(t *testing.T) {
response, err := client.Post("/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
} }
func testBrokenRequest(t *testing.T, client GitlabClient) { func testBrokenRequest(t *testing.T, client GitlabClient) {
t.Run("Broken request", func(t *testing.T) { t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken") response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable") assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response) assert.Nil(t, response)
}) })
t.Run("Broken request for POST", func(t *testing.T) {
response, err := client.Post("/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
})
} }
func testAuthenticationHeader(t *testing.T, client GitlabClient) { func testAuthenticationHeader(t *testing.T, client GitlabClient) {
t.Run("Authentication headers", func(t *testing.T) { t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth") response, err := client.Get("/auth")
defer response.Body.Close() defer response.Body.Close()
...@@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) { ...@@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header)) assert.Equal(t, "sssh, it's a secret", string(header))
}) })
t.Run("Authentication headers for POST", func(t *testing.T) {
response, err := client.Post("/auth", map[string]string{})
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
require.NoError(t, err)
header, err := base64.StdEncoding.DecodeString(string(responseBody))
require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header))
})
} }
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
) )
...@@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) { ...@@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil return &Client{config: config, client: client}, nil
} }
func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) {
if args.GitlabKeyId != "" {
return c.GetByKeyId(args.GitlabKeyId)
} else if args.GitlabUsername != "" {
return c.GetByUsername(args.GitlabUsername)
} else {
// There was no 'who' information, this matches the ruby error
// message.
return nil, fmt.Errorf("who='' is invalid")
}
}
func (c *Client) GetByKeyId(keyId string) (*Response, error) { func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{} params := url.Values{}
params.Add("key_id", keyId) params.Add("key_id", keyId)
......
package gitlabnet package gitlabnet
import ( import (
"bytes"
"context" "context"
"encoding/json"
"net" "net"
"net/http" "net/http"
"strings" "strings"
...@@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) { ...@@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
return doRequest(c.httpClient, c.config, request) return doRequest(c.httpClient, c.config, request)
} }
func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
path = normalizePath(path)
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
request.Header.Add("Content-Type", "application/json")
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
package twofactorrecover
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
type Client struct {
config *config.Config
client gitlabnet.GitlabClient
}
type Response struct {
Success bool `json:"success"`
RecoveryCodes []string `json:"recovery_codes"`
Message string `json:"message"`
}
type RequestBody struct {
KeyId string `json:"key_id,omitempty"`
UserId int64 `json:"user_id,omitempty"`
}
func NewClient(config *config.Config) (*Client, error) {
client, err := gitlabnet.GetClient(config)
if err != nil {
return nil, fmt.Errorf("Error creating http client: %v", err)
}
return &Client{config: config, client: client}, nil
}
func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, error) {
requestBody, err := c.getRequestBody(args)
if err != nil {
return nil, err
}
response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
if err != nil {
return nil, err
}
defer response.Body.Close()
parsedResponse, err := c.parseResponse(response)
if err != nil {
return nil, fmt.Errorf("Parsing failed")
}
if parsedResponse.Success {
return parsedResponse.RecoveryCodes, nil
} else {
return nil, errors.New(parsedResponse.Message)
}
}
func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
parsedResponse := &Response{}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if err := json.Unmarshal(body, parsedResponse); err != nil {
return nil, err
} else {
return parsedResponse, nil
}
}
func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
return nil, err
}
var requestBody *RequestBody
if args.GitlabKeyId != "" {
requestBody = &RequestBody{KeyId: args.GitlabKeyId}
} else {
userInfo, err := client.GetByCommandArgs(args)
if err != nil {
return nil, err
}
requestBody = &RequestBody{UserId: userInfo.UserId}
}
return requestBody, nil
}
package twofactorrecover
import (
"encoding/json"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
)
func initialize(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
Handler: func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
require.NoError(t, err)
var requestBody *RequestBody
json.Unmarshal(b, &requestBody)
switch requestBody.KeyId {
case "0":
body := map[string]interface{}{
"success": true,
"recovery_codes": [2]string{"recovery 1", "codes 1"},
}
json.NewEncoder(w).Encode(body)
case "1":
body := map[string]interface{}{
"success": false,
"message": "missing user",
}
json.NewEncoder(w).Encode(body)
case "2":
w.WriteHeader(http.StatusForbidden)
body := &gitlabnet.ErrorResponse{
Message: "Not allowed!",
}
json.NewEncoder(w).Encode(body)
case "3":
w.Write([]byte("{ \"message\": \"broken json!\""))
case "4":
w.WriteHeader(http.StatusForbidden)
}
if requestBody.UserId == 1 {
body := map[string]interface{}{
"success": true,
"recovery_codes": [2]string{"recovery 2", "codes 2"},
}
json.NewEncoder(w).Encode(body)
}
},
},
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
body := &discover.Response{
UserId: 1,
Username: "jane-doe",
Name: "Jane Doe",
}
json.NewEncoder(w).Encode(body)
},
},
}
}
func TestGetRecoveryCodesByKeyId(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
args := &commandargs.CommandArgs{GitlabKeyId: "0"}
result, err := client.GetRecoveryCodes(args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
}
func TestGetRecoveryCodesByUsername(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
args := &commandargs.CommandArgs{GitlabUsername: "jane-doe"}
result, err := client.GetRecoveryCodes(args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
}
func TestMissingUser(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
args := &commandargs.CommandArgs{GitlabKeyId: "1"}
_, err := client.GetRecoveryCodes(args)
assert.Equal(t, "missing user", err.Error())
}
func TestErrorResponses(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
testCases := []struct {
desc string
fakeId string
expectedError string
}{
{
desc: "A response with an error message",
fakeId: "2",
expectedError: "Not allowed!",
},
{
desc: "A response with bad JSON",
fakeId: "3",
expectedError: "Parsing failed",
},
{
desc: "An error response without message",
fakeId: "4",
expectedError: "Internal API error (403)",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId}
resp, err := client.GetRecoveryCodes(args)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)
})
}
}
func setup(t *testing.T) (*Client, func()) {
initialize(t)
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
client, err := NewClient(testConfig)
require.NoError(t, err)
return client, cleanup
}
require_relative 'spec_helper' require_relative 'spec_helper'
describe 'bin/gitlab-shell-authorized-keys-check' do describe 'bin/gitlab-shell-authorized-keys-check' do
def original_root_path include_context 'gitlab shell'
ROOT_PATH
end
# All this test boilerplate is mostly copy/pasted between
# gitlab_shell_gitlab_shell_spec.rb and
# gitlab_shell_authorized_keys_check_spec.rb
def tmp_root_path
@tmp_root_path ||= File.realpath(Dir.mktmpdir)
end
def config_path
File.join(tmp_root_path, 'config.yml')
end
def tmp_socket_path def tmp_socket_path
# This has to be a relative path shorter than 100 bytes due to # This has to be a relative path shorter than 100 bytes due to
...@@ -22,12 +9,8 @@ describe 'bin/gitlab-shell-authorized-keys-check' do ...@@ -22,12 +9,8 @@ describe 'bin/gitlab-shell-authorized-keys-check' do
'tmp/gitlab-shell-authorized-keys-check-socket' 'tmp/gitlab-shell-authorized-keys-check-socket'
end end
before(:all) do def mock_server(server)
FileUtils.mkdir_p(File.dirname(tmp_socket_path)) server.mount_proc('/api/v4/internal/authorized_keys') do |req, res|
FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
@server.mount_proc('/api/v4/internal/authorized_keys') do |req, res|
if req.query['key'] == 'known-rsa-key' if req.query['key'] == 'known-rsa-key'
res.status = 200 res.status = 200
res.content_type = 'application/json' res.content_type = 'application/json'
...@@ -36,28 +19,14 @@ describe 'bin/gitlab-shell-authorized-keys-check' do ...@@ -36,28 +19,14 @@ describe 'bin/gitlab-shell-authorized-keys-check' do
res.status = 404 res.status = 404
end end
end end
@webrick_thread = Thread.new { @server.start }
sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
File.open(config_path, 'w') do |f|
f.write("---\ngitlab_url: http+unix://#{CGI.escape(tmp_socket_path)}\n")
end end
copy_dirs = ['bin', 'lib'] before(:all) do
FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) }) write_config(
FileUtils.cp_r(copy_dirs, tmp_root_path) "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
end )
after(:all) do
@server.shutdown if @server
@webrick_thread.join if @webrick_thread
FileUtils.rm_rf(tmp_root_path)
end end
let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') }
let(:authorized_keys_check_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell-authorized-keys-check') } let(:authorized_keys_check_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell-authorized-keys-check') }
it 'succeeds when a valid key is given' do it 'succeeds when a valid key is given' do
......
...@@ -3,33 +3,10 @@ require_relative 'spec_helper' ...@@ -3,33 +3,10 @@ require_relative 'spec_helper'
require 'open3' require 'open3'
describe 'bin/gitlab-shell' do describe 'bin/gitlab-shell' do
def original_root_path include_context 'gitlab shell'
ROOT_PATH
end
# All this test boilerplate is mostly copy/pasted between
# gitlab_shell_gitlab_shell_spec.rb and
# gitlab_shell_authorized_keys_check_spec.rb
def tmp_root_path
@tmp_root_path ||= File.realpath(Dir.mktmpdir)
end
def config_path
File.join(tmp_root_path, 'config.yml')
end
def tmp_socket_path
# This has to be a relative path shorter than 100 bytes due to
# limitations in how Unix sockets work.
'tmp/gitlab-shell-socket'
end
before(:all) do
FileUtils.mkdir_p(File.dirname(tmp_socket_path))
FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) def mock_server(server)
@server.mount_proc('/api/v4/internal/discover') do |req, res| server.mount_proc('/api/v4/internal/discover') do |req, res|
identifier = req.query['key_id'] || req.query['username'] || req.query['user_id'] identifier = req.query['key_id'] || req.query['username'] || req.query['user_id']
known_identifiers = %w(10 someuser 100) known_identifiers = %w(10 someuser 100)
if known_identifiers.include?(identifier) if known_identifiers.include?(identifier)
...@@ -47,24 +24,16 @@ describe 'bin/gitlab-shell' do ...@@ -47,24 +24,16 @@ describe 'bin/gitlab-shell' do
res.status = 500 res.status = 500
end end
end end
@webrick_thread = Thread.new { @server.start }
sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
system(original_root_path, 'bin/compile')
copy_dirs = ['bin', 'lib']
FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) })
FileUtils.cp_r(copy_dirs, tmp_root_path)
end end
after(:all) do def run!(args, env: {'SSH_CONNECTION' => 'fake'})
@server.shutdown if @server cmd = [
@webrick_thread.join if @webrick_thread gitlab_shell_path,
FileUtils.rm_rf(tmp_root_path) args
end ].flatten.compact.join(' ')
let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') } Open3.capture3(env, cmd)
end
shared_examples 'results with keys' do shared_examples 'results with keys' do
# Basic valid input # Basic valid input
...@@ -175,19 +144,4 @@ describe 'bin/gitlab-shell' do ...@@ -175,19 +144,4 @@ describe 'bin/gitlab-shell' do
expect(status).not_to be_success expect(status).not_to be_success
end end
end end
def run!(args, env: {'SSH_CONNECTION' => 'fake'})
cmd = [
gitlab_shell_path,
args
].flatten.compact.join(' ')
Open3.capture3(env, cmd)
end
def write_config(config)
File.open(config_path, 'w') do |f|
f.write(config.to_yaml)
end
end
end end
require_relative 'spec_helper'
require 'open3'
describe 'bin/gitlab-shell 2fa_recovery_codes' do
include_context 'gitlab shell'
def mock_server(server)
server.mount_proc('/api/v4/internal/two_factor_recovery_codes') do |req, res|
res.content_type = 'application/json'
res.status = 200
key_id = req.query['key_id'] || req.query['user_id']
unless key_id
body = JSON.parse(req.body)
key_id = body['key_id'] || body['user_id'].to_s
end
if key_id == '100'
res.body = '{"success":true, "recovery_codes": ["1", "2"]}'
else
res.body = '{"success":false, "message": "Forbidden!"}'
end
end
server.mount_proc('/api/v4/internal/discover') do |req, res|
res.status = 200
res.content_type = 'application/json'
res.body = '{"id":100, "name": "Some User", "username": "someuser"}'
end
end
shared_examples 'dialog for regenerating recovery keys' do
context 'when the user agrees to regenerate keys' do
def verify_successful_regeneration!(cmd)
Open3.popen2(env, cmd) do |stdin, stdout|
expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
stdin.puts('yes')
expect(stdout.flush.read).to eq(
"\nYour two-factor authentication recovery codes are:\n\n" \
"1\n2\n\n" \
"During sign in, use one of the codes above when prompted for\n" \
"your two-factor code. Then, visit your Profile Settings and add\n" \
"a new device so you do not lose access to your account again.\n"
)
end
end
context 'when key is provided' do
let(:cmd) { "#{gitlab_shell_path} key-100" }
it 'the recovery keys are regenerated' do
verify_successful_regeneration!(cmd)
end
end
context 'when username is provided' do
let(:cmd) { "#{gitlab_shell_path} username-someone" }
it 'the recovery keys are regenerated' do
verify_successful_regeneration!(cmd)
end
end
end
context 'when the user disagrees to regenerate keys' do
let(:cmd) { "#{gitlab_shell_path} key-100" }
it 'the recovery keys are not regenerated' do
Open3.popen2(env, cmd) do |stdin, stdout|
expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
stdin.puts('no')
expect(stdout.flush.read).to eq(
"\nNew recovery codes have *not* been generated. Existing codes will remain valid.\n"
)
end
end
end
context 'when API error occurs' do
let(:cmd) { "#{gitlab_shell_path} key-101" }
context 'when the user agrees to regenerate keys' do
it 'the recovery keys are regenerated' do
Open3.popen2(env, cmd) do |stdin, stdout|
expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
stdin.puts('yes')
expect(stdout.flush.read).to eq("\nAn error occurred while trying to generate new recovery codes.\nForbidden!\n")
end
end
end
end
end
let(:env) { {'SSH_CONNECTION' => 'fake', 'SSH_ORIGINAL_COMMAND' => '2fa_recovery_codes' } }
describe 'without go features' do
before(:context) do
write_config(
"gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
)
end
it_behaves_like 'dialog for regenerating recovery keys'
end
describe 'with go features' do
before(:context) do
write_config(
"gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
"migration" => { "enabled" => true,
"features" => ["2fa_recovery_codes"] }
)
end
it_behaves_like 'dialog for regenerating recovery keys'
end
end
RSpec.shared_context 'gitlab shell', shared_context: :metadata do
def original_root_path
ROOT_PATH
end
def config_path
File.join(tmp_root_path, 'config.yml')
end
def write_config(config)
File.open(config_path, 'w') do |f|
f.write(config.to_yaml)
end
end
def tmp_root_path
@tmp_root_path ||= File.realpath(Dir.mktmpdir)
end
def mock_server(server)
raise NotImplementedError.new(
'mock_server method must be implemented in order to include gitlab shell context'
)
end
# This has to be a relative path shorter than 100 bytes due to
# limitations in how Unix sockets work.
def tmp_socket_path
'tmp/gitlab-shell-socket'
end
let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') }
before(:all) do
FileUtils.mkdir_p(File.dirname(tmp_socket_path))
FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
mock_server(@server)
@webrick_thread = Thread.new { @server.start }
sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
system(original_root_path, 'bin/compile')
copy_dirs = ['bin', 'lib']
FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) })
FileUtils.cp_r(copy_dirs, tmp_root_path)
end
after(:all) do
@server.shutdown if @server
@webrick_thread.join if @webrick_thread
FileUtils.rm_rf(tmp_root_path)
end
end
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment