Commit bbbd41f4 authored by Dave Cheney's avatar Dave Cheney Committed by Adam Langley

exp/ssh: simplify client channel open logic

This is part one of a small set of CL's that aim to resolve
the outstanding TODOs relating to channel close and blocking
behavior.

Firstly, the hairy handling of assigning the peersId is now
done in one place. The cost of this change is the slightly
paradoxical construction of the partially created clientChan.

Secondly, by creating clientChan.stdin/out/err when the channel
is opened, the creation of consumers like tcpchan and Session
is simplified; they just have to wire themselves up to the
relevant readers/writers.

R=agl, gustav.paul, rsc
CC=golang-dev
https://golang.org/cl/5448073
parent 0a5508c6
...@@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() { ...@@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() {
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
packet = packet[9:] packet = packet[9:]
c.getChan(peersId).data <- packet[:length] c.getChan(peersId).stdout.data <- packet[:length]
} }
case msgChannelExtendedData: case msgChannelExtendedData:
if len(packet) < 13 { if len(packet) < 13 {
...@@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() { ...@@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() {
// for stderr on interactive sessions. Other data types are // for stderr on interactive sessions. Other data types are
// silently discarded. // silently discarded.
if datatype == 1 { if datatype == 1 {
c.getChan(peersId).dataExt <- packet[:length] c.getChan(peersId).stderr.data <- packet[:length]
} }
} }
default: default:
...@@ -228,9 +228,9 @@ func (c *ClientConn) mainLoop() { ...@@ -228,9 +228,9 @@ func (c *ClientConn) mainLoop() {
c.getChan(msg.PeersId).msg <- msg c.getChan(msg.PeersId).msg <- msg
case *channelCloseMsg: case *channelCloseMsg:
ch := c.getChan(msg.PeersId) ch := c.getChan(msg.PeersId)
close(ch.win) close(ch.stdin.win)
close(ch.data) close(ch.stdout.data)
close(ch.dataExt) close(ch.stderr.data)
c.chanlist.remove(msg.PeersId) c.chanlist.remove(msg.PeersId)
case *channelEOFMsg: case *channelEOFMsg:
c.getChan(msg.PeersId).msg <- msg c.getChan(msg.PeersId).msg <- msg
...@@ -241,7 +241,7 @@ func (c *ClientConn) mainLoop() { ...@@ -241,7 +241,7 @@ func (c *ClientConn) mainLoop() {
case *channelRequestMsg: case *channelRequestMsg:
c.getChan(msg.PeersId).msg <- msg c.getChan(msg.PeersId).msg <- msg
case *windowAdjustMsg: case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
default: default:
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg) fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
} }
...@@ -290,21 +290,49 @@ func (c *ClientConfig) rand() io.Reader { ...@@ -290,21 +290,49 @@ func (c *ClientConfig) rand() io.Reader {
type clientChan struct { type clientChan struct {
packetWriter packetWriter
id, peersId uint32 id, peersId uint32
data chan []byte // receives the payload of channelData messages stdin *chanWriter // receives window adjustments
dataExt chan []byte // receives the payload of channelExtendedData messages stdout *chanReader // receives the payload of channelData messages
win chan int // receives window adjustments stderr *chanReader // receives the payload of channelExtendedData messages
msg chan interface{} // incoming messages msg chan interface{} // incoming messages
} }
// newClientChan returns a partially constructed *clientChan
// using the local id provided. To be usable clientChan.peersId
// needs to be assigned once known.
func newClientChan(t *transport, id uint32) *clientChan { func newClientChan(t *transport, id uint32) *clientChan {
return &clientChan{ c := &clientChan{
packetWriter: t, packetWriter: t,
id: id, id: id,
data: make(chan []byte, 16),
dataExt: make(chan []byte, 16),
win: make(chan int, 16),
msg: make(chan interface{}, 16), msg: make(chan interface{}, 16),
} }
c.stdin = &chanWriter{
win: make(chan int, 16),
clientChan: c,
}
c.stdout = &chanReader{
data: make(chan []byte, 16),
clientChan: c,
}
c.stderr = &chanReader{
data: make(chan []byte, 16),
clientChan: c,
}
return c
}
// waitForChannelOpenResponse, if successful, fills out
// the peerId and records any initial window advertisement.
func (c *clientChan) waitForChannelOpenResponse() error {
switch msg := (<-c.msg).(type) {
case *channelOpenConfirmMsg:
// fixup peersId field
c.peersId = msg.MyId
c.stdin.win <- int(msg.MyWindow)
return nil
case *channelOpenFailureMsg:
return errors.New(safeString(msg.Message))
}
return errors.New("unexpected packet")
} }
// Close closes the channel. This does not close the underlying connection. // Close closes the channel. This does not close the underlying connection.
...@@ -355,10 +383,9 @@ func (c *chanlist) remove(id uint32) { ...@@ -355,10 +383,9 @@ func (c *chanlist) remove(id uint32) {
// A chanWriter represents the stdin of a remote process. // A chanWriter represents the stdin of a remote process.
type chanWriter struct { type chanWriter struct {
win chan int // receives window adjustments win chan int // receives window adjustments
peersId uint32 // the peer's id rwin int // current rwin size
rwin int // current rwin size clientChan *clientChan // the channel backing this writer
packetWriter // for sending channelDataMsg
} }
// Write writes data to the remote process's standard input. // Write writes data to the remote process's standard input.
...@@ -372,12 +399,13 @@ func (w *chanWriter) Write(data []byte) (n int, err error) { ...@@ -372,12 +399,13 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
w.rwin += win w.rwin += win
continue continue
} }
peersId := w.clientChan.peersId
n = len(data) n = len(data)
packet := make([]byte, 0, 9+n) packet := make([]byte, 0, 9+n)
packet = append(packet, msgChannelData, packet = append(packet, msgChannelData,
byte(w.peersId>>24), byte(w.peersId>>16), byte(w.peersId>>8), byte(w.peersId), byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
err = w.writePacket(append(packet, data...)) err = w.clientChan.writePacket(append(packet, data...))
w.rwin -= n w.rwin -= n
return return
} }
...@@ -385,7 +413,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) { ...@@ -385,7 +413,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
} }
func (w *chanWriter) Close() error { func (w *chanWriter) Close() error {
return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.peersId})) return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId}))
} }
// A chanReader represents stdout or stderr of a remote process. // A chanReader represents stdout or stderr of a remote process.
...@@ -393,10 +421,9 @@ type chanReader struct { ...@@ -393,10 +421,9 @@ type chanReader struct {
// TODO(dfc) a fixed size channel may not be the right data structure. // TODO(dfc) a fixed size channel may not be the right data structure.
// If writes to this channel block, they will block mainLoop, making // If writes to this channel block, they will block mainLoop, making
// it unable to receive new messages from the remote side. // it unable to receive new messages from the remote side.
data chan []byte // receives data from remote data chan []byte // receives data from remote
peersId uint32 // the peer's id clientChan *clientChan // the channel backing this reader
packetWriter // for sending windowAdjustMsg buf []byte
buf []byte
} }
// Read reads data from the remote process's stdout or stderr. // Read reads data from the remote process's stdout or stderr.
...@@ -407,10 +434,10 @@ func (r *chanReader) Read(data []byte) (int, error) { ...@@ -407,10 +434,10 @@ func (r *chanReader) Read(data []byte) (int, error) {
n := copy(data, r.buf) n := copy(data, r.buf)
r.buf = r.buf[n:] r.buf = r.buf[n:]
msg := windowAdjustMsg{ msg := windowAdjustMsg{
PeersId: r.peersId, PeersId: r.clientChan.peersId,
AdditionalBytes: uint32(n), AdditionalBytes: uint32(n),
} }
return n, r.writePacket(marshal(msgChannelWindowAdjust, msg)) return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
} }
r.buf, ok = <-r.data r.buf, ok = <-r.data
if !ok { if !ok {
......
...@@ -285,13 +285,8 @@ func (s *Session) stdin() error { ...@@ -285,13 +285,8 @@ func (s *Session) stdin() error {
s.Stdin = new(bytes.Buffer) s.Stdin = new(bytes.Buffer)
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
w := &chanWriter{ _, err := io.Copy(s.clientChan.stdin, s.Stdin)
packetWriter: s, if err1 := s.clientChan.stdin.Close(); err == nil {
peersId: s.peersId,
win: s.win,
}
_, err := io.Copy(w, s.Stdin)
if err1 := w.Close(); err == nil {
err = err1 err = err1
} }
return err return err
...@@ -304,12 +299,7 @@ func (s *Session) stdout() error { ...@@ -304,12 +299,7 @@ func (s *Session) stdout() error {
s.Stdout = ioutil.Discard s.Stdout = ioutil.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
r := &chanReader{ _, err := io.Copy(s.Stdout, s.clientChan.stdout)
packetWriter: s,
peersId: s.peersId,
data: s.data,
}
_, err := io.Copy(s.Stdout, r)
return err return err
}) })
return nil return nil
...@@ -320,12 +310,7 @@ func (s *Session) stderr() error { ...@@ -320,12 +310,7 @@ func (s *Session) stderr() error {
s.Stderr = ioutil.Discard s.Stderr = ioutil.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
r := &chanReader{ _, err := io.Copy(s.Stderr, s.clientChan.stderr)
packetWriter: s,
peersId: s.peersId,
data: s.dataExt,
}
_, err := io.Copy(s.Stderr, r)
return err return err
}) })
return nil return nil
...@@ -398,19 +383,11 @@ func (c *ClientConn) NewSession() (*Session, error) { ...@@ -398,19 +383,11 @@ func (c *ClientConn) NewSession() (*Session, error) {
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, err return nil, err
} }
// wait for response if err := ch.waitForChannelOpenResponse(); err != nil {
msg := <-ch.msg
switch msg := msg.(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
return &Session{
clientChan: ch,
}, nil
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, fmt.Errorf("ssh: channel open failed: %s", msg.Message) return nil, fmt.Errorf("ssh: unable to open session: %v", err)
} }
c.chanlist.remove(ch.id) return &Session{
return nil, fmt.Errorf("ssh: unexpected message %T: %v", msg, msg) clientChan: ch,
}, nil
} }
...@@ -6,6 +6,7 @@ package ssh ...@@ -6,6 +6,7 @@ package ssh
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
) )
...@@ -42,20 +43,21 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err ...@@ -42,20 +43,21 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err
}, nil }, nil
} }
// RFC 4254 7.2
type channelOpenDirectMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
raddr string
rport uint32
laddr string
lport uint32
}
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as // dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
// strings and are expected to be resolveable at the remote end. // strings and are expected to be resolveable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) { func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
// RFC 4254 7.2
type channelOpenDirectMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
raddr string
rport uint32
laddr string
lport uint32
}
ch := c.newChan(c.transport) ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{ if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip", ChanType: "direct-tcpip",
...@@ -70,30 +72,14 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc ...@@ -70,30 +72,14 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, err return nil, err
} }
// wait for response if err := ch.waitForChannelOpenResponse(); err != nil {
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
default:
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, errors.New("ssh: unexpected packet") return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
} }
return &tcpchan{ return &tcpchan{
clientChan: ch, clientChan: ch,
Reader: &chanReader{ Reader: ch.stdout,
packetWriter: ch, Writer: ch.stdin,
peersId: ch.peersId,
data: ch.data,
},
Writer: &chanWriter{
packetWriter: ch,
peersId: ch.peersId,
win: ch.win,
},
}, nil }, nil
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment