Testing for net.

This commit is contained in:
Andrey Petrov 2014-12-22 15:53:30 -08:00
parent 59ac8bb037
commit 7beb7f99bb
5 changed files with 143 additions and 107 deletions

View File

@ -7,7 +7,7 @@ var logger *stdlog.Logger
func SetLogger(w io.Writer) {
flags := stdlog.Flags()
prefix := "[chat] "
prefix := "[sshd] "
logger = stdlog.New(w, prefix, flags)
}

View File

@ -2,7 +2,6 @@ package sshd
import (
"net"
"syscall"
"golang.org/x/crypto/ssh"
)
@ -19,8 +18,7 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
if err != nil {
return nil, err
}
l := socket.(SSHListener)
l.config = config
l := SSHListener{socket, config}
return &l, nil
}
@ -41,15 +39,14 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal {
go func() {
defer l.Close()
defer close(ch)
for {
conn, err := l.Accept()
if err != nil {
logger.Printf("Failed to accept connection: %v", err)
if err == syscall.EINVAL {
return
}
return
}
// Goroutineify to resume accepting sockets early

137
sshd/net_test.go Normal file
View File

@ -0,0 +1,137 @@
package sshd
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"io"
"testing"
"golang.org/x/crypto/ssh"
)
// TODO: Move some of these into their own package?
func MakeKey(bits int) (ssh.Signer, error) {
key, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
return ssh.NewSignerFromKey(key)
}
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
config := &ssh.ClientConfig{
User: name,
Auth: []ssh.AuthMethod{
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return
}),
},
}
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
return err
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return err
}
defer session.Close()
in, err := session.StdinPipe()
if err != nil {
return err
}
out, err := session.StdoutPipe()
if err != nil {
return err
}
err = session.Shell()
if err != nil {
return err
}
handler(out, in)
return nil
}
func TestServerInit(t *testing.T) {
config := MakeNoAuth()
s, err := ListenSSH(":badport", config)
if err == nil {
t.Fatal("should fail on bad port")
}
s, err = ListenSSH(":0", config)
if err != nil {
t.Error(err)
}
err = s.Close()
if err != nil {
t.Error(err)
}
}
func TestServeTerminals(t *testing.T) {
signer, err := MakeKey(512)
config := MakeNoAuth()
config.AddHostKey(signer)
s, err := ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
terminals := s.ServeTerminal()
go func() {
// Accept one terminal, read from it, echo back, close.
term := <-terminals
term.SetPrompt("> ")
line, err := term.ReadLine()
if err != nil {
t.Error(err)
}
_, err = term.Write([]byte("echo: " + line + "\r\n"))
if err != nil {
t.Error(err)
}
term.Close()
}()
host := s.Addr().String()
name := "foo"
err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) {
// Consume if there is anything
buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n"))
buf.Reset()
_, err := io.Copy(buf, r)
if err != nil {
t.Error(err)
}
expected := "> hello\r\necho: hello\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("Got `%s`; expected `%s`", actual, expected)
}
s.Close()
})
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,98 +0,0 @@
package sshd
import (
"net"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
// Server holds all the fields used by a server
type Server struct {
sshConfig *ssh.ServerConfig
done chan struct{}
started time.Time
sync.RWMutex
}
// Initialize a new server
func NewServer(privateKey []byte) (*Server, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
server := Server{
done: make(chan struct{}),
started: time.Now(),
}
config := MakeNoAuth()
config.AddHostKey(signer)
server.sshConfig = config
return &server, nil
}
// Start starts the server
func (s *Server) Start(laddr string) error {
// Once a ServerConfig has been configured, connections can be
// accepted.
socket, err := net.Listen("tcp", laddr)
if err != nil {
return err
}
logger.Infof("Listening on %s", laddr)
go func() {
defer socket.Close()
for {
conn, err := socket.Accept()
if err != nil {
logger.Printf("Failed to accept connection: %v", err)
if err == syscall.EINVAL {
// TODO: Handle shutdown more gracefully?
return
}
}
// Goroutineify to resume accepting sockets early.
go func() {
// From a standard TCP connection to an encrypted SSH connection
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
logger.Printf("Failed to handshake: %v", err)
return
}
go ssh.DiscardRequests(requests)
client := NewClient(s, sshConn)
go client.handleChannels(channels)
}()
}
}()
go func() {
<-s.done
socket.Close()
}()
return nil
}
// Stop stops the server
func (s *Server) Stop() {
s.Lock()
for _, client := range s.clients {
client.Conn.Close()
}
s.Unlock()
close(s.done)
}

View File

@ -10,7 +10,7 @@ import (
// Extending ssh/terminal to include a closer interface
type Terminal struct {
*terminal.Terminal
terminal.Terminal
Conn ssh.Conn
Channel ssh.Channel
}
@ -25,7 +25,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
return nil, err
}
term := Terminal{
terminal.NewTerminal(channel, "Connecting..."),
*terminal.NewTerminal(channel, "Connecting..."),
conn,
channel,
}