Update go dependencies

This commit is contained in:
Manuel de Brito Fontes 2018-05-26 11:27:53 -04:00 committed by Manuel Alejandro de Brito Fontes
parent 15ffb51394
commit bb4d483837
No known key found for this signature in database
GPG key ID: 786136016A8BA02A
1621 changed files with 86368 additions and 284392 deletions

View file

@ -4,16 +4,15 @@ Go is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
## Filing issues
When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions:
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
@ -23,9 +22,5 @@ The gophers there will answer or ask you to file an issue if you've tripped over
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches.
**We do not accept GitHub pull requests**
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file.

View file

@ -171,9 +171,16 @@ func Verify(publicKey PublicKey, message, sig []byte) bool {
edwards25519.ScReduce(&hReduced, &digest)
var R edwards25519.ProjectiveGroupElement
var b [32]byte
copy(b[:], sig[32:])
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &b)
var s [32]byte
copy(s[:], sig[32:])
// https://tools.ietf.org/html/rfc8032#section-5.1.7 requires that s be in
// the range [0, order) in order to prevent signature malleability.
if !edwards25519.ScMinimal(&s) {
return false
}
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &s)
var checkR [32]byte
R.ToBytes(&checkR)

View file

@ -1,183 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ed25519
import (
"bufio"
"bytes"
"compress/gzip"
"crypto"
"crypto/rand"
"encoding/hex"
"os"
"strings"
"testing"
"golang.org/x/crypto/ed25519/internal/edwards25519"
)
type zeroReader struct{}
func (zeroReader) Read(buf []byte) (int, error) {
for i := range buf {
buf[i] = 0
}
return len(buf), nil
}
func TestUnmarshalMarshal(t *testing.T) {
pub, _, _ := GenerateKey(rand.Reader)
var A edwards25519.ExtendedGroupElement
var pubBytes [32]byte
copy(pubBytes[:], pub)
if !A.FromBytes(&pubBytes) {
t.Fatalf("ExtendedGroupElement.FromBytes failed")
}
var pub2 [32]byte
A.ToBytes(&pub2)
if pubBytes != pub2 {
t.Errorf("FromBytes(%v)->ToBytes does not round-trip, got %x\n", pubBytes, pub2)
}
}
func TestSignVerify(t *testing.T) {
var zero zeroReader
public, private, _ := GenerateKey(zero)
message := []byte("test message")
sig := Sign(private, message)
if !Verify(public, message, sig) {
t.Errorf("valid signature rejected")
}
wrongMessage := []byte("wrong message")
if Verify(public, wrongMessage, sig) {
t.Errorf("signature of different message accepted")
}
}
func TestCryptoSigner(t *testing.T) {
var zero zeroReader
public, private, _ := GenerateKey(zero)
signer := crypto.Signer(private)
publicInterface := signer.Public()
public2, ok := publicInterface.(PublicKey)
if !ok {
t.Fatalf("expected PublicKey from Public() but got %T", publicInterface)
}
if !bytes.Equal(public, public2) {
t.Errorf("public keys do not match: original:%x vs Public():%x", public, public2)
}
message := []byte("message")
var noHash crypto.Hash
signature, err := signer.Sign(zero, message, noHash)
if err != nil {
t.Fatalf("error from Sign(): %s", err)
}
if !Verify(public, message, signature) {
t.Errorf("Verify failed on signature from Sign()")
}
}
func TestGolden(t *testing.T) {
// sign.input.gz is a selection of test cases from
// https://ed25519.cr.yp.to/python/sign.input
testDataZ, err := os.Open("testdata/sign.input.gz")
if err != nil {
t.Fatal(err)
}
defer testDataZ.Close()
testData, err := gzip.NewReader(testDataZ)
if err != nil {
t.Fatal(err)
}
defer testData.Close()
scanner := bufio.NewScanner(testData)
lineNo := 0
for scanner.Scan() {
lineNo++
line := scanner.Text()
parts := strings.Split(line, ":")
if len(parts) != 5 {
t.Fatalf("bad number of parts on line %d", lineNo)
}
privBytes, _ := hex.DecodeString(parts[0])
pubKey, _ := hex.DecodeString(parts[1])
msg, _ := hex.DecodeString(parts[2])
sig, _ := hex.DecodeString(parts[3])
// The signatures in the test vectors also include the message
// at the end, but we just want R and S.
sig = sig[:SignatureSize]
if l := len(pubKey); l != PublicKeySize {
t.Fatalf("bad public key length on line %d: got %d bytes", lineNo, l)
}
var priv [PrivateKeySize]byte
copy(priv[:], privBytes)
copy(priv[32:], pubKey)
sig2 := Sign(priv[:], msg)
if !bytes.Equal(sig, sig2[:]) {
t.Errorf("different signature result on line %d: %x vs %x", lineNo, sig, sig2)
}
if !Verify(pubKey, msg, sig2) {
t.Errorf("signature failed to verify on line %d", lineNo)
}
}
if err := scanner.Err(); err != nil {
t.Fatalf("error reading test data: %s", err)
}
}
func BenchmarkKeyGeneration(b *testing.B) {
var zero zeroReader
for i := 0; i < b.N; i++ {
if _, _, err := GenerateKey(zero); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSigning(b *testing.B) {
var zero zeroReader
_, priv, err := GenerateKey(zero)
if err != nil {
b.Fatal(err)
}
message := []byte("Hello, world!")
b.ResetTimer()
for i := 0; i < b.N; i++ {
Sign(priv, message)
}
}
func BenchmarkVerification(b *testing.B) {
var zero zeroReader
pub, priv, err := GenerateKey(zero)
if err != nil {
b.Fatal(err)
}
message := []byte("Hello, world!")
signature := Sign(priv, message)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Verify(pub, message, signature)
}
}

View file

@ -4,6 +4,8 @@
package edwards25519
import "encoding/binary"
// This code is a port of the public domain, “ref10” implementation of ed25519
// from SUPERCOP.
@ -1769,3 +1771,23 @@ func ScReduce(out *[32]byte, s *[64]byte) {
out[30] = byte(s11 >> 9)
out[31] = byte(s11 >> 17)
}
// order is the order of Curve25519 in little-endian form.
var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000000}
// ScMinimal returns true if the given scalar is less than the order of the
// curve.
func ScMinimal(scalar *[32]byte) bool {
for i := 3; ; i-- {
v := binary.LittleEndian.Uint64(scalar[i*8:])
if v > order[i] {
return false
} else if v < order[i] {
break
} else if i == 0 {
return false
}
}
return true
}

View file

@ -1,122 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
"testing"
)
type server struct {
*ServerConn
chans <-chan NewChannel
}
func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
sconn, chans, reqs, err := NewServerConn(c, conf)
if err != nil {
return nil, err
}
go DiscardRequests(reqs)
return &server{sconn, chans}, nil
}
func (s *server) Accept() (NewChannel, error) {
n, ok := <-s.chans
if !ok {
return nil, io.EOF
}
return n, nil
}
func sshPipe() (Conn, *server, error) {
c1, c2, err := netPipe()
if err != nil {
return nil, nil, err
}
clientConf := ClientConfig{
User: "user",
}
serverConf := ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["ecdsa"])
done := make(chan *server, 1)
go func() {
server, err := newServer(c2, &serverConf)
if err != nil {
done <- nil
}
done <- server
}()
client, _, reqs, err := NewClientConn(c1, "", &clientConf)
if err != nil {
return nil, nil, err
}
server := <-done
if server == nil {
return nil, nil, errors.New("server handshake failed.")
}
go DiscardRequests(reqs)
return client, server, nil
}
func BenchmarkEndToEnd(b *testing.B) {
b.StopTimer()
client, server, err := sshPipe()
if err != nil {
b.Fatalf("sshPipe: %v", err)
}
defer client.Close()
defer server.Close()
size := (1 << 20)
input := make([]byte, size)
output := make([]byte, size)
b.SetBytes(int64(size))
done := make(chan int, 1)
go func() {
newCh, err := server.Accept()
if err != nil {
b.Fatalf("Client: %v", err)
}
ch, incoming, err := newCh.Accept()
go DiscardRequests(incoming)
for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(ch, output); err != nil {
b.Fatalf("ReadFull: %v", err)
}
}
ch.Close()
done <- 1
}()
ch, in, err := client.OpenChannel("speed", nil)
if err != nil {
b.Fatalf("OpenChannel: %v", err)
}
go DiscardRequests(in)
b.ResetTimer()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := ch.Write(input); err != nil {
b.Fatalf("WriteFull: %v", err)
}
}
ch.Close()
b.StopTimer()
<-done
}

View file

@ -1,87 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"testing"
)
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
func TestBufferReadwrite(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
r, _ := b.Read(make([]byte, 10))
if r != 10 {
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
r, _ = b.Read(make([]byte, 10))
if r != 5 {
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
}
b = newBuffer()
b.write(alphabet[:10])
r, _ = b.Read(make([]byte, 5))
if r != 5 {
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
b.write(alphabet[5:15])
r, _ = b.Read(make([]byte, 10))
r2, _ := b.Read(make([]byte, 10))
if r != 10 || r2 != 5 || 15 != r+r2 {
t.Fatal("Expected written == read == 15")
}
}
func TestBufferClose(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
b.eof()
_, err := b.Read(make([]byte, 5))
if err != nil {
t.Fatal("expected read of 5 to not return EOF")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err := b.Read(make([]byte, 5))
r2, err2 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || err != nil || err2 != nil {
t.Fatal("expected reads of 5 and 5")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err = b.Read(make([]byte, 5))
r2, err2 = b.Read(make([]byte, 10))
r3, err3 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
t.Fatal("expected reads of 5 and 5 and 0, with EOF")
}
b = newBuffer()
b.write(make([]byte, 5))
b.write(make([]byte, 10))
b.eof()
r, err = b.Read(make([]byte, 9))
r2, err2 = b.Read(make([]byte, 3))
r3, err3 = b.Read(make([]byte, 3))
r4, err4 := b.Read(make([]byte, 10))
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
}
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
t.Fatal("Expected written == read == 15", r, r2, r3, r4)
}
}

View file

@ -44,7 +44,9 @@ type Signature struct {
const CertTimeInfinity = 1<<64 - 1
// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
// PublicKey interface, so it can be unmarshaled using
// ParsePublicKey.
type Certificate struct {
Nonce []byte
Key PublicKey
@ -340,10 +342,10 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
// the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
}
for opt, _ := range cert.CriticalOptions {
for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by
// serverAuthenticate
if opt == sourceAddressCriticalOption {

View file

@ -1,222 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"reflect"
"testing"
"time"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
// % ssh-keygen -s ca-key -I test user-key
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
if _, ok := key.(*Certificate); !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
// Critical Options:
// force-command /bin/sleep
// source-address 192.168.1.0/24
// Extensions:
// permit-X11-forwarding
// permit-agent-forwarding
// permit-port-forwarding
// permit-pty
// permit-user-rc
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ`
func TestParseCertWithOptions(t *testing.T) {
opts := map[string]string{
"source-address": "192.168.1.0/24",
"force-command": "/bin/sleep",
}
exts := map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}
authKeyBytes := []byte(exampleSSHCertWithOptions)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
cert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
if !reflect.DeepEqual(cert.CriticalOptions, opts) {
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts)
}
if !reflect.DeepEqual(cert.Extensions, exts) {
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
func TestValidateCert(t *testing.T) {
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
validCert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
if err := checker.CheckCert("user", validCert); err != nil {
t.Errorf("Unable to validate certificate: %v", err)
}
invalidCert := &Certificate{
Key: testPublicKeys["rsa"],
SignatureKey: testPublicKeys["ecdsa"],
ValidBefore: CertTimeInfinity,
Signature: &Signature{},
}
if err := checker.CheckCert("user", invalidCert); err == nil {
t.Error("Invalid cert signature passed validation")
}
}
func TestValidateCertTime(t *testing.T) {
cert := Certificate{
ValidPrincipals: []string{"user"},
Key: testPublicKeys["rsa"],
ValidAfter: 50,
ValidBefore: 100,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
for ts, ok := range map[int64]bool{
25: false,
50: true,
99: true,
100: false,
125: false,
} {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
if v := checker.CheckCert("user", &cert); (v == nil) != ok {
t.Errorf("Authenticate(%d): %v", ts, v)
}
}
}
// TODO(hanwen): tests for
//
// host keys:
// * fallbacks
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
IsHostAuthority: func(p PublicKey, addr string) bool {
return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Errorf("NewCertSigner: %v", err)
}
for _, test := range []struct {
addr string
succeed bool
}{
{addr: "hostname:22", succeed: true},
{addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
{addr: "lasthost:22", succeed: false},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
errc := make(chan error)
go func() {
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(certSigner)
_, _, _, err := NewServerConn(c1, &conf)
errc <- err
}()
config := &ClientConfig{
User: "user",
HostKeyCallback: checker.CheckHostKey,
}
_, _, _, err = NewClientConn(c2, test.addr, config)
if (err == nil) != test.succeed {
t.Fatalf("NewClientConn(%q): %v", test.addr, err)
}
err = <-errc
if (err == nil) != test.succeed {
t.Fatalf("NewServerConn(%q): %v", test.addr, err)
}
}
}

View file

@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error {
c.writeMu.Lock()
if c.sentClose {
c.writeMu.Unlock()
func (ch *channel) writePacket(packet []byte) error {
ch.writeMu.Lock()
if ch.sentClose {
ch.writeMu.Unlock()
return io.EOF
}
c.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet)
c.writeMu.Unlock()
ch.sentClose = (packet[0] == msgChannelClose)
err := ch.mux.conn.writePacket(packet)
ch.writeMu.Unlock()
return err
}
func (c *channel) sendMessage(msg interface{}) error {
func (ch *channel) sendMessage(msg interface{}) error {
if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg)
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
}
p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId)
return c.writePacket(p)
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return ch.writePacket(p)
}
// WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF {
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if ch.sentEOF {
return 0, io.EOF
}
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData
}
c.writeMu.Lock()
packet := c.packetPool[extendedCode]
ch.writeMu.Lock()
packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be
// flagged as errors by the race detector.
c.writeMu.Unlock()
ch.writeMu.Unlock()
for len(data) > 0 {
space := min(c.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil {
space := min(ch.maxRemotePayload, len(data))
if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err
}
if want := headerLength + space; uint32(cap(packet)) < want {
@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space]
packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId)
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
}
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil {
if err = ch.writePacket(packet); err != nil {
return n, err
}
@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):]
}
c.writeMu.Lock()
c.packetPool[extendedCode] = packet
c.writeMu.Unlock()
ch.writeMu.Lock()
ch.packetPool[extendedCode] = packet
ch.writeMu.Unlock()
return n, err
}
func (c *channel) handleData(packet []byte) error {
func (ch *channel) handleData(packet []byte) error {
headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData {
@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 {
return nil
}
if length > c.maxIncomingPayload {
if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size")
}
@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length")
}
c.windowMu.Lock()
if c.myWindow < length {
c.windowMu.Unlock()
ch.windowMu.Lock()
if ch.myWindow < length {
ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much")
}
c.myWindow -= length
c.windowMu.Unlock()
ch.myWindow -= length
ch.windowMu.Unlock()
if extended == 1 {
c.extPending.write(data)
ch.extPending.write(data)
} else if extended > 0 {
// discard other extended data.
} else {
c.pending.write(data)
ch.pending.write(data)
}
return nil
}
@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the
// given channel.
func (c *channel) responseMessageReceived() error {
if c.direction == channelInbound {
func (ch *channel) responseMessageReceived() error {
if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel")
}
if c.decided {
if ch.decided {
return errors.New("ssh: duplicate response received for channel")
}
c.decided = true
ch.decided = true
return nil
}
func (c *channel) handlePacket(packet []byte) error {
func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] {
case msgChannelData, msgChannelExtendedData:
return c.handleData(packet)
return ch.handleData(packet)
case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
c.mux.chanList.remove(c.localId)
c.close()
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
ch.mux.chanList.remove(ch.localId)
ch.close()
return nil
case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
c.extPending.eof()
c.pending.eof()
ch.extPending.eof()
ch.pending.eof()
return nil
}
@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) {
case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
c.mux.chanList.remove(msg.PeersId)
c.msg <- msg
ch.mux.chanList.remove(msg.PeersID)
ch.msg <- msg
case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
}
c.remoteId = msg.MyId
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
c.msg <- msg
ch.remoteId = msg.MyID
ch.maxRemotePayload = msg.MaxPacketSize
ch.remoteWin.add(msg.MyWindow)
ch.msg <- msg
case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) {
if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
}
case *channelRequestMsg:
@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request,
WantReply: msg.WantReply,
Payload: msg.RequestSpecificData,
ch: c,
ch: ch,
}
c.incomingRequests <- &req
ch.incomingRequests <- &req
default:
c.msg <- msg
ch.msg <- msg
}
return nil
}
@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code)
}
func (c *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided {
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if ch.decided {
return nil, nil, errDecidedAlready
}
c.maxIncomingPayload = channelMaxPacket
ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxIncomingPayload,
PeersID: ch.remoteId,
MyID: ch.localId,
MyWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
}
c.decided = true
if err := c.sendMessage(confirm); err != nil {
ch.decided = true
if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err
}
return c, c.incomingRequests, nil
return ch, ch.incomingRequests, nil
}
func (ch *channel) Reject(reason RejectionReason, message string) error {
@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready
}
reject := channelOpenFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Reason: reason,
Message: message,
Language: "en",
@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
}
ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
func (ch *channel) Close() error {
@ -550,7 +550,7 @@ func (ch *channel) Close() error {
}
return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
// Extended returns an io.ReadWriter that sends and receives data on the given,
@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
}
msg := channelRequestMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Request: name,
WantReply: wantReply,
RequestSpecificData: payload,
@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{}
if !ok {
msg = channelRequestFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
} else {
msg = channelRequestSuccessMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
}
return ch.sendMessage(msg)

View file

@ -16,6 +16,10 @@ import (
"hash"
"io"
"io/ioutil"
"math/bits"
"golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305"
)
const (
@ -53,78 +57,78 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type streamCipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
}
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
var streamDump []byte
if c.skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
var streamDump []byte
if skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
mac := macModes[algs.MAC].new(macKey)
return &streamPacketCipher{
mac: mac,
etm: macModes[algs.MAC].etm,
macResult: make([]byte, mac.Size()),
cipher: stream,
}, nil
}
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*streamCipherMode{
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, 0, newRC4},
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AES-GCM is not a stream cipher, so it is constructed with a
// special case. If we add any more non-stream ciphers, we
// should invest a cleaner way to do this.
gcmCipherID: {16, 12, 0, nil},
// AEAD ciphers
gcmCipherID: {16, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config.
// (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if
// you do.
aes128cbcID: {16, aes.BlockSize, 0, nil},
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is disabled by default.
tripledescbcID: {24, des.BlockSize, 0, nil},
// 3des-cbc is insecure and is not included in the default
// config.
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
}
// prefixLen is the length of the packet prefix that contains the packet length
@ -304,7 +308,7 @@ type gcmCipher struct {
buf []byte
}
func newGCMCipher(iv, key []byte) (packetCipher, error) {
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
@ -372,7 +376,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
}
length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.")
return nil, errors.New("ssh: max packet length exceeded")
}
if cap(c.buf) < int(length+gcmTagSize) {
@ -422,7 +426,7 @@ type cbcCipher struct {
oracleCamouflage uint32
}
func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
@ -436,13 +440,13 @@ func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorith
return cbc, nil
}
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, iv, key, macKey, algs)
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
@ -450,13 +454,13 @@ func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCi
return cbc, nil
}
func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, iv, key, macKey, algs)
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
@ -548,11 +552,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize]
}
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err
} else {
c.oracleCamouflage -= uint32(n)
}
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
@ -627,3 +631,140 @@ func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, pack
return nil
}
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here:
//
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
//
// the methods here also implement padding, which RFC4253 Section 6
// also requires of stream ciphers.
type chacha20Poly1305Cipher struct {
lengthKey [8]uint32
contentKey [8]uint32
buf []byte
}
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
c := &chacha20Poly1305Cipher{
buf: make([]byte, 256),
}
for i := range c.contentKey {
c.contentKey[i] = binary.LittleEndian.Uint32(key[i*4 : (i+1)*4])
}
for i := range c.lengthKey {
c.lengthKey[i] = binary.LittleEndian.Uint32(key[(i+8)*4 : (i+9)*4])
}
return c, nil
}
func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
encryptedLength := c.buf[:4]
if _, err := io.ReadFull(r, encryptedLength); err != nil {
return nil, err
}
var lenBytes [4]byte
chacha20.New(c.lengthKey, nonce).XORKeyStream(lenBytes[:], encryptedLength)
length := binary.BigEndian.Uint32(lenBytes[:])
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
contentEnd := 4 + length
packetEnd := contentEnd + poly1305.TagSize
if uint32(cap(c.buf)) < packetEnd {
c.buf = make([]byte, packetEnd)
copy(c.buf[:], encryptedLength)
} else {
c.buf = c.buf[:packetEnd]
}
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
return nil, err
}
var mac [poly1305.TagSize]byte
copy(mac[:], c.buf[contentEnd:packetEnd])
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
return nil, errors.New("ssh: MAC failure")
}
plain := c.buf[4:contentEnd]
s.XORKeyStream(plain, plain)
padding := plain[0]
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding)+1 >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : len(plain)-int(padding)]
return plain, nil
}
func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
// There is no blocksize, so fall back to multiple of 8 byte
// padding, as described in RFC 4253, Sec 6.
const packetSizeMultiple = 8
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
if padding < 4 {
padding += packetSizeMultiple
}
// size (4 bytes), padding (1), payload, padding, tag.
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
if cap(c.buf) < totalLength {
c.buf = make([]byte, totalLength)
} else {
c.buf = c.buf[:totalLength]
}
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
chacha20.New(c.lengthKey, nonce).XORKeyStream(c.buf, c.buf[:4])
c.buf[4] = byte(padding)
copy(c.buf[5:], payload)
packetEnd := 5 + len(payload) + padding
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
return err
}
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
var mac [poly1305.TagSize]byte
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
copy(c.buf[packetEnd:], mac[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
return nil
}

View file

@ -1,129 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/rand"
"testing"
)
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range supportedCiphers {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
func TestPacketCiphers(t *testing.T) {
// Still test aes128cbc cipher although it's commented out.
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
defer delete(cipherModes, aes128cbcID)
for cipher := range cipherModes {
for mac := range macModes {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: cipher,
MAC: mac,
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q, %q): %v", cipher, mac, err)
continue
}
packet, err := server.readPacket(0, buf)
if err != nil {
t.Errorf("readPacket(%q, %q): %v", cipher, mac, err)
continue
}
if string(packet) != want {
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
}
}
}
}
func TestCBCOracleCounterMeasure(t *testing.T) {
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
defer delete(cipherModes, aes128cbcID)
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: aes128cbcID,
MAC: "hmac-sha1",
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket: %v", err)
}
packetSize := buf.Len()
buf.Write(make([]byte, 2*maxPacket))
// We corrupt each byte, but this usually will only test the
// 'packet too large' or 'MAC failure' cases.
lastRead := -1
for i := 0; i < packetSize; i++ {
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
fresh := &bytes.Buffer{}
fresh.Write(buf.Bytes())
fresh.Bytes()[i] ^= 0x01
before := fresh.Len()
_, err = server.readPacket(0, fresh)
if err == nil {
t.Errorf("corrupt byte %d: readPacket succeeded ", i)
continue
}
if _, ok := err.(cbcError); !ok {
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
continue
}
after := fresh.Len()
bytesRead := before - after
if bytesRead < maxPacket {
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
continue
}
if i > 0 && bytesRead != lastRead {
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
}
lastRead = bytesRead
}
}

View file

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
@ -18,6 +19,8 @@ import (
type Client struct {
Conn
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex
channelHandlers map[string]chan NewChannel
@ -59,8 +62,6 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.Wait()
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn
}
@ -187,6 +188,10 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
// net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
@ -209,6 +214,12 @@ type ClientConfig struct {
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
ClientVersion string
@ -255,3 +266,13 @@ func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View file

@ -11,6 +11,14 @@ import (
"io"
)
type authResult int
const (
authFailure authResult = iota
authPartialSuccess
authSuccess
)
// clientAuthenticate authenticates with the remote server. See RFC 4252.
func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
@ -37,11 +45,12 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
if err != nil {
return err
}
if ok {
if ok == authSuccess {
// success
return nil
} else if ok == authFailure {
tried[auth.method()] = true
}
tried[auth.method()] = true
if methods == nil {
methods = lastMethods
}
@ -82,7 +91,7 @@ type AuthMethod interface {
// If authentication is not successful, a []string of alternative
// method names is returned. If the slice is nil, it will be ignored
// and the previous set of possible methods will be reused.
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
// method returns the RFC 4252 method name.
method() string
@ -91,13 +100,13 @@ type AuthMethod interface {
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -111,7 +120,7 @@ func (n *noneAuth) method() string {
// a function call, e.g. by prompting the user.
type passwordCallback func() (password string, err error)
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type passwordAuthMsg struct {
User string `sshtype:"50"`
Service string
@ -125,7 +134,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
// The program may only find out that the user doesn't have a password
// when prompting.
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if err := c.writePacket(Marshal(&passwordAuthMsg{
@ -135,7 +144,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
Reply: false,
Password: pw,
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -178,7 +187,7 @@ func (cb publicKeyCallback) method() string {
return "publickey"
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
// Authentication is performed by sending an enquiry to test if a key is
// acceptable to the remote. If the key is acceptable, the client will
// attempt to authenticate with the valid key. If not the client will repeat
@ -186,13 +195,13 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
signers, err := cb()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
var methods []string
for _, signer := range signers {
ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if !ok {
continue
@ -206,7 +215,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
Method: cb.method(),
}, []byte(pub.Type()), pubKey))
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// manually wrap the serialized signature in a string
@ -224,24 +233,24 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
}
p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
return authFailure, nil, err
}
var success bool
var success authResult
success, methods, err = handleAuthResponse(c)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success || !containsMethod(methods, cb.method()) {
if success == authSuccess || !containsMethod(methods, cb.method()) {
return success, methods, err
}
}
return false, methods, nil
return authFailure, methods, nil
}
func containsMethod(methods []string, method string) bool {
@ -283,7 +292,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
}
switch packet[0] {
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil {
@ -316,30 +327,53 @@ func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMet
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (bool, []string, error) {
func handleAuthResponse(c packetConn) (authResult, []string, error) {
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After
@ -359,7 +393,7 @@ func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type initiateMsg struct {
User string `sshtype:"50"`
Service string
@ -373,37 +407,42 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
// TODO: Print banners during userauth.
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
// Manually unpack the prompt/echo pairs.
@ -413,7 +452,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return false, nil, errors.New("ssh: prompt format error")
return authFailure, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
@ -421,16 +460,16 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if len(rest) != 0 {
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
answers, err := cb(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if len(answers) != len(prompts) {
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
return authFailure, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
}
responseLength := 1 + 4
for _, a := range answers {
@ -446,7 +485,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if err := c.writePacket(serialized); err != nil {
return false, nil, err
return authFailure, nil, err
}
}
}
@ -456,10 +495,10 @@ type retryableAuthMethod struct {
maxTries int
}
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) {
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
ok, methods, err = r.authMethod.auth(session, user, c, rand)
if ok || err != nil { // either success or error terminate
if ok != authFailure || err != nil { // either success, partial success or error terminate
return ok, methods, err
}
}

View file

@ -1,628 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"os"
"strings"
"testing"
)
type keyboardInteractive map[string]string
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
var answers []string
for _, q := range questions {
answers = append(answers, cr[q])
}
return answers, nil
}
// reused internally by tests
var clientPassword = "tiger"
// tryAuth runs a handshake with a given config against an SSH server
// with config serverConfig
func tryAuth(t *testing.T, config *ClientConfig) error {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
certChecker := CertChecker{
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
IsRevoked: func(c *Certificate) bool {
return c.Serial == 666
},
}
serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
PublicKeyCallback: certChecker.Authenticate,
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
ans, err := challenge("user",
"instruction",
[]string{"question1", "question2"},
[]bool{true, true})
if err != nil {
return nil, err
}
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
if ok {
challenge("user", "motd", nil, nil)
return nil, nil
}
return nil, errors.New("keyboard-interactive failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", config)
return err
}
func TestClientAuthPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodFallback(t *testing.T) {
var passwordCalled bool
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
PasswordCallback(
func() (string, error) {
passwordCalled = true
return "WRONG", nil
}),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if passwordCalled {
t.Errorf("password auth tried before public-key auth.")
}
}
func TestAuthMethodWrongPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password("wrong"),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "answer2",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "WRONG",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
}
}
// the mock server will only authenticate ssh-rsa keys
func TestAuthMethodInvalidPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"]),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("dsa private key should not have authenticated with rsa public key")
}
}
// the client should authenticate with the second key
func TestAuthMethodRSAandDSA(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
}
func TestClientHMAC(t *testing.T) {
for _, mac := range supportedMACs {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
Config: Config{
MACs: []string{mac},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
}
}
}
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
Ciphers: []string{"aes128-cbc"}, // not currently supported
},
}
if err := tryAuth(t, config); err == nil {
t.Errorf("expected no ciphers in common")
}
}
func TestClientUnsupportedKex(t *testing.T) {
if os.Getenv("GO_BUILDER_NAME") != "" {
t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198")
}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") {
t.Errorf("got %v, expected 'common algorithm'", err)
}
}
func TestClientLoginCert(t *testing.T) {
cert := &Certificate{
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: UserCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Fatalf("NewCertSigner: %v", err)
}
clientConfig := &ClientConfig{
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
}
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
// should succeed
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
// corrupted signature
cert.Signature.Blob[0]++
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with corrupted sig")
}
// revoked
cert.Serial = 666
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("revoked cert login succeeded")
}
cert.Serial = 1
// sign with wrong key
cert.SignCert(rand.Reader, testSigners["dsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with non-authoritative key")
}
// host cert
cert.CertType = HostCert
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong type")
}
cert.CertType = UserCert
// principal specified
cert.ValidPrincipals = []string{"user"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
// wrong principal specified
cert.ValidPrincipals = []string{"fred"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong principal")
}
cert.ValidPrincipals = nil
// added critical option
cert.CriticalOptions = map[string]string{"root-access": "yes"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with unrecognized critical option")
}
// allowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login with source-address failed: %v", err)
}
// disallowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login with source-address succeeded")
}
}
func testPermissionsPassing(withPermissions bool, t *testing.T) {
serverConfig := &ServerConfig{
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "nopermissions" {
return nil, nil
}
return &Permissions{}, nil
},
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if withPermissions {
clientConfig.User = "permissions"
} else {
clientConfig.User = "nopermissions"
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatal(err)
}
if p := serverConn.Permissions; (p != nil) != withPermissions {
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p)
}
}
func TestPermissionsPassing(t *testing.T) {
testPermissionsPassing(true, t)
}
func TestNoPermissionsPassing(t *testing.T) {
testPermissionsPassing(false, t)
}
func TestRetryableAuth(t *testing.T) {
n := 0
passwords := []string{"WRONG1", "WRONG2"}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
p := passwords[n]
n++
return p, nil
}), 2),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if n != 2 {
t.Fatalf("Did not try all passwords")
}
}
func ExampleRetryableAuthMethod(t *testing.T) {
user := "testuser"
NumberOfPrompts := 3
// Normally this would be a callback that prompts the user to answer the
// provided questions
Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return []string{"answer1", "answer2"}, nil
}
config := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
// Test if username is received on server side when NoClientAuth is used
func TestClientAuthNone(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
NoClientAuth: true,
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
User: user,
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatalf("newServer: %v", err)
}
if serverConn.User() != user {
t.Fatalf("server: got %q, want %q", serverConn.User(), user)
}
}
// Test if authentication attempts are limited on server when MaxAuthTries is set
func TestClientAuthMaxAuthTries(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
MaxAuthTries: 2,
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == "right" {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
for tries := 2; tries < 4; tries++ {
n := tries
clientConfig := &ClientConfig{
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
n--
if n == 0 {
return "right", nil
}
return "wrong", nil
}), tries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", clientConfig)
if tries > 2 {
if err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
} else {
if err != nil {
t.Fatalf("client: got %s, want no error", err)
}
}
}
}
// Test if authentication attempts are correctly limited on server
// when more public keys are provided then MaxAuthTries
func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
signers := []Signer{}
for i := 0; i < 6; i++ {
signers = append(signers, testSigners["dsa"])
}
validConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, validConfig); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
invalidConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append(signers, testSigners["rsa"])...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, invalidConfig); err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
}
// Test whether authentication errors are being properly logged if all
// authentication methods have been exhausted
func TestClientAuthErrorList(t *testing.T) {
publicKeyErr := errors.New("This is an error from PublicKeyCallback")
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverConfig := &ServerConfig{
PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) {
return nil, publicKeyErr
},
}
serverConfig.AddHostKey(testSigners["rsa"])
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
_, err = newServer(c1, serverConfig)
if err == nil {
t.Fatal("newServer: got nil, expected errors")
}
authErrs, ok := err.(*ServerAuthError)
if !ok {
t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err)
}
for i, e := range authErrs.Errors {
switch i {
case 0:
if e.Error() != "no auth passed yet" {
t.Fatalf("errors: got %v, want no auth passed yet", e.Error())
}
case 1:
if e != publicKeyErr {
t.Fatalf("errors: got %v, want %v", e, publicKeyErr)
}
default:
t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors)
}
}
}

View file

@ -1,81 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"net"
"strings"
"testing"
)
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
receivedVersion := make(chan string, 1)
config.HostKeyCallback = InsecureIgnoreHostKey()
go func() {
version, err := readVersion(serverConn)
if err != nil {
receivedVersion <- ""
} else {
receivedVersion <- string(version)
}
serverConn.Close()
}()
NewClientConn(clientConn, "", config)
actual := <-receivedVersion
if actual != expected {
t.Fatalf("got %s; want %s", actual, expected)
}
}
func TestCustomClientVersion(t *testing.T) {
version := "Test-Client-Version-0.0"
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
}
func TestHostKeyCheck(t *testing.T) {
for _, tt := range []struct {
name string
wantError string
key PublicKey
}{
{"no callback", "must specify HostKeyCallback", nil},
{"correct key", "", testSigners["rsa"].PublicKey()},
{"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
clientConf := ClientConfig{
User: "user",
}
if tt.key != nil {
clientConf.HostKeyCallback = FixedHostKey(tt.key)
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
}
} else if tt.wantError != "" {
t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
}
}
}

View file

@ -24,11 +24,21 @@ const (
serviceSSH = "ssh-connection"
)
// supportedCiphers specifies the supported ciphers in preference order.
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com",
"arcfour256", "arcfour128",
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
tripledescbcID,
}
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com",
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
@ -211,7 +221,7 @@ func (c *Config) SetDefaults() {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = supportedCiphers
c.Ciphers = preferredCiphers
}
var ciphers []string
for _, c := range c.Ciphers {
@ -242,7 +252,7 @@ func (c *Config) SetDefaults() {
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct {
Session []byte
Type byte
@ -253,7 +263,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte
PubKey []byte
}{
sessionId,
sessionID,
msgUserAuthRequest,
req.User,
req.Service,

View file

@ -1,320 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh_test
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
func ExampleNewServerConn() {
// Public key authentication is done by comparing
// the public key of a received connection
// with the entries in the authorized_keys file.
authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys")
if err != nil {
log.Fatalf("Failed to load authorized_keys, err: %v", err)
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
log.Fatal(err)
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
// Remove to disable password auth.
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
// Remove to disable public key auth.
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeysMap[string(pubKey.Marshal())] {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
log.Fatal("Failed to load private key: ", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key: ", err)
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
log.Fatal("failed to listen for connection: ", err)
}
nConn, err := listener.Accept()
if err != nil {
log.Fatal("failed to accept incoming connection: ", err)
}
// Before use, a handshake must be performed on the incoming
// net.Conn.
conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Fatal("failed to handshake: ", err)
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
// "session" and ServerShell may be used to present a simple
// terminal interface.
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Fatalf("Could not accept channel: %v", err)
}
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
}
}(requests)
term := terminal.NewTerminal(channel, "> ")
go func() {
defer channel.Close()
for {
line, err := term.ReadLine()
if err != nil {
break
}
fmt.Println(line)
}
}()
}
}
func ExampleHostKeyCheck() {
// Every client must provide a host key check. Here is a
// simple-minded parse of OpenSSH's known_hosts file
host := "hostname"
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var hostKey ssh.PublicKey
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
if strings.Contains(fields[0], host) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
log.Fatalf("error parsing %q: %v", fields[2], err)
}
break
}
}
if hostKey == nil {
log.Fatalf("no hostkey for %s", host)
}
config := ssh.ClientConfig{
User: os.Getenv("USER"),
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
_, err = ssh.Dial("tcp", host+":22", &config)
log.Println(err)
}
func ExampleDial() {
var hostKey ssh.PublicKey
// An SSH client is represented with a ClientConn.
//
// To authenticate with the remote server you must pass at least one
// implementation of AuthMethod via the Auth field in ClientConfig,
// and provide a HostKeyCallback.
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
client, err := ssh.Dial("tcp", "yourserver.com:22", config)
if err != nil {
log.Fatal("Failed to dial: ", err)
}
// Each ClientConn can support multiple interactive sessions,
// represented by a Session.
session, err := client.NewSession()
if err != nil {
log.Fatal("Failed to create session: ", err)
}
defer session.Close()
// Once a Session is created, you can execute a single command on
// the remote side using the Run method.
var b bytes.Buffer
session.Stdout = &b
if err := session.Run("/usr/bin/whoami"); err != nil {
log.Fatal("Failed to run: " + err.Error())
}
fmt.Println(b.String())
}
func ExamplePublicKeys() {
var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote
// server by using an unencrypted PEM-encoded private key file.
//
// If you have an encrypted private key, the crypto/x509 package
// can be used to decrypt it.
key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa")
if err != nil {
log.Fatalf("unable to read private key: %v", err)
}
// Create the Signer for this private key.
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Fatalf("unable to parse private key: %v", err)
}
config := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
// Use the PublicKeys method for remote authentication.
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to the remote server and perform the SSH handshake.
client, err := ssh.Dial("tcp", "host.com:22", config)
if err != nil {
log.Fatalf("unable to connect: %v", err)
}
defer client.Close()
}
func ExampleClient_Listen() {
var hostKey ssh.PublicKey
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Dial your ssh server.
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Request the remote side to open port 8080 on all interfaces.
l, err := conn.Listen("tcp", "0.0.0.0:8080")
if err != nil {
log.Fatal("unable to register tcp forward: ", err)
}
defer l.Close()
// Serve HTTP with your SSH server acting as a reverse proxy.
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
fmt.Fprintf(resp, "Hello world!\n")
}))
}
func ExampleSession_RequestPty() {
var hostKey ssh.PublicKey
// Create client config
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to ssh server
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Create a session
session, err := conn.NewSession()
if err != nil {
log.Fatal("unable to create session: ", err)
}
defer session.Close()
// Set up terminal modes
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := session.RequestPty("xterm", 40, 80, modes); err != nil {
log.Fatal("request for pseudo terminal failed: ", err)
}
// Start remote shell
if err := session.Shell(); err != nil {
log.Fatal("failed to start shell: ", err)
}
}

View file

@ -78,6 +78,11 @@ type handshakeTransport struct {
dialAddress string
remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange.
algorithms *algorithms
@ -120,6 +125,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr
t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else {

View file

@ -1,559 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"reflect"
"runtime"
"strings"
"sync"
"testing"
)
type testChecker struct {
calls []string
}
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
if dialAddr == "bad" {
return fmt.Errorf("dialAddr is bad")
}
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
}
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
return nil
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
// noiseTransport inserts ignore messages to check that the read loop
// and the key exchange filters out these messages.
type noiseTransport struct {
keyingTransport
}
func (t *noiseTransport) writePacket(p []byte) error {
ignore := []byte{msgIgnore}
if err := t.keyingTransport.writePacket(ignore); err != nil {
return err
}
debug := []byte{msgDebug, 1, 2, 3}
if err := t.keyingTransport.writePacket(debug); err != nil {
return err
}
return t.keyingTransport.writePacket(p)
}
func addNoiseTransport(t keyingTransport) keyingTransport {
return &noiseTransport{t}
}
// handshakePair creates two handshakeTransports connected with each
// other. If the noise argument is true, both transports will try to
// confuse the other side by sending ignore and debug messages.
func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe()
if err != nil {
return nil, nil, err
}
var trC, trS keyingTransport
trC = newTransport(a, rand.Reader, true)
trS = newTransport(b, rand.Reader, false)
if noise {
trC = addNoiseTransport(trC)
trS = addNoiseTransport(trS)
}
clientConf.SetDefaults()
v := []byte("version")
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
if err := server.waitSession(); err != nil {
return nil, nil, fmt.Errorf("server.waitSession: %v", err)
}
if err := client.waitSession(); err != nil {
return nil, nil, fmt.Errorf("client.waitSession: %v", err)
}
return client, server, nil
}
func TestHandshakeBasic(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &syncChecker{
waitCall: make(chan int, 10),
called: make(chan int, 10),
}
checker.waitCall <- 1
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// Let first kex complete normally.
<-checker.called
clientDone := make(chan int, 0)
gotHalf := make(chan int, 0)
const N = 20
go func() {
defer close(clientDone)
// Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the
// handshake in progress. We do this twice, so we test
// that the packet buffer is reset correctly.
for i := 0; i < N; i++ {
p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil {
t.Fatalf("sendPacket: %v", err)
}
if (i % 10) == 5 {
<-gotHalf
// halfway through, we request a key change.
trC.requestKeyExchange()
// Wait until we can be sure the key
// change has really started before we
// write more.
<-checker.called
}
if (i % 10) == 7 {
// write some packets until the kex
// completes, to test buffering of
// packets.
checker.waitCall <- 1
}
}
}()
// Server checks that client messages come in cleanly
i := 0
err = nil
for ; i < N; i++ {
var p []byte
p, err = trS.readPacket()
if err != nil {
break
}
if (i % 10) == 5 {
gotHalf <- 1
}
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %v, want %v", i, p, want)
}
}
<-clientDone
if err != nil && err != io.EOF {
t.Fatalf("server error: %v", err)
}
if i != N {
t.Errorf("received %d messages, want 10.", i)
}
close(checker.called)
if _, ok := <-checker.called; ok {
// If all went well, we registered exactly 2 key changes: one
// that establishes the session, and one that we requested
// additionally.
t.Fatalf("got another host key checks after 2 handshakes")
}
}
func TestForceFirstKex(t *testing.T) {
// like handshakePair, but must access the keyingTransport.
checker := &testChecker{}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
var trC, trS keyingTransport
trC = newTransport(a, rand.Reader, true)
// This is the disallowed packet:
trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
// Rest of the setup.
trS = newTransport(b, rand.Reader, false)
clientConf.SetDefaults()
v := []byte("version")
client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server := newServerTransport(trS, v, v, serverConf)
defer client.Close()
defer server.Close()
// We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is
// disallowed.
if err := server.waitSession(); err == nil {
t.Errorf("server first kex init should reject unexpected packet")
}
}
func TestHandshakeAutoRekeyWrite(t *testing.T) {
checker := &syncChecker{
called: make(chan int, 10),
waitCall: nil,
}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
input := make([]byte, 251)
input[0] = msgRequestSuccess
done := make(chan int, 1)
const numPacket = 5
go func() {
defer close(done)
j := 0
for ; j < numPacket; j++ {
if p, err := trS.readPacket(); err != nil {
break
} else if !bytes.Equal(input, p) {
t.Errorf("got packet type %d, want %d", p[0], input[0])
}
}
if j != numPacket {
t.Errorf("got %d, want 5 messages", j)
}
}()
<-checker.called
for i := 0; i < numPacket; i++ {
p := make([]byte, len(input))
copy(p, input)
if err := trC.writePacket(p); err != nil {
t.Errorf("writePacket: %v", err)
}
if i == 2 {
// Make sure the kex is in progress.
<-checker.called
}
}
<-done
}
type syncChecker struct {
waitCall chan int
called chan int
}
func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
c.called <- 1
if c.waitCall != nil {
<-c.waitCall
}
return nil
}
func TestHandshakeAutoRekeyRead(t *testing.T) {
sync := &syncChecker{
called: make(chan int, 2),
waitCall: nil,
}
clientConf := &ClientConfig{
HostKeyCallback: sync.Check,
}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
packet := make([]byte, 501)
packet[0] = msgRequestSuccess
if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err)
}
// While we read out the packet, a key change will be
// initiated.
done := make(chan int, 1)
go func() {
defer close(done)
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
}()
<-done
<-sync.called
}
// errorKeyingTransport generates errors after a given number of
// read/write operations.
type errorKeyingTransport struct {
packetConn
readLeft, writeLeft int
}
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil
}
func (n *errorKeyingTransport) getSessionID() []byte {
return nil
}
func (n *errorKeyingTransport) writePacket(packet []byte) error {
if n.writeLeft == 0 {
n.Close()
return errors.New("barf")
}
n.writeLeft--
return n.packetConn.writePacket(packet)
}
func (n *errorKeyingTransport) readPacket() ([]byte, error) {
if n.readLeft == 0 {
n.Close()
return nil, errors.New("barf")
}
n.readLeft--
return n.packetConn.readPacket()
}
func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, false)
}
}
func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, false)
}
}
func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, true)
}
}
func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, true)
}
}
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe()
defer a.Close()
defer b.Close()
key := testSigners["ecdsa"]
serverConf := Config{RekeyThreshold: minRekeyThreshold}
serverConf.SetDefaults()
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key}
go serverConn.readLoop()
go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
clientConn.hostKeyCallback = InsecureIgnoreHostKey()
go clientConn.readLoop()
go clientConn.kexLoop()
var wg sync.WaitGroup
for _, hs := range []packetConn{serverConn, clientConn} {
if !coupled {
wg.Add(2)
go func(c packetConn) {
for i := 0; ; i++ {
str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
err := c.writePacket(Marshal(&serviceRequestMsg{str}))
if err != nil {
break
}
}
wg.Done()
c.Close()
}(hs)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
}
wg.Done()
}(hs)
} else {
wg.Add(1)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
if err := c.writePacket(msg); err != nil {
break
}
}
wg.Done()
}(hs)
}
}
wg.Wait()
}
func TestDisconnect(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
errMsg := &disconnectMsg{
Reason: 42,
Message: "such is life",
}
trC.writePacket(Marshal(errMsg))
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
packet, err := trS.readPacket()
if err != nil {
t.Fatalf("readPacket 1: %v", err)
}
if packet[0] != msgRequestSuccess {
t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 2 succeeded")
} else if !reflect.DeepEqual(err, errMsg) {
t.Errorf("got error %#v, want %#v", err, errMsg)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 3 succeeded")
}
}
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{
Ciphers: []string{"aes128-ctr"},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
trC.Close()
rgb := (1024 + trC.readBytesLeft) >> 30
wgb := (1024 + trC.writeBytesLeft) >> 30
if rgb != 64 {
t.Errorf("got rekey after %dG read, want 64G", rgb)
}
if wgb != 64 {
t.Errorf("got rekey after %dG write, want 64G", wgb)
}
}

View file

@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err
}
kInt, err := group.diffieHellman(kexDHReply.Y, x)
ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, err
}
@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
}
Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y)
ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil {
return nil, err
}
@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)

View file

@ -1,50 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Key exchange tests.
import (
"crypto/rand"
"reflect"
"testing"
)
func TestKexes(t *testing.T) {
type kexResultErr struct {
result *kexResult
err error
}
for name, kex := range kexAlgoMap {
a, b := memPipe()
s := make(chan kexResultErr, 1)
c := make(chan kexResultErr, 1)
var magics handshakeMagics
go func() {
r, e := kex.Client(a, rand.Reader, &magics)
a.Close()
c <- kexResultErr{r, e}
}()
go func() {
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
b.Close()
s <- kexResultErr{r, e}
}()
clientRes := <-c
serverRes := <-s
if clientRes.err != nil {
t.Errorf("client: %v", clientRes.err)
}
if serverRes.err != nil {
t.Errorf("server: %v", serverRes.err)
}
if !reflect.DeepEqual(clientRes.result, serverRes.result) {
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
}
}
}

View file

@ -276,7 +276,8 @@ type PublicKey interface {
Type() string
// Marshal returns the serialized key data in SSH wire format,
// with the name prefix.
// with the name prefix. To unmarshal the returned data, use
// the ParsePublicKey function.
Marshal() []byte
// Verify that sig is a signature on the given data using this
@ -363,7 +364,7 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string {
func (k *dsaPublicKey) Type() string {
return "ssh-dss"
}
@ -481,12 +482,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID()
func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + k.nistID()
}
func (key *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize {
func (k *ecdsaPublicKey) nistID() string {
switch k.Params().BitSize {
case 256:
return "nistp256"
case 384:
@ -499,7 +500,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string {
func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519
}
@ -518,23 +519,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil
}
func (key ed25519PublicKey) Marshal() []byte {
func (k ed25519PublicKey) Marshal() []byte {
w := struct {
Name string
KeyBytes []byte
}{
KeyAlgoED25519,
[]byte(key),
[]byte(k),
}
return Marshal(&w)
}
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
edKey := (ed25519.PublicKey)(key)
edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify")
}
@ -595,9 +596,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil
}
func (key *ecdsaPublicKey) Marshal() []byte {
func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package.
w := struct {
@ -605,20 +606,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string
Key []byte
}{
key.Type(),
key.nistID(),
k.Type(),
k.nistID(),
keyBytes,
}
return Marshal(&w)
}
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := ecHash(key.Curve).New()
h := ecHash(k.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@ -635,7 +636,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err
}
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
@ -758,7 +759,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.")
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
}
return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey:

View file

@ -1,500 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"fmt"
"reflect"
"strings"
"testing"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh/testdata"
)
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
case *rsaPublicKey:
return (*rsa.PublicKey)(k)
case *dsaPublicKey:
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
case ed25519PublicKey:
return (ed25519.PublicKey)(k)
case *Certificate:
return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
roundtrip, err := ParsePublicKey(pub.Marshal())
if err != nil {
t.Errorf("ParsePublicKey(%T): %v", pub, err)
}
k1 := rawKey(pub)
k2 := rawKey(roundtrip)
if !reflect.DeepEqual(k1, k2) {
t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
}
}
}
func TestUnsupportedCurves(t *testing.T) {
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err)
}
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err)
}
}
func TestNewPublicKey(t *testing.T) {
for _, k := range testSigners {
raw := rawKey(k.PublicKey())
// Skip certificates, as NewPublicKey does not support them.
if _, ok := raw.(*Certificate); ok {
continue
}
pub, err := NewPublicKey(raw)
if err != nil {
t.Errorf("NewPublicKey(%#v): %v", raw, err)
}
if !reflect.DeepEqual(k.PublicKey(), pub) {
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
}
}
}
func TestKeySignVerify(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
data := []byte("sign me")
sig, err := priv.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
if err := pub.Verify(data, sig); err != nil {
t.Errorf("publicKey.Verify(%T): %v", priv, err)
}
sig.Blob[5]++
if err := pub.Verify(data, sig); err == nil {
t.Errorf("publicKey.Verify on broken sig did not fail")
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
key := testPrivateKeys["rsa"]
rsa, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
if err := rsa.Validate(); err != nil {
t.Errorf("Validate: %v", err)
}
}
func TestParseECPrivateKey(t *testing.T) {
key := testPrivateKeys["ecdsa"]
ecKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
}
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
t.Fatalf("public key does not validate.")
}
}
// See Issue https://github.com/golang/go/issues/6650.
func TestParseEncryptedPrivateKeysFails(t *testing.T) {
const wantSubstring = "encrypted"
for i, tt := range testdata.PEMEncryptedKeys {
_, err := ParsePrivateKey(tt.PEMBytes)
if err == nil {
t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name)
continue
}
if !strings.Contains(err.Error(), wantSubstring) {
t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring)
}
}
}
// Parse encrypted private keys with passphrase
func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
data := []byte("sign me")
for _, tt := range testdata.PEMEncryptedKeys {
s, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
if err != nil {
t.Fatalf("ParsePrivateKeyWithPassphrase returned error: %s", err)
continue
}
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
tt := testdata.PEMEncryptedKeys[0]
_, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte("incorrect"))
if err != x509.IncorrectPasswordError {
t.Fatalf("got %v want IncorrectPasswordError", err)
}
}
func TestParseDSA(t *testing.T) {
// We actually exercise the ParsePrivateKey codepath here, as opposed to
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
// uses.
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
if err != nil {
t.Fatalf("ParsePrivateKey returned error: %s", err)
}
data := []byte("sign me")
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
// Tests for authorized_keys parsing.
// getTestKey returns a public key, and its base64 encoding.
func getTestKey() (PublicKey, string) {
k := testPublicKeys["rsa"]
b := &bytes.Buffer{}
e := base64.NewEncoder(base64.StdEncoding, b)
e.Write(k.Marshal())
e.Close()
return k, b.String()
}
func TestMarshalParsePublicKey(t *testing.T) {
pub, pubSerialized := getTestKey()
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
authKeys := MarshalAuthorizedKey(pub)
actualFields := strings.Fields(string(authKeys))
if len(actualFields) == 0 {
t.Fatalf("failed authKeys: %v", authKeys)
}
// drop the comment
expectedFields := strings.Fields(line)[0:2]
if !reflect.DeepEqual(actualFields, expectedFields) {
t.Errorf("got %v, expected %v", actualFields, expectedFields)
}
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
if err != nil {
t.Fatalf("cannot parse %v: %v", line, err)
}
if !reflect.DeepEqual(actPub, pub) {
t.Errorf("got %v, expected %v", actPub, pub)
}
}
type authResult struct {
pubKey PublicKey
options []string
comments string
rest string
ok bool
}
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
rest := authKeys
var values []authResult
for len(rest) > 0 {
var r authResult
var err error
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
r.ok = (err == nil)
t.Log(err)
r.rest = string(rest)
values = append(values, r)
}
if !reflect.DeepEqual(values, expected) {
t.Errorf("got %#v, expected %#v", values, expected)
}
}
func TestAuthorizedKeyBasic(t *testing.T) {
pub, pubSerialized := getTestKey()
line := "ssh-rsa " + pubSerialized + " user@host"
testAuthorizedKeys(t, []byte(line),
[]authResult{
{pub, nil, "user@host", "", true},
})
}
func TestAuth(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithOptions := []string{
`# comments to ignore before any keys...`,
``,
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
`# comments to ignore, along with a blank line`,
``,
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
``,
`# more comments, plus a invalid entry`,
`ssh-rsa data-that-will-not-parse user@host3`,
}
for _, eol := range []string{"\n", "\r\n"} {
authOptions := strings.Join(authWithOptions, eol)
rest2 := strings.Join(authWithOptions[3:], eol)
rest3 := strings.Join(authWithOptions[6:], eol)
testAuthorizedKeys(t, []byte(authOptions), []authResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
})
}
}
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
})
}
func TestAuthWithInvalidSpace(t *testing.T) {
_, pubSerialized := getTestKey()
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
#more to follow but still no valid keys`)
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
{nil, nil, "", "", false},
})
}
func TestAuthWithMissingQuote(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
})
}
func TestInvalidEntry(t *testing.T) {
authInvalid := []byte(`ssh-rsa`)
_, _, _, _, err := ParseAuthorizedKey(authInvalid)
if err == nil {
t.Errorf("got valid entry for %q", authInvalid)
}
}
var knownHostsParseTests = []struct {
input string
err string
marker string
comment string
hosts []string
rest string
}{
{
"",
"EOF",
"", "", nil, "",
},
{
"# Just a comment",
"EOF",
"", "", nil, "",
},
{
" \t ",
"EOF",
"", "", nil, "",
},
{
"localhost ssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line",
"",
"", "comment comment", []string{"localhost"}, "next line",
},
{
"localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}",
"",
"marker", "", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd",
"short read",
"", "", nil, "",
},
}
func TestKnownHostsParsing(t *testing.T) {
rsaPub, rsaPubSerialized := getTestKey()
for i, test := range knownHostsParseTests {
var expectedKey PublicKey
const rsaKeyToken = "{RSAPUB}"
input := test.input
if strings.Contains(input, rsaKeyToken) {
expectedKey = rsaPub
input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1)
}
marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input))
if err != nil {
if len(test.err) == 0 {
t.Errorf("#%d: unexpectedly failed with %q", i, err)
} else if !strings.Contains(err.Error(), test.err) {
t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err)
}
continue
} else if len(test.err) != 0 {
t.Errorf("#%d: succeeded but expected error including %q", i, test.err)
continue
}
if !reflect.DeepEqual(expectedKey, pubKey) {
t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey)
}
if marker != test.marker {
t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker)
}
if comment != test.comment {
t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment)
}
if !reflect.DeepEqual(test.hosts, hosts) {
t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts)
}
if rest := string(rest); rest != test.rest {
t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest)
}
}
}
func TestFingerprintLegacyMD5(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintLegacyMD5(pub)
want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}
func TestFingerprintSHA256(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintSHA256(pub)
want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}

View file

@ -1,110 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
"testing"
)
// An in-memory packetConn. It is safe to call Close and writePacket
// from different goroutines.
type memTransport struct {
eof bool
pending [][]byte
write *memTransport
sync.Mutex
*sync.Cond
}
func (t *memTransport) readPacket() ([]byte, error) {
t.Lock()
defer t.Unlock()
for {
if len(t.pending) > 0 {
r := t.pending[0]
t.pending = t.pending[1:]
return r, nil
}
if t.eof {
return nil, io.EOF
}
t.Cond.Wait()
}
}
func (t *memTransport) closeSelf() error {
t.Lock()
defer t.Unlock()
if t.eof {
return io.EOF
}
t.eof = true
t.Cond.Broadcast()
return nil
}
func (t *memTransport) Close() error {
err := t.write.closeSelf()
t.closeSelf()
return err
}
func (t *memTransport) writePacket(p []byte) error {
t.write.Lock()
defer t.write.Unlock()
if t.write.eof {
return io.EOF
}
c := make([]byte, len(p))
copy(c, p)
t.write.pending = append(t.write.pending, c)
t.write.Cond.Signal()
return nil
}
func memPipe() (a, b packetConn) {
t1 := memTransport{}
t2 := memTransport{}
t1.write = &t2
t2.write = &t1
t1.Cond = sync.NewCond(&t1.Mutex)
t2.Cond = sync.NewCond(&t2.Mutex)
return &t1, &t2
}
func TestMemPipe(t *testing.T) {
a, b := memPipe()
if err := a.writePacket([]byte{42}); err != nil {
t.Fatalf("writePacket: %v", err)
}
if err := a.Close(); err != nil {
t.Fatal("Close: ", err)
}
p, err := b.readPacket()
if err != nil {
t.Fatal("readPacket: ", err)
}
if len(p) != 1 || p[0] != 42 {
t.Fatalf("got %v, want {42}", p)
}
p, err = b.readPacket()
if err != io.EOF {
t.Fatalf("got %v, %v, want EOF", p, err)
}
}
func TestDoubleClose(t *testing.T) {
a, _ := memPipe()
err := a.Close()
if err != nil {
t.Errorf("Close: %v", err)
}
err = a.Close()
if err != io.EOF {
t.Errorf("expect EOF on double close.")
}
}

View file

@ -23,10 +23,6 @@ const (
msgUnimplemented = 3
msgDebug = 4
msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
)
// SSH messages:
@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool
}
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61
@ -154,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersId uint32
PeersID uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -165,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersId uint32 `sshtype:"94"`
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
@ -174,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"`
MyId uint32
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -185,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"`
PeersID uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
@ -194,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"`
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
@ -204,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"`
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"`
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"`
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"`
PeersID uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
@ -255,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"`
PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32
}

View file

@ -1,288 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"math/big"
"math/rand"
"reflect"
"testing"
"testing/quick"
)
var intLengthTests = []struct {
val, length int
}{
{0, 4 + 0},
{1, 4 + 1},
{127, 4 + 1},
{128, 4 + 2},
{-1, 4 + 1},
}
func TestIntLength(t *testing.T) {
for _, test := range intLengthTests {
v := new(big.Int).SetInt64(int64(test.val))
length := intLength(v)
if length != test.length {
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
}
}
}
type msgAllTypes struct {
Bool bool `sshtype:"21"`
Array [16]byte
Uint64 uint64
Uint32 uint32
Uint8 uint8
String string
Strings []string
Bytes []byte
Int *big.Int
Rest []byte `ssh:"rest"`
}
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
m := &msgAllTypes{}
m.Bool = rand.Intn(2) == 1
randomBytes(m.Array[:], rand)
m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
m.Uint8 = uint8(rand.Intn(1 << 8))
m.String = string(m.Array[:])
m.Strings = randomNameList(rand)
m.Bytes = m.Array[:]
m.Int = randomInt(rand)
m.Rest = m.Array[:]
return reflect.ValueOf(m)
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
iface := &msgAllTypes{}
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("failed to create value")
break
}
m1 := v.Elem().Interface()
m2 := iface
marshaled := Marshal(m1)
if err := Unmarshal(marshaled, m2); err != nil {
t.Errorf("Unmarshal %#v: %s", m1, err)
break
}
if !reflect.DeepEqual(v.Interface(), m2) {
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
break
}
}
}
func TestUnmarshalEmptyPacket(t *testing.T) {
var b []byte
var m channelRequestSuccessMsg
if err := Unmarshal(b, &m); err == nil {
t.Fatalf("unmarshal of empty slice succeeded")
}
}
func TestUnmarshalUnexpectedPacket(t *testing.T) {
type S struct {
I uint32 `sshtype:"43"`
S string
B bool
}
s := S{11, "hello", true}
packet := Marshal(s)
packet[0] = 42
roundtrip := S{}
err := Unmarshal(packet, &roundtrip)
if err == nil {
t.Fatal("expected error, not nil")
}
}
func TestMarshalPtr(t *testing.T) {
s := struct {
S string
}{"hello"}
m1 := Marshal(s)
m2 := Marshal(&s)
if !bytes.Equal(m1, m2) {
t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
}
}
func TestBareMarshalUnmarshal(t *testing.T) {
type S struct {
I uint32
S string
B bool
}
s := S{42, "hello", true}
packet := Marshal(s)
roundtrip := S{}
Unmarshal(packet, &roundtrip)
if !reflect.DeepEqual(s, roundtrip) {
t.Errorf("got %#v, want %#v", roundtrip, s)
}
}
func TestBareMarshal(t *testing.T) {
type S2 struct {
I uint32
}
s := S2{42}
packet := Marshal(s)
i, rest, ok := parseUint32(packet)
if len(rest) > 0 || !ok {
t.Errorf("parseInt(%q): parse error", packet)
}
if i != s.I {
t.Errorf("got %d, want %d", i, s.I)
}
}
func TestUnmarshalShortKexInitPacket(t *testing.T) {
// This used to panic.
// Issue 11348
packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
kim := &kexInitMsg{}
if err := Unmarshal(packet, kim); err == nil {
t.Error("truncated packet unmarshaled without error")
}
}
func TestMarshalMultiTag(t *testing.T) {
var res struct {
A uint32 `sshtype:"1|2"`
}
good1 := struct {
A uint32 `sshtype:"1"`
}{
1,
}
good2 := struct {
A uint32 `sshtype:"2"`
}{
1,
}
if e := Unmarshal(Marshal(good1), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
if e := Unmarshal(Marshal(good2), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
bad1 := struct {
A uint32 `sshtype:"3"`
}{
1,
}
if e := Unmarshal(Marshal(bad1), &res); e == nil {
t.Errorf("bad struct unmarshaled without error")
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())
}
}
func randomNameList(rand *rand.Rand) []string {
ret := make([]string, rand.Int31()&15)
for i := range ret {
s := make([]byte, 1+(rand.Int31()&15))
for j := range s {
s[j] = 'a' + uint8(rand.Int31()&15)
}
ret[i] = string(s)
}
return ret
}
func randomInt(rand *rand.Rand) *big.Int {
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
ki := &kexInitMsg{}
randomBytes(ki.Cookie[:], rand)
ki.KexAlgos = randomNameList(rand)
ki.ServerHostKeyAlgos = randomNameList(rand)
ki.CiphersClientServer = randomNameList(rand)
ki.CiphersServerClient = randomNameList(rand)
ki.MACsClientServer = randomNameList(rand)
ki.MACsServerClient = randomNameList(rand)
ki.CompressionClientServer = randomNameList(rand)
ki.CompressionServerClient = randomNameList(rand)
ki.LanguagesClientServer = randomNameList(rand)
ki.LanguagesServerClient = randomNameList(rand)
if rand.Int31()&1 == 1 {
ki.FirstKexFollows = true
}
return reflect.ValueOf(ki)
}
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
dhi := &kexDHInitMsg{}
dhi.X = randomInt(rand)
return reflect.ValueOf(dhi)
}
var (
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexInit = Marshal(_kexInitMsg)
_kexDHInit = Marshal(_kexDHInitMsg)
)
func BenchmarkMarshalKexInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexInitMsg)
}
}
func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
m := new(kexInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexInit, m)
}
}
func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexDHInitMsg)
}
}
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
m := new(kexDHInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexDHInit, m)
}
}

View file

@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId,
PeersID: msg.PeersID,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
}
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId
c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c
@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra,
PeersId: ch.localId,
PeersID: ch.localId,
}
if err := m.sendMessage(open); err != nil {
return nil, err

View file

@ -1,505 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"io/ioutil"
"sync"
"testing"
)
func muxPair() (*mux, *mux) {
a, b := memPipe()
s := newMux(a)
c := newMux(b)
return s, c
}
// Returns both ends of a channel, and the mux for the the 2nd
// channel.
func channelPair(t *testing.T) (*channel, *channel, *mux) {
c, s := muxPair()
res := make(chan *channel, 1)
go func() {
newCh, ok := <-s.incomingChannels
if !ok {
t.Fatalf("No incoming channel")
}
if newCh.ChannelType() != "chan" {
t.Fatalf("got type %q want chan", newCh.ChannelType())
}
ch, _, err := newCh.Accept()
if err != nil {
t.Fatalf("Accept %v", err)
}
res <- ch.(*channel)
}()
ch, err := c.openChannel("chan", nil)
if err != nil {
t.Fatalf("OpenChannel: %v", err)
}
return <-res, ch, c
}
// Test that stderr and stdout can be addressed from different
// goroutines. This is intended for use with the race detector.
func TestMuxChannelExtendedThreadSafety(t *testing.T) {
writer, reader, mux := channelPair(t)
defer writer.Close()
defer reader.Close()
defer mux.Close()
var wr, rd sync.WaitGroup
magic := "hello world"
wr.Add(2)
go func() {
io.WriteString(writer, magic)
wr.Done()
}()
go func() {
io.WriteString(writer.Stderr(), magic)
wr.Done()
}()
rd.Add(2)
go func() {
c, err := ioutil.ReadAll(reader)
if string(c) != magic {
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
go func() {
c, err := ioutil.ReadAll(reader.Stderr())
if string(c) != magic {
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
wr.Wait()
writer.CloseWrite()
rd.Wait()
}
func TestMuxReadWrite(t *testing.T) {
s, c, mux := channelPair(t)
defer s.Close()
defer c.Close()
defer mux.Close()
magic := "hello world"
magicExt := "hello stderr"
go func() {
_, err := s.Write([]byte(magic))
if err != nil {
t.Fatalf("Write: %v", err)
}
_, err = s.Extended(1).Write([]byte(magicExt))
if err != nil {
t.Fatalf("Write: %v", err)
}
err = s.Close()
if err != nil {
t.Fatalf("Close: %v", err)
}
}()
var buf [1024]byte
n, err := c.Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got := string(buf[:n])
if got != magic {
t.Fatalf("server: got %q want %q", got, magic)
}
n, err = c.Extended(1).Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got = string(buf[:n])
if got != magicExt {
t.Fatalf("server: got %q want %q", got, magic)
}
}
func TestMuxChannelOverflow(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
// Send 1 byte.
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], writer.remoteId)
marshalUint32(packet[5:], uint32(1))
packet[9] = 42
if err := writer.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
<-wDone
}
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
reader.Close()
<-wDone
}
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
mux.Close()
<-wDone
}
func TestMuxReject(t *testing.T) {
client, server := muxPair()
defer server.Close()
defer client.Close()
go func() {
ch, ok := <-server.incomingChannels
if !ok {
t.Fatalf("Accept")
}
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
}
ch.Reject(RejectionReason(42), "message")
}()
ch, err := client.openChannel("ch", []byte("extra"))
if ch != nil {
t.Fatal("openChannel not rejected")
}
ocf, ok := err.(*OpenChannelError)
if !ok {
t.Errorf("got %#v want *OpenChannelError", err)
} else if ocf.Reason != 42 || ocf.Message != "message" {
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
}
want := "ssh: rejected: unknown reason 42 (message)"
if err.Error() != want {
t.Errorf("got %q, want %q", err.Error(), want)
}
}
func TestMuxChannelRequest(t *testing.T) {
client, server, mux := channelPair(t)
defer server.Close()
defer client.Close()
defer mux.Close()
var received int
var wg sync.WaitGroup
wg.Add(1)
go func() {
for r := range server.incomingRequests {
received++
r.Reply(r.Type == "yes", nil)
}
wg.Done()
}()
_, err := client.SendRequest("yes", false, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
ok, err := client.SendRequest("yes", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if !ok {
t.Errorf("SendRequest(yes): %v", ok)
}
ok, err = client.SendRequest("no", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if ok {
t.Errorf("SendRequest(no): %v", ok)
}
client.Close()
wg.Wait()
if received != 3 {
t.Errorf("got %d requests, want %d", received, 3)
}
}
func TestMuxGlobalRequest(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
var seen bool
go func() {
for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
if err != nil {
t.Errorf("AckRequest: %v", err)
}
}
}
}()
_, _, err := clientMux.SendRequest("peek", false, nil)
if err != nil {
t.Errorf("SendRequest: %v", err)
}
ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
if !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}
if !seen {
t.Errorf("never saw 'peek' request")
}
}
func TestMuxGlobalRequestUnblock(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
result := make(chan error, 1)
go func() {
_, _, err := clientMux.SendRequest("hello", true, nil)
result <- err
}()
<-serverMux.incomingRequests
serverMux.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", io.EOF)
}
}
func TestMuxChannelRequestUnblock(t *testing.T) {
a, b, connB := channelPair(t)
defer a.Close()
defer b.Close()
defer connB.Close()
result := make(chan error, 1)
go func() {
_, err := a.SendRequest("hello", true, nil)
result <- err
}()
<-b.incomingRequests
connB.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", err)
}
}
func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
defer r.Close()
defer w.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.Close(); err != nil {
t.Errorf("w.Close: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after Close", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxCloseWriteChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.CloseWrite(); err != nil {
t.Errorf("w.CloseWrite: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after CloseWrite", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxInvalidRecord(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
marshalUint32(packet[5:], 1)
packet[9] = 42
a.conn.writePacket(packet)
go a.SendRequest("hello", false, nil)
// 'a' wrote an invalid packet, so 'b' has exited.
req, ok := <-b.incomingRequests
if ok {
t.Errorf("got request %#v after receiving invalid packet", req)
}
}
func TestZeroWindowAdjust(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
go func() {
io.WriteString(a, "hello")
// bogus adjust.
a.sendMessage(windowAdjustMsg{})
io.WriteString(a, "world")
a.Close()
}()
want := "helloworld"
c, _ := ioutil.ReadAll(b)
if string(c) != want {
t.Errorf("got %q want %q", c, want)
}
}
func TestMuxMaxPacketSize(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
large := make([]byte, a.maxRemotePayload+1)
packet := make([]byte, 1+4+4+1+len(large))
packet[0] = msgChannelData
marshalUint32(packet[1:], a.remoteId)
marshalUint32(packet[5:], uint32(len(large)))
packet[9] = 42
if err := a.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
go a.SendRequest("hello", false, nil)
_, ok := <-b.incomingRequests
if ok {
t.Errorf("connection still alive after receiving large packet.")
}
}
// Don't ship code with debug=true.
func TestDebug(t *testing.T) {
if debugMux {
t.Error("mux debug switched on")
}
if debugHandshake {
t.Error("handshake debug switched on")
}
if debugTransport {
t.Error("transport debug switched on")
}
}

View file

@ -95,6 +95,10 @@ type ServerConfig struct {
// Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-".
ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
}
// AddHostKey adds a private key as a host key. If an existing host
@ -162,6 +166,9 @@ type ServerConn struct {
// unsuccessful, it closes the connection and returns an error. The
// Request and NewChannel channels must be serviced, or the connection
// will hang.
//
// The returned error may be of type *ServerAuthError for
// authentication errors.
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
@ -252,7 +259,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true
}
return false
@ -288,12 +295,13 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
// ServerAuthError implements the error interface. It appends any authentication
// errors that may occur, and is returned if all of the authentication methods
// provided by the user failed to authenticate.
// ServerAuthError represents server authentication errors and is
// sometimes returned by NewServerConn. It appends any authentication
// errors that may occur, and is returned if all of the authentication
// methods provided by the user failed to authenticate.
type ServerAuthError struct {
// Errors contains authentication errors returned by the authentication
// callback methods.
// callback methods. The first entry is typically ErrNoAuth.
Errors []error
}
@ -305,6 +313,13 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]"
}
// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
@ -312,6 +327,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0
var authErrs []error
var displayedBanner bool
userAuthLoop:
for {
@ -343,8 +359,22 @@ userAuthLoop:
}
s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil
authErr := errors.New("no auth passed yet")
authErr := ErrNoAuth
switch userAuthReq.Method {
case "none":

View file

@ -406,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close()
}
var copyError error
for _ = range s.copyFuncs {
for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
copyError = err
}

View file

@ -1,774 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
crypto_rand "crypto/rand"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"testing"
"golang.org/x/crypto/ssh/terminal"
)
type serverType func(Channel, <-chan *Request, *testing.T)
// dial constructs a new test server and returns a *ClientConn.
func dial(handler serverType, t *testing.T) *Client {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go func() {
defer c1.Close()
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(testSigners["rsa"])
_, chans, reqs, err := NewServerConn(c1, &conf)
if err != nil {
t.Fatalf("Unable to handshake: %v", err)
}
go DiscardRequests(reqs)
for newCh := range chans {
if newCh.ChannelType() != "session" {
newCh.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch, inReqs, err := newCh.Accept()
if err != nil {
t.Errorf("Accept: %v", err)
continue
}
go func() {
handler(ch, inReqs, t)
}()
}
}()
config := &ClientConfig{
User: "testuser",
HostKeyCallback: InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("unable to dial remote side: %v", err)
}
return NewClient(conn, chans, reqs)
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// Test that a simple string is returned via the Output helper,
// and that stderr is discarded.
func TestSessionOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
w := "this-is-stdout."
g := string(buf)
if g != w {
t.Error("Remote command did not return expected string:")
t.Logf("want %q", w)
t.Logf("got %q", g)
}
}
// Test that both stdout and stderr are returned
// via the CombinedOutput helper.
func TestSessionCombinedOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
const stdout = "this-is-stdout."
const stderr = "this-is-stderr."
g := string(buf)
if g != stdout+stderr && g != stderr+stdout {
t.Error("Remote command did not return expected string:")
t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
t.Logf("got %q", g)
}
}
// Test non-0 exit status is returned correctly.
func TestExitStatusNonZero(t *testing.T) {
conn := dial(exitStatusNonZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
}
}
// Test 0 exit status is returned correctly.
func TestExitStatusZero(t *testing.T) {
conn := dial(exitStatusZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}
// Test exit signal and status are both returned correctly.
func TestExitSignalAndStatus(t *testing.T) {
conn := dial(exitSignalAndStatusHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestKnownExitSignalOnly(t *testing.T) {
conn := dial(exitSignalHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 143 {
t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestUnknownExitSignal(t *testing.T) {
conn := dial(exitSignalUnknownHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "SYS" || e.ExitStatus() != 128 {
t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
func TestExitWithoutStatusOrSignal(t *testing.T) {
conn := dial(exitWithoutSignalOrStatus, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
if _, ok := err.(*ExitMissingError); !ok {
t.Fatalf("got %T want *ExitMissingError", err)
}
}
// windowTestBytes is the number of bytes that we'll send to the SSH server.
const windowTestBytes = 16000 * 200
// TestServerWindow writes random data to the server. The server is expected to echo
// the same data back, which is compared against the original.
func TestServerWindow(t *testing.T) {
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
origBytes := origBuf.Bytes()
conn := dial(echoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
serverStdout, err := session.StdoutPipe()
if err != nil {
t.Errorf("StdoutPipe failed: %v", err)
return
}
n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
if err != nil && err != io.EOF {
t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
}
result <- echoedBuf.Bytes()
}()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Fatalf("failed to copy origBuf to serverStdin: %v", err)
}
if written != windowTestBytes {
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
}
echoedBytes := <-result
if !bytes.Equal(origBytes, echoedBytes) {
t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
}
}
// Verify the client can handle a keepalive packet from the server.
func TestClientHandlesKeepalives(t *testing.T) {
conn := dial(channelKeepaliveSender, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got: %v", err)
}
}
type exitStatusMsg struct {
Status uint32
}
type exitSignalMsg struct {
Signal string
CoreDumped bool
Errmsg string
Lang string
}
func handleTerminalRequests(in <-chan *Request) {
for req := range in {
ok := false
switch req.Type {
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
req.Reply(ok, nil)
}
}
func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
term := terminal.NewTerminal(ch, prompt)
go handleTerminalRequests(in)
return term
}
func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(0, ch, t)
}
func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
}
func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
sendSignal("TERM", ch, t)
}
func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("TERM", ch, t)
}
func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("SYS", ch, t)
}
func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
}
func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "golang")
readLine(shell, t)
sendStatus(0, ch, t)
}
// Ignores the command, writes fixed strings to stderr and stdout.
// Strings are "this-is-stdout." and "this-is-stderr.".
func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
_, err := ch.Read(nil)
req, ok := <-in
if !ok {
t.Fatalf("error: expected channel request, got: %#v", err)
return
}
// ignore request, always send some text
req.Reply(true, nil)
_, err = io.WriteString(ch, "this-is-stdout.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
sendStatus(0, ch, t)
}
func readLine(shell *terminal.Terminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
t.Errorf("unable to read line: %v", err)
}
}
func sendStatus(status uint32, ch Channel, t *testing.T) {
msg := exitStatusMsg{
Status: status,
}
if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
t.Errorf("unable to send status: %v", err)
}
}
func sendSignal(signal string, ch Channel, t *testing.T) {
sig := exitSignalMsg{
Signal: signal,
CoreDumped: false,
Errmsg: "Process terminated",
Lang: "en-GB-oed",
}
if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
t.Errorf("unable to send signal: %v", err)
}
}
func discardHandler(ch Channel, t *testing.T) {
defer ch.Close()
io.Copy(ioutil.Discard, ch)
}
func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
}
}
// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
// buffer size to exercise more code paths.
func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
var (
buf = make([]byte, 32*1024)
written int
remaining = n
)
for remaining > 0 {
l := rand.Intn(1 << 15)
if remaining < l {
l = remaining
}
nr, er := src.Read(buf[:l])
nw, ew := dst.Write(buf[:nr])
remaining -= nw
written += nw
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
if er != nil && er != io.EOF {
return written, er
}
}
return written, nil
}
func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
t.Errorf("unable to send channel keepalive request: %v", err)
}
sendStatus(0, ch, t)
}
func TestClientWriteEOF(t *testing.T) {
conn := dial(simpleEchoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("StdoutPipe failed: %v", err)
}
data := []byte(`0000`)
_, err = stdin.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
stdin.Close()
res, err := ioutil.ReadAll(stdout)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if !bytes.Equal(data, res) {
t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
}
}
func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
data, err := ioutil.ReadAll(ch)
if err != nil {
t.Errorf("handler read error: %v", err)
}
_, err = ch.Write(data)
if err != nil {
t.Errorf("handler write error: %v", err)
}
}
func TestSessionID(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverID := make(chan []byte, 1)
clientID := make(chan []byte, 1)
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: "user",
}
go func() {
conn, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Fatalf("server handshake: %v", err)
}
serverID <- conn.SessionID()
go DiscardRequests(reqs)
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
go func() {
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("client handshake: %v", err)
}
clientID <- conn.SessionID()
go DiscardRequests(reqs)
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
s := <-serverID
c := <-clientID
if bytes.Compare(s, c) != 0 {
t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
} else if len(s) == 0 {
t.Errorf("client and server SessionID were empty.")
}
}
type noReadConn struct {
readSeen bool
net.Conn
}
func (c *noReadConn) Close() error {
return nil
}
func (c *noReadConn) Read(b []byte) (int, error) {
c.readSeen = true
return 0, errors.New("noReadConn error")
}
func TestInvalidServerConfiguration(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serveConn := noReadConn{Conn: c1}
serverConf := &ServerConfig{}
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
}
serverConf.AddHostKey(testSigners["ecdsa"])
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
}
}
func TestHostKeyAlgorithms(t *testing.T) {
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
alg = key.Type()
return nil
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}
if alg != want {
t.Errorf("selected key algorithm %s, want %s", alg, want)
}
}
// By default, we get the preferred algorithm, which is ECDSA 256.
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
}
connect(clientConf, KeyAlgoECDSA256)
// Client asks for RSA explicitly.
clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
connect(clientConf, KeyAlgoRSA)
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
t.Fatal("succeeded connecting with unknown hostkey algorithm")
}
}

View file

@ -32,6 +32,7 @@ type streamLocalChannelForwardMsg struct {
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := streamLocalChannelForwardMsg{
socketPath,
}

View file

@ -90,10 +90,19 @@ type channelForwardMsg struct {
rport uint32
}
// handleForwards starts goroutines handling forwarded connections.
// It's called on first use by (*Client).ListenTCP to not launch
// goroutines until needed.
func (c *Client) handleForwards() {
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}

View file

@ -1,20 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
)
func TestAutoPortListenBroken(t *testing.T) {
broken := "SSH-2.0-OpenSSH_5.9hh11"
works := "SSH-2.0-OpenSSH_6.1"
if !isBrokenOpenSSHVersion(broken) {
t.Errorf("version %q not marked as broken", broken)
}
if isBrokenOpenSSHVersion(works) {
t.Errorf("version %q marked as broken", works)
}
}

View file

@ -617,7 +617,7 @@ func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
if _, err = w.Write(crlf); err != nil {
return n, err
}
n += 1
n++
buf = buf[1:]
}
}

View file

@ -1,350 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import (
"bytes"
"io"
"os"
"testing"
)
type MockTerminal struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockTerminal) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockTerminal) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func TestClose(t *testing.T) {
c := &MockTerminal{}
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != io.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err error
throwAwayLines int
}{
{
err: io.EOF,
},
{
in: "\r",
line: "",
},
{
in: "foo\r",
line: "foo",
},
{
in: "a\x1b[Cb\r", // right
line: "ab",
},
{
in: "a\x1b[Db\r", // left
line: "ba",
},
{
in: "a\177b\r", // backspace
line: "b",
},
{
in: "\x1b[A\r", // up
},
{
in: "\x1b[B\r", // down
},
{
in: "line\x1b[A\x1b[B\r", // up then down
line: "line",
},
{
in: "line1\rline2\x1b[A\r", // recall previous line.
line: "line1",
throwAwayLines: 1,
},
{
// recall two previous lines and append.
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
line: "line1xxx",
throwAwayLines: 2,
},
{
// Ctrl-A to move to beginning of line followed by ^K to kill
// line.
in: "a b \001\013\r",
line: "",
},
{
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
// finally ^K to kill nothing.
in: "a b \001\005\013\r",
line: "a b ",
},
{
in: "\027\r",
line: "",
},
{
in: "a\027\r",
line: "",
},
{
in: "a \027\r",
line: "",
},
{
in: "a b\027\r",
line: "a ",
},
{
in: "a b \027\r",
line: "a ",
},
{
in: "one two thr\x1b[D\027\r",
line: "one two r",
},
{
in: "\013\r",
line: "",
},
{
in: "a\013\r",
line: "a",
},
{
in: "ab\x1b[D\013\r",
line: "a",
},
{
in: "Ξεσκεπάζω\r",
line: "Ξεσκεπάζω",
},
{
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
line: "",
throwAwayLines: 1,
},
{
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
line: "£",
throwAwayLines: 1,
},
{
// Ctrl-D at the end of the line should be ignored.
in: "a\004\r",
line: "a",
},
{
// a, b, left, Ctrl-D should erase the b.
in: "ab\x1b[D\004\r",
line: "a",
},
{
// a, b, c, d, left, left, ^U should erase to the beginning of
// the line.
in: "abcd\x1b[D\x1b[D\025\r",
line: "cd",
},
{
// Bracketed paste mode: control sequences should be returned
// verbatim in paste mode.
in: "abc\x1b[200~de\177f\x1b[201~\177\r",
line: "abcde\177",
},
{
// Enter in bracketed paste mode should still work.
in: "abc\x1b[200~d\refg\x1b[201~h\r",
line: "efgh",
throwAwayLines: 1,
},
{
// Lines consisting entirely of pasted data should be indicated as such.
in: "\x1b[200~a\r",
line: "a",
err: ErrPasteIndicator,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 1; j < len(test.in); j++ {
c := &MockTerminal{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
for k := 0; k < test.throwAwayLines; k++ {
_, err := ss.ReadLine()
if err != nil {
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
}
}
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}
func TestPasswordNotSaved(t *testing.T) {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
pw, _ := ss.ReadPassword("> ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
line, _ := ss.ReadLine()
if len(line) > 0 {
t.Fatalf("password was saved in history")
}
}
var setSizeTests = []struct {
width, height int
}{
{40, 13},
{80, 24},
{132, 43},
}
func TestTerminalSetSize(t *testing.T) {
for _, setSize := range setSizeTests {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
ss.SetSize(setSize.width, setSize.height)
pw, _ := ss.ReadPassword("Password: ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
if string(c.received) != "Password: \r\n" {
t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received)
}
}
}
func TestReadPasswordLineEnd(t *testing.T) {
var tests = []struct {
input string
want string
}{
{"\n", ""},
{"\r\n", ""},
{"test\r\n", "test"},
{"testtesttesttes\n", "testtesttesttes"},
{"testtesttesttes\r\n", "testtesttesttes"},
{"testtesttesttesttest\n", "testtesttesttesttest"},
{"testtesttesttesttest\r\n", "testtesttesttesttest"},
}
for _, test := range tests {
buf := new(bytes.Buffer)
if _, err := buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err := readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
if _, err = buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err = readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
}
}
func TestMakeRawState(t *testing.T) {
fd := int(os.Stdout.Fd())
if !IsTerminal(fd) {
t.Skip("stdout is not a terminal; skipping test")
}
st, err := GetState(fd)
if err != nil {
t.Fatalf("failed to get terminal state from GetState: %s", err)
}
defer Restore(fd, st)
raw, err := MakeRaw(fd)
if err != nil {
t.Fatalf("failed to get terminal state from MakeRaw: %s", err)
}
if *st != *raw {
t.Errorf("states do not match; was %v, expected %v", raw, st)
}
}
func TestOutputNewlines(t *testing.T) {
// \n should be changed to \r\n in terminal output.
buf := new(bytes.Buffer)
term := NewTerminal(buf, ">")
term.Write([]byte("1\n2\n"))
output := string(buf.Bytes())
const expected = "1\r\n2\r\n"
if output != expected {
t.Errorf("incorrect output: was %q, expected %q", output, expected)
}
}

View file

@ -108,9 +108,7 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err
}
defer func() {
unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
}()
defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
return readPasswordLine(passwordReader(fd))
}

View file

@ -14,7 +14,7 @@ import (
// State contains the state of a terminal.
type State struct {
state *unix.Termios
termios unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
@ -75,47 +75,43 @@ func ReadPassword(fd int) ([]byte, error) {
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldTermios := *oldTermiosPtr
newTermios := oldTermios
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
oldState := State{termios: *termios}
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
return &State{termios: *termios}, nil
}
// GetSize returns the dimensions of the given terminal.

View file

@ -17,6 +17,8 @@
package terminal
import (
"os"
"golang.org/x/sys/windows"
)
@ -71,13 +73,6 @@ func GetSize(fd int) (width, height int, err error) {
return int(info.Size.X), int(info.Size.Y), nil
}
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return windows.Read(windows.Handle(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
@ -94,9 +89,15 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err
}
defer func() {
windows.SetConsoleMode(windows.Handle(fd), old)
}()
defer windows.SetConsoleMode(windows.Handle(fd), old)
return readPasswordLine(passwordReader(fd))
var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}
f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}

View file

@ -1,63 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.
package ssh
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/ssh/testdata"
)
var (
testPrivateKeys map[string]interface{}
testSigners map[string]Signer
testPublicKeys map[string]PublicKey
)
func init() {
var err error
n := len(testdata.PEMBytes)
testPrivateKeys = make(map[string]interface{}, n)
testSigners = make(map[string]Signer, n)
testPublicKeys = make(map[string]PublicKey, n)
for t, k := range testdata.PEMBytes {
testPrivateKeys[t], err = ParseRawPrivateKey(k)
if err != nil {
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
}
testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
if err != nil {
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
}
testPublicKeys[t] = testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: testPublicKeys["ecdsa"],
SignatureKey: testPublicKeys["rsa"],
Permissions: Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
testCert.SignCert(rand.Reader, testSigners["rsa"])
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
if err != nil {
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
}
}

View file

@ -6,6 +6,7 @@ package ssh
import (
"bufio"
"bytes"
"errors"
"io"
"log"
@ -76,17 +77,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err
} else {
t.reader.pendingKeyChange <- ciph
}
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err
} else {
t.writer.pendingKeyChange <- ciph
}
t.writer.pendingKeyChange <- ciph
return nil
}
@ -139,7 +140,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
return nil, errors.New("ssh: got bogus newkeys message")
}
case msgDisconnect:
@ -232,52 +233,22 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// generateKeys generates key material for IV, MAC and encryption.
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
iv = make([]byte, cipherMode.ivSize)
key = make([]byte, cipherMode.keySize)
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
return
}
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
iv, key, macKey := generateKeys(d, algs, kex)
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key)
}
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macMode.keySize)
if algs.Cipher == aes128cbcID {
return newAESCBCCipher(iv, key, macKey, algs)
}
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
if algs.Cipher == tripledescbcID {
return newTripleDESCBCCipher(iv, key, macKey, algs)
}
c := &streamPacketCipher{
mac: macModes[algs.MAC].new(macKey),
etm: macModes[algs.MAC].etm,
}
c.macResult = make([]byte, c.mac.Size())
var err error
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
if err != nil {
return nil, err
}
return c, nil
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
}
// generateKeyMaterial fills out with key material generated from tag, K, H
@ -342,7 +313,7 @@ func readVersion(r io.Reader) ([]byte, error) {
var ok bool
var buf [1]byte
for len(versionString) < maxVersionStringBytes {
for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
@ -350,6 +321,13 @@ func readVersion(r io.Reader) ([]byte, error) {
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true
break
}

View file

@ -1,109 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"encoding/binary"
"strings"
"testing"
)
func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion,
}
for in, want := range cases {
result, err := readVersion(bytes.NewBufferString(in))
if err != nil {
t.Errorf("readVersion(%q): %s", in, err)
}
got := string(result)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := []string{
longversion + "too-long\r\n",
}
for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
t.Errorf("readVersion(%q) should have failed", in)
}
}
}
func TestExchangeVersionsBasic(t *testing.T) {
v := "SSH-2.0-bla"
buf := bytes.NewBufferString(v + "\r\n")
them, err := exchangeVersions(buf, []byte("xyz"))
if err != nil {
t.Errorf("exchangeVersions: %v", err)
}
if want := "SSH-2.0-bla"; string(them) != want {
t.Errorf("got %q want %q for our version", them, want)
}
}
func TestExchangeVersions(t *testing.T) {
cases := []string{
"not\x000allowed",
"not allowed\n",
}
for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
if _, err := exchangeVersions(buf, []byte(c)); err == nil {
t.Errorf("exchangeVersions(%q): should have failed", c)
}
}
}
type closerBuffer struct {
bytes.Buffer
}
func (b *closerBuffer) Close() error {
return nil
}
func TestTransportMaxPacketWrite(t *testing.T) {
buf := &closerBuffer{}
tr := newTransport(buf, rand.Reader, true)
huge := make([]byte, maxPacket+1)
err := tr.writePacket(huge)
if err == nil {
t.Errorf("transport accepted write for a huge packet.")
}
}
func TestTransportMaxPacketReader(t *testing.T) {
var header [5]byte
huge := make([]byte, maxPacket+128)
binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
// padding.
header[4] = 0
buf := &closerBuffer{}
buf.Write(header[:])
buf.Write(huge)
tr := newTransport(buf, rand.Reader, true)
_, err := tr.readPacket()
if err == nil {
t.Errorf("transport succeeded reading huge packet.")
} else if !strings.Contains(err.Error(), "large") {
t.Errorf("got %q, should mention %q", err.Error(), "large")
}
}