Update go dependencies

This commit is contained in:
Manuel de Brito Fontes 2018-05-26 11:27:53 -04:00 committed by Manuel Alejandro de Brito Fontes
parent 15ffb51394
commit bb4d483837
No known key found for this signature in database
GPG key ID: 786136016A8BA02A
1621 changed files with 86368 additions and 284392 deletions

View file

@ -1,122 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
"testing"
)
type server struct {
*ServerConn
chans <-chan NewChannel
}
func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
sconn, chans, reqs, err := NewServerConn(c, conf)
if err != nil {
return nil, err
}
go DiscardRequests(reqs)
return &server{sconn, chans}, nil
}
func (s *server) Accept() (NewChannel, error) {
n, ok := <-s.chans
if !ok {
return nil, io.EOF
}
return n, nil
}
func sshPipe() (Conn, *server, error) {
c1, c2, err := netPipe()
if err != nil {
return nil, nil, err
}
clientConf := ClientConfig{
User: "user",
}
serverConf := ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["ecdsa"])
done := make(chan *server, 1)
go func() {
server, err := newServer(c2, &serverConf)
if err != nil {
done <- nil
}
done <- server
}()
client, _, reqs, err := NewClientConn(c1, "", &clientConf)
if err != nil {
return nil, nil, err
}
server := <-done
if server == nil {
return nil, nil, errors.New("server handshake failed.")
}
go DiscardRequests(reqs)
return client, server, nil
}
func BenchmarkEndToEnd(b *testing.B) {
b.StopTimer()
client, server, err := sshPipe()
if err != nil {
b.Fatalf("sshPipe: %v", err)
}
defer client.Close()
defer server.Close()
size := (1 << 20)
input := make([]byte, size)
output := make([]byte, size)
b.SetBytes(int64(size))
done := make(chan int, 1)
go func() {
newCh, err := server.Accept()
if err != nil {
b.Fatalf("Client: %v", err)
}
ch, incoming, err := newCh.Accept()
go DiscardRequests(incoming)
for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(ch, output); err != nil {
b.Fatalf("ReadFull: %v", err)
}
}
ch.Close()
done <- 1
}()
ch, in, err := client.OpenChannel("speed", nil)
if err != nil {
b.Fatalf("OpenChannel: %v", err)
}
go DiscardRequests(in)
b.ResetTimer()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := ch.Write(input); err != nil {
b.Fatalf("WriteFull: %v", err)
}
}
ch.Close()
b.StopTimer()
<-done
}

View file

@ -1,87 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"testing"
)
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
func TestBufferReadwrite(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
r, _ := b.Read(make([]byte, 10))
if r != 10 {
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
r, _ = b.Read(make([]byte, 10))
if r != 5 {
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
}
b = newBuffer()
b.write(alphabet[:10])
r, _ = b.Read(make([]byte, 5))
if r != 5 {
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
b.write(alphabet[5:15])
r, _ = b.Read(make([]byte, 10))
r2, _ := b.Read(make([]byte, 10))
if r != 10 || r2 != 5 || 15 != r+r2 {
t.Fatal("Expected written == read == 15")
}
}
func TestBufferClose(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
b.eof()
_, err := b.Read(make([]byte, 5))
if err != nil {
t.Fatal("expected read of 5 to not return EOF")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err := b.Read(make([]byte, 5))
r2, err2 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || err != nil || err2 != nil {
t.Fatal("expected reads of 5 and 5")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err = b.Read(make([]byte, 5))
r2, err2 = b.Read(make([]byte, 10))
r3, err3 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
t.Fatal("expected reads of 5 and 5 and 0, with EOF")
}
b = newBuffer()
b.write(make([]byte, 5))
b.write(make([]byte, 10))
b.eof()
r, err = b.Read(make([]byte, 9))
r2, err2 = b.Read(make([]byte, 3))
r3, err3 = b.Read(make([]byte, 3))
r4, err4 := b.Read(make([]byte, 10))
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
}
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
t.Fatal("Expected written == read == 15", r, r2, r3, r4)
}
}

View file

@ -44,7 +44,9 @@ type Signature struct {
const CertTimeInfinity = 1<<64 - 1
// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
// PublicKey interface, so it can be unmarshaled using
// ParsePublicKey.
type Certificate struct {
Nonce []byte
Key PublicKey
@ -340,10 +342,10 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
// the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
}
for opt, _ := range cert.CriticalOptions {
for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by
// serverAuthenticate
if opt == sourceAddressCriticalOption {

View file

@ -1,222 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"reflect"
"testing"
"time"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
// % ssh-keygen -s ca-key -I test user-key
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
if _, ok := key.(*Certificate); !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
// Critical Options:
// force-command /bin/sleep
// source-address 192.168.1.0/24
// Extensions:
// permit-X11-forwarding
// permit-agent-forwarding
// permit-port-forwarding
// permit-pty
// permit-user-rc
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ`
func TestParseCertWithOptions(t *testing.T) {
opts := map[string]string{
"source-address": "192.168.1.0/24",
"force-command": "/bin/sleep",
}
exts := map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}
authKeyBytes := []byte(exampleSSHCertWithOptions)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
cert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
if !reflect.DeepEqual(cert.CriticalOptions, opts) {
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts)
}
if !reflect.DeepEqual(cert.Extensions, exts) {
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
func TestValidateCert(t *testing.T) {
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
validCert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
if err := checker.CheckCert("user", validCert); err != nil {
t.Errorf("Unable to validate certificate: %v", err)
}
invalidCert := &Certificate{
Key: testPublicKeys["rsa"],
SignatureKey: testPublicKeys["ecdsa"],
ValidBefore: CertTimeInfinity,
Signature: &Signature{},
}
if err := checker.CheckCert("user", invalidCert); err == nil {
t.Error("Invalid cert signature passed validation")
}
}
func TestValidateCertTime(t *testing.T) {
cert := Certificate{
ValidPrincipals: []string{"user"},
Key: testPublicKeys["rsa"],
ValidAfter: 50,
ValidBefore: 100,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
for ts, ok := range map[int64]bool{
25: false,
50: true,
99: true,
100: false,
125: false,
} {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
if v := checker.CheckCert("user", &cert); (v == nil) != ok {
t.Errorf("Authenticate(%d): %v", ts, v)
}
}
}
// TODO(hanwen): tests for
//
// host keys:
// * fallbacks
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
IsHostAuthority: func(p PublicKey, addr string) bool {
return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Errorf("NewCertSigner: %v", err)
}
for _, test := range []struct {
addr string
succeed bool
}{
{addr: "hostname:22", succeed: true},
{addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
{addr: "lasthost:22", succeed: false},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
errc := make(chan error)
go func() {
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(certSigner)
_, _, _, err := NewServerConn(c1, &conf)
errc <- err
}()
config := &ClientConfig{
User: "user",
HostKeyCallback: checker.CheckHostKey,
}
_, _, _, err = NewClientConn(c2, test.addr, config)
if (err == nil) != test.succeed {
t.Fatalf("NewClientConn(%q): %v", test.addr, err)
}
err = <-errc
if (err == nil) != test.succeed {
t.Fatalf("NewServerConn(%q): %v", test.addr, err)
}
}
}

View file

@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error {
c.writeMu.Lock()
if c.sentClose {
c.writeMu.Unlock()
func (ch *channel) writePacket(packet []byte) error {
ch.writeMu.Lock()
if ch.sentClose {
ch.writeMu.Unlock()
return io.EOF
}
c.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet)
c.writeMu.Unlock()
ch.sentClose = (packet[0] == msgChannelClose)
err := ch.mux.conn.writePacket(packet)
ch.writeMu.Unlock()
return err
}
func (c *channel) sendMessage(msg interface{}) error {
func (ch *channel) sendMessage(msg interface{}) error {
if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg)
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
}
p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId)
return c.writePacket(p)
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return ch.writePacket(p)
}
// WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF {
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if ch.sentEOF {
return 0, io.EOF
}
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData
}
c.writeMu.Lock()
packet := c.packetPool[extendedCode]
ch.writeMu.Lock()
packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be
// flagged as errors by the race detector.
c.writeMu.Unlock()
ch.writeMu.Unlock()
for len(data) > 0 {
space := min(c.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil {
space := min(ch.maxRemotePayload, len(data))
if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err
}
if want := headerLength + space; uint32(cap(packet)) < want {
@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space]
packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId)
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
}
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil {
if err = ch.writePacket(packet); err != nil {
return n, err
}
@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):]
}
c.writeMu.Lock()
c.packetPool[extendedCode] = packet
c.writeMu.Unlock()
ch.writeMu.Lock()
ch.packetPool[extendedCode] = packet
ch.writeMu.Unlock()
return n, err
}
func (c *channel) handleData(packet []byte) error {
func (ch *channel) handleData(packet []byte) error {
headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData {
@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 {
return nil
}
if length > c.maxIncomingPayload {
if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size")
}
@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length")
}
c.windowMu.Lock()
if c.myWindow < length {
c.windowMu.Unlock()
ch.windowMu.Lock()
if ch.myWindow < length {
ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much")
}
c.myWindow -= length
c.windowMu.Unlock()
ch.myWindow -= length
ch.windowMu.Unlock()
if extended == 1 {
c.extPending.write(data)
ch.extPending.write(data)
} else if extended > 0 {
// discard other extended data.
} else {
c.pending.write(data)
ch.pending.write(data)
}
return nil
}
@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the
// given channel.
func (c *channel) responseMessageReceived() error {
if c.direction == channelInbound {
func (ch *channel) responseMessageReceived() error {
if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel")
}
if c.decided {
if ch.decided {
return errors.New("ssh: duplicate response received for channel")
}
c.decided = true
ch.decided = true
return nil
}
func (c *channel) handlePacket(packet []byte) error {
func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] {
case msgChannelData, msgChannelExtendedData:
return c.handleData(packet)
return ch.handleData(packet)
case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
c.mux.chanList.remove(c.localId)
c.close()
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
ch.mux.chanList.remove(ch.localId)
ch.close()
return nil
case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
c.extPending.eof()
c.pending.eof()
ch.extPending.eof()
ch.pending.eof()
return nil
}
@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) {
case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
c.mux.chanList.remove(msg.PeersId)
c.msg <- msg
ch.mux.chanList.remove(msg.PeersID)
ch.msg <- msg
case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
}
c.remoteId = msg.MyId
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
c.msg <- msg
ch.remoteId = msg.MyID
ch.maxRemotePayload = msg.MaxPacketSize
ch.remoteWin.add(msg.MyWindow)
ch.msg <- msg
case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) {
if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
}
case *channelRequestMsg:
@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request,
WantReply: msg.WantReply,
Payload: msg.RequestSpecificData,
ch: c,
ch: ch,
}
c.incomingRequests <- &req
ch.incomingRequests <- &req
default:
c.msg <- msg
ch.msg <- msg
}
return nil
}
@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code)
}
func (c *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided {
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if ch.decided {
return nil, nil, errDecidedAlready
}
c.maxIncomingPayload = channelMaxPacket
ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxIncomingPayload,
PeersID: ch.remoteId,
MyID: ch.localId,
MyWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
}
c.decided = true
if err := c.sendMessage(confirm); err != nil {
ch.decided = true
if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err
}
return c, c.incomingRequests, nil
return ch, ch.incomingRequests, nil
}
func (ch *channel) Reject(reason RejectionReason, message string) error {
@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready
}
reject := channelOpenFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Reason: reason,
Message: message,
Language: "en",
@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
}
ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
func (ch *channel) Close() error {
@ -550,7 +550,7 @@ func (ch *channel) Close() error {
}
return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
// Extended returns an io.ReadWriter that sends and receives data on the given,
@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
}
msg := channelRequestMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Request: name,
WantReply: wantReply,
RequestSpecificData: payload,
@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{}
if !ok {
msg = channelRequestFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
} else {
msg = channelRequestSuccessMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
}
return ch.sendMessage(msg)

View file

@ -16,6 +16,10 @@ import (
"hash"
"io"
"io/ioutil"
"math/bits"
"golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305"
)
const (
@ -53,78 +57,78 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type streamCipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
}
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
var streamDump []byte
if c.skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
var streamDump []byte
if skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
mac := macModes[algs.MAC].new(macKey)
return &streamPacketCipher{
mac: mac,
etm: macModes[algs.MAC].etm,
macResult: make([]byte, mac.Size()),
cipher: stream,
}, nil
}
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*streamCipherMode{
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, 0, newRC4},
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AES-GCM is not a stream cipher, so it is constructed with a
// special case. If we add any more non-stream ciphers, we
// should invest a cleaner way to do this.
gcmCipherID: {16, 12, 0, nil},
// AEAD ciphers
gcmCipherID: {16, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config.
// (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if
// you do.
aes128cbcID: {16, aes.BlockSize, 0, nil},
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is disabled by default.
tripledescbcID: {24, des.BlockSize, 0, nil},
// 3des-cbc is insecure and is not included in the default
// config.
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
}
// prefixLen is the length of the packet prefix that contains the packet length
@ -304,7 +308,7 @@ type gcmCipher struct {
buf []byte
}
func newGCMCipher(iv, key []byte) (packetCipher, error) {
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
@ -372,7 +376,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
}
length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.")
return nil, errors.New("ssh: max packet length exceeded")
}
if cap(c.buf) < int(length+gcmTagSize) {
@ -422,7 +426,7 @@ type cbcCipher struct {
oracleCamouflage uint32
}
func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
@ -436,13 +440,13 @@ func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorith
return cbc, nil
}
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, iv, key, macKey, algs)
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
@ -450,13 +454,13 @@ func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCi
return cbc, nil
}
func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, iv, key, macKey, algs)
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
@ -548,11 +552,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize]
}
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err
} else {
c.oracleCamouflage -= uint32(n)
}
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
@ -627,3 +631,140 @@ func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, pack
return nil
}
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here:
//
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
//
// the methods here also implement padding, which RFC4253 Section 6
// also requires of stream ciphers.
type chacha20Poly1305Cipher struct {
lengthKey [8]uint32
contentKey [8]uint32
buf []byte
}
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
c := &chacha20Poly1305Cipher{
buf: make([]byte, 256),
}
for i := range c.contentKey {
c.contentKey[i] = binary.LittleEndian.Uint32(key[i*4 : (i+1)*4])
}
for i := range c.lengthKey {
c.lengthKey[i] = binary.LittleEndian.Uint32(key[(i+8)*4 : (i+9)*4])
}
return c, nil
}
func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
encryptedLength := c.buf[:4]
if _, err := io.ReadFull(r, encryptedLength); err != nil {
return nil, err
}
var lenBytes [4]byte
chacha20.New(c.lengthKey, nonce).XORKeyStream(lenBytes[:], encryptedLength)
length := binary.BigEndian.Uint32(lenBytes[:])
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
contentEnd := 4 + length
packetEnd := contentEnd + poly1305.TagSize
if uint32(cap(c.buf)) < packetEnd {
c.buf = make([]byte, packetEnd)
copy(c.buf[:], encryptedLength)
} else {
c.buf = c.buf[:packetEnd]
}
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
return nil, err
}
var mac [poly1305.TagSize]byte
copy(mac[:], c.buf[contentEnd:packetEnd])
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
return nil, errors.New("ssh: MAC failure")
}
plain := c.buf[4:contentEnd]
s.XORKeyStream(plain, plain)
padding := plain[0]
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding)+1 >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : len(plain)-int(padding)]
return plain, nil
}
func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
// There is no blocksize, so fall back to multiple of 8 byte
// padding, as described in RFC 4253, Sec 6.
const packetSizeMultiple = 8
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
if padding < 4 {
padding += packetSizeMultiple
}
// size (4 bytes), padding (1), payload, padding, tag.
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
if cap(c.buf) < totalLength {
c.buf = make([]byte, totalLength)
} else {
c.buf = c.buf[:totalLength]
}
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
chacha20.New(c.lengthKey, nonce).XORKeyStream(c.buf, c.buf[:4])
c.buf[4] = byte(padding)
copy(c.buf[5:], payload)
packetEnd := 5 + len(payload) + padding
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
return err
}
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
var mac [poly1305.TagSize]byte
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
copy(c.buf[packetEnd:], mac[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
return nil
}

View file

@ -1,129 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/rand"
"testing"
)
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range supportedCiphers {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
func TestPacketCiphers(t *testing.T) {
// Still test aes128cbc cipher although it's commented out.
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
defer delete(cipherModes, aes128cbcID)
for cipher := range cipherModes {
for mac := range macModes {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: cipher,
MAC: mac,
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q, %q): %v", cipher, mac, err)
continue
}
packet, err := server.readPacket(0, buf)
if err != nil {
t.Errorf("readPacket(%q, %q): %v", cipher, mac, err)
continue
}
if string(packet) != want {
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
}
}
}
}
func TestCBCOracleCounterMeasure(t *testing.T) {
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
defer delete(cipherModes, aes128cbcID)
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: aes128cbcID,
MAC: "hmac-sha1",
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket: %v", err)
}
packetSize := buf.Len()
buf.Write(make([]byte, 2*maxPacket))
// We corrupt each byte, but this usually will only test the
// 'packet too large' or 'MAC failure' cases.
lastRead := -1
for i := 0; i < packetSize; i++ {
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
fresh := &bytes.Buffer{}
fresh.Write(buf.Bytes())
fresh.Bytes()[i] ^= 0x01
before := fresh.Len()
_, err = server.readPacket(0, fresh)
if err == nil {
t.Errorf("corrupt byte %d: readPacket succeeded ", i)
continue
}
if _, ok := err.(cbcError); !ok {
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
continue
}
after := fresh.Len()
bytesRead := before - after
if bytesRead < maxPacket {
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
continue
}
if i > 0 && bytesRead != lastRead {
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
}
lastRead = bytesRead
}
}

View file

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
@ -18,6 +19,8 @@ import (
type Client struct {
Conn
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex
channelHandlers map[string]chan NewChannel
@ -59,8 +62,6 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.Wait()
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn
}
@ -187,6 +188,10 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
// net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
@ -209,6 +214,12 @@ type ClientConfig struct {
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
ClientVersion string
@ -255,3 +266,13 @@ func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View file

@ -11,6 +11,14 @@ import (
"io"
)
type authResult int
const (
authFailure authResult = iota
authPartialSuccess
authSuccess
)
// clientAuthenticate authenticates with the remote server. See RFC 4252.
func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
@ -37,11 +45,12 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
if err != nil {
return err
}
if ok {
if ok == authSuccess {
// success
return nil
} else if ok == authFailure {
tried[auth.method()] = true
}
tried[auth.method()] = true
if methods == nil {
methods = lastMethods
}
@ -82,7 +91,7 @@ type AuthMethod interface {
// If authentication is not successful, a []string of alternative
// method names is returned. If the slice is nil, it will be ignored
// and the previous set of possible methods will be reused.
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
// method returns the RFC 4252 method name.
method() string
@ -91,13 +100,13 @@ type AuthMethod interface {
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -111,7 +120,7 @@ func (n *noneAuth) method() string {
// a function call, e.g. by prompting the user.
type passwordCallback func() (password string, err error)
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type passwordAuthMsg struct {
User string `sshtype:"50"`
Service string
@ -125,7 +134,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
// The program may only find out that the user doesn't have a password
// when prompting.
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if err := c.writePacket(Marshal(&passwordAuthMsg{
@ -135,7 +144,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
Reply: false,
Password: pw,
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -178,7 +187,7 @@ func (cb publicKeyCallback) method() string {
return "publickey"
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
// Authentication is performed by sending an enquiry to test if a key is
// acceptable to the remote. If the key is acceptable, the client will
// attempt to authenticate with the valid key. If not the client will repeat
@ -186,13 +195,13 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
signers, err := cb()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
var methods []string
for _, signer := range signers {
ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if !ok {
continue
@ -206,7 +215,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
Method: cb.method(),
}, []byte(pub.Type()), pubKey))
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// manually wrap the serialized signature in a string
@ -224,24 +233,24 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
}
p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
return authFailure, nil, err
}
var success bool
var success authResult
success, methods, err = handleAuthResponse(c)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success || !containsMethod(methods, cb.method()) {
if success == authSuccess || !containsMethod(methods, cb.method()) {
return success, methods, err
}
}
return false, methods, nil
return authFailure, methods, nil
}
func containsMethod(methods []string, method string) bool {
@ -283,7 +292,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
}
switch packet[0] {
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil {
@ -316,30 +327,53 @@ func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMet
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (bool, []string, error) {
func handleAuthResponse(c packetConn) (authResult, []string, error) {
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After
@ -359,7 +393,7 @@ func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type initiateMsg struct {
User string `sshtype:"50"`
Service string
@ -373,37 +407,42 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
// TODO: Print banners during userauth.
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
// Manually unpack the prompt/echo pairs.
@ -413,7 +452,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return false, nil, errors.New("ssh: prompt format error")
return authFailure, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
@ -421,16 +460,16 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if len(rest) != 0 {
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
answers, err := cb(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if len(answers) != len(prompts) {
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
return authFailure, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
}
responseLength := 1 + 4
for _, a := range answers {
@ -446,7 +485,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if err := c.writePacket(serialized); err != nil {
return false, nil, err
return authFailure, nil, err
}
}
}
@ -456,10 +495,10 @@ type retryableAuthMethod struct {
maxTries int
}
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) {
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
ok, methods, err = r.authMethod.auth(session, user, c, rand)
if ok || err != nil { // either success or error terminate
if ok != authFailure || err != nil { // either success, partial success or error terminate
return ok, methods, err
}
}

View file

@ -1,628 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"os"
"strings"
"testing"
)
type keyboardInteractive map[string]string
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
var answers []string
for _, q := range questions {
answers = append(answers, cr[q])
}
return answers, nil
}
// reused internally by tests
var clientPassword = "tiger"
// tryAuth runs a handshake with a given config against an SSH server
// with config serverConfig
func tryAuth(t *testing.T, config *ClientConfig) error {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
certChecker := CertChecker{
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
IsRevoked: func(c *Certificate) bool {
return c.Serial == 666
},
}
serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
PublicKeyCallback: certChecker.Authenticate,
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
ans, err := challenge("user",
"instruction",
[]string{"question1", "question2"},
[]bool{true, true})
if err != nil {
return nil, err
}
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
if ok {
challenge("user", "motd", nil, nil)
return nil, nil
}
return nil, errors.New("keyboard-interactive failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", config)
return err
}
func TestClientAuthPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodFallback(t *testing.T) {
var passwordCalled bool
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
PasswordCallback(
func() (string, error) {
passwordCalled = true
return "WRONG", nil
}),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if passwordCalled {
t.Errorf("password auth tried before public-key auth.")
}
}
func TestAuthMethodWrongPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password("wrong"),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "answer2",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "WRONG",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
}
}
// the mock server will only authenticate ssh-rsa keys
func TestAuthMethodInvalidPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"]),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("dsa private key should not have authenticated with rsa public key")
}
}
// the client should authenticate with the second key
func TestAuthMethodRSAandDSA(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
}
func TestClientHMAC(t *testing.T) {
for _, mac := range supportedMACs {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
Config: Config{
MACs: []string{mac},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
}
}
}
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
Ciphers: []string{"aes128-cbc"}, // not currently supported
},
}
if err := tryAuth(t, config); err == nil {
t.Errorf("expected no ciphers in common")
}
}
func TestClientUnsupportedKex(t *testing.T) {
if os.Getenv("GO_BUILDER_NAME") != "" {
t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198")
}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") {
t.Errorf("got %v, expected 'common algorithm'", err)
}
}
func TestClientLoginCert(t *testing.T) {
cert := &Certificate{
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: UserCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Fatalf("NewCertSigner: %v", err)
}
clientConfig := &ClientConfig{
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
}
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
// should succeed
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
// corrupted signature
cert.Signature.Blob[0]++
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with corrupted sig")
}
// revoked
cert.Serial = 666
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("revoked cert login succeeded")
}
cert.Serial = 1
// sign with wrong key
cert.SignCert(rand.Reader, testSigners["dsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with non-authoritative key")
}
// host cert
cert.CertType = HostCert
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong type")
}
cert.CertType = UserCert
// principal specified
cert.ValidPrincipals = []string{"user"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
// wrong principal specified
cert.ValidPrincipals = []string{"fred"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong principal")
}
cert.ValidPrincipals = nil
// added critical option
cert.CriticalOptions = map[string]string{"root-access": "yes"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with unrecognized critical option")
}
// allowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login with source-address failed: %v", err)
}
// disallowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login with source-address succeeded")
}
}
func testPermissionsPassing(withPermissions bool, t *testing.T) {
serverConfig := &ServerConfig{
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "nopermissions" {
return nil, nil
}
return &Permissions{}, nil
},
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if withPermissions {
clientConfig.User = "permissions"
} else {
clientConfig.User = "nopermissions"
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatal(err)
}
if p := serverConn.Permissions; (p != nil) != withPermissions {
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p)
}
}
func TestPermissionsPassing(t *testing.T) {
testPermissionsPassing(true, t)
}
func TestNoPermissionsPassing(t *testing.T) {
testPermissionsPassing(false, t)
}
func TestRetryableAuth(t *testing.T) {
n := 0
passwords := []string{"WRONG1", "WRONG2"}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
p := passwords[n]
n++
return p, nil
}), 2),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if n != 2 {
t.Fatalf("Did not try all passwords")
}
}
func ExampleRetryableAuthMethod(t *testing.T) {
user := "testuser"
NumberOfPrompts := 3
// Normally this would be a callback that prompts the user to answer the
// provided questions
Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return []string{"answer1", "answer2"}, nil
}
config := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
// Test if username is received on server side when NoClientAuth is used
func TestClientAuthNone(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
NoClientAuth: true,
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
User: user,
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatalf("newServer: %v", err)
}
if serverConn.User() != user {
t.Fatalf("server: got %q, want %q", serverConn.User(), user)
}
}
// Test if authentication attempts are limited on server when MaxAuthTries is set
func TestClientAuthMaxAuthTries(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
MaxAuthTries: 2,
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == "right" {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
for tries := 2; tries < 4; tries++ {
n := tries
clientConfig := &ClientConfig{
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
n--
if n == 0 {
return "right", nil
}
return "wrong", nil
}), tries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", clientConfig)
if tries > 2 {
if err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
} else {
if err != nil {
t.Fatalf("client: got %s, want no error", err)
}
}
}
}
// Test if authentication attempts are correctly limited on server
// when more public keys are provided then MaxAuthTries
func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
signers := []Signer{}
for i := 0; i < 6; i++ {
signers = append(signers, testSigners["dsa"])
}
validConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, validConfig); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
invalidConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append(signers, testSigners["rsa"])...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, invalidConfig); err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
}
// Test whether authentication errors are being properly logged if all
// authentication methods have been exhausted
func TestClientAuthErrorList(t *testing.T) {
publicKeyErr := errors.New("This is an error from PublicKeyCallback")
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverConfig := &ServerConfig{
PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) {
return nil, publicKeyErr
},
}
serverConfig.AddHostKey(testSigners["rsa"])
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
_, err = newServer(c1, serverConfig)
if err == nil {
t.Fatal("newServer: got nil, expected errors")
}
authErrs, ok := err.(*ServerAuthError)
if !ok {
t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err)
}
for i, e := range authErrs.Errors {
switch i {
case 0:
if e.Error() != "no auth passed yet" {
t.Fatalf("errors: got %v, want no auth passed yet", e.Error())
}
case 1:
if e != publicKeyErr {
t.Fatalf("errors: got %v, want %v", e, publicKeyErr)
}
default:
t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors)
}
}
}

View file

@ -1,81 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"net"
"strings"
"testing"
)
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
receivedVersion := make(chan string, 1)
config.HostKeyCallback = InsecureIgnoreHostKey()
go func() {
version, err := readVersion(serverConn)
if err != nil {
receivedVersion <- ""
} else {
receivedVersion <- string(version)
}
serverConn.Close()
}()
NewClientConn(clientConn, "", config)
actual := <-receivedVersion
if actual != expected {
t.Fatalf("got %s; want %s", actual, expected)
}
}
func TestCustomClientVersion(t *testing.T) {
version := "Test-Client-Version-0.0"
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
}
func TestHostKeyCheck(t *testing.T) {
for _, tt := range []struct {
name string
wantError string
key PublicKey
}{
{"no callback", "must specify HostKeyCallback", nil},
{"correct key", "", testSigners["rsa"].PublicKey()},
{"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
clientConf := ClientConfig{
User: "user",
}
if tt.key != nil {
clientConf.HostKeyCallback = FixedHostKey(tt.key)
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
}
} else if tt.wantError != "" {
t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
}
}
}

View file

@ -24,11 +24,21 @@ const (
serviceSSH = "ssh-connection"
)
// supportedCiphers specifies the supported ciphers in preference order.
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com",
"arcfour256", "arcfour128",
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
tripledescbcID,
}
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com",
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
@ -211,7 +221,7 @@ func (c *Config) SetDefaults() {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = supportedCiphers
c.Ciphers = preferredCiphers
}
var ciphers []string
for _, c := range c.Ciphers {
@ -242,7 +252,7 @@ func (c *Config) SetDefaults() {
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct {
Session []byte
Type byte
@ -253,7 +263,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte
PubKey []byte
}{
sessionId,
sessionID,
msgUserAuthRequest,
req.User,
req.Service,

View file

@ -1,320 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh_test
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
func ExampleNewServerConn() {
// Public key authentication is done by comparing
// the public key of a received connection
// with the entries in the authorized_keys file.
authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys")
if err != nil {
log.Fatalf("Failed to load authorized_keys, err: %v", err)
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
log.Fatal(err)
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
// Remove to disable password auth.
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
// Remove to disable public key auth.
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeysMap[string(pubKey.Marshal())] {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
log.Fatal("Failed to load private key: ", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key: ", err)
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
log.Fatal("failed to listen for connection: ", err)
}
nConn, err := listener.Accept()
if err != nil {
log.Fatal("failed to accept incoming connection: ", err)
}
// Before use, a handshake must be performed on the incoming
// net.Conn.
conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Fatal("failed to handshake: ", err)
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
// "session" and ServerShell may be used to present a simple
// terminal interface.
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Fatalf("Could not accept channel: %v", err)
}
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
}
}(requests)
term := terminal.NewTerminal(channel, "> ")
go func() {
defer channel.Close()
for {
line, err := term.ReadLine()
if err != nil {
break
}
fmt.Println(line)
}
}()
}
}
func ExampleHostKeyCheck() {
// Every client must provide a host key check. Here is a
// simple-minded parse of OpenSSH's known_hosts file
host := "hostname"
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var hostKey ssh.PublicKey
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
if strings.Contains(fields[0], host) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
log.Fatalf("error parsing %q: %v", fields[2], err)
}
break
}
}
if hostKey == nil {
log.Fatalf("no hostkey for %s", host)
}
config := ssh.ClientConfig{
User: os.Getenv("USER"),
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
_, err = ssh.Dial("tcp", host+":22", &config)
log.Println(err)
}
func ExampleDial() {
var hostKey ssh.PublicKey
// An SSH client is represented with a ClientConn.
//
// To authenticate with the remote server you must pass at least one
// implementation of AuthMethod via the Auth field in ClientConfig,
// and provide a HostKeyCallback.
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
client, err := ssh.Dial("tcp", "yourserver.com:22", config)
if err != nil {
log.Fatal("Failed to dial: ", err)
}
// Each ClientConn can support multiple interactive sessions,
// represented by a Session.
session, err := client.NewSession()
if err != nil {
log.Fatal("Failed to create session: ", err)
}
defer session.Close()
// Once a Session is created, you can execute a single command on
// the remote side using the Run method.
var b bytes.Buffer
session.Stdout = &b
if err := session.Run("/usr/bin/whoami"); err != nil {
log.Fatal("Failed to run: " + err.Error())
}
fmt.Println(b.String())
}
func ExamplePublicKeys() {
var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote
// server by using an unencrypted PEM-encoded private key file.
//
// If you have an encrypted private key, the crypto/x509 package
// can be used to decrypt it.
key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa")
if err != nil {
log.Fatalf("unable to read private key: %v", err)
}
// Create the Signer for this private key.
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Fatalf("unable to parse private key: %v", err)
}
config := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
// Use the PublicKeys method for remote authentication.
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to the remote server and perform the SSH handshake.
client, err := ssh.Dial("tcp", "host.com:22", config)
if err != nil {
log.Fatalf("unable to connect: %v", err)
}
defer client.Close()
}
func ExampleClient_Listen() {
var hostKey ssh.PublicKey
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Dial your ssh server.
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Request the remote side to open port 8080 on all interfaces.
l, err := conn.Listen("tcp", "0.0.0.0:8080")
if err != nil {
log.Fatal("unable to register tcp forward: ", err)
}
defer l.Close()
// Serve HTTP with your SSH server acting as a reverse proxy.
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
fmt.Fprintf(resp, "Hello world!\n")
}))
}
func ExampleSession_RequestPty() {
var hostKey ssh.PublicKey
// Create client config
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to ssh server
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Create a session
session, err := conn.NewSession()
if err != nil {
log.Fatal("unable to create session: ", err)
}
defer session.Close()
// Set up terminal modes
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := session.RequestPty("xterm", 40, 80, modes); err != nil {
log.Fatal("request for pseudo terminal failed: ", err)
}
// Start remote shell
if err := session.Shell(); err != nil {
log.Fatal("failed to start shell: ", err)
}
}

View file

@ -78,6 +78,11 @@ type handshakeTransport struct {
dialAddress string
remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange.
algorithms *algorithms
@ -120,6 +125,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr
t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else {

View file

@ -1,559 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"reflect"
"runtime"
"strings"
"sync"
"testing"
)
type testChecker struct {
calls []string
}
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
if dialAddr == "bad" {
return fmt.Errorf("dialAddr is bad")
}
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
}
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
return nil
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
// noiseTransport inserts ignore messages to check that the read loop
// and the key exchange filters out these messages.
type noiseTransport struct {
keyingTransport
}
func (t *noiseTransport) writePacket(p []byte) error {
ignore := []byte{msgIgnore}
if err := t.keyingTransport.writePacket(ignore); err != nil {
return err
}
debug := []byte{msgDebug, 1, 2, 3}
if err := t.keyingTransport.writePacket(debug); err != nil {
return err
}
return t.keyingTransport.writePacket(p)
}
func addNoiseTransport(t keyingTransport) keyingTransport {
return &noiseTransport{t}
}
// handshakePair creates two handshakeTransports connected with each
// other. If the noise argument is true, both transports will try to
// confuse the other side by sending ignore and debug messages.
func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe()
if err != nil {
return nil, nil, err
}
var trC, trS keyingTransport
trC = newTransport(a, rand.Reader, true)
trS = newTransport(b, rand.Reader, false)
if noise {
trC = addNoiseTransport(trC)
trS = addNoiseTransport(trS)
}
clientConf.SetDefaults()
v := []byte("version")
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
if err := server.waitSession(); err != nil {
return nil, nil, fmt.Errorf("server.waitSession: %v", err)
}
if err := client.waitSession(); err != nil {
return nil, nil, fmt.Errorf("client.waitSession: %v", err)
}
return client, server, nil
}
func TestHandshakeBasic(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &syncChecker{
waitCall: make(chan int, 10),
called: make(chan int, 10),
}
checker.waitCall <- 1
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// Let first kex complete normally.
<-checker.called
clientDone := make(chan int, 0)
gotHalf := make(chan int, 0)
const N = 20
go func() {
defer close(clientDone)
// Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the
// handshake in progress. We do this twice, so we test
// that the packet buffer is reset correctly.
for i := 0; i < N; i++ {
p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil {
t.Fatalf("sendPacket: %v", err)
}
if (i % 10) == 5 {
<-gotHalf
// halfway through, we request a key change.
trC.requestKeyExchange()
// Wait until we can be sure the key
// change has really started before we
// write more.
<-checker.called
}
if (i % 10) == 7 {
// write some packets until the kex
// completes, to test buffering of
// packets.
checker.waitCall <- 1
}
}
}()
// Server checks that client messages come in cleanly
i := 0
err = nil
for ; i < N; i++ {
var p []byte
p, err = trS.readPacket()
if err != nil {
break
}
if (i % 10) == 5 {
gotHalf <- 1
}
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %v, want %v", i, p, want)
}
}
<-clientDone
if err != nil && err != io.EOF {
t.Fatalf("server error: %v", err)
}
if i != N {
t.Errorf("received %d messages, want 10.", i)
}
close(checker.called)
if _, ok := <-checker.called; ok {
// If all went well, we registered exactly 2 key changes: one
// that establishes the session, and one that we requested
// additionally.
t.Fatalf("got another host key checks after 2 handshakes")
}
}
func TestForceFirstKex(t *testing.T) {
// like handshakePair, but must access the keyingTransport.
checker := &testChecker{}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
var trC, trS keyingTransport
trC = newTransport(a, rand.Reader, true)
// This is the disallowed packet:
trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
// Rest of the setup.
trS = newTransport(b, rand.Reader, false)
clientConf.SetDefaults()
v := []byte("version")
client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server := newServerTransport(trS, v, v, serverConf)
defer client.Close()
defer server.Close()
// We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is
// disallowed.
if err := server.waitSession(); err == nil {
t.Errorf("server first kex init should reject unexpected packet")
}
}
func TestHandshakeAutoRekeyWrite(t *testing.T) {
checker := &syncChecker{
called: make(chan int, 10),
waitCall: nil,
}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
input := make([]byte, 251)
input[0] = msgRequestSuccess
done := make(chan int, 1)
const numPacket = 5
go func() {
defer close(done)
j := 0
for ; j < numPacket; j++ {
if p, err := trS.readPacket(); err != nil {
break
} else if !bytes.Equal(input, p) {
t.Errorf("got packet type %d, want %d", p[0], input[0])
}
}
if j != numPacket {
t.Errorf("got %d, want 5 messages", j)
}
}()
<-checker.called
for i := 0; i < numPacket; i++ {
p := make([]byte, len(input))
copy(p, input)
if err := trC.writePacket(p); err != nil {
t.Errorf("writePacket: %v", err)
}
if i == 2 {
// Make sure the kex is in progress.
<-checker.called
}
}
<-done
}
type syncChecker struct {
waitCall chan int
called chan int
}
func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
c.called <- 1
if c.waitCall != nil {
<-c.waitCall
}
return nil
}
func TestHandshakeAutoRekeyRead(t *testing.T) {
sync := &syncChecker{
called: make(chan int, 2),
waitCall: nil,
}
clientConf := &ClientConfig{
HostKeyCallback: sync.Check,
}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
packet := make([]byte, 501)
packet[0] = msgRequestSuccess
if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err)
}
// While we read out the packet, a key change will be
// initiated.
done := make(chan int, 1)
go func() {
defer close(done)
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
}()
<-done
<-sync.called
}
// errorKeyingTransport generates errors after a given number of
// read/write operations.
type errorKeyingTransport struct {
packetConn
readLeft, writeLeft int
}
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil
}
func (n *errorKeyingTransport) getSessionID() []byte {
return nil
}
func (n *errorKeyingTransport) writePacket(packet []byte) error {
if n.writeLeft == 0 {
n.Close()
return errors.New("barf")
}
n.writeLeft--
return n.packetConn.writePacket(packet)
}
func (n *errorKeyingTransport) readPacket() ([]byte, error) {
if n.readLeft == 0 {
n.Close()
return nil, errors.New("barf")
}
n.readLeft--
return n.packetConn.readPacket()
}
func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, false)
}
}
func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, false)
}
}
func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, true)
}
}
func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, true)
}
}
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe()
defer a.Close()
defer b.Close()
key := testSigners["ecdsa"]
serverConf := Config{RekeyThreshold: minRekeyThreshold}
serverConf.SetDefaults()
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key}
go serverConn.readLoop()
go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
clientConn.hostKeyCallback = InsecureIgnoreHostKey()
go clientConn.readLoop()
go clientConn.kexLoop()
var wg sync.WaitGroup
for _, hs := range []packetConn{serverConn, clientConn} {
if !coupled {
wg.Add(2)
go func(c packetConn) {
for i := 0; ; i++ {
str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
err := c.writePacket(Marshal(&serviceRequestMsg{str}))
if err != nil {
break
}
}
wg.Done()
c.Close()
}(hs)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
}
wg.Done()
}(hs)
} else {
wg.Add(1)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
if err := c.writePacket(msg); err != nil {
break
}
}
wg.Done()
}(hs)
}
}
wg.Wait()
}
func TestDisconnect(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
errMsg := &disconnectMsg{
Reason: 42,
Message: "such is life",
}
trC.writePacket(Marshal(errMsg))
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
packet, err := trS.readPacket()
if err != nil {
t.Fatalf("readPacket 1: %v", err)
}
if packet[0] != msgRequestSuccess {
t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 2 succeeded")
} else if !reflect.DeepEqual(err, errMsg) {
t.Errorf("got error %#v, want %#v", err, errMsg)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 3 succeeded")
}
}
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{
Ciphers: []string{"aes128-ctr"},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
trC.Close()
rgb := (1024 + trC.readBytesLeft) >> 30
wgb := (1024 + trC.writeBytesLeft) >> 30
if rgb != 64 {
t.Errorf("got rekey after %dG read, want 64G", rgb)
}
if wgb != 64 {
t.Errorf("got rekey after %dG write, want 64G", wgb)
}
}

View file

@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err
}
kInt, err := group.diffieHellman(kexDHReply.Y, x)
ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, err
}
@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
}
Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y)
ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil {
return nil, err
}
@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)

View file

@ -1,50 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Key exchange tests.
import (
"crypto/rand"
"reflect"
"testing"
)
func TestKexes(t *testing.T) {
type kexResultErr struct {
result *kexResult
err error
}
for name, kex := range kexAlgoMap {
a, b := memPipe()
s := make(chan kexResultErr, 1)
c := make(chan kexResultErr, 1)
var magics handshakeMagics
go func() {
r, e := kex.Client(a, rand.Reader, &magics)
a.Close()
c <- kexResultErr{r, e}
}()
go func() {
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
b.Close()
s <- kexResultErr{r, e}
}()
clientRes := <-c
serverRes := <-s
if clientRes.err != nil {
t.Errorf("client: %v", clientRes.err)
}
if serverRes.err != nil {
t.Errorf("server: %v", serverRes.err)
}
if !reflect.DeepEqual(clientRes.result, serverRes.result) {
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
}
}
}

View file

@ -276,7 +276,8 @@ type PublicKey interface {
Type() string
// Marshal returns the serialized key data in SSH wire format,
// with the name prefix.
// with the name prefix. To unmarshal the returned data, use
// the ParsePublicKey function.
Marshal() []byte
// Verify that sig is a signature on the given data using this
@ -363,7 +364,7 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string {
func (k *dsaPublicKey) Type() string {
return "ssh-dss"
}
@ -481,12 +482,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID()
func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + k.nistID()
}
func (key *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize {
func (k *ecdsaPublicKey) nistID() string {
switch k.Params().BitSize {
case 256:
return "nistp256"
case 384:
@ -499,7 +500,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string {
func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519
}
@ -518,23 +519,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil
}
func (key ed25519PublicKey) Marshal() []byte {
func (k ed25519PublicKey) Marshal() []byte {
w := struct {
Name string
KeyBytes []byte
}{
KeyAlgoED25519,
[]byte(key),
[]byte(k),
}
return Marshal(&w)
}
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
edKey := (ed25519.PublicKey)(key)
edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify")
}
@ -595,9 +596,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil
}
func (key *ecdsaPublicKey) Marshal() []byte {
func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package.
w := struct {
@ -605,20 +606,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string
Key []byte
}{
key.Type(),
key.nistID(),
k.Type(),
k.nistID(),
keyBytes,
}
return Marshal(&w)
}
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := ecHash(key.Curve).New()
h := ecHash(k.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@ -635,7 +636,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err
}
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
@ -758,7 +759,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.")
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
}
return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey:

View file

@ -1,500 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"fmt"
"reflect"
"strings"
"testing"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh/testdata"
)
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
case *rsaPublicKey:
return (*rsa.PublicKey)(k)
case *dsaPublicKey:
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
case ed25519PublicKey:
return (ed25519.PublicKey)(k)
case *Certificate:
return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
roundtrip, err := ParsePublicKey(pub.Marshal())
if err != nil {
t.Errorf("ParsePublicKey(%T): %v", pub, err)
}
k1 := rawKey(pub)
k2 := rawKey(roundtrip)
if !reflect.DeepEqual(k1, k2) {
t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
}
}
}
func TestUnsupportedCurves(t *testing.T) {
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err)
}
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err)
}
}
func TestNewPublicKey(t *testing.T) {
for _, k := range testSigners {
raw := rawKey(k.PublicKey())
// Skip certificates, as NewPublicKey does not support them.
if _, ok := raw.(*Certificate); ok {
continue
}
pub, err := NewPublicKey(raw)
if err != nil {
t.Errorf("NewPublicKey(%#v): %v", raw, err)
}
if !reflect.DeepEqual(k.PublicKey(), pub) {
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
}
}
}
func TestKeySignVerify(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
data := []byte("sign me")
sig, err := priv.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
if err := pub.Verify(data, sig); err != nil {
t.Errorf("publicKey.Verify(%T): %v", priv, err)
}
sig.Blob[5]++
if err := pub.Verify(data, sig); err == nil {
t.Errorf("publicKey.Verify on broken sig did not fail")
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
key := testPrivateKeys["rsa"]
rsa, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
if err := rsa.Validate(); err != nil {
t.Errorf("Validate: %v", err)
}
}
func TestParseECPrivateKey(t *testing.T) {
key := testPrivateKeys["ecdsa"]
ecKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
}
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
t.Fatalf("public key does not validate.")
}
}
// See Issue https://github.com/golang/go/issues/6650.
func TestParseEncryptedPrivateKeysFails(t *testing.T) {
const wantSubstring = "encrypted"
for i, tt := range testdata.PEMEncryptedKeys {
_, err := ParsePrivateKey(tt.PEMBytes)
if err == nil {
t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name)
continue
}
if !strings.Contains(err.Error(), wantSubstring) {
t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring)
}
}
}
// Parse encrypted private keys with passphrase
func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
data := []byte("sign me")
for _, tt := range testdata.PEMEncryptedKeys {
s, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
if err != nil {
t.Fatalf("ParsePrivateKeyWithPassphrase returned error: %s", err)
continue
}
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
tt := testdata.PEMEncryptedKeys[0]
_, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte("incorrect"))
if err != x509.IncorrectPasswordError {
t.Fatalf("got %v want IncorrectPasswordError", err)
}
}
func TestParseDSA(t *testing.T) {
// We actually exercise the ParsePrivateKey codepath here, as opposed to
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
// uses.
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
if err != nil {
t.Fatalf("ParsePrivateKey returned error: %s", err)
}
data := []byte("sign me")
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
// Tests for authorized_keys parsing.
// getTestKey returns a public key, and its base64 encoding.
func getTestKey() (PublicKey, string) {
k := testPublicKeys["rsa"]
b := &bytes.Buffer{}
e := base64.NewEncoder(base64.StdEncoding, b)
e.Write(k.Marshal())
e.Close()
return k, b.String()
}
func TestMarshalParsePublicKey(t *testing.T) {
pub, pubSerialized := getTestKey()
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
authKeys := MarshalAuthorizedKey(pub)
actualFields := strings.Fields(string(authKeys))
if len(actualFields) == 0 {
t.Fatalf("failed authKeys: %v", authKeys)
}
// drop the comment
expectedFields := strings.Fields(line)[0:2]
if !reflect.DeepEqual(actualFields, expectedFields) {
t.Errorf("got %v, expected %v", actualFields, expectedFields)
}
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
if err != nil {
t.Fatalf("cannot parse %v: %v", line, err)
}
if !reflect.DeepEqual(actPub, pub) {
t.Errorf("got %v, expected %v", actPub, pub)
}
}
type authResult struct {
pubKey PublicKey
options []string
comments string
rest string
ok bool
}
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
rest := authKeys
var values []authResult
for len(rest) > 0 {
var r authResult
var err error
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
r.ok = (err == nil)
t.Log(err)
r.rest = string(rest)
values = append(values, r)
}
if !reflect.DeepEqual(values, expected) {
t.Errorf("got %#v, expected %#v", values, expected)
}
}
func TestAuthorizedKeyBasic(t *testing.T) {
pub, pubSerialized := getTestKey()
line := "ssh-rsa " + pubSerialized + " user@host"
testAuthorizedKeys(t, []byte(line),
[]authResult{
{pub, nil, "user@host", "", true},
})
}
func TestAuth(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithOptions := []string{
`# comments to ignore before any keys...`,
``,
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
`# comments to ignore, along with a blank line`,
``,
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
``,
`# more comments, plus a invalid entry`,
`ssh-rsa data-that-will-not-parse user@host3`,
}
for _, eol := range []string{"\n", "\r\n"} {
authOptions := strings.Join(authWithOptions, eol)
rest2 := strings.Join(authWithOptions[3:], eol)
rest3 := strings.Join(authWithOptions[6:], eol)
testAuthorizedKeys(t, []byte(authOptions), []authResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
})
}
}
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
})
}
func TestAuthWithInvalidSpace(t *testing.T) {
_, pubSerialized := getTestKey()
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
#more to follow but still no valid keys`)
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
{nil, nil, "", "", false},
})
}
func TestAuthWithMissingQuote(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
})
}
func TestInvalidEntry(t *testing.T) {
authInvalid := []byte(`ssh-rsa`)
_, _, _, _, err := ParseAuthorizedKey(authInvalid)
if err == nil {
t.Errorf("got valid entry for %q", authInvalid)
}
}
var knownHostsParseTests = []struct {
input string
err string
marker string
comment string
hosts []string
rest string
}{
{
"",
"EOF",
"", "", nil, "",
},
{
"# Just a comment",
"EOF",
"", "", nil, "",
},
{
" \t ",
"EOF",
"", "", nil, "",
},
{
"localhost ssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line",
"",
"", "comment comment", []string{"localhost"}, "next line",
},
{
"localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}",
"",
"marker", "", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd",
"short read",
"", "", nil, "",
},
}
func TestKnownHostsParsing(t *testing.T) {
rsaPub, rsaPubSerialized := getTestKey()
for i, test := range knownHostsParseTests {
var expectedKey PublicKey
const rsaKeyToken = "{RSAPUB}"
input := test.input
if strings.Contains(input, rsaKeyToken) {
expectedKey = rsaPub
input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1)
}
marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input))
if err != nil {
if len(test.err) == 0 {
t.Errorf("#%d: unexpectedly failed with %q", i, err)
} else if !strings.Contains(err.Error(), test.err) {
t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err)
}
continue
} else if len(test.err) != 0 {
t.Errorf("#%d: succeeded but expected error including %q", i, test.err)
continue
}
if !reflect.DeepEqual(expectedKey, pubKey) {
t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey)
}
if marker != test.marker {
t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker)
}
if comment != test.comment {
t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment)
}
if !reflect.DeepEqual(test.hosts, hosts) {
t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts)
}
if rest := string(rest); rest != test.rest {
t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest)
}
}
}
func TestFingerprintLegacyMD5(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintLegacyMD5(pub)
want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}
func TestFingerprintSHA256(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintSHA256(pub)
want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}

View file

@ -1,110 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
"testing"
)
// An in-memory packetConn. It is safe to call Close and writePacket
// from different goroutines.
type memTransport struct {
eof bool
pending [][]byte
write *memTransport
sync.Mutex
*sync.Cond
}
func (t *memTransport) readPacket() ([]byte, error) {
t.Lock()
defer t.Unlock()
for {
if len(t.pending) > 0 {
r := t.pending[0]
t.pending = t.pending[1:]
return r, nil
}
if t.eof {
return nil, io.EOF
}
t.Cond.Wait()
}
}
func (t *memTransport) closeSelf() error {
t.Lock()
defer t.Unlock()
if t.eof {
return io.EOF
}
t.eof = true
t.Cond.Broadcast()
return nil
}
func (t *memTransport) Close() error {
err := t.write.closeSelf()
t.closeSelf()
return err
}
func (t *memTransport) writePacket(p []byte) error {
t.write.Lock()
defer t.write.Unlock()
if t.write.eof {
return io.EOF
}
c := make([]byte, len(p))
copy(c, p)
t.write.pending = append(t.write.pending, c)
t.write.Cond.Signal()
return nil
}
func memPipe() (a, b packetConn) {
t1 := memTransport{}
t2 := memTransport{}
t1.write = &t2
t2.write = &t1
t1.Cond = sync.NewCond(&t1.Mutex)
t2.Cond = sync.NewCond(&t2.Mutex)
return &t1, &t2
}
func TestMemPipe(t *testing.T) {
a, b := memPipe()
if err := a.writePacket([]byte{42}); err != nil {
t.Fatalf("writePacket: %v", err)
}
if err := a.Close(); err != nil {
t.Fatal("Close: ", err)
}
p, err := b.readPacket()
if err != nil {
t.Fatal("readPacket: ", err)
}
if len(p) != 1 || p[0] != 42 {
t.Fatalf("got %v, want {42}", p)
}
p, err = b.readPacket()
if err != io.EOF {
t.Fatalf("got %v, %v, want EOF", p, err)
}
}
func TestDoubleClose(t *testing.T) {
a, _ := memPipe()
err := a.Close()
if err != nil {
t.Errorf("Close: %v", err)
}
err = a.Close()
if err != io.EOF {
t.Errorf("expect EOF on double close.")
}
}

View file

@ -23,10 +23,6 @@ const (
msgUnimplemented = 3
msgDebug = 4
msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
)
// SSH messages:
@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool
}
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61
@ -154,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersId uint32
PeersID uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -165,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersId uint32 `sshtype:"94"`
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
@ -174,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"`
MyId uint32
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -185,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"`
PeersID uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
@ -194,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"`
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
@ -204,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"`
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"`
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"`
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"`
PeersID uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
@ -255,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"`
PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32
}

View file

@ -1,288 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"math/big"
"math/rand"
"reflect"
"testing"
"testing/quick"
)
var intLengthTests = []struct {
val, length int
}{
{0, 4 + 0},
{1, 4 + 1},
{127, 4 + 1},
{128, 4 + 2},
{-1, 4 + 1},
}
func TestIntLength(t *testing.T) {
for _, test := range intLengthTests {
v := new(big.Int).SetInt64(int64(test.val))
length := intLength(v)
if length != test.length {
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
}
}
}
type msgAllTypes struct {
Bool bool `sshtype:"21"`
Array [16]byte
Uint64 uint64
Uint32 uint32
Uint8 uint8
String string
Strings []string
Bytes []byte
Int *big.Int
Rest []byte `ssh:"rest"`
}
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
m := &msgAllTypes{}
m.Bool = rand.Intn(2) == 1
randomBytes(m.Array[:], rand)
m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
m.Uint8 = uint8(rand.Intn(1 << 8))
m.String = string(m.Array[:])
m.Strings = randomNameList(rand)
m.Bytes = m.Array[:]
m.Int = randomInt(rand)
m.Rest = m.Array[:]
return reflect.ValueOf(m)
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
iface := &msgAllTypes{}
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("failed to create value")
break
}
m1 := v.Elem().Interface()
m2 := iface
marshaled := Marshal(m1)
if err := Unmarshal(marshaled, m2); err != nil {
t.Errorf("Unmarshal %#v: %s", m1, err)
break
}
if !reflect.DeepEqual(v.Interface(), m2) {
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
break
}
}
}
func TestUnmarshalEmptyPacket(t *testing.T) {
var b []byte
var m channelRequestSuccessMsg
if err := Unmarshal(b, &m); err == nil {
t.Fatalf("unmarshal of empty slice succeeded")
}
}
func TestUnmarshalUnexpectedPacket(t *testing.T) {
type S struct {
I uint32 `sshtype:"43"`
S string
B bool
}
s := S{11, "hello", true}
packet := Marshal(s)
packet[0] = 42
roundtrip := S{}
err := Unmarshal(packet, &roundtrip)
if err == nil {
t.Fatal("expected error, not nil")
}
}
func TestMarshalPtr(t *testing.T) {
s := struct {
S string
}{"hello"}
m1 := Marshal(s)
m2 := Marshal(&s)
if !bytes.Equal(m1, m2) {
t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
}
}
func TestBareMarshalUnmarshal(t *testing.T) {
type S struct {
I uint32
S string
B bool
}
s := S{42, "hello", true}
packet := Marshal(s)
roundtrip := S{}
Unmarshal(packet, &roundtrip)
if !reflect.DeepEqual(s, roundtrip) {
t.Errorf("got %#v, want %#v", roundtrip, s)
}
}
func TestBareMarshal(t *testing.T) {
type S2 struct {
I uint32
}
s := S2{42}
packet := Marshal(s)
i, rest, ok := parseUint32(packet)
if len(rest) > 0 || !ok {
t.Errorf("parseInt(%q): parse error", packet)
}
if i != s.I {
t.Errorf("got %d, want %d", i, s.I)
}
}
func TestUnmarshalShortKexInitPacket(t *testing.T) {
// This used to panic.
// Issue 11348
packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
kim := &kexInitMsg{}
if err := Unmarshal(packet, kim); err == nil {
t.Error("truncated packet unmarshaled without error")
}
}
func TestMarshalMultiTag(t *testing.T) {
var res struct {
A uint32 `sshtype:"1|2"`
}
good1 := struct {
A uint32 `sshtype:"1"`
}{
1,
}
good2 := struct {
A uint32 `sshtype:"2"`
}{
1,
}
if e := Unmarshal(Marshal(good1), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
if e := Unmarshal(Marshal(good2), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
bad1 := struct {
A uint32 `sshtype:"3"`
}{
1,
}
if e := Unmarshal(Marshal(bad1), &res); e == nil {
t.Errorf("bad struct unmarshaled without error")
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())
}
}
func randomNameList(rand *rand.Rand) []string {
ret := make([]string, rand.Int31()&15)
for i := range ret {
s := make([]byte, 1+(rand.Int31()&15))
for j := range s {
s[j] = 'a' + uint8(rand.Int31()&15)
}
ret[i] = string(s)
}
return ret
}
func randomInt(rand *rand.Rand) *big.Int {
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
ki := &kexInitMsg{}
randomBytes(ki.Cookie[:], rand)
ki.KexAlgos = randomNameList(rand)
ki.ServerHostKeyAlgos = randomNameList(rand)
ki.CiphersClientServer = randomNameList(rand)
ki.CiphersServerClient = randomNameList(rand)
ki.MACsClientServer = randomNameList(rand)
ki.MACsServerClient = randomNameList(rand)
ki.CompressionClientServer = randomNameList(rand)
ki.CompressionServerClient = randomNameList(rand)
ki.LanguagesClientServer = randomNameList(rand)
ki.LanguagesServerClient = randomNameList(rand)
if rand.Int31()&1 == 1 {
ki.FirstKexFollows = true
}
return reflect.ValueOf(ki)
}
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
dhi := &kexDHInitMsg{}
dhi.X = randomInt(rand)
return reflect.ValueOf(dhi)
}
var (
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexInit = Marshal(_kexInitMsg)
_kexDHInit = Marshal(_kexDHInitMsg)
)
func BenchmarkMarshalKexInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexInitMsg)
}
}
func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
m := new(kexInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexInit, m)
}
}
func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexDHInitMsg)
}
}
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
m := new(kexDHInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexDHInit, m)
}
}

View file

@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId,
PeersID: msg.PeersID,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
}
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId
c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c
@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra,
PeersId: ch.localId,
PeersID: ch.localId,
}
if err := m.sendMessage(open); err != nil {
return nil, err

View file

@ -1,505 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"io/ioutil"
"sync"
"testing"
)
func muxPair() (*mux, *mux) {
a, b := memPipe()
s := newMux(a)
c := newMux(b)
return s, c
}
// Returns both ends of a channel, and the mux for the the 2nd
// channel.
func channelPair(t *testing.T) (*channel, *channel, *mux) {
c, s := muxPair()
res := make(chan *channel, 1)
go func() {
newCh, ok := <-s.incomingChannels
if !ok {
t.Fatalf("No incoming channel")
}
if newCh.ChannelType() != "chan" {
t.Fatalf("got type %q want chan", newCh.ChannelType())
}
ch, _, err := newCh.Accept()
if err != nil {
t.Fatalf("Accept %v", err)
}
res <- ch.(*channel)
}()
ch, err := c.openChannel("chan", nil)
if err != nil {
t.Fatalf("OpenChannel: %v", err)
}
return <-res, ch, c
}
// Test that stderr and stdout can be addressed from different
// goroutines. This is intended for use with the race detector.
func TestMuxChannelExtendedThreadSafety(t *testing.T) {
writer, reader, mux := channelPair(t)
defer writer.Close()
defer reader.Close()
defer mux.Close()
var wr, rd sync.WaitGroup
magic := "hello world"
wr.Add(2)
go func() {
io.WriteString(writer, magic)
wr.Done()
}()
go func() {
io.WriteString(writer.Stderr(), magic)
wr.Done()
}()
rd.Add(2)
go func() {
c, err := ioutil.ReadAll(reader)
if string(c) != magic {
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
go func() {
c, err := ioutil.ReadAll(reader.Stderr())
if string(c) != magic {
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
wr.Wait()
writer.CloseWrite()
rd.Wait()
}
func TestMuxReadWrite(t *testing.T) {
s, c, mux := channelPair(t)
defer s.Close()
defer c.Close()
defer mux.Close()
magic := "hello world"
magicExt := "hello stderr"
go func() {
_, err := s.Write([]byte(magic))
if err != nil {
t.Fatalf("Write: %v", err)
}
_, err = s.Extended(1).Write([]byte(magicExt))
if err != nil {
t.Fatalf("Write: %v", err)
}
err = s.Close()
if err != nil {
t.Fatalf("Close: %v", err)
}
}()
var buf [1024]byte
n, err := c.Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got := string(buf[:n])
if got != magic {
t.Fatalf("server: got %q want %q", got, magic)
}
n, err = c.Extended(1).Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got = string(buf[:n])
if got != magicExt {
t.Fatalf("server: got %q want %q", got, magic)
}
}
func TestMuxChannelOverflow(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
// Send 1 byte.
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], writer.remoteId)
marshalUint32(packet[5:], uint32(1))
packet[9] = 42
if err := writer.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
<-wDone
}
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
reader.Close()
<-wDone
}
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
mux.Close()
<-wDone
}
func TestMuxReject(t *testing.T) {
client, server := muxPair()
defer server.Close()
defer client.Close()
go func() {
ch, ok := <-server.incomingChannels
if !ok {
t.Fatalf("Accept")
}
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
}
ch.Reject(RejectionReason(42), "message")
}()
ch, err := client.openChannel("ch", []byte("extra"))
if ch != nil {
t.Fatal("openChannel not rejected")
}
ocf, ok := err.(*OpenChannelError)
if !ok {
t.Errorf("got %#v want *OpenChannelError", err)
} else if ocf.Reason != 42 || ocf.Message != "message" {
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
}
want := "ssh: rejected: unknown reason 42 (message)"
if err.Error() != want {
t.Errorf("got %q, want %q", err.Error(), want)
}
}
func TestMuxChannelRequest(t *testing.T) {
client, server, mux := channelPair(t)
defer server.Close()
defer client.Close()
defer mux.Close()
var received int
var wg sync.WaitGroup
wg.Add(1)
go func() {
for r := range server.incomingRequests {
received++
r.Reply(r.Type == "yes", nil)
}
wg.Done()
}()
_, err := client.SendRequest("yes", false, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
ok, err := client.SendRequest("yes", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if !ok {
t.Errorf("SendRequest(yes): %v", ok)
}
ok, err = client.SendRequest("no", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if ok {
t.Errorf("SendRequest(no): %v", ok)
}
client.Close()
wg.Wait()
if received != 3 {
t.Errorf("got %d requests, want %d", received, 3)
}
}
func TestMuxGlobalRequest(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
var seen bool
go func() {
for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
if err != nil {
t.Errorf("AckRequest: %v", err)
}
}
}
}()
_, _, err := clientMux.SendRequest("peek", false, nil)
if err != nil {
t.Errorf("SendRequest: %v", err)
}
ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
if !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}
if !seen {
t.Errorf("never saw 'peek' request")
}
}
func TestMuxGlobalRequestUnblock(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
result := make(chan error, 1)
go func() {
_, _, err := clientMux.SendRequest("hello", true, nil)
result <- err
}()
<-serverMux.incomingRequests
serverMux.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", io.EOF)
}
}
func TestMuxChannelRequestUnblock(t *testing.T) {
a, b, connB := channelPair(t)
defer a.Close()
defer b.Close()
defer connB.Close()
result := make(chan error, 1)
go func() {
_, err := a.SendRequest("hello", true, nil)
result <- err
}()
<-b.incomingRequests
connB.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", err)
}
}
func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
defer r.Close()
defer w.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.Close(); err != nil {
t.Errorf("w.Close: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after Close", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxCloseWriteChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.CloseWrite(); err != nil {
t.Errorf("w.CloseWrite: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after CloseWrite", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxInvalidRecord(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
marshalUint32(packet[5:], 1)
packet[9] = 42
a.conn.writePacket(packet)
go a.SendRequest("hello", false, nil)
// 'a' wrote an invalid packet, so 'b' has exited.
req, ok := <-b.incomingRequests
if ok {
t.Errorf("got request %#v after receiving invalid packet", req)
}
}
func TestZeroWindowAdjust(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
go func() {
io.WriteString(a, "hello")
// bogus adjust.
a.sendMessage(windowAdjustMsg{})
io.WriteString(a, "world")
a.Close()
}()
want := "helloworld"
c, _ := ioutil.ReadAll(b)
if string(c) != want {
t.Errorf("got %q want %q", c, want)
}
}
func TestMuxMaxPacketSize(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
large := make([]byte, a.maxRemotePayload+1)
packet := make([]byte, 1+4+4+1+len(large))
packet[0] = msgChannelData
marshalUint32(packet[1:], a.remoteId)
marshalUint32(packet[5:], uint32(len(large)))
packet[9] = 42
if err := a.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
go a.SendRequest("hello", false, nil)
_, ok := <-b.incomingRequests
if ok {
t.Errorf("connection still alive after receiving large packet.")
}
}
// Don't ship code with debug=true.
func TestDebug(t *testing.T) {
if debugMux {
t.Error("mux debug switched on")
}
if debugHandshake {
t.Error("handshake debug switched on")
}
if debugTransport {
t.Error("transport debug switched on")
}
}

View file

@ -95,6 +95,10 @@ type ServerConfig struct {
// Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-".
ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
}
// AddHostKey adds a private key as a host key. If an existing host
@ -162,6 +166,9 @@ type ServerConn struct {
// unsuccessful, it closes the connection and returns an error. The
// Request and NewChannel channels must be serviced, or the connection
// will hang.
//
// The returned error may be of type *ServerAuthError for
// authentication errors.
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
@ -252,7 +259,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true
}
return false
@ -288,12 +295,13 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
// ServerAuthError implements the error interface. It appends any authentication
// errors that may occur, and is returned if all of the authentication methods
// provided by the user failed to authenticate.
// ServerAuthError represents server authentication errors and is
// sometimes returned by NewServerConn. It appends any authentication
// errors that may occur, and is returned if all of the authentication
// methods provided by the user failed to authenticate.
type ServerAuthError struct {
// Errors contains authentication errors returned by the authentication
// callback methods.
// callback methods. The first entry is typically ErrNoAuth.
Errors []error
}
@ -305,6 +313,13 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]"
}
// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
@ -312,6 +327,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0
var authErrs []error
var displayedBanner bool
userAuthLoop:
for {
@ -343,8 +359,22 @@ userAuthLoop:
}
s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil
authErr := errors.New("no auth passed yet")
authErr := ErrNoAuth
switch userAuthReq.Method {
case "none":

View file

@ -406,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close()
}
var copyError error
for _ = range s.copyFuncs {
for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
copyError = err
}

View file

@ -1,774 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
crypto_rand "crypto/rand"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"testing"
"golang.org/x/crypto/ssh/terminal"
)
type serverType func(Channel, <-chan *Request, *testing.T)
// dial constructs a new test server and returns a *ClientConn.
func dial(handler serverType, t *testing.T) *Client {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go func() {
defer c1.Close()
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(testSigners["rsa"])
_, chans, reqs, err := NewServerConn(c1, &conf)
if err != nil {
t.Fatalf("Unable to handshake: %v", err)
}
go DiscardRequests(reqs)
for newCh := range chans {
if newCh.ChannelType() != "session" {
newCh.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch, inReqs, err := newCh.Accept()
if err != nil {
t.Errorf("Accept: %v", err)
continue
}
go func() {
handler(ch, inReqs, t)
}()
}
}()
config := &ClientConfig{
User: "testuser",
HostKeyCallback: InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("unable to dial remote side: %v", err)
}
return NewClient(conn, chans, reqs)
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// Test that a simple string is returned via the Output helper,
// and that stderr is discarded.
func TestSessionOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
w := "this-is-stdout."
g := string(buf)
if g != w {
t.Error("Remote command did not return expected string:")
t.Logf("want %q", w)
t.Logf("got %q", g)
}
}
// Test that both stdout and stderr are returned
// via the CombinedOutput helper.
func TestSessionCombinedOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
const stdout = "this-is-stdout."
const stderr = "this-is-stderr."
g := string(buf)
if g != stdout+stderr && g != stderr+stdout {
t.Error("Remote command did not return expected string:")
t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
t.Logf("got %q", g)
}
}
// Test non-0 exit status is returned correctly.
func TestExitStatusNonZero(t *testing.T) {
conn := dial(exitStatusNonZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
}
}
// Test 0 exit status is returned correctly.
func TestExitStatusZero(t *testing.T) {
conn := dial(exitStatusZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}
// Test exit signal and status are both returned correctly.
func TestExitSignalAndStatus(t *testing.T) {
conn := dial(exitSignalAndStatusHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestKnownExitSignalOnly(t *testing.T) {
conn := dial(exitSignalHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 143 {
t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestUnknownExitSignal(t *testing.T) {
conn := dial(exitSignalUnknownHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "SYS" || e.ExitStatus() != 128 {
t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
func TestExitWithoutStatusOrSignal(t *testing.T) {
conn := dial(exitWithoutSignalOrStatus, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
if _, ok := err.(*ExitMissingError); !ok {
t.Fatalf("got %T want *ExitMissingError", err)
}
}
// windowTestBytes is the number of bytes that we'll send to the SSH server.
const windowTestBytes = 16000 * 200
// TestServerWindow writes random data to the server. The server is expected to echo
// the same data back, which is compared against the original.
func TestServerWindow(t *testing.T) {
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
origBytes := origBuf.Bytes()
conn := dial(echoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
serverStdout, err := session.StdoutPipe()
if err != nil {
t.Errorf("StdoutPipe failed: %v", err)
return
}
n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
if err != nil && err != io.EOF {
t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
}
result <- echoedBuf.Bytes()
}()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Fatalf("failed to copy origBuf to serverStdin: %v", err)
}
if written != windowTestBytes {
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
}
echoedBytes := <-result
if !bytes.Equal(origBytes, echoedBytes) {
t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
}
}
// Verify the client can handle a keepalive packet from the server.
func TestClientHandlesKeepalives(t *testing.T) {
conn := dial(channelKeepaliveSender, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got: %v", err)
}
}
type exitStatusMsg struct {
Status uint32
}
type exitSignalMsg struct {
Signal string
CoreDumped bool
Errmsg string
Lang string
}
func handleTerminalRequests(in <-chan *Request) {
for req := range in {
ok := false
switch req.Type {
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
req.Reply(ok, nil)
}
}
func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
term := terminal.NewTerminal(ch, prompt)
go handleTerminalRequests(in)
return term
}
func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(0, ch, t)
}
func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
}
func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
sendSignal("TERM", ch, t)
}
func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("TERM", ch, t)
}
func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("SYS", ch, t)
}
func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
}
func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "golang")
readLine(shell, t)
sendStatus(0, ch, t)
}
// Ignores the command, writes fixed strings to stderr and stdout.
// Strings are "this-is-stdout." and "this-is-stderr.".
func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
_, err := ch.Read(nil)
req, ok := <-in
if !ok {
t.Fatalf("error: expected channel request, got: %#v", err)
return
}
// ignore request, always send some text
req.Reply(true, nil)
_, err = io.WriteString(ch, "this-is-stdout.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
sendStatus(0, ch, t)
}
func readLine(shell *terminal.Terminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
t.Errorf("unable to read line: %v", err)
}
}
func sendStatus(status uint32, ch Channel, t *testing.T) {
msg := exitStatusMsg{
Status: status,
}
if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
t.Errorf("unable to send status: %v", err)
}
}
func sendSignal(signal string, ch Channel, t *testing.T) {
sig := exitSignalMsg{
Signal: signal,
CoreDumped: false,
Errmsg: "Process terminated",
Lang: "en-GB-oed",
}
if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
t.Errorf("unable to send signal: %v", err)
}
}
func discardHandler(ch Channel, t *testing.T) {
defer ch.Close()
io.Copy(ioutil.Discard, ch)
}
func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
}
}
// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
// buffer size to exercise more code paths.
func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
var (
buf = make([]byte, 32*1024)
written int
remaining = n
)
for remaining > 0 {
l := rand.Intn(1 << 15)
if remaining < l {
l = remaining
}
nr, er := src.Read(buf[:l])
nw, ew := dst.Write(buf[:nr])
remaining -= nw
written += nw
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
if er != nil && er != io.EOF {
return written, er
}
}
return written, nil
}
func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
t.Errorf("unable to send channel keepalive request: %v", err)
}
sendStatus(0, ch, t)
}
func TestClientWriteEOF(t *testing.T) {
conn := dial(simpleEchoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("StdoutPipe failed: %v", err)
}
data := []byte(`0000`)
_, err = stdin.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
stdin.Close()
res, err := ioutil.ReadAll(stdout)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if !bytes.Equal(data, res) {
t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
}
}
func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
data, err := ioutil.ReadAll(ch)
if err != nil {
t.Errorf("handler read error: %v", err)
}
_, err = ch.Write(data)
if err != nil {
t.Errorf("handler write error: %v", err)
}
}
func TestSessionID(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverID := make(chan []byte, 1)
clientID := make(chan []byte, 1)
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: "user",
}
go func() {
conn, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Fatalf("server handshake: %v", err)
}
serverID <- conn.SessionID()
go DiscardRequests(reqs)
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
go func() {
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("client handshake: %v", err)
}
clientID <- conn.SessionID()
go DiscardRequests(reqs)
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
s := <-serverID
c := <-clientID
if bytes.Compare(s, c) != 0 {
t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
} else if len(s) == 0 {
t.Errorf("client and server SessionID were empty.")
}
}
type noReadConn struct {
readSeen bool
net.Conn
}
func (c *noReadConn) Close() error {
return nil
}
func (c *noReadConn) Read(b []byte) (int, error) {
c.readSeen = true
return 0, errors.New("noReadConn error")
}
func TestInvalidServerConfiguration(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serveConn := noReadConn{Conn: c1}
serverConf := &ServerConfig{}
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
}
serverConf.AddHostKey(testSigners["ecdsa"])
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
}
}
func TestHostKeyAlgorithms(t *testing.T) {
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
alg = key.Type()
return nil
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}
if alg != want {
t.Errorf("selected key algorithm %s, want %s", alg, want)
}
}
// By default, we get the preferred algorithm, which is ECDSA 256.
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
}
connect(clientConf, KeyAlgoECDSA256)
// Client asks for RSA explicitly.
clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
connect(clientConf, KeyAlgoRSA)
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
t.Fatal("succeeded connecting with unknown hostkey algorithm")
}
}

