Add /allowlist command (#399)
* move loading whitelist+ops from file to auth and save the loaded files fro reloading
* add /whitelist command with lots of open questions
* add test for /whitelist
* gofmt
* use the same auth (the tests don't seem to care, but htis is more right)
* mutex whitelistMode and remove some deferred TODOs
* s/whitelist/allowlist/ (user-facing); move helper functions outside the handler function
* check for ops in Auth.CheckPublicKey and move /allowlist handling to helper functions
* possibly fix the test timeout in HostNameCollision
* Revert "possibly fix the test timeout in HostNameCollision" (didn't work)
This reverts commit 664dbb0976
.
* managed to reproduce the timeout after updating, hopefully it's the same one
* remove some unimportant TODOs; add a message when reverify kicks people; add a reverify test
* add client connection with key; add test for /allowlist import AGE
* hopefully make test less racy
* s/whitelist/allowlist/
* fix crash on specifying exactly one more -v flag than the max level
* use a key loader function to move file reading out of auth
* add loader to allowlist test
* minor message changes
* add --whitelist with a warning; update tests for messages
* apparently, we have another prefix
* check names directly on the User objects in TestHostNameCollision
* not allowlisted -> not allowed
* small message change
* update test
This commit is contained in:
parent
84bc5c76dd
commit
621ae1b0d3
103
auth.go
103
auth.go
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shazow/ssh-chat/set"
|
||||
|
@ -15,9 +16,13 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// ErrNotWhitelisted Is the error returned when a key is checked that is not whitelisted,
|
||||
// when whitelisting is enabled.
|
||||
var ErrNotWhitelisted = errors.New("not whitelisted")
|
||||
// KeyLoader loads public keys, e.g. from an authorized_keys file.
|
||||
// It must return a nil slice on error.
|
||||
type KeyLoader func() ([]ssh.PublicKey, error)
|
||||
|
||||
// ErrNotAllowed Is the error returned when a key is checked that is not allowlisted,
|
||||
// when allowlisting is enabled.
|
||||
var ErrNotAllowed = errors.New("not allowed")
|
||||
|
||||
// ErrBanned is the error returned when a client is banned.
|
||||
var ErrBanned = errors.New("banned")
|
||||
|
@ -47,15 +52,20 @@ func newAuthAddr(addr net.Addr) string {
|
|||
return host
|
||||
}
|
||||
|
||||
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
|
||||
// If the contained passphrase is not empty, it complements a whitelist.
|
||||
// Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface.
|
||||
// If the contained passphrase is not empty, it complements a allowlist.
|
||||
type Auth struct {
|
||||
passphraseHash []byte
|
||||
bannedAddr *set.Set
|
||||
bannedClient *set.Set
|
||||
banned *set.Set
|
||||
whitelist *set.Set
|
||||
allowlist *set.Set
|
||||
ops *set.Set
|
||||
|
||||
settingsMu sync.RWMutex
|
||||
allowlistMode bool
|
||||
opLoader KeyLoader
|
||||
allowlistLoader KeyLoader
|
||||
}
|
||||
|
||||
// NewAuth creates a new empty Auth.
|
||||
|
@ -64,11 +74,23 @@ func NewAuth() *Auth {
|
|||
bannedAddr: set.New(),
|
||||
bannedClient: set.New(),
|
||||
banned: set.New(),
|
||||
whitelist: set.New(),
|
||||
allowlist: set.New(),
|
||||
ops: set.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) AllowlistMode() bool {
|
||||
a.settingsMu.RLock()
|
||||
defer a.settingsMu.RUnlock()
|
||||
return a.allowlistMode
|
||||
}
|
||||
|
||||
func (a *Auth) SetAllowlistMode(value bool) {
|
||||
a.settingsMu.Lock()
|
||||
defer a.settingsMu.Unlock()
|
||||
a.allowlistMode = value
|
||||
}
|
||||
|
||||
// SetPassphrase enables passphrase authentication with the given passphrase.
|
||||
// If an empty passphrase is given, disable passphrase authentication.
|
||||
func (a *Auth) SetPassphrase(passphrase string) {
|
||||
|
@ -82,7 +104,7 @@ func (a *Auth) SetPassphrase(passphrase string) {
|
|||
|
||||
// AllowAnonymous determines if anonymous users are permitted.
|
||||
func (a *Auth) AllowAnonymous() bool {
|
||||
return a.whitelist.Len() == 0 && a.passphraseHash == nil
|
||||
return !a.AllowlistMode() && a.passphraseHash == nil
|
||||
}
|
||||
|
||||
// AcceptPassphrase determines if passphrase authentication is accepted.
|
||||
|
@ -115,11 +137,11 @@ func (a *Auth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string)
|
|||
// CheckPubkey determines if a pubkey fingerprint is permitted.
|
||||
func (a *Auth) CheckPublicKey(key ssh.PublicKey) error {
|
||||
authkey := newAuthKey(key)
|
||||
whitelisted := a.whitelist.In(authkey)
|
||||
if a.AllowAnonymous() || whitelisted {
|
||||
allowlisted := a.allowlist.In(authkey)
|
||||
if a.AllowAnonymous() || allowlisted || a.IsOp(key) {
|
||||
return nil
|
||||
} else {
|
||||
return ErrNotWhitelisted
|
||||
return ErrNotAllowed
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,25 +173,68 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
|
|||
|
||||
// IsOp checks if a public key is an op.
|
||||
func (a *Auth) IsOp(key ssh.PublicKey) bool {
|
||||
if key == nil {
|
||||
return false
|
||||
}
|
||||
authkey := newAuthKey(key)
|
||||
return a.ops.In(authkey)
|
||||
}
|
||||
|
||||
// Whitelist will set a public key as a whitelisted user.
|
||||
func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
|
||||
// LoadOps sets the public keys form loader to operators and saves the loader for later use
|
||||
func (a *Auth) LoadOps(loader KeyLoader) error {
|
||||
a.settingsMu.Lock()
|
||||
a.opLoader = loader
|
||||
a.settingsMu.Unlock()
|
||||
return a.ReloadOps()
|
||||
}
|
||||
|
||||
// ReloadOps sets the public keys from a loader saved in the last call to operators
|
||||
func (a *Auth) ReloadOps() error {
|
||||
a.settingsMu.RLock()
|
||||
defer a.settingsMu.RUnlock()
|
||||
return addFromLoader(a.opLoader, a.Op)
|
||||
}
|
||||
|
||||
// Allowlist will set a public key as a allowlisted user.
|
||||
func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
authItem := newAuthItem(key)
|
||||
if d != 0 {
|
||||
a.whitelist.Set(set.Expire(authItem, d))
|
||||
err = a.allowlist.Set(set.Expire(authItem, d))
|
||||
} else {
|
||||
a.whitelist.Set(authItem)
|
||||
err = a.allowlist.Set(authItem)
|
||||
}
|
||||
logger.Debugf("Added to whitelist: %q (for %s)", authItem.Key(), d)
|
||||
if err == nil {
|
||||
logger.Debugf("Added to allowlist: %q (for %s)", authItem.Key(), d)
|
||||
} else {
|
||||
logger.Errorf("Error adding %q to allowlist for %s: %s", authItem.Key(), d, err)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use
|
||||
func (a *Auth) LoadAllowlist(loader KeyLoader) error {
|
||||
a.settingsMu.Lock()
|
||||
a.allowlistLoader = loader
|
||||
a.settingsMu.Unlock()
|
||||
return a.ReloadAllowlist()
|
||||
}
|
||||
|
||||
// LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist
|
||||
func (a *Auth) ReloadAllowlist() error {
|
||||
a.settingsMu.RLock()
|
||||
defer a.settingsMu.RUnlock()
|
||||
return addFromLoader(a.allowlistLoader, a.Allowlist)
|
||||
}
|
||||
|
||||
func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error {
|
||||
if loader == nil {
|
||||
return nil
|
||||
}
|
||||
keys, err := loader()
|
||||
for _, key := range keys {
|
||||
adder(key, 0)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Ban will set a public key as banned.
|
||||
|
|
|
@ -21,7 +21,7 @@ func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
|
|||
return ssh.ParsePublicKey(key.Marshal())
|
||||
}
|
||||
|
||||
func TestAuthWhitelist(t *testing.T) {
|
||||
func TestAuthAllowlist(t *testing.T) {
|
||||
key, err := NewRandomPublicKey(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -33,7 +33,8 @@ func TestAuthWhitelist(t *testing.T) {
|
|||
t.Error("Failed to permit in default state:", err)
|
||||
}
|
||||
|
||||
auth.Whitelist(key, 0)
|
||||
auth.Allowlist(key, 0)
|
||||
auth.SetAllowlistMode(true)
|
||||
|
||||
keyClone, err := ClonePublicKey(key)
|
||||
if err != nil {
|
||||
|
@ -46,7 +47,7 @@ func TestAuthWhitelist(t *testing.T) {
|
|||
|
||||
err = auth.CheckPublicKey(keyClone)
|
||||
if err != nil {
|
||||
t.Error("Failed to permit whitelisted:", err)
|
||||
t.Error("Failed to permit allowlisted:", err)
|
||||
}
|
||||
|
||||
key2, err := NewRandomPublicKey(512)
|
||||
|
@ -56,7 +57,7 @@ func TestAuthWhitelist(t *testing.T) {
|
|||
|
||||
err = auth.CheckPublicKey(key2)
|
||||
if err == nil {
|
||||
t.Error("Failed to restrict not whitelisted:", err)
|
||||
t.Error("Failed to restrict not allowlisted:", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,9 @@ type Options struct {
|
|||
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
|
||||
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
|
||||
Version bool `long:"version" description:"Print version and exit."`
|
||||
Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
|
||||
Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Whitelist feature is more secure."`
|
||||
Allowlist string `long:"allowlist" description:"Optional file of public keys who are allowed to connect."`
|
||||
Whitelist string `long:"whitelist" dexcription:"Old name for allowlist option"`
|
||||
Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Allowlist feature is more secure."`
|
||||
}
|
||||
|
||||
const extraHelp = `There are hidden options and easter eggs in ssh-chat. The source code is a good
|
||||
|
@ -87,7 +88,7 @@ func main() {
|
|||
|
||||
// Figure out the log level
|
||||
numVerbose := len(options.Verbose)
|
||||
if numVerbose > len(logLevels) {
|
||||
if numVerbose >= len(logLevels) {
|
||||
numVerbose = len(logLevels) - 1
|
||||
}
|
||||
|
||||
|
@ -141,35 +142,20 @@ func main() {
|
|||
auth.SetPassphrase(options.Passphrase)
|
||||
}
|
||||
|
||||
err = fromFile(options.Admin, func(line []byte) error {
|
||||
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
||||
if err != nil {
|
||||
if err.Error() == "ssh: no key found" {
|
||||
return nil // Skip line
|
||||
}
|
||||
return err
|
||||
}
|
||||
auth.Op(key, 0)
|
||||
return nil
|
||||
})
|
||||
err = auth.LoadOps(loaderFromFile(options.Admin, logger))
|
||||
if err != nil {
|
||||
fail(5, "Failed to load admins: %v\n", err)
|
||||
}
|
||||
|
||||
err = fromFile(options.Whitelist, func(line []byte) error {
|
||||
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
||||
if err != nil {
|
||||
if err.Error() == "ssh: no key found" {
|
||||
return nil // Skip line
|
||||
}
|
||||
return err
|
||||
}
|
||||
auth.Whitelist(key, 0)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
fail(6, "Failed to load whitelist: %v\n", err)
|
||||
if options.Allowlist == "" && options.Whitelist != "" {
|
||||
fmt.Println("--whitelist was renamed to --allowlist.")
|
||||
options.Allowlist = options.Whitelist
|
||||
}
|
||||
err = auth.LoadAllowlist(loaderFromFile(options.Allowlist, logger))
|
||||
if err != nil {
|
||||
fail(6, "Failed to load allowlist: %v\n", err)
|
||||
}
|
||||
auth.SetAllowlistMode(options.Allowlist != "")
|
||||
|
||||
if options.Motd != "" {
|
||||
host.GetMOTD = func() (string, error) {
|
||||
|
@ -210,24 +196,32 @@ func main() {
|
|||
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
|
||||
}
|
||||
|
||||
func fromFile(path string, handler func(line []byte) error) error {
|
||||
func loaderFromFile(path string, logger *golog.Logger) sshchat.KeyLoader {
|
||||
if path == "" {
|
||||
// Skip
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
err := handler(scanner.Bytes())
|
||||
return func() ([]ssh.PublicKey, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var keys []ssh.PublicKey
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes())
|
||||
if err != nil {
|
||||
if err.Error() == "ssh: no key found" {
|
||||
continue // Skip line
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if keys == nil {
|
||||
logger.Warning("file", path, "contained no keys")
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
201
host.go
201
host.go
|
@ -9,11 +9,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/shazow/rateio"
|
||||
"github.com/shazow/ssh-chat/chat"
|
||||
"github.com/shazow/ssh-chat/chat/message"
|
||||
"github.com/shazow/ssh-chat/internal/humantime"
|
||||
"github.com/shazow/ssh-chat/internal/sanitize"
|
||||
"github.com/shazow/ssh-chat/set"
|
||||
"github.com/shazow/ssh-chat/sshd"
|
||||
)
|
||||
|
||||
|
@ -695,4 +698,202 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
|||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
forConnectedUsers := func(cmd func(*chat.Member, ssh.PublicKey) error) error {
|
||||
return h.Members.Each(func(key string, item set.Item) error {
|
||||
v := item.Value()
|
||||
if v == nil { // expired between Each and here
|
||||
return nil
|
||||
}
|
||||
user := v.(*chat.Member)
|
||||
pk := user.Identifier.(*Identity).PublicKey()
|
||||
return cmd(user, pk)
|
||||
})
|
||||
}
|
||||
|
||||
forPubkeyUser := func(args []string, cmd func(ssh.PublicKey)) (errors []string) {
|
||||
invalidUsers := []string{}
|
||||
invalidKeys := []string{}
|
||||
noKeyUsers := []string{}
|
||||
var keyType string
|
||||
for _, v := range args {
|
||||
switch {
|
||||
case keyType != "":
|
||||
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + v))
|
||||
if err == nil {
|
||||
cmd(pk)
|
||||
} else {
|
||||
invalidKeys = append(invalidKeys, keyType+" "+v)
|
||||
}
|
||||
keyType = ""
|
||||
case strings.HasPrefix(v, "ssh-"):
|
||||
keyType = v
|
||||
default:
|
||||
user, ok := h.GetUser(v)
|
||||
if ok {
|
||||
pk := user.Identifier.(*Identity).PublicKey()
|
||||
if pk == nil {
|
||||
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
|
||||
} else {
|
||||
cmd(pk)
|
||||
}
|
||||
} else {
|
||||
invalidUsers = append(invalidUsers, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(noKeyUsers) != 0 {
|
||||
errors = append(errors, fmt.Sprintf("users without a public key: %v", noKeyUsers))
|
||||
}
|
||||
if len(invalidUsers) != 0 {
|
||||
errors = append(errors, fmt.Sprintf("invalid users: %v", invalidUsers))
|
||||
}
|
||||
if len(invalidKeys) != 0 {
|
||||
errors = append(errors, fmt.Sprintf("invalid keys: %v", invalidKeys))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
allowlistHelptext := []string{
|
||||
"Usage: /allowlist help | on | off | add {PUBKEY|USER}... | remove {PUBKEY|USER}... | import [AGE] | reload {keep|flush} | reverify | status",
|
||||
"help: this help message",
|
||||
"on, off: set allowlist mode (applies to new connections)",
|
||||
"add, remove: add or remove keys from the allowlist",
|
||||
"import: add all keys of users connected since AGE (default 0) ago to the allowlist",
|
||||
"reload: re-read the allowlist file and keep or discard entries in the current allowlist but not in the file",
|
||||
"reverify: kick all users not in the allowlist if allowlisting is enabled",
|
||||
"status: show status information",
|
||||
}
|
||||
|
||||
allowlistImport := func(args []string) (msgs []string, err error) {
|
||||
var since time.Duration
|
||||
if len(args) > 0 {
|
||||
since, err = time.ParseDuration(args[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
cutoff := time.Now().Add(-since)
|
||||
noKeyUsers := []string{}
|
||||
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||
if user.Joined().Before(cutoff) {
|
||||
if pk == nil {
|
||||
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
|
||||
} else {
|
||||
h.auth.Allowlist(pk, 0)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if len(noKeyUsers) != 0 {
|
||||
msgs = []string{fmt.Sprintf("users without a public key: %v", noKeyUsers)}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
allowlistReload := func(args []string) error {
|
||||
if !(len(args) > 0 && (args[0] == "keep" || args[0] == "flush")) {
|
||||
return errors.New("must specify whether to keep or flush current entries")
|
||||
}
|
||||
if args[0] == "flush" {
|
||||
h.auth.allowlist.Clear()
|
||||
}
|
||||
return h.auth.ReloadAllowlist()
|
||||
}
|
||||
|
||||
allowlistReverify := func(room *chat.Room) []string {
|
||||
if !h.auth.AllowlistMode() {
|
||||
return []string{"allowlist is disabled, so nobody will be kicked"}
|
||||
}
|
||||
var kicked []string
|
||||
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||
if h.auth.CheckPublicKey(pk) != nil && !user.IsOp { // we do this check here as well for ops without keys
|
||||
kicked = append(kicked, user.Name())
|
||||
user.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if kicked != nil {
|
||||
room.Send(message.NewAnnounceMsg("Kicked during pubkey reverification: " + strings.Join(kicked, ", ")))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
allowlistStatus := func() (msgs []string) {
|
||||
if h.auth.AllowlistMode() {
|
||||
msgs = []string{"allowlist enabled"}
|
||||
} else {
|
||||
msgs = []string{"allowlist disabled"}
|
||||
}
|
||||
allowlistedUsers := []string{}
|
||||
allowlistedKeys := []string{}
|
||||
h.auth.allowlist.Each(func(key string, item set.Item) error {
|
||||
keyFP := item.Key()
|
||||
if forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||
if pk != nil && sshd.Fingerprint(pk) == keyFP {
|
||||
allowlistedUsers = append(allowlistedUsers, user.Name())
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
}) == nil {
|
||||
// if we land here, the key matches no users
|
||||
allowlistedKeys = append(allowlistedKeys, keyFP)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if len(allowlistedUsers) != 0 {
|
||||
msgs = append(msgs, "Connected users on the allowlist: "+strings.Join(allowlistedUsers, ", "))
|
||||
}
|
||||
if len(allowlistedKeys) != 0 {
|
||||
msgs = append(msgs, "Keys on the allowlist without connected user: "+strings.Join(allowlistedKeys, ", "))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/allowlist",
|
||||
PrefixHelp: "COMMAND [ARGS...]",
|
||||
Help: "Modify the allowlist or allowlist state. See /allowlist help for subcommands",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) (err error) {
|
||||
if !room.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
|
||||
args := msg.Args()
|
||||
if len(args) == 0 {
|
||||
args = []string{"help"}
|
||||
}
|
||||
|
||||
// send exactly one message to preserve order
|
||||
var replyLines []string
|
||||
|
||||
switch args[0] {
|
||||
case "help":
|
||||
replyLines = allowlistHelptext
|
||||
case "on":
|
||||
h.auth.SetAllowlistMode(true)
|
||||
case "off":
|
||||
h.auth.SetAllowlistMode(false)
|
||||
case "add":
|
||||
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 0) })
|
||||
case "remove":
|
||||
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 1) })
|
||||
case "import":
|
||||
replyLines, err = allowlistImport(args[1:])
|
||||
case "reload":
|
||||
err = allowlistReload(args[1:])
|
||||
case "reverify":
|
||||
replyLines = allowlistReverify(room)
|
||||
case "status":
|
||||
replyLines = allowlistStatus()
|
||||
default:
|
||||
err = errors.New("invalid subcommand: " + args[0])
|
||||
}
|
||||
if err == nil && replyLines != nil {
|
||||
room.Send(message.NewSystemMsg(strings.Join(replyLines, "\r\n"), msg.From()))
|
||||
}
|
||||
return
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
323
host_test.go
323
host_test.go
|
@ -2,8 +2,6 @@ package sshchat
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -25,9 +23,15 @@ func stripPrompt(s string) string {
|
|||
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
|
||||
return s[endPos+4:]
|
||||
}
|
||||
if endPos := strings.Index(s, "\x1b[K-> "); endPos > 0 {
|
||||
return s[endPos+6:]
|
||||
}
|
||||
if endPos := strings.Index(s, "] "); endPos > 0 {
|
||||
return s[endPos+2:]
|
||||
}
|
||||
if strings.HasPrefix(s, "-> ") {
|
||||
return s[3:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -44,6 +48,14 @@ func TestStripPrompt(t *testing.T) {
|
|||
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
|
||||
Want: " * Guest1 joined. (Connected: 2)\r",
|
||||
},
|
||||
{
|
||||
Input: "[foo] \x1b[6D\x1b[K-> From your friendly system.\r",
|
||||
Want: "From your friendly system.\r",
|
||||
},
|
||||
{
|
||||
Input: "-> Err: must be op.\r",
|
||||
Want: "Err: must be op.\r",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
|
@ -77,20 +89,29 @@ func TestHostGetPrompt(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHostNameCollision(t *testing.T) {
|
||||
key, err := sshd.NewRandomSigner(512)
|
||||
func getHost(t *testing.T, auth *Auth) (*sshd.SSHListener, *Host) {
|
||||
key, err := sshd.NewRandomSigner(1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := sshd.MakeNoAuth()
|
||||
var config *ssh.ServerConfig
|
||||
if auth == nil {
|
||||
config = sshd.MakeNoAuth()
|
||||
} else {
|
||||
config = sshd.MakeAuth(auth)
|
||||
}
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return s, NewHost(s, auth)
|
||||
}
|
||||
|
||||
func TestHostNameCollision(t *testing.T) {
|
||||
s, host := getHost(t, nil)
|
||||
defer s.Close()
|
||||
host := NewHost(s, nil)
|
||||
|
||||
newUsers := make(chan *message.User)
|
||||
host.OnUserJoined = func(u *message.User) {
|
||||
|
@ -103,51 +124,23 @@ func TestHostNameCollision(t *testing.T) {
|
|||
// First client
|
||||
g.Go(func() error {
|
||||
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
scanner.Scan()
|
||||
actual := stripPrompt(scanner.Text())
|
||||
expected := " * foo joined. (Connected: 1)\r"
|
||||
if actual != expected {
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
// second client
|
||||
name := (<-newUsers).Name()
|
||||
if name != "Guest1" {
|
||||
t.Errorf("Second client did not get Guest1 name: %q", name)
|
||||
}
|
||||
|
||||
// wait for the second client
|
||||
<-newUsers
|
||||
|
||||
scanner.Scan()
|
||||
actual = scanner.Text()
|
||||
// This check has to happen second because prompt doesn't always
|
||||
// get set before the first message.
|
||||
if !strings.HasPrefix(actual, "[foo] ") {
|
||||
t.Errorf("First client failed to get 'foo' name: %q", actual)
|
||||
}
|
||||
actual = stripPrompt(actual)
|
||||
expected = " * Guest1 joined. (Connected: 2)\r"
|
||||
if actual != expected {
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
// Second client
|
||||
g.Go(func() error {
|
||||
// wait for the first client
|
||||
<-newUsers
|
||||
// first client
|
||||
name := (<-newUsers).Name()
|
||||
if name != "foo" {
|
||||
t.Errorf("First client did not get foo name: %q", name)
|
||||
}
|
||||
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Consume the initial buffer
|
||||
scanner.Scan()
|
||||
scanner.Scan()
|
||||
scanner.Scan()
|
||||
|
||||
actual := scanner.Text()
|
||||
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
@ -157,62 +150,193 @@ func TestHostNameCollision(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHostWhitelist(t *testing.T) {
|
||||
key, err := sshd.NewRandomSigner(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
func TestHostAllowlist(t *testing.T) {
|
||||
auth := NewAuth()
|
||||
config := sshd.MakeAuth(auth)
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s, host := getHost(t, auth)
|
||||
defer s.Close()
|
||||
host := NewHost(s, auth)
|
||||
go host.Serve()
|
||||
|
||||
target := s.Addr().String()
|
||||
|
||||
clientPrivateKey, err := sshd.NewRandomSigner(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clientKey := clientPrivateKey.PublicKey()
|
||||
loadCount := -1
|
||||
loader := func() ([]ssh.PublicKey, error) {
|
||||
loadCount++
|
||||
return [][]ssh.PublicKey{
|
||||
{},
|
||||
{clientKey},
|
||||
}[loadCount], nil
|
||||
}
|
||||
auth.LoadAllowlist(loader)
|
||||
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||
auth.Whitelist(clientpubkey, 0)
|
||||
|
||||
auth.SetAllowlistMode(true)
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err == nil {
|
||||
t.Error("Failed to block unwhitelisted connection.")
|
||||
t.Error(err)
|
||||
}
|
||||
err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
auth.ReloadAllowlist()
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err == nil {
|
||||
t.Error("Failed to block unallowlisted connection.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostKick(t *testing.T) {
|
||||
key, err := sshd.NewRandomSigner(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
auth := NewAuth()
|
||||
config := sshd.MakeAuth(auth)
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
func TestHostAllowlistCommand(t *testing.T) {
|
||||
s, host := getHost(t, NewAuth())
|
||||
defer s.Close()
|
||||
go host.Serve()
|
||||
|
||||
users := make(chan *message.User)
|
||||
host.OnUserJoined = func(u *message.User) {
|
||||
users <- u
|
||||
}
|
||||
|
||||
kickSignal := make(chan struct{})
|
||||
clientKey, err := sshd.NewRandomSigner(1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clientKeyFP := sshd.Fingerprint(clientKey.PublicKey())
|
||||
go sshd.ConnectShellWithKey(s.Addr().String(), "bar", clientKey, func(r io.Reader, w io.WriteCloser) error {
|
||||
<-kickSignal
|
||||
n, err := w.Write([]byte("alive and well"))
|
||||
if n != 0 || err == nil {
|
||||
t.Error("could write after being kicked")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
<-users
|
||||
<-users
|
||||
m, ok := host.MemberByID("foo")
|
||||
if !ok {
|
||||
t.Fatal("can't get member foo")
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Scan() // Joined
|
||||
scanner.Scan()
|
||||
|
||||
assertLineEq := func(expected ...string) {
|
||||
if !scanner.Scan() {
|
||||
t.Error("no line available")
|
||||
}
|
||||
actual := stripPrompt(scanner.Text())
|
||||
for _, exp := range expected {
|
||||
if exp == actual {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected %#v, got %q", expected, actual)
|
||||
}
|
||||
sendCmd := func(cmd string, formatting ...interface{}) {
|
||||
host.HandleMsg(message.ParseInput(fmt.Sprintf(cmd, formatting...), m.User))
|
||||
}
|
||||
|
||||
sendCmd("/allowlist")
|
||||
assertLineEq("Err: must be op\r")
|
||||
m.IsOp = true
|
||||
sendCmd("/allowlist")
|
||||
for _, expected := range [...]string{"Usage", "help", "on, off", "add, remove", "import", "reload", "reverify", "status"} {
|
||||
if !scanner.Scan() {
|
||||
t.Error("no line available")
|
||||
}
|
||||
if actual := stripPrompt(scanner.Text()); !strings.HasPrefix(actual, expected) {
|
||||
t.Errorf("Unexpected help message order: have %q, want prefix %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
sendCmd("/allowlist on")
|
||||
if !host.auth.AllowlistMode() {
|
||||
t.Error("allowlist not enabled after /allowlist on")
|
||||
}
|
||||
sendCmd("/allowlist off")
|
||||
if host.auth.AllowlistMode() {
|
||||
t.Error("allowlist not disabled after /allowlist off")
|
||||
}
|
||||
|
||||
testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPUiNw0nQku4pcUCbZcJlIEAIf5bXJYTy/DKI1vh5b+P"
|
||||
testKeyFP := "SHA256:GJNSl9NUcOS2pZYALn0C5Qgfh5deT+R+FfqNIUvpM9s="
|
||||
|
||||
if host.auth.allowlist.Len() != 0 {
|
||||
t.Error("allowlist not empty before adding anyone")
|
||||
}
|
||||
sendCmd("/allowlist add ssh-invalid blah ssh-rsa wrongAsWell invalid foo bar %s", testKey)
|
||||
assertLineEq("users without a public key: [foo]\r")
|
||||
assertLineEq("invalid users: [invalid]\r")
|
||||
assertLineEq("invalid keys: [ssh-invalid blah ssh-rsa wrongAsWell]\r")
|
||||
if !host.auth.allowlist.In(testKeyFP) || !host.auth.allowlist.In(clientKeyFP) {
|
||||
t.Error("failed to add keys to allowlist")
|
||||
}
|
||||
sendCmd("/allowlist remove invalid bar")
|
||||
assertLineEq("invalid users: [invalid]\r")
|
||||
if host.auth.allowlist.In(clientKeyFP) {
|
||||
t.Error("failed to remove key from allowlist")
|
||||
}
|
||||
if !host.auth.allowlist.In(testKeyFP) {
|
||||
t.Error("removed wrong key")
|
||||
}
|
||||
|
||||
sendCmd("/allowlist import 5h")
|
||||
if host.auth.allowlist.In(clientKeyFP) {
|
||||
t.Error("imporrted key not seen long enough")
|
||||
}
|
||||
sendCmd("/allowlist import")
|
||||
assertLineEq("users without a public key: [foo]\r")
|
||||
if !host.auth.allowlist.In(clientKeyFP) {
|
||||
t.Error("failed to import key")
|
||||
}
|
||||
|
||||
sendCmd("/allowlist reload keep")
|
||||
if !host.auth.allowlist.In(testKeyFP) {
|
||||
t.Error("cleared allowlist to be kept")
|
||||
}
|
||||
sendCmd("/allowlist reload flush")
|
||||
if host.auth.allowlist.In(testKeyFP) {
|
||||
t.Error("kept allowlist to be cleared")
|
||||
}
|
||||
sendCmd("/allowlist reload thisIsWrong")
|
||||
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
||||
sendCmd("/allowlist reload")
|
||||
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
||||
|
||||
sendCmd("/allowlist reverify")
|
||||
assertLineEq("allowlist is disabled, so nobody will be kicked\r")
|
||||
sendCmd("/allowlist on")
|
||||
sendCmd("/allowlist reverify")
|
||||
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
||||
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
||||
kickSignal <- struct{}{}
|
||||
|
||||
sendCmd("/allowlist add " + testKey)
|
||||
sendCmd("/allowlist status")
|
||||
assertLineEq("allowlist enabled\r")
|
||||
assertLineEq(fmt.Sprintf("Keys on the allowlist without connected user: %s\r", testKeyFP))
|
||||
|
||||
sendCmd("/allowlist invalidSubcommand")
|
||||
assertLineEq("Err: invalid subcommand: invalidSubcommand\r")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostKick(t *testing.T) {
|
||||
s, host := getHost(t, NewAuth())
|
||||
defer s.Close()
|
||||
addr := s.Addr().String()
|
||||
host := NewHost(s, nil)
|
||||
go host.Serve()
|
||||
|
||||
g := errgroup.Group{}
|
||||
|
@ -221,7 +345,7 @@ func TestHostKick(t *testing.T) {
|
|||
|
||||
g.Go(func() error {
|
||||
// First client
|
||||
return sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
|
@ -258,7 +382,7 @@ func TestHostKick(t *testing.T) {
|
|||
|
||||
g.Go(func() error {
|
||||
// Second client
|
||||
return sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
||||
return sshd.ConnectShell(s.Addr().String(), "bar", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
<-connected
|
||||
scanner.Scan()
|
||||
|
@ -296,12 +420,9 @@ func TestTimestampEnvConfig(t *testing.T) {
|
|||
{"datetime +8h", strptr("2006-01-02 15:04:05")},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
u, err := connectUserWithConfig("dingus", map[string]string{
|
||||
u := connectUserWithConfig(t, "dingus", map[string]string{
|
||||
"SSHCHAT_TIMESTAMP": tc.input,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
userConfig := u.Config()
|
||||
if userConfig.Timeformat != nil && tc.timeformat != nil {
|
||||
if *userConfig.Timeformat != *tc.timeformat {
|
||||
|
@ -315,20 +436,9 @@ func strptr(s string) *string {
|
|||
return &s
|
||||
}
|
||||
|
||||
func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) {
|
||||
key, err := sshd.NewRandomSigner(512)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create signer: %w", err)
|
||||
}
|
||||
config := sshd.MakeNoAuth()
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a test server: %w", err)
|
||||
}
|
||||
func connectUserWithConfig(t *testing.T, name string, envConfig map[string]string) *message.User {
|
||||
s, host := getHost(t, nil)
|
||||
defer s.Close()
|
||||
host := NewHost(s, nil)
|
||||
|
||||
newUsers := make(chan *message.User)
|
||||
host.OnUserJoined = func(u *message.User) {
|
||||
|
@ -339,13 +449,13 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U
|
|||
clientConfig := sshd.NewClientConfig(name)
|
||||
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err)
|
||||
t.Fatal("unable to connect to test ssh-chat server:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open session: %w", err)
|
||||
t.Fatal("unable to open session:", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
|
@ -355,13 +465,14 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U
|
|||
|
||||
err = session.Shell()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open shell: %w", err)
|
||||
t.Fatal("unable to open shell:", err)
|
||||
}
|
||||
|
||||
for u := range newUsers {
|
||||
if u.Name() == name {
|
||||
return u, nil
|
||||
return u
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("user %s not found in the host", name)
|
||||
t.Fatalf("user %s not found in the host", name)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -30,9 +30,24 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func NewClientConfigWithKey(name string, key ssh.Signer) *ssh.ClientConfig {
|
||||
return &ssh.ClientConfig{
|
||||
User: name,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||
config := NewClientConfig(name)
|
||||
return connectShell(host, NewClientConfig(name), handler)
|
||||
}
|
||||
|
||||
func ConnectShellWithKey(host string, name string, key ssh.Signer, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||
return connectShell(host, NewClientConfigWithKey(name, key), handler)
|
||||
}
|
||||
|
||||
func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||
conn, err := ssh.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -25,7 +25,7 @@ func TestServerInit(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServeTerminals(t *testing.T) {
|
||||
signer, err := NewRandomSigner(512)
|
||||
signer, err := NewRandomSigner(1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue