2015-01-21 20:47:59 +01:00
|
|
|
package sshchat
|
2014-12-27 02:40:57 +01:00
|
|
|
|
|
|
|
import (
|
2015-01-07 06:42:57 +01:00
|
|
|
"bufio"
|
2015-01-10 21:44:06 +01:00
|
|
|
"crypto/rand"
|
|
|
|
"crypto/rsa"
|
2016-07-24 21:34:56 +02:00
|
|
|
"errors"
|
2015-01-07 06:42:57 +01:00
|
|
|
"io"
|
|
|
|
"strings"
|
2014-12-27 02:40:57 +01:00
|
|
|
"testing"
|
|
|
|
|
2015-01-21 00:57:01 +01:00
|
|
|
"github.com/shazow/ssh-chat/chat/message"
|
2015-01-07 06:42:57 +01:00
|
|
|
"github.com/shazow/ssh-chat/sshd"
|
2015-01-10 21:44:06 +01:00
|
|
|
"golang.org/x/crypto/ssh"
|
2014-12-27 02:40:57 +01:00
|
|
|
)
|
|
|
|
|
2015-01-07 06:42:57 +01:00
|
|
|
func stripPrompt(s string) string {
|
|
|
|
pos := strings.LastIndex(s, "\033[K")
|
|
|
|
if pos < 0 {
|
|
|
|
return s
|
|
|
|
}
|
|
|
|
return s[pos+3:]
|
|
|
|
}
|
|
|
|
|
2014-12-27 02:40:57 +01:00
|
|
|
func TestHostGetPrompt(t *testing.T) {
|
|
|
|
var expected, actual string
|
|
|
|
|
2016-08-07 00:20:34 +02:00
|
|
|
u := message.NewUser(&Identity{id: "foo"})
|
2014-12-27 02:40:57 +01:00
|
|
|
|
|
|
|
actual = GetPrompt(u)
|
|
|
|
expected = "[foo] "
|
|
|
|
if actual != expected {
|
2015-01-07 06:42:57 +01:00
|
|
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
2014-12-27 02:40:57 +01:00
|
|
|
}
|
|
|
|
|
2016-08-29 16:11:39 +02:00
|
|
|
u.SetConfig(message.UserConfig{
|
|
|
|
Theme: &message.Themes[0],
|
|
|
|
})
|
2014-12-27 02:40:57 +01:00
|
|
|
actual = GetPrompt(u)
|
2016-08-29 15:58:17 +02:00
|
|
|
expected = "[\033[38;05;88mfoo\033[0m] "
|
2014-12-27 02:40:57 +01:00
|
|
|
if actual != expected {
|
2015-01-07 06:42:57 +01:00
|
|
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestHostNameCollision(t *testing.T) {
|
2015-01-10 21:44:06 +01:00
|
|
|
key, err := sshd.NewRandomSigner(512)
|
2015-01-07 06:42:57 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
config := sshd.MakeNoAuth()
|
|
|
|
config.AddHostKey(key)
|
|
|
|
|
2016-08-02 22:03:19 +02:00
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
2015-01-07 06:42:57 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
2015-01-10 21:44:06 +01:00
|
|
|
defer s.Close()
|
2015-01-21 20:47:59 +01:00
|
|
|
host := NewHost(s, nil)
|
2015-01-07 06:42:57 +01:00
|
|
|
go host.Serve()
|
|
|
|
|
|
|
|
done := make(chan struct{}, 1)
|
|
|
|
|
|
|
|
// First client
|
|
|
|
go func() {
|
2016-07-24 21:34:56 +02:00
|
|
|
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
2015-01-07 06:42:57 +01:00
|
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
|
|
|
|
// Consume the initial buffer
|
|
|
|
scanner.Scan()
|
2016-08-24 19:53:59 +02:00
|
|
|
actual := stripPrompt(scanner.Text())
|
2017-06-14 15:07:24 +02:00
|
|
|
expected := " * foo joined. (Connected: 1)\r"
|
2015-01-07 06:42:57 +01:00
|
|
|
if actual != expected {
|
|
|
|
t.Errorf("Got %q; expected %q", actual, expected)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ready for second client
|
|
|
|
done <- struct{}{}
|
|
|
|
|
|
|
|
scanner.Scan()
|
2016-08-24 19:53:59 +02:00
|
|
|
actual = scanner.Text()
|
|
|
|
// This check has to happen second because prompt doesn't always
|
|
|
|
// get set before the first message.
|
|
|
|
if !strings.HasPrefix(actual, "[foo] ") {
|
|
|
|
t.Errorf("First client failed to get 'foo' name: %q", actual)
|
|
|
|
}
|
|
|
|
actual = stripPrompt(actual)
|
2017-06-14 15:07:24 +02:00
|
|
|
expected = " * Guest1 joined. (Connected: 2)\r"
|
2015-01-07 06:42:57 +01:00
|
|
|
if actual != expected {
|
|
|
|
t.Errorf("Got %q; expected %q", actual, expected)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Wrap it up.
|
|
|
|
close(done)
|
2016-07-24 21:34:56 +02:00
|
|
|
return nil
|
2015-01-07 06:42:57 +01:00
|
|
|
})
|
|
|
|
if err != nil {
|
2017-06-14 15:07:24 +02:00
|
|
|
done <- struct{}{}
|
2015-01-07 06:42:57 +01:00
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
// Wait for first client
|
|
|
|
<-done
|
|
|
|
|
|
|
|
// Second client
|
2016-07-24 21:34:56 +02:00
|
|
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
2015-01-07 06:42:57 +01:00
|
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
|
|
|
|
// Consume the initial buffer
|
|
|
|
scanner.Scan()
|
2016-07-11 18:12:23 +02:00
|
|
|
scanner.Scan()
|
|
|
|
scanner.Scan()
|
|
|
|
|
2015-01-07 06:42:57 +01:00
|
|
|
actual := scanner.Text()
|
|
|
|
if !strings.HasPrefix(actual, "[Guest1] ") {
|
2016-07-11 18:12:23 +02:00
|
|
|
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
2015-01-07 06:42:57 +01:00
|
|
|
}
|
2016-07-24 21:34:56 +02:00
|
|
|
return nil
|
2015-01-07 06:42:57 +01:00
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
2014-12-27 02:40:57 +01:00
|
|
|
}
|
2015-01-07 06:42:57 +01:00
|
|
|
|
|
|
|
<-done
|
2015-01-10 21:44:06 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestHostWhitelist(t *testing.T) {
|
|
|
|
key, err := sshd.NewRandomSigner(512)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
auth := NewAuth()
|
|
|
|
config := sshd.MakeAuth(auth)
|
|
|
|
config.AddHostKey(key)
|
|
|
|
|
2016-08-02 22:03:19 +02:00
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
2015-01-10 21:44:06 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
defer s.Close()
|
2015-01-21 20:47:59 +01:00
|
|
|
host := NewHost(s, auth)
|
2015-01-10 21:44:06 +01:00
|
|
|
go host.Serve()
|
|
|
|
|
|
|
|
target := s.Addr().String()
|
|
|
|
|
2016-07-24 21:34:56 +02:00
|
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
2015-01-10 21:44:06 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Error(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
2015-01-20 04:16:37 +01:00
|
|
|
auth.Whitelist(clientpubkey, 0)
|
2015-01-10 21:44:06 +01:00
|
|
|
|
2016-07-24 21:34:56 +02:00
|
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
2015-01-10 21:44:06 +01:00
|
|
|
if err == nil {
|
|
|
|
t.Error("Failed to block unwhitelisted connection.")
|
|
|
|
}
|
2014-12-27 02:40:57 +01:00
|
|
|
}
|
2015-01-11 23:12:51 +01:00
|
|
|
|
|
|
|
func TestHostKick(t *testing.T) {
|
|
|
|
key, err := sshd.NewRandomSigner(512)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
auth := NewAuth()
|
|
|
|
config := sshd.MakeAuth(auth)
|
|
|
|
config.AddHostKey(key)
|
|
|
|
|
2016-08-02 22:03:19 +02:00
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
2015-01-11 23:12:51 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
defer s.Close()
|
|
|
|
addr := s.Addr().String()
|
2015-01-21 20:47:59 +01:00
|
|
|
host := NewHost(s, nil)
|
2015-01-11 23:12:51 +01:00
|
|
|
go host.Serve()
|
|
|
|
|
|
|
|
connected := make(chan struct{})
|
2019-03-15 23:18:20 +01:00
|
|
|
kicked := make(chan struct{})
|
2015-01-11 23:12:51 +01:00
|
|
|
done := make(chan struct{})
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
// First client
|
2016-07-24 21:34:56 +02:00
|
|
|
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
2019-03-15 23:18:20 +01:00
|
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
|
|
|
|
// Consume the initial buffer
|
|
|
|
scanner.Scan()
|
|
|
|
|
2015-01-11 23:12:51 +01:00
|
|
|
// Make op
|
2016-08-24 19:47:22 +02:00
|
|
|
member, _ := host.Room.MemberByID("foo")
|
2016-07-11 19:02:34 +02:00
|
|
|
if member == nil {
|
2016-08-24 19:47:22 +02:00
|
|
|
return errors.New("failed to load MemberByID")
|
2016-07-11 19:02:34 +02:00
|
|
|
}
|
2019-03-15 23:18:20 +01:00
|
|
|
member.IsOp = true
|
|
|
|
|
|
|
|
// Change nicks, make sure op sticks
|
|
|
|
w.Write([]byte("/nick quux\r\n"))
|
|
|
|
scanner.Scan() // Prompt
|
|
|
|
scanner.Scan() // Nick change response
|
2015-01-11 23:12:51 +01:00
|
|
|
|
|
|
|
// Block until second client is here
|
|
|
|
connected <- struct{}{}
|
2019-03-15 23:18:20 +01:00
|
|
|
scanner.Scan() // Connected message
|
|
|
|
|
2015-01-11 23:12:51 +01:00
|
|
|
w.Write([]byte("/kick bar\r\n"))
|
2019-03-15 23:18:20 +01:00
|
|
|
scanner.Scan() // Prompt
|
|
|
|
|
|
|
|
scanner.Scan()
|
|
|
|
if actual, expected := stripPrompt(scanner.Text()), " * bar was kicked by quux.\r"; actual != expected {
|
|
|
|
t.Errorf("Got %q; expected %q", actual, expected)
|
|
|
|
}
|
|
|
|
|
|
|
|
kicked <- struct{}{}
|
|
|
|
|
2016-07-24 21:34:56 +02:00
|
|
|
return nil
|
2015-01-11 23:12:51 +01:00
|
|
|
})
|
|
|
|
if err != nil {
|
2017-06-14 15:07:24 +02:00
|
|
|
connected <- struct{}{}
|
2016-07-24 21:34:56 +02:00
|
|
|
close(connected)
|
2015-01-11 23:12:51 +01:00
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
// Second client
|
2016-07-24 21:34:56 +02:00
|
|
|
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
2019-03-15 23:18:20 +01:00
|
|
|
scanner := bufio.NewScanner(r)
|
2015-01-11 23:12:51 +01:00
|
|
|
<-connected
|
2019-03-15 23:18:20 +01:00
|
|
|
scanner.Scan()
|
2015-01-11 23:12:51 +01:00
|
|
|
|
2019-03-15 23:18:20 +01:00
|
|
|
<-kicked
|
|
|
|
|
|
|
|
scanner.Scan()
|
|
|
|
return scanner.Err()
|
2015-01-11 23:12:51 +01:00
|
|
|
})
|
|
|
|
if err != nil {
|
2017-06-14 15:07:24 +02:00
|
|
|
close(done)
|
2015-01-11 23:12:51 +01:00
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
close(done)
|
|
|
|
}()
|
|
|
|
|
2016-07-13 00:24:02 +02:00
|
|
|
<-done
|
2015-01-11 23:12:51 +01:00
|
|
|
}
|