Commit fb57134d authored by Dave Cheney's avatar Dave Cheney Committed by Adam Langley

exp/ssh: alter Session to match the exec.Cmd API

This CL inverts the direction of the Stdin/out/err members of the
Session struct so they reflect the API of the exec.Cmd. In doing so
it borrows heavily from the exec package.

Additionally Shell now returns immediately, wait for completion using
Wait. Exec calls Wait internally and so blocks until the remote
command is complete.

Credit to Gustavo Niemeyer for the impetus for this CL.

R=rsc, agl, n13m3y3r, huin, bradfitz
CC=cw, golang-dev
https://golang.org/cl/5322055
parent 05d8d112
...@@ -342,17 +342,6 @@ func (c *clientChan) Close() error { ...@@ -342,17 +342,6 @@ func (c *clientChan) Close() error {
})) }))
} }
func (c *clientChan) sendChanReq(req channelRequestMsg) error {
if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
return err
}
msg := <-c.msg
if _, ok := msg.(*channelRequestSuccessMsg); ok {
return nil
}
return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
}
// Thread safe channel list. // Thread safe channel list.
type chanlist struct { type chanlist struct {
// protects concurrent access to chans // protects concurrent access to chans
......
...@@ -8,66 +8,104 @@ package ssh ...@@ -8,66 +8,104 @@ package ssh
// "RFC 4254, section 6". // "RFC 4254, section 6".
import ( import (
"encoding/binary" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil"
) )
// A Session represents a connection to a remote command or shell. // A Session represents a connection to a remote command or shell.
type Session struct { type Session struct {
// Writes to Stdin are made available to the remote command's standard input. // Stdin specifies the remote process's standard input.
// Closing Stdin causes the command to observe an EOF on its standard input. // If Stdin is nil, the remote process reads from an empty
Stdin io.WriteCloser // bytes.Buffer.
Stdin io.Reader
// Reads from Stdout and Stderr consume from the remote command's standard
// output and error streams, respectively. // Stdout and Stderr specify the remote process's standard
// There is a fixed amount of buffering that is shared for the two streams. // output and error.
// Failing to read from either may eventually cause the command to block. //
// Closing Stdout unblocks such writes and causes them to return errors. // If either is nil, Run connects the corresponding file
Stdout io.ReadCloser // descriptor to an instance of ioutil.Discard. There is a
Stderr io.Reader // fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote
// command to block.
Stdout io.Writer
Stderr io.Writer
*clientChan // the channel backing this session *clientChan // the channel backing this session
started bool // started is set to true once a Shell or Exec is invoked. started bool // true once a Shell or Exec is invoked.
copyFuncs []func() error
errch chan error // one send per copyFunc
}
// RFC 4254 Section 6.4.
type setenvRequest struct {
PeersId uint32
Request string
WantReply bool
Name string
Value string
} }
// Setenv sets an environment variable that will be applied to any // Setenv sets an environment variable that will be applied to any
// command executed by Shell or Exec. // command executed by Shell or Exec.
func (s *Session) Setenv(name, value string) error { func (s *Session) Setenv(name, value string) error {
n, v := []byte(name), []byte(value) req := setenvRequest{
nlen, vlen := stringLength(n), stringLength(v) PeersId: s.id,
payload := make([]byte, nlen+vlen) Request: "env",
marshalString(payload[:nlen], n) WantReply: true,
marshalString(payload[nlen:], v) Name: name,
Value: value,
return s.sendChanReq(channelRequestMsg{ }
PeersId: s.id, if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
Request: "env", return err
WantReply: true, }
RequestSpecificData: payload, return s.waitForResponse()
})
} }
// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8. // An empty mode list, see RFC 4254 Section 8.
var emptyModeList = []byte{0, 0, 0, 1, 0} var emptyModelist = "\x00"
// RFC 4254 Section 6.2.
type ptyRequestMsg struct {
PeersId uint32
Request string
WantReply bool
Term string
Columns uint32
Rows uint32
Width uint32
Height uint32
Modelist string
}
// RequestPty requests the association of a pty with the session on the remote host. // RequestPty requests the association of a pty with the session on the remote host.
func (s *Session) RequestPty(term string, h, w int) error { func (s *Session) RequestPty(term string, h, w int) error {
buf := make([]byte, 4+len(term)+16+len(emptyModeList)) req := ptyRequestMsg{
b := marshalString(buf, []byte(term)) PeersId: s.id,
binary.BigEndian.PutUint32(b, uint32(h)) Request: "pty-req",
binary.BigEndian.PutUint32(b[4:], uint32(w)) WantReply: true,
binary.BigEndian.PutUint32(b[8:], uint32(h*8)) Term: term,
binary.BigEndian.PutUint32(b[12:], uint32(w*8)) Columns: uint32(w),
copy(b[16:], emptyModeList) Rows: uint32(h),
Width: uint32(w * 8),
return s.sendChanReq(channelRequestMsg{ Height: uint32(h * 8),
PeersId: s.id, Modelist: emptyModelist,
Request: "pty-req", }
WantReply: true, if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
RequestSpecificData: buf, return err
}) }
return s.waitForResponse()
}
// RFC 4254 Section 6.5.
type execMsg struct {
PeersId uint32
Request string
WantReply bool
Command string
} }
// Exec runs cmd on the remote host. Typically, the remote // Exec runs cmd on the remote host. Typically, the remote
...@@ -75,34 +113,166 @@ func (s *Session) RequestPty(term string, h, w int) error { ...@@ -75,34 +113,166 @@ func (s *Session) RequestPty(term string, h, w int) error {
// A Session only accepts one call to Exec or Shell. // A Session only accepts one call to Exec or Shell.
func (s *Session) Exec(cmd string) error { func (s *Session) Exec(cmd string) error {
if s.started { if s.started {
return errors.New("session already started") return errors.New("ssh: session already started")
} }
cmdLen := stringLength([]byte(cmd)) req := execMsg{
payload := make([]byte, cmdLen) PeersId: s.id,
marshalString(payload, []byte(cmd)) Request: "exec",
s.started = true WantReply: true,
Command: cmd,
return s.sendChanReq(channelRequestMsg{ }
PeersId: s.id, if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
Request: "exec", return err
WantReply: true, }
RequestSpecificData: payload, if err := s.waitForResponse(); err != nil {
}) return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err)
}
if err := s.start(); err != nil {
return err
}
return s.Wait()
} }
// Shell starts a login shell on the remote host. A Session only // Shell starts a login shell on the remote host. A Session only
// accepts one call to Exec or Shell. // accepts one call to Exec or Shell.
func (s *Session) Shell() error { func (s *Session) Shell() error {
if s.started { if s.started {
return errors.New("session already started") return errors.New("ssh: session already started")
} }
s.started = true req := channelRequestMsg{
return s.sendChanReq(channelRequestMsg{
PeersId: s.id, PeersId: s.id,
Request: "shell", Request: "shell",
WantReply: true, WantReply: true,
}
if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
return err
}
if err := s.waitForResponse(); err != nil {
return fmt.Errorf("ssh: cound not execute shell: %v", err)
}
return s.start()
}
func (s *Session) waitForResponse() error {
msg := <-s.msg
switch msg.(type) {
case *channelRequestSuccessMsg:
return nil
case *channelRequestFailureMsg:
return errors.New("request failed")
}
return fmt.Errorf("unknown packet %T received: %v", msg, msg)
}
func (s *Session) start() error {
s.started = true
type F func(*Session) error
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
if err := setupFd(s); err != nil {
return err
}
}
s.errch = make(chan error, len(s.copyFuncs))
for _, fn := range s.copyFuncs {
go func(fn func() error) {
s.errch <- fn()
}(fn)
}
return nil
}
// Wait waits for the remote command to exit.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
}
waitErr := s.wait()
var copyError error
for _ = range s.copyFuncs {
if err := <-s.errch; err != nil && copyError == nil {
copyError = err
}
}
if waitErr != nil {
return waitErr
}
return copyError
}
func (s *Session) wait() error {
for {
switch msg := (<-s.msg).(type) {
case *channelRequestMsg:
// TODO(dfc) improve this behavior to match os.Waitmsg
switch msg.Request {
case "exit-status":
d := msg.RequestSpecificData
status := int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
if status > 0 {
return fmt.Errorf("remote process exited with %d", status)
}
return nil
case "exit-signal":
// TODO(dfc) make a more readable error message
return fmt.Errorf("%v", msg.RequestSpecificData)
default:
return fmt.Errorf("wait: unexpected channel request: %v", msg)
}
default:
return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
}
}
panic("unreachable")
}
func (s *Session) stdin() error {
if s.Stdin == nil {
s.Stdin = new(bytes.Buffer)
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(&chanWriter{
packetWriter: s,
id: s.id,
win: s.win,
}, s.Stdin)
return err
})
return nil
}
func (s *Session) stdout() error {
if s.Stdout == nil {
s.Stdout = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, &chanReader{
packetWriter: s,
id: s.id,
data: s.data,
})
return err
})
return nil
}
func (s *Session) stderr() error {
if s.Stderr == nil {
s.Stderr = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, &chanReader{
packetWriter: s,
id: s.id,
data: s.dataExt,
})
return err
}) })
return nil
} }
// NewSession returns a new interactive session on the remote host. // NewSession returns a new interactive session on the remote host.
...@@ -112,21 +282,6 @@ func (c *ClientConn) NewSession() (*Session, error) { ...@@ -112,21 +282,6 @@ func (c *ClientConn) NewSession() (*Session, error) {
return nil, err return nil, err
} }
return &Session{ return &Session{
Stdin: &chanWriter{
packetWriter: ch,
id: ch.id,
win: ch.win,
},
Stdout: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.data,
},
Stderr: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.dataExt,
},
clientChan: ch, clientChan: ch,
}, nil }, nil
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment