479 lines
12 KiB
Go
479 lines
12 KiB
Go
package sshchat
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
mathRand "math/rand"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/shazow/ssh-chat/chat/message"
|
|
"github.com/shazow/ssh-chat/sshd"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
func stripPrompt(s string) string {
|
|
// FIXME: Is there a better way to do this?
|
|
if endPos := strings.Index(s, "\x1b[K "); endPos > 0 {
|
|
return s[endPos+3:]
|
|
}
|
|
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
|
|
return s[endPos+4:]
|
|
}
|
|
if endPos := strings.Index(s, "\x1b[K-> "); endPos > 0 {
|
|
return s[endPos+6:]
|
|
}
|
|
if endPos := strings.Index(s, "] "); endPos > 0 {
|
|
return s[endPos+2:]
|
|
}
|
|
if strings.HasPrefix(s, "-> ") {
|
|
return s[3:]
|
|
}
|
|
return s
|
|
}
|
|
|
|
func TestStripPrompt(t *testing.T) {
|
|
tests := []struct {
|
|
Input string
|
|
Want string
|
|
}{
|
|
{
|
|
Input: "\x1b[A\x1b[2K[quux] hello",
|
|
Want: "hello",
|
|
},
|
|
{
|
|
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
|
|
Want: " * Guest1 joined. (Connected: 2)\r",
|
|
},
|
|
{
|
|
Input: "[foo] \x1b[6D\x1b[K-> From your friendly system.\r",
|
|
Want: "From your friendly system.\r",
|
|
},
|
|
{
|
|
Input: "-> Err: must be op.\r",
|
|
Want: "Err: must be op.\r",
|
|
},
|
|
}
|
|
|
|
for i, tc := range tests {
|
|
if got, want := stripPrompt(tc.Input), tc.Want; got != want {
|
|
t.Errorf("case #%d:\n got: %q\nwant: %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHostGetPrompt(t *testing.T) {
|
|
var expected, actual string
|
|
|
|
// Make the random colors consistent across tests
|
|
mathRand.Seed(1)
|
|
|
|
u := message.NewUser(&Identity{id: "foo"})
|
|
|
|
actual = GetPrompt(u)
|
|
expected = "[foo] "
|
|
if actual != expected {
|
|
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
|
|
u.SetConfig(message.UserConfig{
|
|
Theme: &message.Themes[0],
|
|
})
|
|
actual = GetPrompt(u)
|
|
expected = "[\033[38;05;88mfoo\033[0m] "
|
|
if actual != expected {
|
|
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
}
|
|
|
|
func getHost(t *testing.T, auth *Auth) (*sshd.SSHListener, *Host) {
|
|
key, err := sshd.NewRandomSigner(1024)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var config *ssh.ServerConfig
|
|
if auth == nil {
|
|
config = sshd.MakeNoAuth()
|
|
} else {
|
|
config = sshd.MakeAuth(auth)
|
|
}
|
|
config.AddHostKey(key)
|
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return s, NewHost(s, auth)
|
|
}
|
|
|
|
func TestHostNameCollision(t *testing.T) {
|
|
s, host := getHost(t, nil)
|
|
defer s.Close()
|
|
|
|
newUsers := make(chan *message.User)
|
|
host.OnUserJoined = func(u *message.User) {
|
|
newUsers <- u
|
|
}
|
|
go host.Serve()
|
|
|
|
g := errgroup.Group{}
|
|
|
|
// First client
|
|
g.Go(func() error {
|
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
// second client
|
|
name := (<-newUsers).Name()
|
|
if name != "Guest1" {
|
|
t.Errorf("Second client did not get Guest1 name: %q", name)
|
|
}
|
|
return nil
|
|
})
|
|
})
|
|
|
|
// Second client
|
|
g.Go(func() error {
|
|
// first client
|
|
name := (<-newUsers).Name()
|
|
if name != "foo" {
|
|
t.Errorf("First client did not get foo name: %q", name)
|
|
}
|
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
return nil
|
|
})
|
|
})
|
|
|
|
if err := g.Wait(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestHostAllowlist(t *testing.T) {
|
|
auth := NewAuth()
|
|
s, host := getHost(t, auth)
|
|
defer s.Close()
|
|
go host.Serve()
|
|
|
|
target := s.Addr().String()
|
|
|
|
clientPrivateKey, err := sshd.NewRandomSigner(512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientKey := clientPrivateKey.PublicKey()
|
|
loadCount := -1
|
|
loader := func() ([]ssh.PublicKey, error) {
|
|
loadCount++
|
|
return [][]ssh.PublicKey{
|
|
{},
|
|
{clientKey},
|
|
}[loadCount], nil
|
|
}
|
|
auth.LoadAllowlist(loader)
|
|
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
auth.SetAllowlistMode(true)
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err == nil {
|
|
t.Error(err)
|
|
}
|
|
err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err == nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
auth.ReloadAllowlist()
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err == nil {
|
|
t.Error("Failed to block unallowlisted connection.")
|
|
}
|
|
}
|
|
|
|
func TestHostAllowlistCommand(t *testing.T) {
|
|
s, host := getHost(t, NewAuth())
|
|
defer s.Close()
|
|
go host.Serve()
|
|
|
|
users := make(chan *message.User)
|
|
host.OnUserJoined = func(u *message.User) {
|
|
users <- u
|
|
}
|
|
|
|
kickSignal := make(chan struct{})
|
|
clientKey, err := sshd.NewRandomSigner(1024)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientKeyFP := sshd.Fingerprint(clientKey.PublicKey())
|
|
go sshd.ConnectShellWithKey(s.Addr().String(), "bar", clientKey, func(r io.Reader, w io.WriteCloser) error {
|
|
<-kickSignal
|
|
n, err := w.Write([]byte("alive and well"))
|
|
if n != 0 || err == nil {
|
|
t.Error("could write after being kicked")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
<-users
|
|
<-users
|
|
m, ok := host.MemberByID("foo")
|
|
if !ok {
|
|
t.Fatal("can't get member foo")
|
|
}
|
|
|
|
scanner := bufio.NewScanner(r)
|
|
scanner.Scan() // Joined
|
|
scanner.Scan()
|
|
|
|
assertLineEq := func(expected ...string) {
|
|
if !scanner.Scan() {
|
|
t.Error("no line available")
|
|
}
|
|
actual := stripPrompt(scanner.Text())
|
|
for _, exp := range expected {
|
|
if exp == actual {
|
|
return
|
|
}
|
|
}
|
|
t.Errorf("expected %#v, got %q", expected, actual)
|
|
}
|
|
sendCmd := func(cmd string, formatting ...interface{}) {
|
|
host.HandleMsg(message.ParseInput(fmt.Sprintf(cmd, formatting...), m.User))
|
|
}
|
|
|
|
sendCmd("/allowlist")
|
|
assertLineEq("Err: must be op\r")
|
|
m.IsOp = true
|
|
sendCmd("/allowlist")
|
|
for _, expected := range [...]string{"Usage", "help", "on, off", "add, remove", "import", "reload", "reverify", "status"} {
|
|
if !scanner.Scan() {
|
|
t.Error("no line available")
|
|
}
|
|
if actual := stripPrompt(scanner.Text()); !strings.HasPrefix(actual, expected) {
|
|
t.Errorf("Unexpected help message order: have %q, want prefix %q", actual, expected)
|
|
}
|
|
}
|
|
|
|
sendCmd("/allowlist on")
|
|
if !host.auth.AllowlistMode() {
|
|
t.Error("allowlist not enabled after /allowlist on")
|
|
}
|
|
sendCmd("/allowlist off")
|
|
if host.auth.AllowlistMode() {
|
|
t.Error("allowlist not disabled after /allowlist off")
|
|
}
|
|
|
|
testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPUiNw0nQku4pcUCbZcJlIEAIf5bXJYTy/DKI1vh5b+P"
|
|
testKeyFP := "SHA256:GJNSl9NUcOS2pZYALn0C5Qgfh5deT+R+FfqNIUvpM9s="
|
|
|
|
if host.auth.allowlist.Len() != 0 {
|
|
t.Error("allowlist not empty before adding anyone")
|
|
}
|
|
sendCmd("/allowlist add ssh-invalid blah ssh-rsa wrongAsWell invalid foo bar %s", testKey)
|
|
assertLineEq("users without a public key: [foo]\r")
|
|
assertLineEq("invalid users: [invalid]\r")
|
|
assertLineEq("invalid keys: [ssh-invalid blah ssh-rsa wrongAsWell]\r")
|
|
if !host.auth.allowlist.In(testKeyFP) || !host.auth.allowlist.In(clientKeyFP) {
|
|
t.Error("failed to add keys to allowlist")
|
|
}
|
|
sendCmd("/allowlist remove invalid bar")
|
|
assertLineEq("invalid users: [invalid]\r")
|
|
if host.auth.allowlist.In(clientKeyFP) {
|
|
t.Error("failed to remove key from allowlist")
|
|
}
|
|
if !host.auth.allowlist.In(testKeyFP) {
|
|
t.Error("removed wrong key")
|
|
}
|
|
|
|
sendCmd("/allowlist import 5h")
|
|
if host.auth.allowlist.In(clientKeyFP) {
|
|
t.Error("imporrted key not seen long enough")
|
|
}
|
|
sendCmd("/allowlist import")
|
|
assertLineEq("users without a public key: [foo]\r")
|
|
if !host.auth.allowlist.In(clientKeyFP) {
|
|
t.Error("failed to import key")
|
|
}
|
|
|
|
sendCmd("/allowlist reload keep")
|
|
if !host.auth.allowlist.In(testKeyFP) {
|
|
t.Error("cleared allowlist to be kept")
|
|
}
|
|
sendCmd("/allowlist reload flush")
|
|
if host.auth.allowlist.In(testKeyFP) {
|
|
t.Error("kept allowlist to be cleared")
|
|
}
|
|
sendCmd("/allowlist reload thisIsWrong")
|
|
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
|
sendCmd("/allowlist reload")
|
|
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
|
|
|
sendCmd("/allowlist reverify")
|
|
assertLineEq("allowlist is disabled, so nobody will be kicked\r")
|
|
sendCmd("/allowlist on")
|
|
sendCmd("/allowlist reverify")
|
|
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
|
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
|
kickSignal <- struct{}{}
|
|
|
|
sendCmd("/allowlist add " + testKey)
|
|
sendCmd("/allowlist status")
|
|
assertLineEq("allowlist enabled\r")
|
|
assertLineEq(fmt.Sprintf("Keys on the allowlist without connected user: %s\r", testKeyFP))
|
|
|
|
sendCmd("/allowlist invalidSubcommand")
|
|
assertLineEq("Err: invalid subcommand: invalidSubcommand\r")
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func TestHostKick(t *testing.T) {
|
|
s, host := getHost(t, NewAuth())
|
|
defer s.Close()
|
|
go host.Serve()
|
|
|
|
g := errgroup.Group{}
|
|
connected := make(chan struct{})
|
|
kicked := make(chan struct{})
|
|
|
|
g.Go(func() error {
|
|
// First client
|
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
// Consume the initial buffer
|
|
scanner.Scan() // Joined
|
|
|
|
// Make op
|
|
member, _ := host.Room.MemberByID("foo")
|
|
if member == nil {
|
|
return errors.New("failed to load MemberByID")
|
|
}
|
|
member.IsOp = true
|
|
|
|
// Change nicks, make sure op sticks
|
|
w.Write([]byte("/nick quux\r\n"))
|
|
scanner.Scan() // Prompt
|
|
scanner.Scan() // Nick change response
|
|
|
|
// Block until second client is here
|
|
connected <- struct{}{}
|
|
scanner.Scan() // Connected message
|
|
|
|
w.Write([]byte("/kick bar\r\n"))
|
|
scanner.Scan() // Prompt
|
|
|
|
scanner.Scan() // Kick result
|
|
if actual, expected := stripPrompt(scanner.Text()), " * bar was kicked by quux.\r"; actual != expected {
|
|
t.Errorf("Failed to detect kick:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
|
|
kicked <- struct{}{}
|
|
return nil
|
|
})
|
|
})
|
|
|
|
g.Go(func() error {
|
|
// Second client
|
|
return sshd.ConnectShell(s.Addr().String(), "bar", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
<-connected
|
|
scanner.Scan()
|
|
|
|
<-kicked
|
|
|
|
if _, err := w.Write([]byte("am I still here?\r\n")); err != io.EOF {
|
|
return errors.New("expected to be kicked")
|
|
}
|
|
|
|
scanner.Scan()
|
|
if err := scanner.Err(); err == io.EOF {
|
|
// All good, we got kicked.
|
|
return nil
|
|
} else {
|
|
return err
|
|
}
|
|
})
|
|
})
|
|
|
|
if err := g.Wait(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestTimestampEnvConfig(t *testing.T) {
|
|
cases := []struct {
|
|
input string
|
|
timeformat *string
|
|
}{
|
|
{"", strptr("15:04")},
|
|
{"1", strptr("15:04")},
|
|
{"0", nil},
|
|
{"time +8h", strptr("15:04")},
|
|
{"datetime +8h", strptr("2006-01-02 15:04:05")},
|
|
}
|
|
for _, tc := range cases {
|
|
u := connectUserWithConfig(t, "dingus", map[string]string{
|
|
"SSHCHAT_TIMESTAMP": tc.input,
|
|
})
|
|
userConfig := u.Config()
|
|
if userConfig.Timeformat != nil && tc.timeformat != nil {
|
|
if *userConfig.Timeformat != *tc.timeformat {
|
|
t.Fatal("unexpected timeformat:", *userConfig.Timeformat, "expected:", *tc.timeformat)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func strptr(s string) *string {
|
|
return &s
|
|
}
|
|
|
|
func connectUserWithConfig(t *testing.T, name string, envConfig map[string]string) *message.User {
|
|
s, host := getHost(t, nil)
|
|
defer s.Close()
|
|
|
|
newUsers := make(chan *message.User)
|
|
host.OnUserJoined = func(u *message.User) {
|
|
newUsers <- u
|
|
}
|
|
go host.Serve()
|
|
|
|
clientConfig := sshd.NewClientConfig(name)
|
|
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
|
|
if err != nil {
|
|
t.Fatal("unable to connect to test ssh-chat server:", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
session, err := conn.NewSession()
|
|
if err != nil {
|
|
t.Fatal("unable to open session:", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
for key := range envConfig {
|
|
session.Setenv(key, envConfig[key])
|
|
}
|
|
|
|
err = session.Shell()
|
|
if err != nil {
|
|
t.Fatal("unable to open shell:", err)
|
|
}
|
|
|
|
for u := range newUsers {
|
|
if u.Name() == name {
|
|
return u
|
|
}
|
|
}
|
|
t.Fatalf("user %s not found in the host", name)
|
|
return nil
|
|
}
|