View file

@ -32,6 +32,7 @@ type streamLocalChannelForwardMsg struct {
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := streamLocalChannelForwardMsg{
socketPath,
}

View file

@ -90,10 +90,19 @@ type channelForwardMsg struct {
rport uint32
}
// handleForwards starts goroutines handling forwarded connections.
// It's called on first use by (*Client).ListenTCP to not launch
// goroutines until needed.
func (c *Client) handleForwards() {
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}

View file

@ -1,20 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
)
func TestAutoPortListenBroken(t *testing.T) {
broken := "SSH-2.0-OpenSSH_5.9hh11"
works := "SSH-2.0-OpenSSH_6.1"
if !isBrokenOpenSSHVersion(broken) {
t.Errorf("version %q not marked as broken", broken)
}
if isBrokenOpenSSHVersion(works) {
t.Errorf("version %q marked as broken", works)
}
}

View file

@ -617,7 +617,7 @@ func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
if _, err = w.Write(crlf); err != nil {
return n, err
}
n += 1
n++
buf = buf[1:]
}
}

View file

@ -1,350 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import (
"bytes"
"io"
"os"
"testing"
)
type MockTerminal struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockTerminal) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockTerminal) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func TestClose(t *testing.T) {
c := &MockTerminal{}
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != io.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err error
throwAwayLines int
}{
{
err: io.EOF,
},
{
in: "\r",
line: "",
},
{
in: "foo\r",
line: "foo",
},
{
in: "a\x1b[Cb\r", // right
line: "ab",
},
{
in: "a\x1b[Db\r", // left
line: "ba",
},
{
in: "a\177b\r", // backspace
line: "b",
},
{
in: "\x1b[A\r", // up
},
{
in: "\x1b[B\r", // down
},
{
in: "line\x1b[A\x1b[B\r", // up then down
line: "line",
},
{
in: "line1\rline2\x1b[A\r", // recall previous line.
line: "line1",
throwAwayLines: 1,
},
{
// recall two previous lines and append.
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
line: "line1xxx",
throwAwayLines: 2,
},
{
// Ctrl-A to move to beginning of line followed by ^K to kill
// line.
in: "a b \001\013\r",
line: "",
},
{
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
// finally ^K to kill nothing.
in: "a b \001\005\013\r",
line: "a b ",
},
{
in: "\027\r",
line: "",
},
{
in: "a\027\r",
line: "",
},
{
in: "a \027\r",
line: "",
},
{
in: "a b\027\r",
line: "a ",
},
{
in: "a b \027\r",
line: "a ",
},
{
in: "one two thr\x1b[D\027\r",
line: "one two r",
},
{
in: "\013\r",
line: "",
},
{
in: "a\013\r",
line: "a",
},
{
in: "ab\x1b[D\013\r",
line: "a",
},
{
in: "Ξεσκεπάζω\r",
line: "Ξεσκεπάζω",
},
{
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
line: "",
throwAwayLines: 1,
},
{
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
line: "£",
throwAwayLines: 1,
},
{
// Ctrl-D at the end of the line should be ignored.
in: "a\004\r",
line: "a",
},
{
// a, b, left, Ctrl-D should erase the b.
in: "ab\x1b[D\004\r",
line: "a",
},
{
// a, b, c, d, left, left, ^U should erase to the beginning of
// the line.
in: "abcd\x1b[D\x1b[D\025\r",
line: "cd",
},
{
// Bracketed paste mode: control sequences should be returned
// verbatim in paste mode.
in: "abc\x1b[200~de\177f\x1b[201~\177\r",
line: "abcde\177",
},
{
// Enter in bracketed paste mode should still work.
in: "abc\x1b[200~d\refg\x1b[201~h\r",
line: "efgh",
throwAwayLines: 1,
},
{
// Lines consisting entirely of pasted data should be indicated as such.
in: "\x1b[200~a\r",
line: "a",
err: ErrPasteIndicator,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 1; j < len(test.in); j++ {
c := &MockTerminal{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
for k := 0; k < test.throwAwayLines; k++ {
_, err := ss.ReadLine()
if err != nil {
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
}
}
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}
func TestPasswordNotSaved(t *testing.T) {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
pw, _ := ss.ReadPassword("> ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
line, _ := ss.ReadLine()
if len(line) > 0 {
t.Fatalf("password was saved in history")
}
}
var setSizeTests = []struct {
width, height int
}{
{40, 13},
{80, 24},
{132, 43},
}
func TestTerminalSetSize(t *testing.T) {
for _, setSize := range setSizeTests {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
ss.SetSize(setSize.width, setSize.height)
pw, _ := ss.ReadPassword("Password: ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
if string(c.received) != "Password: \r\n" {
t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received)
}
}
}
func TestReadPasswordLineEnd(t *testing.T) {
var tests = []struct {
input string
want string
}{
{"\n", ""},
{"\r\n", ""},
{"test\r\n", "test"},
{"testtesttesttes\n", "testtesttesttes"},
{"testtesttesttes\r\n", "testtesttesttes"},
{"testtesttesttesttest\n", "testtesttesttesttest"},
{"testtesttesttesttest\r\n", "testtesttesttesttest"},
}
for _, test := range tests {
buf := new(bytes.Buffer)
if _, err := buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err := readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
if _, err = buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err = readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
}
}
func TestMakeRawState(t *testing.T) {
fd := int(os.Stdout.Fd())
if !IsTerminal(fd) {
t.Skip("stdout is not a terminal; skipping test")
}
st, err := GetState(fd)
if err != nil {
t.Fatalf("failed to get terminal state from GetState: %s", err)
}
defer Restore(fd, st)
raw, err := MakeRaw(fd)
if err != nil {
t.Fatalf("failed to get terminal state from MakeRaw: %s", err)
}
if *st != *raw {
t.Errorf("states do not match; was %v, expected %v", raw, st)
}
}
func TestOutputNewlines(t *testing.T) {
// \n should be changed to \r\n in terminal output.
buf := new(bytes.Buffer)
term := NewTerminal(buf, ">")
term.Write([]byte("1\n2\n"))
output := string(buf.Bytes())
const expected = "1\r\n2\r\n"
if output != expected {
t.Errorf("incorrect output: was %q, expected %q", output, expected)
}
}

View file

@ -108,9 +108,7 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err
}
defer func() {
unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
}()
defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
return readPasswordLine(passwordReader(fd))
}

View file

@ -14,7 +14,7 @@ import (
// State contains the state of a terminal.
type State struct {
state *unix.Termios
termios unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
@ -75,47 +75,43 @@ func ReadPassword(fd int) ([]byte, error) {
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldTermios := *oldTermiosPtr
newTermios := oldTermios
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
oldState := State{termios: *termios}
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
return &State{termios: *termios}, nil
}
// GetSize returns the dimensions of the given terminal.

View file

@ -17,6 +17,8 @@
package terminal
import (
"os"
"golang.org/x/sys/windows"
)
@ -71,13 +73,6 @@ func GetSize(fd int) (width, height int, err error) {
return int(info.Size.X), int(info.Size.Y), nil
}
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return windows.Read(windows.Handle(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
@ -94,9 +89,15 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err
}
defer func() {
windows.SetConsoleMode(windows.Handle(fd), old)
}()
defer windows.SetConsoleMode(windows.Handle(fd), old)
return readPasswordLine(passwordReader(fd))
var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}
f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}

View file

@ -1,63 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.
package ssh
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/ssh/testdata"
)
var (
testPrivateKeys map[string]interface{}
testSigners map[string]Signer
testPublicKeys map[string]PublicKey
)
func init() {
var err error
n := len(testdata.PEMBytes)
testPrivateKeys = make(map[string]interface{}, n)
testSigners = make(map[string]Signer, n)
testPublicKeys = make(map[string]PublicKey, n)
for t, k := range testdata.PEMBytes {
testPrivateKeys[t], err = ParseRawPrivateKey(k)
if err != nil {
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
}
testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
if err != nil {
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
}
testPublicKeys[t] = testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: testPublicKeys["ecdsa"],
SignatureKey: testPublicKeys["rsa"],
Permissions: Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
testCert.SignCert(rand.Reader, testSigners["rsa"])
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
if err != nil {
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
}
}

View file

@ -6,6 +6,7 @@ package ssh
import (
"bufio"
"bytes"
"errors"
"io"
"log"
@ -76,17 +77,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err
} else {
t.reader.pendingKeyChange <- ciph
}
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err
} else {
t.writer.pendingKeyChange <- ciph
}
t.writer.pendingKeyChange <- ciph
return nil
}
@ -139,7 +140,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
return nil, errors.New("ssh: got bogus newkeys message")
}
case msgDisconnect:
@ -232,52 +233,22 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// generateKeys generates key material for IV, MAC and encryption.
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
iv = make([]byte, cipherMode.ivSize)
key = make([]byte, cipherMode.keySize)
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
return
}
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
iv, key, macKey := generateKeys(d, algs, kex)
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key)
}
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macMode.keySize)
if algs.Cipher == aes128cbcID {
return newAESCBCCipher(iv, key, macKey, algs)
}
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
if algs.Cipher == tripledescbcID {
return newTripleDESCBCCipher(iv, key, macKey, algs)
}
c := &streamPacketCipher{
mac: macModes[algs.MAC].new(macKey),
etm: macModes[algs.MAC].etm,
}
c.macResult = make([]byte, c.mac.Size())
var err error
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
if err != nil {
return nil, err
}
return c, nil
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
}
// generateKeyMaterial fills out with key material generated from tag, K, H
@ -342,7 +313,7 @@ func readVersion(r io.Reader) ([]byte, error) {
var ok bool
var buf [1]byte
for len(versionString) < maxVersionStringBytes {
for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
@ -350,6 +321,13 @@ func readVersion(r io.Reader) ([]byte, error) {
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true
break
}

View file

@ -1,109 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"encoding/binary"
"strings"
"testing"
)
func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion,
}
for in, want := range cases {
result, err := readVersion(bytes.NewBufferString(in))
if err != nil {
t.Errorf("readVersion(%q): %s", in, err)
}
got := string(result)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := []string{
longversion + "too-long\r\n",
}
for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
t.Errorf("readVersion(%q) should have failed", in)
}
}
}
func TestExchangeVersionsBasic(t *testing.T) {
v := "SSH-2.0-bla"
buf := bytes.NewBufferString(v + "\r\n")
them, err := exchangeVersions(buf, []byte("xyz"))
if err != nil {
t.Errorf("exchangeVersions: %v", err)
}
if want := "SSH-2.0-bla"; string(them) != want {
t.Errorf("got %q want %q for our version", them, want)
}
}
func TestExchangeVersions(t *testing.T) {
cases := []string{
"not\x000allowed",
"not allowed\n",
}
for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
if _, err := exchangeVersions(buf, []byte(c)); err == nil {
t.Errorf("exchangeVersions(%q): should have failed", c)
}
}
}
type closerBuffer struct {
bytes.Buffer
}
func (b *closerBuffer) Close() error {
return nil
}
func TestTransportMaxPacketWrite(t *testing.T) {
buf := &closerBuffer{}
tr := newTransport(buf, rand.Reader, true)
huge := make([]byte, maxPacket+1)
err := tr.writePacket(huge)
if err == nil {
t.Errorf("transport accepted write for a huge packet.")
}
}
func TestTransportMaxPacketReader(t *testing.T) {
var header [5]byte
huge := make([]byte, maxPacket+128)
binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
// padding.
header[4] = 0
buf := &closerBuffer{}
buf.Write(header[:])
buf.Write(huge)
tr := newTransport(buf, rand.Reader, true)
_, err := tr.readPacket()
if err == nil {
t.Errorf("transport succeeded reading huge packet.")
} else if !strings.Contains(err.Error(), "large") {
t.Errorf("got %q, should mention %q", err.Error(), "large")
}
}