diff --git a/Makefile b/Makefile index 48efb15..a2b03db 100644 --- a/Makefile +++ b/Makefile @@ -28,5 +28,5 @@ debug: $(BINARY) $(KEY) ./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv test: - go test . - golint + go test ./... + golint ./... diff --git a/README.md b/README.md index 5cf42ab..92895ce 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,6 @@ The server's RSA key fingerprint is `e5:d5:d1:75:90:38:42:f6:c7:03:d7:d0:56:7d:6 ## Compiling / Developing -**If you're going to be diving into the code, please use the [refactor branch](https://github.com/shazow/ssh-chat/tree/refactor) or see [issue #87](https://github.com/shazow/ssh-chat/pull/87).** It's not quite at feature parity yet, but the code is way nicer. The master branch is what's running on chat.shazow.net, but that will change soon. - You can compile ssh-chat by using `make build`. The resulting binary is portable and can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile. @@ -40,7 +38,7 @@ Usage: Application Options: -v, --verbose Show verbose logging. -i, --identity= Private key to identify server with. (~/.ssh/id_rsa) - --bind= Host and port to listen on. (0.0.0.0:22) + --bind= Host and port to listen on. (0.0.0.0:2022) --admin= Fingerprint of pubkey to mark as admin. --whitelist= Optional file of pubkey fingerprints that are allowed to connect --motd= Message of the Day file (optional) @@ -54,7 +52,7 @@ After doing `go get github.com/shazow/ssh-chat` on this repo, you should be able to run a command like: ``` -$ ssh-chat --verbose --bind ":2022" --identity ~/.ssh/id_dsa +$ ssh-chat --verbose --bind ":22" --identity ~/.ssh/id_dsa ``` To bind on port 22, you'll need to make sure it's free (move any other ssh diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..8c86b26 --- /dev/null +++ b/auth.go @@ -0,0 +1,150 @@ +package main + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/shazow/ssh-chat/sshd" + "golang.org/x/crypto/ssh" +) + +// The error returned a key is checked that is not whitelisted, with whitelisting required. +var ErrNotWhitelisted = errors.New("not whitelisted") + +// The error returned a key is checked that is banned. +var ErrBanned = errors.New("banned") + +// NewAuthKey returns string from an ssh.PublicKey. +func NewAuthKey(key ssh.PublicKey) string { + if key == nil { + return "" + } + // FIXME: Is there a way to index pubkeys without marshal'ing them into strings? + return sshd.Fingerprint(key) +} + +// NewAuthAddr returns a string from a net.Addr +func NewAuthAddr(addr net.Addr) string { + if addr == nil { + return "" + } + host, _, _ := net.SplitHostPort(addr.String()) + return host +} + +// Auth stores fingerprint lookups +// TODO: Add timed auth by using a time.Time instead of struct{} for values. +type Auth struct { + sync.RWMutex + bannedAddr *Set + banned *Set + whitelist *Set + ops *Set +} + +// NewAuth creates a new default Auth. +func NewAuth() *Auth { + return &Auth{ + bannedAddr: NewSet(), + banned: NewSet(), + whitelist: NewSet(), + ops: NewSet(), + } +} + +// AllowAnonymous determines if anonymous users are permitted. +func (a Auth) AllowAnonymous() bool { + return a.whitelist.Len() == 0 +} + +// Check determines if a pubkey fingerprint is permitted. +func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) { + authkey := NewAuthKey(key) + + if a.whitelist.Len() != 0 { + // Only check whitelist if there is something in it, otherwise it's disabled. + whitelisted := a.whitelist.In(authkey) + if !whitelisted { + return false, ErrNotWhitelisted + } + return true, nil + } + + banned := a.banned.In(authkey) + if !banned { + banned = a.bannedAddr.In(NewAuthAddr(addr)) + } + if banned { + return false, ErrBanned + } + + return true, nil +} + +// Op sets a public key as a known operator. +func (a *Auth) Op(key ssh.PublicKey, d time.Duration) { + if key == nil { + return + } + authkey := NewAuthKey(key) + if d != 0 { + a.ops.AddExpiring(authkey, d) + } else { + a.ops.Add(authkey) + } + logger.Debugf("Added to ops: %s (for %s)", authkey, d) +} + +// IsOp checks if a public key is an op. +func (a *Auth) IsOp(key ssh.PublicKey) bool { + if key == nil { + return false + } + authkey := NewAuthKey(key) + return a.ops.In(authkey) +} + +// Whitelist will set a public key as a whitelisted user. +func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) { + if key == nil { + return + } + authkey := NewAuthKey(key) + if d != 0 { + a.whitelist.AddExpiring(authkey, d) + } else { + a.whitelist.Add(authkey) + } + logger.Debugf("Added to whitelist: %s (for %s)", authkey, d) +} + +// Ban will set a public key as banned. +func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) { + if key == nil { + return + } + a.BanFingerprint(NewAuthKey(key), d) +} + +// BanFingerprint will set a public key fingerprint as banned. +func (a *Auth) BanFingerprint(authkey string, d time.Duration) { + if d != 0 { + a.banned.AddExpiring(authkey, d) + } else { + a.banned.Add(authkey) + } + logger.Debugf("Added to banned: %s (for %s)", authkey, d) +} + +// Ban will set an IP address as banned. +func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { + key := NewAuthAddr(addr) + if d != 0 { + a.bannedAddr.AddExpiring(key, d) + } else { + a.bannedAddr.Add(key) + } + logger.Debugf("Added to bannedAddr: %s (for %s)", key, d) +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..981a1d6 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "testing" + + "golang.org/x/crypto/ssh" +) + +func NewRandomPublicKey(bits int) (ssh.PublicKey, error) { + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + + return ssh.NewPublicKey(key.Public()) +} + +func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) { + return ssh.ParsePublicKey(key.Marshal()) +} + +func TestAuthWhitelist(t *testing.T) { + key, err := NewRandomPublicKey(512) + if err != nil { + t.Fatal(err) + } + + auth := NewAuth() + ok, err := auth.Check(nil, key) + if !ok || err != nil { + t.Error("Failed to permit in default state:", err) + } + + auth.Whitelist(key, 0) + + keyClone, err := ClonePublicKey(key) + if err != nil { + t.Fatal(err) + } + + if string(keyClone.Marshal()) != string(key.Marshal()) { + t.Error("Clone key does not match.") + } + + ok, err = auth.Check(nil, keyClone) + if !ok || err != nil { + t.Error("Failed to permit whitelisted:", err) + } + + key2, err := NewRandomPublicKey(512) + if err != nil { + t.Fatal(err) + } + + ok, err = auth.Check(nil, key2) + if ok || err == nil { + t.Error("Failed to restrict not whitelisted:", err) + } + +} diff --git a/chat/command.go b/chat/command.go new file mode 100644 index 0000000..eb085c9 --- /dev/null +++ b/chat/command.go @@ -0,0 +1,241 @@ +package chat + +// FIXME: Would be sweet if we could piggyback on a cli parser or something. + +import ( + "errors" + "fmt" + "strings" +) + +// The error returned when an invalid command is issued. +var ErrInvalidCommand = errors.New("invalid command") + +// The error returned when a command is given without an owner. +var ErrNoOwner = errors.New("command without owner") + +// The error returned when a command is performed without the necessary number +// of arguments. +var ErrMissingArg = errors.New("missing argument") + +// The error returned when a command is added without a prefix. +var ErrMissingPrefix = errors.New("command missing prefix") + +// Command is a definition of a handler for a command. +type Command struct { + // The command's key, such as /foo + Prefix string + // Extra help regarding arguments + PrefixHelp string + // If omitted, command is hidden from /help + Help string + Handler func(*Room, CommandMsg) error + // Command requires Op permissions + Op bool +} + +// Commands is a registry of available commands. +type Commands map[string]*Command + +// Add will register a command. If help string is empty, it will be hidden from +// Help(). +func (c Commands) Add(cmd Command) error { + if cmd.Prefix == "" { + return ErrMissingPrefix + } + + c[cmd.Prefix] = &cmd + return nil +} + +// Alias will add another command for the same handler, won't get added to help. +func (c Commands) Alias(command string, alias string) error { + cmd, ok := c[command] + if !ok { + return ErrInvalidCommand + } + c[alias] = cmd + return nil +} + +// Run executes a command message. +func (c Commands) Run(room *Room, msg CommandMsg) error { + if msg.From == nil { + return ErrNoOwner + } + + cmd, ok := c[msg.Command()] + if !ok { + return ErrInvalidCommand + } + + return cmd.Handler(room, msg) +} + +// Help will return collated help text as one string. +func (c Commands) Help(showOp bool) string { + // Filter by op + op := []*Command{} + normal := []*Command{} + for _, cmd := range c { + if cmd.Op { + op = append(op, cmd) + } else { + normal = append(normal, cmd) + } + } + help := "Available commands:" + Newline + NewCommandsHelp(normal).String() + if showOp { + help += Newline + "-> Operator commands:" + Newline + NewCommandsHelp(op).String() + } + return help +} + +var defaultCommands *Commands + +func init() { + defaultCommands = &Commands{} + InitCommands(defaultCommands) +} + +// InitCommands injects default commands into a Commands registry. +func InitCommands(c *Commands) { + c.Add(Command{ + Prefix: "/help", + Handler: func(room *Room, msg CommandMsg) error { + op := room.IsOp(msg.From()) + room.Send(NewSystemMsg(room.commands.Help(op), msg.From())) + return nil + }, + }) + + c.Add(Command{ + Prefix: "/me", + Handler: func(room *Room, msg CommandMsg) error { + me := strings.TrimLeft(msg.body, "/me") + if me == "" { + me = "is at a loss for words." + } else { + me = me[1:] + } + + room.Send(NewEmoteMsg(me, msg.From())) + return nil + }, + }) + + c.Add(Command{ + Prefix: "/exit", + Help: "Exit the chat.", + Handler: func(room *Room, msg CommandMsg) error { + msg.From().Close() + return nil + }, + }) + c.Alias("/exit", "/quit") + + c.Add(Command{ + Prefix: "/nick", + PrefixHelp: "NAME", + Help: "Rename yourself.", + Handler: func(room *Room, msg CommandMsg) error { + args := msg.Args() + if len(args) != 1 { + return ErrMissingArg + } + u := msg.From() + + member, ok := room.MemberById(u.Id()) + if !ok { + return errors.New("failed to find member") + } + + oldId := member.Id() + member.SetId(SanitizeName(args[0])) + err := room.Rename(oldId, member) + if err != nil { + member.SetId(oldId) + return err + } + return nil + }, + }) + + c.Add(Command{ + Prefix: "/names", + Help: "List users who are connected.", + Handler: func(room *Room, msg CommandMsg) error { + // TODO: colorize + names := room.NamesPrefix("") + body := fmt.Sprintf("%d connected: %s", len(names), strings.Join(names, ", ")) + room.Send(NewSystemMsg(body, msg.From())) + return nil + }, + }) + c.Alias("/names", "/list") + + c.Add(Command{ + Prefix: "/theme", + PrefixHelp: "[mono|colors]", + Help: "Set your color theme.", + Handler: func(room *Room, msg CommandMsg) error { + user := msg.From() + args := msg.Args() + if len(args) == 0 { + theme := "plain" + if user.Config.Theme != nil { + theme = user.Config.Theme.Id() + } + body := fmt.Sprintf("Current theme: %s", theme) + room.Send(NewSystemMsg(body, user)) + return nil + } + + id := args[0] + for _, t := range Themes { + if t.Id() == id { + user.Config.Theme = &t + body := fmt.Sprintf("Set theme: %s", id) + room.Send(NewSystemMsg(body, user)) + return nil + } + } + return errors.New("theme not found") + }, + }) + + c.Add(Command{ + Prefix: "/quiet", + Help: "Silence room announcements.", + Handler: func(room *Room, msg CommandMsg) error { + u := msg.From() + u.ToggleQuietMode() + + var body string + if u.Config.Quiet { + body = "Quiet mode is toggled ON" + } else { + body = "Quiet mode is toggled OFF" + } + room.Send(NewSystemMsg(body, u)) + return nil + }, + }) + + c.Add(Command{ + Prefix: "/slap", + PrefixHelp: "NAME", + Handler: func(room *Room, msg CommandMsg) error { + var me string + args := msg.Args() + if len(args) == 0 { + me = "slaps themselves around a bit with a large trout." + } else { + me = fmt.Sprintf("slaps %s around a bit with a large trout.", strings.Join(args, " ")) + } + + room.Send(NewEmoteMsg(me, msg.From())) + return nil + }, + }) +} diff --git a/chat/doc.go b/chat/doc.go new file mode 100644 index 0000000..22760e7 --- /dev/null +++ b/chat/doc.go @@ -0,0 +1,13 @@ +/* +`chat` package is a server-agnostic implementation of a chat interface, built +with the intention of using with the intention of using as the backend for +ssh-chat. + +This package should not know anything about sockets. It should expose io-style +interfaces and rooms for communicating with any method of transnport. + +TODO: Add usage examples here. + +*/ + +package chat diff --git a/chat/help.go b/chat/help.go new file mode 100644 index 0000000..0ab62c6 --- /dev/null +++ b/chat/help.go @@ -0,0 +1,58 @@ +package chat + +import ( + "fmt" + "sort" + "strings" +) + +type helpItem struct { + Prefix string + Text string +} + +type help struct { + items []helpItem + prefixWidth int +} + +// NewCommandsHelp creates a help container from a commands container. +func NewCommandsHelp(c []*Command) fmt.Stringer { + lookup := map[string]struct{}{} + h := help{ + items: []helpItem{}, + } + for _, cmd := range c { + if cmd.Help == "" { + // Skip hidden commands. + continue + } + _, exists := lookup[cmd.Prefix] + if exists { + // Duplicate (alias) + continue + } + lookup[cmd.Prefix] = struct{}{} + prefix := fmt.Sprintf("%s %s", cmd.Prefix, cmd.PrefixHelp) + h.add(helpItem{prefix, cmd.Help}) + } + return &h +} + +func (h *help) add(item helpItem) { + h.items = append(h.items, item) + if len(item.Prefix) > h.prefixWidth { + h.prefixWidth = len(item.Prefix) + } +} + +func (h help) String() string { + r := []string{} + format := fmt.Sprintf("%%-%ds - %%s", h.prefixWidth) + for _, item := range h.items { + r = append(r, fmt.Sprintf(format, item.Prefix, item.Text)) + } + + sort.Strings(r) + return strings.Join(r, Newline) +} diff --git a/history.go b/chat/history.go similarity index 54% rename from history.go rename to chat/history.go index 74ef513..6b999ca 100644 --- a/history.go +++ b/chat/history.go @@ -1,27 +1,33 @@ -// TODO: Split this out into its own module, it's kinda neat. -package main +package chat -import "sync" +import ( + "fmt" + "io" + "sync" +) + +const timestampFmt = "2006-01-02 15:04:05" // History contains the history entries type History struct { - entries []string + sync.RWMutex + entries []Message head int size int - lock sync.Mutex + out io.Writer } // NewHistory constructs a new history of the given size func NewHistory(size int) *History { return &History{ - entries: make([]string, size), + entries: make([]Message, size), } } // Add adds the given entry to the entries in the history -func (h *History) Add(entry string) { - h.lock.Lock() - defer h.lock.Unlock() +func (h *History) Add(entry Message) { + h.Lock() + defer h.Unlock() max := cap(h.entries) h.head = (h.head + 1) % max @@ -29,6 +35,10 @@ func (h *History) Add(entry string) { if h.size < max { h.size++ } + + if h.out != nil { + fmt.Fprintf(h.out, "[%s] %s\n", entry.Timestamp().UTC().Format(timestampFmt), entry.String()) + } } // Len returns the number of entries in the history @@ -37,16 +47,16 @@ func (h *History) Len() int { } // Get the entry with the given number -func (h *History) Get(num int) []string { - h.lock.Lock() - defer h.lock.Unlock() +func (h *History) Get(num int) []Message { + h.RLock() + defer h.RUnlock() max := cap(h.entries) if num > h.size { num = h.size } - r := make([]string, num) + r := make([]Message, num) for i := 0; i < num; i++ { idx := (h.head - i) % max if idx < 0 { @@ -57,3 +67,10 @@ func (h *History) Get(num int) []string { return r } + +// SetOutput sets the output for logging added messages +func (h *History) SetOutput(w io.Writer) { + h.Lock() + h.out = w + h.Unlock() +} diff --git a/chat/history_test.go b/chat/history_test.go new file mode 100644 index 0000000..de767ec --- /dev/null +++ b/chat/history_test.go @@ -0,0 +1,62 @@ +package chat + +import "testing" + +func msgEqual(a []Message, b []Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].String() != b[i].String() { + return false + } + } + return true +} + +func TestHistory(t *testing.T) { + var r, expected []Message + var size int + + h := NewHistory(5) + + r = h.Get(10) + expected = []Message{} + if !msgEqual(r, expected) { + t.Errorf("Got: %v, Expected: %v", r, expected) + } + + h.Add(NewMsg("1")) + + if size = h.Len(); size != 1 { + t.Errorf("Wrong size: %v", size) + } + + r = h.Get(1) + expected = []Message{NewMsg("1")} + if !msgEqual(r, expected) { + t.Errorf("Got: %v, Expected: %v", r, expected) + } + + h.Add(NewMsg("2")) + h.Add(NewMsg("3")) + h.Add(NewMsg("4")) + h.Add(NewMsg("5")) + h.Add(NewMsg("6")) + + if size = h.Len(); size != 5 { + t.Errorf("Wrong size: %v", size) + } + + r = h.Get(2) + expected = []Message{NewMsg("5"), NewMsg("6")} + if !msgEqual(r, expected) { + t.Errorf("Got: %v, Expected: %v", r, expected) + } + + r = h.Get(10) + expected = []Message{NewMsg("2"), NewMsg("3"), NewMsg("4"), NewMsg("5"), NewMsg("6")} + if !msgEqual(r, expected) { + t.Errorf("Got: %v, Expected: %v", r, expected) + } +} diff --git a/chat/logger.go b/chat/logger.go new file mode 100644 index 0000000..93b2761 --- /dev/null +++ b/chat/logger.go @@ -0,0 +1,22 @@ +package chat + +import "io" +import stdlog "log" + +var logger *stdlog.Logger + +func SetLogger(w io.Writer) { + flags := stdlog.Flags() + prefix := "[chat] " + logger = stdlog.New(w, prefix, flags) +} + +type nullWriter struct{} + +func (nullWriter) Write(data []byte) (int, error) { + return len(data), nil +} + +func init() { + SetLogger(nullWriter{}) +} diff --git a/chat/message.go b/chat/message.go new file mode 100644 index 0000000..9a0b30a --- /dev/null +++ b/chat/message.go @@ -0,0 +1,257 @@ +package chat + +import ( + "fmt" + "strings" + "time" +) + +// Message is an interface for messages. +type Message interface { + Render(*Theme) string + String() string + Command() string + Timestamp() time.Time +} + +type MessageTo interface { + Message + To() *User +} + +type MessageFrom interface { + Message + From() *User +} + +func ParseInput(body string, from *User) Message { + m := NewPublicMsg(body, from) + cmd, isCmd := m.ParseCommand() + if isCmd { + return cmd + } + return m +} + +// Msg is a base type for other message types. +type Msg struct { + body string + timestamp time.Time + // TODO: themeCache *map[*Theme]string +} + +func NewMsg(body string) *Msg { + return &Msg{ + body: body, + timestamp: time.Now(), + } +} + +// Render message based on a theme. +func (m *Msg) Render(t *Theme) string { + // TODO: Render based on theme + // TODO: Cache based on theme + return m.String() +} + +func (m *Msg) String() string { + return m.body +} + +func (m *Msg) Command() string { + return "" +} + +func (m *Msg) Timestamp() time.Time { + return m.timestamp +} + +// PublicMsg is any message from a user sent to the room. +type PublicMsg struct { + Msg + from *User +} + +func NewPublicMsg(body string, from *User) *PublicMsg { + return &PublicMsg{ + Msg: Msg{ + body: body, + timestamp: time.Now(), + }, + from: from, + } +} + +func (m *PublicMsg) From() *User { + return m.from +} + +func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) { + // Check if the message is a command + if !strings.HasPrefix(m.body, "/") { + return nil, false + } + + // Parse + // TODO: Handle quoted fields properly + fields := strings.Fields(m.body) + command, args := fields[0], fields[1:] + msg := CommandMsg{ + PublicMsg: m, + command: command, + args: args, + } + return &msg, true +} + +func (m *PublicMsg) Render(t *Theme) string { + if t == nil { + return m.String() + } + + return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body) +} + +func (m *PublicMsg) RenderFor(cfg UserConfig) string { + if cfg.Highlight == nil || cfg.Theme == nil { + return m.Render(cfg.Theme) + } + + if !cfg.Highlight.MatchString(m.body) { + return m.Render(cfg.Theme) + } + + body := cfg.Highlight.ReplaceAllString(m.body, cfg.Theme.Highlight("${1}")) + if cfg.Bell { + body += Bel + } + return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body) +} + +func (m *PublicMsg) String() string { + return fmt.Sprintf("%s: %s", m.from.Name(), m.body) +} + +// EmoteMsg is a /me message sent to the room. It specifically does not +// extend PublicMsg because it doesn't implement MessageFrom to allow the +// sender to see the emote. +type EmoteMsg struct { + Msg + from *User +} + +func NewEmoteMsg(body string, from *User) *EmoteMsg { + return &EmoteMsg{ + Msg: Msg{ + body: body, + timestamp: time.Now(), + }, + from: from, + } +} + +func (m *EmoteMsg) Render(t *Theme) string { + return fmt.Sprintf("** %s %s", m.from.Name(), m.body) +} + +func (m *EmoteMsg) String() string { + return m.Render(nil) +} + +// PrivateMsg is a message sent to another user, not shown to anyone else. +type PrivateMsg struct { + PublicMsg + to *User +} + +func NewPrivateMsg(body string, from *User, to *User) *PrivateMsg { + return &PrivateMsg{ + PublicMsg: *NewPublicMsg(body, from), + to: to, + } +} + +func (m *PrivateMsg) To() *User { + return m.to +} + +func (m *PrivateMsg) Render(t *Theme) string { + return fmt.Sprintf("[PM from %s] %s", m.from.Name(), m.body) +} + +func (m *PrivateMsg) String() string { + return m.Render(nil) +} + +// SystemMsg is a response sent from the server directly to a user, not shown +// to anyone else. Usually in response to something, like /help. +type SystemMsg struct { + Msg + to *User +} + +func NewSystemMsg(body string, to *User) *SystemMsg { + return &SystemMsg{ + Msg: Msg{ + body: body, + timestamp: time.Now(), + }, + to: to, + } +} + +func (m *SystemMsg) Render(t *Theme) string { + if t == nil { + return m.String() + } + return t.ColorSys(m.String()) +} + +func (m *SystemMsg) String() string { + return fmt.Sprintf("-> %s", m.body) +} + +func (m *SystemMsg) To() *User { + return m.to +} + +// AnnounceMsg is a message sent from the server to everyone, like a join or +// leave event. +type AnnounceMsg struct { + Msg +} + +func NewAnnounceMsg(body string) *AnnounceMsg { + return &AnnounceMsg{ + Msg: Msg{ + body: body, + timestamp: time.Now(), + }, + } +} + +func (m *AnnounceMsg) Render(t *Theme) string { + if t == nil { + return m.String() + } + return t.ColorSys(m.String()) +} + +func (m *AnnounceMsg) String() string { + return fmt.Sprintf(" * %s", m.body) +} + +type CommandMsg struct { + *PublicMsg + command string + args []string + room *Room +} + +func (m *CommandMsg) Command() string { + return m.command +} + +func (m *CommandMsg) Args() []string { + return m.args +} diff --git a/chat/message_test.go b/chat/message_test.go new file mode 100644 index 0000000..bafe014 --- /dev/null +++ b/chat/message_test.go @@ -0,0 +1,52 @@ +package chat + +import "testing" + +type testId string + +func (i testId) Id() string { + return string(i) +} +func (i testId) SetId(s string) { + // no-op +} +func (i testId) Name() string { + return i.Id() +} + +func TestMessage(t *testing.T) { + var expected, actual string + + expected = " * foo" + actual = NewAnnounceMsg("foo").String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + u := NewUser(testId("foo")) + expected = "foo: hello" + actual = NewPublicMsg("hello", u).String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + expected = "** foo sighs." + actual = NewEmoteMsg("sighs.", u).String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + expected = "-> hello" + actual = NewSystemMsg("hello", u).String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + expected = "[PM from foo] hello" + actual = NewPrivateMsg("hello", u, u).String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} + +// TODO: Add theme rendering tests diff --git a/chat/room.go b/chat/room.go new file mode 100644 index 0000000..7d1b3af --- /dev/null +++ b/chat/room.go @@ -0,0 +1,222 @@ +package chat + +import ( + "errors" + "fmt" + "io" + "sync" +) + +const historyLen = 20 +const roomBuffer = 10 + +// The error returned when a message is sent to a room that is already +// closed. +var ErrRoomClosed = errors.New("room closed") + +// The error returned when a user attempts to join with an invalid name, such +// as empty string. +var ErrInvalidName = errors.New("invalid name") + +// Member is a User with per-Room metadata attached to it. +type Member struct { + *User + Op bool +} + +// Room definition, also a Set of User Items +type Room struct { + topic string + history *History + members *Set + broadcast chan Message + commands Commands + closed bool + closeOnce sync.Once +} + +// NewRoom creates a new room. +func NewRoom() *Room { + broadcast := make(chan Message, roomBuffer) + + return &Room{ + broadcast: broadcast, + history: NewHistory(historyLen), + members: NewSet(), + commands: *defaultCommands, + } +} + +// SetCommands sets the room's command handlers. +func (r *Room) SetCommands(commands Commands) { + r.commands = commands +} + +// Close the room and all the users it contains. +func (r *Room) Close() { + r.closeOnce.Do(func() { + r.closed = true + r.members.Each(func(m Identifier) { + m.(*Member).Close() + }) + r.members.Clear() + close(r.broadcast) + }) +} + +// SetLogging sets logging output for the room's history +func (r *Room) SetLogging(out io.Writer) { + r.history.SetOutput(out) +} + +// HandleMsg reacts to a message, will block until done. +func (r *Room) HandleMsg(m Message) { + switch m := m.(type) { + case *CommandMsg: + cmd := *m + err := r.commands.Run(r, cmd) + if err != nil { + m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from) + go r.HandleMsg(m) + } + case MessageTo: + user := m.To() + user.Send(m) + default: + fromMsg, skip := m.(MessageFrom) + var skipUser *User + if skip { + skipUser = fromMsg.From() + } + + r.history.Add(m) + r.members.Each(func(u Identifier) { + user := u.(*Member).User + if skip && skipUser == user { + // Skip + return + } + if _, ok := m.(*AnnounceMsg); ok { + if user.Config.Quiet { + // Skip + return + } + } + user.Send(m) + }) + } +} + +// Serve will consume the broadcast room and handle the messages, should be +// run in a goroutine. +func (r *Room) Serve() { + for m := range r.broadcast { + go r.HandleMsg(m) + } +} + +// Send message, buffered by a chan. +func (r *Room) Send(m Message) { + r.broadcast <- m +} + +// History feeds the room's recent message history to the user's handler. +func (r *Room) History(u *User) { + for _, m := range r.history.Get(historyLen) { + u.Send(m) + } +} + +// Join the room as a user, will announce. +func (r *Room) Join(u *User) (*Member, error) { + if r.closed { + return nil, ErrRoomClosed + } + if u.Id() == "" { + return nil, ErrInvalidName + } + member := Member{u, false} + err := r.members.Add(&member) + if err != nil { + return nil, err + } + r.History(u) + s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len()) + r.Send(NewAnnounceMsg(s)) + return &member, nil +} + +// Leave the room as a user, will announce. Mostly used during setup. +func (r *Room) Leave(u *User) error { + err := r.members.Remove(u) + if err != nil { + return err + } + s := fmt.Sprintf("%s left.", u.Name()) + r.Send(NewAnnounceMsg(s)) + return nil +} + +// Rename member with a new identity. This will not call rename on the member. +func (r *Room) Rename(oldId string, identity Identifier) error { + if identity.Id() == "" { + return ErrInvalidName + } + err := r.members.Replace(oldId, identity) + if err != nil { + return err + } + + s := fmt.Sprintf("%s is now known as %s.", oldId, identity.Id()) + r.Send(NewAnnounceMsg(s)) + return nil +} + +// Member returns a corresponding Member object to a User if the Member is +// present in this room. +func (r *Room) Member(u *User) (*Member, bool) { + m, ok := r.MemberById(u.Id()) + if !ok { + return nil, false + } + // Check that it's the same user + if m.User != u { + return nil, false + } + return m, true +} + +func (r *Room) MemberById(id string) (*Member, bool) { + m, err := r.members.Get(id) + if err != nil { + return nil, false + } + return m.(*Member), true +} + +// IsOp returns whether a user is an operator in this room. +func (r *Room) IsOp(u *User) bool { + m, ok := r.Member(u) + return ok && m.Op +} + +// Topic of the room. +func (r *Room) Topic() string { + return r.topic +} + +// SetTopic will set the topic of the room. +func (r *Room) SetTopic(s string) { + r.topic = s +} + +// NamesPrefix lists all members' names with a given prefix, used to query +// for autocompletion purposes. +func (r *Room) NamesPrefix(prefix string) []string { + members := r.members.ListPrefix(prefix) + names := make([]string, len(members)) + for i, u := range members { + names[i] = u.(*Member).User.Name() + } + return names +} diff --git a/chat/room_test.go b/chat/room_test.go new file mode 100644 index 0000000..e415a0c --- /dev/null +++ b/chat/room_test.go @@ -0,0 +1,192 @@ +package chat + +import ( + "reflect" + "testing" +) + +func TestRoomServe(t *testing.T) { + ch := NewRoom() + ch.Send(NewAnnounceMsg("hello")) + + received := <-ch.broadcast + actual := received.String() + expected := " * hello" + + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} + +func TestRoomJoin(t *testing.T) { + var expected, actual []byte + + s := &MockScreen{} + u := NewUser(testId("foo")) + + ch := NewRoom() + go ch.Serve() + defer ch.Close() + + _, err := ch.Join(u) + if err != nil { + t.Fatal(err) + } + + u.ConsumeOne(s) + expected = []byte(" * foo joined. (Connected: 1)" + Newline) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + ch.Send(NewSystemMsg("hello", u)) + u.ConsumeOne(s) + expected = []byte("-> hello" + Newline) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + ch.Send(ParseInput("/me says hello.", u)) + u.ConsumeOne(s) + expected = []byte("** foo says hello." + Newline) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} + +func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) { + u := NewUser(testId("foo")) + u.Config = UserConfig{ + Quiet: true, + } + + ch := NewRoom() + defer ch.Close() + + _, err := ch.Join(u) + if err != nil { + t.Fatal(err) + } + + // Drain the initial Join message + <-ch.broadcast + + go func() { + for msg := range u.msg { + if _, ok := msg.(*AnnounceMsg); ok { + t.Errorf("Got unexpected `%T`", msg) + } + } + }() + + // Call with an AnnounceMsg and all the other types + // and assert we received only non-announce messages + ch.HandleMsg(NewAnnounceMsg("Ignored")) + // Assert we still get all other types of messages + ch.HandleMsg(NewEmoteMsg("hello", u)) + ch.HandleMsg(NewSystemMsg("hello", u)) + ch.HandleMsg(NewPrivateMsg("hello", u, u)) + ch.HandleMsg(NewPublicMsg("hello", u)) +} + +func TestRoomQuietToggleBroadcasts(t *testing.T) { + u := NewUser(testId("foo")) + u.Config = UserConfig{ + Quiet: true, + } + + ch := NewRoom() + defer ch.Close() + + _, err := ch.Join(u) + if err != nil { + t.Fatal(err) + } + + // Drain the initial Join message + <-ch.broadcast + + u.ToggleQuietMode() + + expectedMsg := NewAnnounceMsg("Ignored") + ch.HandleMsg(expectedMsg) + msg := <-u.msg + if _, ok := msg.(*AnnounceMsg); !ok { + t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg) + } + + u.ToggleQuietMode() + + ch.HandleMsg(NewAnnounceMsg("Ignored")) + ch.HandleMsg(NewSystemMsg("hello", u)) + msg = <-u.msg + if _, ok := msg.(*AnnounceMsg); ok { + t.Errorf("Got unexpected `%T`", msg) + } +} + +func TestQuietToggleDisplayState(t *testing.T) { + var expected, actual []byte + + s := &MockScreen{} + u := NewUser(testId("foo")) + + ch := NewRoom() + go ch.Serve() + defer ch.Close() + + _, err := ch.Join(u) + if err != nil { + t.Fatal(err) + } + + // Drain the initial Join message + <-ch.broadcast + + ch.Send(ParseInput("/quiet", u)) + u.ConsumeOne(s) + expected = []byte("-> Quiet mode is toggled ON" + Newline) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + ch.Send(ParseInput("/quiet", u)) + u.ConsumeOne(s) + expected = []byte("-> Quiet mode is toggled OFF" + Newline) + + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} + +func TestRoomNames(t *testing.T) { + var expected, actual []byte + + s := &MockScreen{} + u := NewUser(testId("foo")) + + ch := NewRoom() + go ch.Serve() + defer ch.Close() + + _, err := ch.Join(u) + if err != nil { + t.Fatal(err) + } + + // Drain the initial Join message + <-ch.broadcast + + ch.Send(ParseInput("/names", u)) + u.ConsumeOne(s) + expected = []byte("-> 1 connected: foo" + Newline) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} diff --git a/chat/sanitize.go b/chat/sanitize.go new file mode 100644 index 0000000..8b162cd --- /dev/null +++ b/chat/sanitize.go @@ -0,0 +1,17 @@ +package chat + +import "regexp" + +var reStripName = regexp.MustCompile("[^\\w.-]") + +// SanitizeName returns a name with only allowed characters. +func SanitizeName(s string) string { + return reStripName.ReplaceAllString(s, "") +} + +var reStripData = regexp.MustCompile("[^[:ascii:]]") + +// SanitizeData returns a string with only allowed characters for client-provided metadata inputs. +func SanitizeData(s string) string { + return reStripData.ReplaceAllString(s, "") +} diff --git a/chat/screen_test.go b/chat/screen_test.go new file mode 100644 index 0000000..c530f94 --- /dev/null +++ b/chat/screen_test.go @@ -0,0 +1,51 @@ +package chat + +import ( + "reflect" + "testing" +) + +// Used for testing +type MockScreen struct { + buffer []byte +} + +func (s *MockScreen) Write(data []byte) (n int, err error) { + s.buffer = append(s.buffer, data...) + return len(data), nil +} + +func (s *MockScreen) Read(p *[]byte) (n int, err error) { + *p = s.buffer + s.buffer = []byte{} + return len(*p), nil +} + +func TestScreen(t *testing.T) { + var actual, expected []byte + + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: %v; Expected: %v", actual, expected) + } + + actual = []byte("foo") + expected = []byte("foo") + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: %v; Expected: %v", actual, expected) + } + + s := &MockScreen{} + + expected = nil + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: %v; Expected: %v", actual, expected) + } + + expected = []byte("hello, world") + s.Write(expected) + s.Read(&actual) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: %v; Expected: %v", actual, expected) + } +} diff --git a/chat/set.go b/chat/set.go new file mode 100644 index 0000000..8617e14 --- /dev/null +++ b/chat/set.go @@ -0,0 +1,142 @@ +package chat + +import ( + "errors" + "strings" + "sync" +) + +// The error returned when an added id already exists in the set. +var ErrIdTaken = errors.New("id already taken") + +// The error returned when a requested item does not exist in the set. +var ErrItemMissing = errors.New("item does not exist") + +// Set with string lookup. +// TODO: Add trie for efficient prefix lookup? +type Set struct { + lookup map[string]Identifier + sync.RWMutex +} + +// NewSet creates a new set. +func NewSet() *Set { + return &Set{ + lookup: map[string]Identifier{}, + } +} + +// Clear removes all items and returns the number removed. +func (s *Set) Clear() int { + s.Lock() + n := len(s.lookup) + s.lookup = map[string]Identifier{} + s.Unlock() + return n +} + +// Len returns the size of the set right now. +func (s *Set) Len() int { + return len(s.lookup) +} + +// In checks if an item exists in this set. +func (s *Set) In(item Identifier) bool { + s.RLock() + _, ok := s.lookup[item.Id()] + s.RUnlock() + return ok +} + +// Get returns an item with the given Id. +func (s *Set) Get(id string) (Identifier, error) { + s.RLock() + item, ok := s.lookup[id] + s.RUnlock() + + if !ok { + return nil, ErrItemMissing + } + + return item, nil +} + +// Add item to this set if it does not exist already. +func (s *Set) Add(item Identifier) error { + s.Lock() + defer s.Unlock() + + _, found := s.lookup[item.Id()] + if found { + return ErrIdTaken + } + + s.lookup[item.Id()] = item + return nil +} + +// Remove item from this set. +func (s *Set) Remove(item Identifier) error { + s.Lock() + defer s.Unlock() + id := item.Id() + _, found := s.lookup[id] + if !found { + return ErrItemMissing + } + delete(s.lookup, id) + return nil +} + +// Replace item from old id with new Identifier. +// Used for moving the same identifier to a new Id, such as a rename. +func (s *Set) Replace(oldId string, item Identifier) error { + s.Lock() + defer s.Unlock() + + // Check if it already exists + _, found := s.lookup[item.Id()] + if found { + return ErrIdTaken + } + + // Remove oldId + _, found = s.lookup[oldId] + if !found { + return ErrItemMissing + } + delete(s.lookup, oldId) + + // Add new identifier + s.lookup[item.Id()] = item + + return nil +} + +// Each loops over every item while holding a read lock and applies fn to each +// element. +func (s *Set) Each(fn func(item Identifier)) { + s.RLock() + for _, item := range s.lookup { + fn(item) + } + s.RUnlock() +} + +// ListPrefix returns a list of items with a prefix, case insensitive. +func (s *Set) ListPrefix(prefix string) []Identifier { + r := []Identifier{} + prefix = strings.ToLower(prefix) + + s.RLock() + defer s.RUnlock() + + for id, item := range s.lookup { + if !strings.HasPrefix(string(id), prefix) { + continue + } + r = append(r, item) + } + + return r +} diff --git a/chat/set_test.go b/chat/set_test.go new file mode 100644 index 0000000..b92bdeb --- /dev/null +++ b/chat/set_test.go @@ -0,0 +1,38 @@ +package chat + +import "testing" + +func TestSet(t *testing.T) { + var err error + s := NewSet() + u := NewUser(testId("foo")) + + if s.In(u) { + t.Errorf("Set should be empty.") + } + + err = s.Add(u) + if err != nil { + t.Error(err) + } + + if !s.In(u) { + t.Errorf("Set should contain user.") + } + + u2 := NewUser(testId("bar")) + err = s.Add(u2) + if err != nil { + t.Error(err) + } + + err = s.Add(u2) + if err != ErrIdTaken { + t.Error(err) + } + + size := s.Len() + if size != 2 { + t.Errorf("Set wrong size: %d (expected %d)", size, 2) + } +} diff --git a/chat/theme.go b/chat/theme.go new file mode 100644 index 0000000..27085c0 --- /dev/null +++ b/chat/theme.go @@ -0,0 +1,195 @@ +package chat + +import "fmt" + +const ( + // Reset resets the color + Reset = "\033[0m" + + // Bold makes the following text bold + Bold = "\033[1m" + + // Dim dims the following text + Dim = "\033[2m" + + // Italic makes the following text italic + Italic = "\033[3m" + + // Underline underlines the following text + Underline = "\033[4m" + + // Blink blinks the following text + Blink = "\033[5m" + + // Invert inverts the following text + Invert = "\033[7m" + + // Newline + Newline = "\r\n" + + // BEL + Bel = "\007" +) + +// Interface for Styles +type Style interface { + String() string + Format(string) string +} + +// General hardcoded style, mostly used as a crutch until we flesh out the +// framework to support backgrounds etc. +type style string + +func (c style) String() string { + return string(c) +} + +func (c style) Format(s string) string { + return c.String() + s + Reset +} + +// 256 color type, for terminals who support it +type Color256 uint8 + +// String version of this color +func (c Color256) String() string { + return fmt.Sprintf("38;05;%d", c) +} + +// Return formatted string with this color +func (c Color256) Format(s string) string { + return "\033[" + c.String() + "m" + s + Reset +} + +// No color, used for mono theme +type Color0 struct{} + +// No-op for Color0 +func (c Color0) String() string { + return "" +} + +// No-op for Color0 +func (c Color0) Format(s string) string { + return s +} + +// Container for a collection of colors +type Palette struct { + colors []Style + size int +} + +// Get a color by index, overflows are looped around. +func (p Palette) Get(i int) Style { + return p.colors[i%(p.size-1)] +} + +func (p Palette) Len() int { + return p.size +} + +func (p Palette) String() string { + r := "" + for _, c := range p.colors { + r += c.Format("X") + } + return r +} + +// Collection of settings for chat +type Theme struct { + id string + sys Style + pm Style + highlight Style + names *Palette +} + +func (t Theme) Id() string { + return t.id +} + +// Colorize name string given some index +func (t Theme) ColorName(u *User) string { + if t.names == nil { + return u.Name() + } + + return t.names.Get(u.colorIdx).Format(u.Name()) +} + +// Colorize the PM string +func (t Theme) ColorPM(s string) string { + if t.pm == nil { + return s + } + + return t.pm.Format(s) +} + +// Colorize the Sys message +func (t Theme) ColorSys(s string) string { + if t.sys == nil { + return s + } + + return t.sys.Format(s) +} + +// Highlight a matched string, usually name +func (t Theme) Highlight(s string) string { + if t.highlight == nil { + return s + } + return t.highlight.Format(s) +} + +// List of initialzied themes +var Themes []Theme + +// Default theme to use +var DefaultTheme *Theme + +func readableColors256() *Palette { + size := 247 + p := Palette{ + colors: make([]Style, size), + size: size, + } + j := 0 + for i := 0; i < 256; i++ { + if (16 <= i && i <= 18) || (232 <= i && i <= 237) { + // Remove the ones near black, this is kinda sadpanda. + continue + } + p.colors[j] = Color256(i) + j++ + } + return &p +} + +func init() { + palette := readableColors256() + + Themes = []Theme{ + Theme{ + id: "colors", + names: palette, + sys: palette.Get(8), // Grey + pm: palette.Get(7), // White + highlight: style(Bold + "\033[48;5;11m\033[38;5;16m"), // Yellow highlight + }, + Theme{ + id: "mono", + }, + } + + // Debug for printing colors: + //for _, color := range palette.colors { + // fmt.Print(color.Format(color.String() + " ")) + //} + + DefaultTheme = &Themes[0] +} diff --git a/chat/theme_test.go b/chat/theme_test.go new file mode 100644 index 0000000..28b9d43 --- /dev/null +++ b/chat/theme_test.go @@ -0,0 +1,71 @@ +package chat + +import ( + "fmt" + "testing" +) + +func TestThemePalette(t *testing.T) { + var expected, actual string + + palette := readableColors256() + color := palette.Get(5) + if color == nil { + t.Fatal("Failed to return a color from palette.") + } + + actual = color.String() + expected = "38;05;5" + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + actual = color.Format("foo") + expected = "\033[38;05;5mfoo\033[0m" + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + actual = palette.Get(palette.Len() + 1).String() + expected = fmt.Sprintf("38;05;%d", 2) + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + +} + +func TestTheme(t *testing.T) { + var expected, actual string + + colorTheme := Themes[0] + color := colorTheme.sys + if color == nil { + t.Fatal("Sys color should not be empty for first theme.") + } + + actual = color.Format("foo") + expected = "\033[38;05;8mfoo\033[0m" + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + actual = colorTheme.ColorSys("foo") + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + u := NewUser(testId("foo")) + u.colorIdx = 4 + actual = colorTheme.ColorName(u) + expected = "\033[38;05;4mfoo\033[0m" + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + + msg := NewPublicMsg("hello", u) + actual = msg.Render(&colorTheme) + expected = "\033[38;05;4mfoo\033[0m: hello" + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} diff --git a/chat/user.go b/chat/user.go new file mode 100644 index 0000000..4a31816 --- /dev/null +++ b/chat/user.go @@ -0,0 +1,178 @@ +package chat + +import ( + "errors" + "fmt" + "io" + "math/rand" + "regexp" + "sync" + "time" +) + +const messageBuffer = 20 +const reHighlight = `\b(%s)\b` + +var ErrUserClosed = errors.New("user closed") + +// Identifier is an interface that can uniquely identify itself. +type Identifier interface { + Id() string + SetId(string) + Name() string +} + +// User definition, implemented set Item interface and io.Writer +type User struct { + Identifier + Config UserConfig + colorIdx int + joined time.Time + msg chan Message + done chan struct{} + replyTo *User // Set when user gets a /msg, for replying. + closed bool + closeOnce sync.Once +} + +func NewUser(identity Identifier) *User { + u := User{ + Identifier: identity, + Config: *DefaultUserConfig, + joined: time.Now(), + msg: make(chan Message, messageBuffer), + done: make(chan struct{}, 1), + } + u.SetColorIdx(rand.Int()) + + return &u +} + +func NewUserScreen(identity Identifier, screen io.Writer) *User { + u := NewUser(identity) + go u.Consume(screen) + + return u +} + +// Rename the user with a new Identifier. +func (u *User) SetId(id string) { + u.Identifier.SetId(id) + u.SetColorIdx(rand.Int()) +} + +// ReplyTo returns the last user that messaged this user. +func (u *User) ReplyTo() *User { + return u.replyTo +} + +// SetReplyTo sets the last user to message this user. +func (u *User) SetReplyTo(user *User) { + u.replyTo = user +} + +// ToggleQuietMode will toggle whether or not quiet mode is enabled +func (u *User) ToggleQuietMode() { + u.Config.Quiet = !u.Config.Quiet +} + +// SetColorIdx will set the colorIdx to a specific value, primarily used for +// testing. +func (u *User) SetColorIdx(idx int) { + u.colorIdx = idx +} + +// Block until user is closed +func (u *User) Wait() { + <-u.done +} + +// Disconnect user, stop accepting messages +func (u *User) Close() { + u.closeOnce.Do(func() { + u.closed = true + close(u.done) + close(u.msg) + }) +} + +// Consume message buffer into an io.Writer. Will block, should be called in a +// goroutine. +// TODO: Not sure if this is a great API. +func (u *User) Consume(out io.Writer) { + for m := range u.msg { + u.HandleMsg(m, out) + } +} + +// Consume one message and stop, mostly for testing +func (u *User) ConsumeOne(out io.Writer) { + u.HandleMsg(<-u.msg, out) +} + +// SetHighlight sets the highlighting regular expression to match string. +func (u *User) SetHighlight(s string) error { + re, err := regexp.Compile(fmt.Sprintf(reHighlight, s)) + if err != nil { + return err + } + u.Config.Highlight = re + return nil +} + +func (u *User) render(m Message) string { + switch m := m.(type) { + case *PublicMsg: + return m.RenderFor(u.Config) + Newline + case *PrivateMsg: + u.SetReplyTo(m.From()) + return m.Render(u.Config.Theme) + Newline + default: + return m.Render(u.Config.Theme) + Newline + } +} + +func (u *User) HandleMsg(m Message, out io.Writer) { + r := u.render(m) + _, err := out.Write([]byte(r)) + if err != nil { + logger.Printf("Write failed to %s, closing: %s", u.Name(), err) + u.Close() + } +} + +// Add message to consume by user +func (u *User) Send(m Message) error { + if u.closed { + return ErrUserClosed + } + + select { + case u.msg <- m: + default: + logger.Printf("Msg buffer full, closing: %s", u.Name()) + u.Close() + return ErrUserClosed + } + return nil +} + +// Container for per-user configurations. +type UserConfig struct { + Highlight *regexp.Regexp + Bell bool + Quiet bool + Theme *Theme +} + +// Default user configuration to use +var DefaultUserConfig *UserConfig + +func init() { + DefaultUserConfig = &UserConfig{ + Bell: true, + Quiet: false, + } + + // TODO: Seed random? +} diff --git a/chat/user_test.go b/chat/user_test.go new file mode 100644 index 0000000..37ecc29 --- /dev/null +++ b/chat/user_test.go @@ -0,0 +1,24 @@ +package chat + +import ( + "reflect" + "testing" +) + +func TestMakeUser(t *testing.T) { + var actual, expected []byte + + s := &MockScreen{} + u := NewUser(testId("foo")) + m := NewAnnounceMsg("hello") + + defer u.Close() + u.Send(m) + u.ConsumeOne(s) + + s.Read(&actual) + expected = []byte(m.String() + Newline) + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} diff --git a/client.go b/client.go deleted file mode 100644 index 781b4d8..0000000 --- a/client.go +++ /dev/null @@ -1,547 +0,0 @@ -package main - -import ( - "fmt" - "strings" - "sync" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" -) - -const ( - // MsgBuffer is the length of the message buffer - MsgBuffer int = 20 - - // MaxMsgLength is the maximum length of a message - MaxMsgLength int = 1024 - - // MaxNamesList is the max number of items to return in a /names command - MaxNamesList int = 20 - - // HelpText is the text returned by /help - HelpText string = `Available commands: - /about - About this chat. - /exit - Exit the chat. - /help - Show this help text. - /list - List the users that are currently connected. - /beep - Enable BEL notifications on mention. - /me $ACTION - Show yourself doing an action. - /nick $NAME - Rename yourself to a new name. - /whois $NAME - Display information about another connected user. - /msg $NAME $MESSAGE - Sends a private message to a user. - /motd - Prints the Message of the Day. - /theme [color|mono] - Set client theme.` - - // OpHelpText is the additional text returned by /help if the client is an Op - OpHelpText string = `Available operator commands: - /ban $NAME - Banish a user from the chat - /unban $FINGERPRINT - Unban a fingerprint - /banned - List all banned fingerprints - /kick $NAME - Kick em' out. - /op $NAME - Promote a user to server operator. - /silence $NAME - Revoke a user's ability to speak. - /shutdown $MESSAGE - Broadcast message and shutdown server. - /motd $MESSAGE - Set message shown whenever somebody joins. - /whitelist $FINGERPRINT - Add fingerprint to whitelist, prevent anyone else from joining. - /whitelist github.com/$USER - Add github user's pubkeys to whitelist.` - - // AboutText is the text returned by /about - AboutText string = `ssh-chat is made by @shazow. - - It is a custom ssh server built in Go to serve a chat experience - instead of a shell. - - Source: https://github.com/shazow/ssh-chat - - For more, visit shazow.net or follow at twitter.com/shazow` - - // RequiredWait is the time a client is required to wait between messages - RequiredWait time.Duration = time.Second / 2 -) - -// Client holds all the fields used by the client -type Client struct { - Server *Server - Conn *ssh.ServerConn - Msg chan string - Name string - Color string - Op bool - ready chan struct{} - term *terminal.Terminal - termWidth int - termHeight int - silencedUntil time.Time - lastTX time.Time - beepMe bool - colorMe bool - closed bool - sync.RWMutex -} - -// NewClient constructs a new client -func NewClient(server *Server, conn *ssh.ServerConn) *Client { - return &Client{ - Server: server, - Conn: conn, - Name: conn.User(), - Color: RandomColor256(), - Msg: make(chan string, MsgBuffer), - ready: make(chan struct{}, 1), - lastTX: time.Now(), - colorMe: true, - } -} - -// ColoredName returns the client name in its color -func (c *Client) ColoredName() string { - return ColorString(c.Color, c.Name) -} - -// SysMsg sends a message in continuous format over the message channel -func (c *Client) SysMsg(msg string, args ...interface{}) { - c.Send(ContinuousFormat(systemMessageFormat, "-> "+fmt.Sprintf(msg, args...))) -} - -// Write writes the given message -func (c *Client) Write(msg string) { - if !c.colorMe { - msg = DeColorString(msg) - } - c.term.Write([]byte(msg + "\r\n")) -} - -// WriteLines writes multiple messages -func (c *Client) WriteLines(msg []string) { - for _, line := range msg { - c.Write(line) - } -} - -// Send sends the given message -func (c *Client) Send(msg string) { - if len(msg) > MaxMsgLength || c.closed { - return - } - select { - case c.Msg <- msg: - default: - logger.Errorf("Msg buffer full, dropping: %s (%s)", c.Name, c.Conn.RemoteAddr()) - c.Conn.Conn.Close() - } -} - -// SendLines sends multiple messages -func (c *Client) SendLines(msg []string) { - for _, line := range msg { - c.Send(line) - } -} - -// IsSilenced checks if the client is silenced -func (c *Client) IsSilenced() bool { - return c.silencedUntil.After(time.Now()) -} - -// Silence silences a client for the given duration -func (c *Client) Silence(d time.Duration) { - c.silencedUntil = time.Now().Add(d) -} - -// Resize resizes the client to the given width and height -func (c *Client) Resize(width, height int) error { - width = 1000000 // TODO: Remove this dirty workaround for text overflow once ssh/terminal is fixed - err := c.term.SetSize(width, height) - if err != nil { - logger.Errorf("Resize failed: %dx%d", width, height) - return err - } - c.termWidth, c.termHeight = width, height - return nil -} - -// Rename renames the client to the given name -func (c *Client) Rename(name string) { - c.Name = name - var prompt string - - if c.colorMe { - prompt = c.ColoredName() - } else { - prompt = c.Name - } - - c.term.SetPrompt(fmt.Sprintf("[%s] ", prompt)) -} - -// Fingerprint returns the fingerprint -func (c *Client) Fingerprint() string { - if c.Conn.Permissions == nil { - return "" - } - return c.Conn.Permissions.Extensions["fingerprint"] -} - -// Emote formats and sends an emote -func (c *Client) Emote(message string) { - formatted := fmt.Sprintf("** %s%s", c.ColoredName(), message) - if c.IsSilenced() || len(message) > 1000 { - c.SysMsg("Message rejected") - } - c.Server.Broadcast(formatted, nil) -} - -func (c *Client) handleShell(channel ssh.Channel) { - defer channel.Close() - - // FIXME: This shouldn't live here, need to restructure the call chaining. - c.Server.Add(c) - go func() { - // Block until done, then remove. - c.Conn.Wait() - c.closed = true - c.Server.Remove(c) - close(c.Msg) - }() - - go func() { - for msg := range c.Msg { - c.Write(msg) - } - }() - - for { - line, err := c.term.ReadLine() - if err != nil { - break - } - - parts := strings.SplitN(line, " ", 3) - isCmd := strings.HasPrefix(parts[0], "/") - - if isCmd { - // TODO: Factor this out. - switch parts[0] { - case "/test-colors": // Shh, this command is a secret! - c.Write(ColorString("32", "Lorem ipsum dolor sit amet,")) - c.Write("consectetur " + ColorString("31;1", "adipiscing") + " elit.") - case "/exit": - channel.Close() - case "/help": - c.SysMsg(strings.Replace(HelpText, "\n", "\r\n", -1)) - if c.Server.IsOp(c) { - c.SysMsg(strings.Replace(OpHelpText, "\n", "\r\n", -1)) - } - case "/about": - c.SysMsg(strings.Replace(AboutText, "\n", "\r\n", -1)) - case "/uptime": - c.SysMsg(c.Server.Uptime()) - case "/beep": - c.beepMe = !c.beepMe - if c.beepMe { - c.SysMsg("I'll beep you good.") - } else { - c.SysMsg("No more beeps. :(") - } - case "/me": - me := strings.TrimLeft(line, "/me") - if me == "" { - me = " is at a loss for words." - } - c.Emote(me) - case "/slap": - slappee := "themself" - if len(parts) > 1 { - slappee = parts[1] - if len(parts[1]) > 100 { - slappee = "some long-named jerk" - } - } - c.Emote(fmt.Sprintf(" slaps %s around a bit with a large trout.", slappee)) - case "/nick": - if len(parts) == 2 { - c.Server.Rename(c, parts[1]) - } else { - c.SysMsg("Missing $NAME from: /nick $NAME") - } - case "/whois": - if len(parts) >= 2 { - client := c.Server.Who(parts[1]) - if client != nil { - version := reStripText.ReplaceAllString(string(client.Conn.ClientVersion()), "") - if len(version) > 100 { - version = "Evil Jerk with a superlong string" - } - c.SysMsg("%s is %s via %s", client.ColoredName(), client.Fingerprint(), version) - } else { - c.SysMsg("No such name: %s", parts[1]) - } - } else { - c.SysMsg("Missing $NAME from: /whois $NAME") - } - case "/names", "/list": - coloredNames := []string{} - for _, name := range c.Server.List(nil) { - coloredNames = append(coloredNames, c.Server.Who(name).ColoredName()) - } - num := len(coloredNames) - if len(coloredNames) > MaxNamesList { - others := fmt.Sprintf("and %d others.", len(coloredNames)-MaxNamesList) - coloredNames = coloredNames[:MaxNamesList] - coloredNames = append(coloredNames, others) - } - - c.SysMsg("%d connected: %s", num, strings.Join(coloredNames, systemMessageFormat+", ")) - case "/ban": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /ban $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - fingerprint := client.Fingerprint() - client.SysMsg("Banned by %s.", c.ColoredName()) - c.Server.Ban(fingerprint, nil) - client.Conn.Close() - c.Server.Broadcast(fmt.Sprintf("* %s was banned by %s", parts[1], c.ColoredName()), nil) - } - } - case "/unban": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $FINGERPRINT from: /unban $FINGERPRINT") - } else { - fingerprint := parts[1] - isBanned := c.Server.IsBanned(fingerprint) - if !isBanned { - c.SysMsg("No such banned fingerprint: %s", fingerprint) - } else { - c.Server.Unban(fingerprint) - c.Server.Broadcast(fmt.Sprintf("* %s was unbanned by %s", fingerprint, c.ColoredName()), nil) - } - } - case "/banned": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 1 { - c.SysMsg("Too many arguments for /banned") - } else { - for fingerprint := range c.Server.bannedPK { - c.SysMsg("Banned fingerprint: %s", fingerprint) - } - } - case "/op": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /op $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - fingerprint := client.Fingerprint() - if fingerprint == "" { - c.SysMsg("Cannot op user without fingerprint.") - } else { - client.SysMsg("Made op by %s.", c.ColoredName()) - c.Server.Op(fingerprint) - } - } - } - case "/kick": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /kick $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - client.SysMsg("Kicked by %s.", c.ColoredName()) - client.Conn.Close() - c.Server.Broadcast(fmt.Sprintf("* %s was kicked by %s", parts[1], c.ColoredName()), nil) - } - } - case "/silence": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) < 2 { - c.SysMsg("Missing $NAME from: /silence $NAME") - } else { - duration := time.Duration(5) * time.Minute - if len(parts) >= 3 { - parsedDuration, err := time.ParseDuration(parts[2]) - if err == nil { - duration = parsedDuration - } - } - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - client.Silence(duration) - client.SysMsg("Silenced for %s by %s.", duration, c.ColoredName()) - } - } - case "/shutdown": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else { - var split = strings.SplitN(line, " ", 2) - var msg string - if len(split) > 1 { - msg = split[1] - } else { - msg = "" - } - // Shutdown after 5 seconds - go func() { - c.Server.Broadcast(ColorString("31", msg), nil) - time.Sleep(time.Second * 5) - c.Server.Stop() - }() - } - case "/msg": /* Send a PM */ - /* Make sure we have a recipient and a message */ - if len(parts) < 2 { - c.SysMsg("Missing $NAME from: /msg $NAME $MESSAGE") - break - } else if len(parts) < 3 { - c.SysMsg("Missing $MESSAGE from: /msg $NAME $MESSAGE") - break - } - /* Ask the server to send the message */ - if err := c.Server.Privmsg(parts[1], parts[2], c); nil != err { - c.SysMsg("Unable to send message to %v: %v", parts[1], err) - } - case "/motd": /* print motd */ - if !c.Server.IsOp(c) { - c.Server.MotdUnicast(c) - } else if len(parts) < 2 { - c.Server.MotdUnicast(c) - } else { - var newmotd string - if len(parts) == 2 { - newmotd = parts[1] - } else { - newmotd = parts[1] + " " + parts[2] - } - c.Server.SetMotd(newmotd) - c.Server.MotdBroadcast(c) - } - case "/theme": - if len(parts) < 2 { - c.SysMsg("Missing $THEME from: /theme $THEME") - c.SysMsg("Choose either color or mono") - } else { - // Sets colorMe attribute of client - if parts[1] == "mono" { - c.colorMe = false - } else if parts[1] == "color" { - c.colorMe = true - } - // Rename to reset prompt - c.Rename(c.Name) - } - - case "/whitelist": /* whitelist a fingerprint */ - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $FINGERPRINT from: /whitelist $FINGERPRINT") - } else { - fingerprint := parts[1] - go func() { - err = c.Server.Whitelist(fingerprint) - if err != nil { - c.SysMsg("Error adding to whitelist: %s", err) - } else { - c.SysMsg("Added %s to the whitelist", fingerprint) - } - }() - } - case "/version": - c.SysMsg("Version " + buildCommit) - - default: - c.SysMsg("Invalid command: %s", line) - } - continue - } - - msg := fmt.Sprintf("%s: %s", c.ColoredName(), line) - /* Rate limit */ - if time.Now().Sub(c.lastTX) < RequiredWait { - c.SysMsg("Rate limiting in effect.") - continue - } - if c.IsSilenced() || len(msg) > 1000 || len(line) < 1 { - c.SysMsg("Message rejected.") - continue - } - c.Server.Broadcast(msg, c) - c.lastTX = time.Now() - } - -} - -func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { - prompt := fmt.Sprintf("[%s] ", c.ColoredName()) - - hasShell := false - - for ch := range channels { - if t := ch.ChannelType(); t != "session" { - ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - continue - } - - channel, requests, err := ch.Accept() - if err != nil { - logger.Errorf("Could not accept channel: %v", err) - continue - } - defer channel.Close() - - c.term = terminal.NewTerminal(channel, prompt) - c.term.AutoCompleteCallback = c.Server.AutoCompleteFunction - - for req := range requests { - var width, height int - var ok bool - - switch req.Type { - case "shell": - if c.term != nil && !hasShell { - go c.handleShell(channel) - ok = true - hasShell = true - } - case "pty-req": - width, height, ok = parsePtyRequest(req.Payload) - if ok { - err := c.Resize(width, height) - ok = err == nil - } - case "window-change": - width, height, ok = parseWinchRequest(req.Payload) - if ok { - err := c.Resize(width, height) - ok = err == nil - } - } - - if req.WantReply { - req.Reply(ok, nil) - } - } - } -} diff --git a/cmd.go b/cmd.go index 34840ad..e60acb9 100644 --- a/cmd.go +++ b/cmd.go @@ -13,18 +13,23 @@ import ( "github.com/alexcesaro/log" "github.com/alexcesaro/log/golog" "github.com/jessevdk/go-flags" + "golang.org/x/crypto/ssh" + + "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" ) import _ "net/http/pprof" // Options contains the flag options type Options struct { - Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` - Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` - Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"` - Admin []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."` - Whitelist string `long:"whitelist" description:"Optional file of pubkey fingerprints who are allowed to connect."` - Motd string `long:"motd" description:"Optional Message of the Day file."` - Pprof int `long:"pprof" description:"Enable pprof http server for profiling."` + Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` + Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` + Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"` + Admin string `long:"admin" description:"File of public keys who are admins."` + Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."` + Motd string `long:"motd" description:"Optional Message of the Day file."` + Log string `long:"log" description:"Write chat log to this file."` + Pprof int `long:"pprof" description:"Enable pprof http server for profiling."` } var logLevels = []log.Level{ @@ -34,6 +39,7 @@ var logLevels = []log.Level{ } var buildCommit string + func main() { options := Options{} parser := flags.NewParser(&options, flags.Default) @@ -42,6 +48,7 @@ func main() { if p == nil { fmt.Print(err) } + os.Exit(1) return } @@ -51,54 +58,84 @@ func main() { }() } - // Initialize seed for random colors - RandomColorInit() - // Figure out the log level numVerbose := len(options.Verbose) if numVerbose > len(logLevels) { - numVerbose = len(logLevels) + numVerbose = len(logLevels) - 1 } logLevel := logLevels[numVerbose] logger = golog.New(os.Stderr, logLevel) + if logLevel == log.Debug { + // Enable logging from submodules + chat.SetLogger(os.Stderr) + sshd.SetLogger(os.Stderr) + } + privateKeyPath := options.Identity - if strings.HasPrefix(privateKeyPath, "~") { + if strings.HasPrefix(privateKeyPath, "~/") { user, err := user.Current() if err == nil { privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1) } } - privateKey, err := ioutil.ReadFile(privateKeyPath) + privateKey, err := ReadPrivateKey(privateKeyPath) if err != nil { - logger.Errorf("Failed to load identity: %v", err) - return + logger.Errorf("Couldn't read private key: %v", err) + os.Exit(2) } - server, err := NewServer(privateKey) + signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { - logger.Errorf("Failed to create server: %v", err) - return + logger.Errorf("Failed to parse key: %v", err) + os.Exit(3) } - for _, fingerprint := range options.Admin { - server.Op(fingerprint) - } + auth := NewAuth() + config := sshd.MakeAuth(auth) + config.AddHostKey(signer) - if options.Whitelist != "" { - file, err := os.Open(options.Whitelist) + s, err := sshd.ListenSSH(options.Bind, config) + if err != nil { + logger.Errorf("Failed to listen on socket: %v", err) + os.Exit(4) + } + defer s.Close() + s.RateLimit = true + + fmt.Printf("Listening for connections on %v\n", s.Addr().String()) + + host := NewHost(s) + host.auth = auth + host.theme = &chat.Themes[0] + + err = fromFile(options.Admin, func(line []byte) error { + key, _, _, _, err := ssh.ParseAuthorizedKey(line) if err != nil { - logger.Errorf("Could not open whitelist file") - return + return err } - defer file.Close() + auth.Op(key, 0) + return nil + }) + if err != nil { + logger.Errorf("Failed to load admins: %v", err) + os.Exit(5) + } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - server.Whitelist(scanner.Text()) + err = fromFile(options.Whitelist, func(line []byte) error { + key, _, _, _, err := ssh.ParseAuthorizedKey(line) + if err != nil { + return err } + auth.Whitelist(key, 0) + logger.Debugf("Whitelisted: %s", line) + return nil + }) + if err != nil { + logger.Errorf("Failed to load whitelist: %v", err) + os.Exit(5) } if options.Motd != "" { @@ -107,24 +144,53 @@ func main() { logger.Errorf("Failed to load MOTD file: %v", err) return } - motdString := string(motd[:]) - /* hack to normalize line endings into \r\n */ + motdString := strings.TrimSpace(string(motd)) + // hack to normalize line endings into \r\n motdString = strings.Replace(motdString, "\r\n", "\n", -1) motdString = strings.Replace(motdString, "\n", "\r\n", -1) - server.SetMotd(motdString) + host.SetMotd(motdString) } + if options.Log == "-" { + host.SetLogging(os.Stdout) + } else if options.Log != "" { + fp, err := os.OpenFile(options.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + logger.Errorf("Failed to open log file for writing: %v", err) + return + } + host.SetLogging(fp) + } + + go host.Serve() + // Construct interrupt handler sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) - err = server.Start(options.Bind) - if err != nil { - logger.Errorf("Failed to start server: %v", err) - return - } - <-sig // Wait for ^C signal logger.Warningf("Interrupt signal detected, shutting down.") - server.Stop() + os.Exit(0) +} + +func fromFile(path string, handler func(line []byte) error) error { + if path == "" { + // Skip + return nil + } + + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + err := handler(scanner.Bytes()) + if err != nil { + return err + } + } + return nil } diff --git a/colors.go b/colors.go deleted file mode 100644 index 6bfc5ca..0000000 --- a/colors.go +++ /dev/null @@ -1,82 +0,0 @@ -package main - -import ( - "fmt" - "math/rand" - "regexp" - "strings" - "time" -) - -const ( - // Reset resets the color - Reset = "\033[0m" - - // Bold makes the following text bold - Bold = "\033[1m" - - // Dim dims the following text - Dim = "\033[2m" - - // Italic makes the following text italic - Italic = "\033[3m" - - // Underline underlines the following text - Underline = "\033[4m" - - // Blink blinks the following text - Blink = "\033[5m" - - // Invert inverts the following text - Invert = "\033[7m" -) - -var colors = []string{"31", "32", "33", "34", "35", "36", "37", "91", "92", "93", "94", "95", "96", "97"} - -// deColor is used for removing ANSI Escapes -var deColor = regexp.MustCompile("\033\\[[\\d;]+m") - -// DeColorString removes all color from the given string -func DeColorString(s string) string { - s = deColor.ReplaceAllString(s, "") - return s -} - -func randomReadableColor() int { - for { - i := rand.Intn(256) - if (16 <= i && i <= 18) || (232 <= i && i <= 237) { - // Remove the ones near black, this is kinda sadpanda. - continue - } - return i - } -} - -// RandomColor256 returns a random (of 256) color -func RandomColor256() string { - return fmt.Sprintf("38;05;%d", randomReadableColor()) -} - -// RandomColor returns a random color -func RandomColor() string { - return colors[rand.Intn(len(colors))] -} - -// ColorString returns a message in the given color -func ColorString(color string, msg string) string { - return Bold + "\033[" + color + "m" + msg + Reset -} - -// RandomColorInit initializes the random seed -func RandomColorInit() { - rand.Seed(time.Now().UTC().UnixNano()) -} - -// ContinuousFormat is a horrible hack to "continue" the previous string color -// and format after a RESET has been encountered. -// -// This is not HTML where you can just do a to resume your previous formatting! -func ContinuousFormat(format string, str string) string { - return systemMessageFormat + strings.Replace(str, Reset, format, -1) + Reset -} diff --git a/history_test.go b/history_test.go deleted file mode 100644 index 0eab1c7..0000000 --- a/history_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -import ( - "reflect" - "testing" -) - -func TestHistory(t *testing.T) { - var r, expected []string - var size int - - h := NewHistory(5) - - r = h.Get(10) - expected = []string{} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - h.Add("1") - - if size = h.Len(); size != 1 { - t.Errorf("Wrong size: %v", size) - } - - r = h.Get(1) - expected = []string{"1"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - h.Add("2") - h.Add("3") - h.Add("4") - h.Add("5") - h.Add("6") - - if size = h.Len(); size != 5 { - t.Errorf("Wrong size: %v", size) - } - - r = h.Get(2) - expected = []string{"5", "6"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - r = h.Get(10) - expected = []string{"2", "3", "4", "5", "6"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } -} diff --git a/host.go b/host.go new file mode 100644 index 0000000..baa5095 --- /dev/null +++ b/host.go @@ -0,0 +1,461 @@ +package main + +import ( + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/shazow/rateio" + "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" +) + +const maxInputLength int = 1024 + +// GetPrompt will render the terminal prompt string based on the user. +func GetPrompt(user *chat.User) string { + name := user.Name() + if user.Config.Theme != nil { + name = user.Config.Theme.ColorName(user) + } + return fmt.Sprintf("[%s] ", name) +} + +// Host is the bridge between sshd and chat modules +// TODO: Should be easy to add support for multiple rooms, if we want. +type Host struct { + *chat.Room + listener *sshd.SSHListener + commands chat.Commands + + motd string + auth *Auth + count int + + // Default theme + theme *chat.Theme +} + +// NewHost creates a Host on top of an existing listener. +func NewHost(listener *sshd.SSHListener) *Host { + room := chat.NewRoom() + h := Host{ + Room: room, + listener: listener, + commands: chat.Commands{}, + } + + // Make our own commands registry instance. + chat.InitCommands(&h.commands) + h.InitCommands(&h.commands) + room.SetCommands(h.commands) + + go room.Serve() + return &h +} + +// SetMotd sets the host's message of the day. +func (h *Host) SetMotd(motd string) { + h.motd = motd +} + +func (h Host) isOp(conn sshd.Connection) bool { + key := conn.PublicKey() + if key == nil { + return false + } + return h.auth.IsOp(key) +} + +// Connect a specific Terminal to this host and its room. +func (h *Host) Connect(term *sshd.Terminal) { + id := NewIdentity(term.Conn) + user := chat.NewUserScreen(id, term) + user.Config.Theme = h.theme + go func() { + // Close term once user is closed. + user.Wait() + term.Close() + }() + defer user.Close() + + // Send MOTD + if h.motd != "" { + user.Send(chat.NewAnnounceMsg(h.motd)) + } + + member, err := h.Join(user) + if err != nil { + // Try again... + id.SetName(fmt.Sprintf("Guest%d", h.count)) + member, err = h.Join(user) + } + if err != nil { + logger.Errorf("Failed to join: %s", err) + return + } + + // Successfully joined. + term.SetPrompt(GetPrompt(user)) + term.AutoCompleteCallback = h.AutoCompleteFunction(user) + user.SetHighlight(user.Name()) + h.count++ + + // Should the user be op'd on join? + member.Op = h.isOp(term.Conn) + ratelimit := rateio.NewSimpleLimiter(3, time.Second*3) + + for { + line, err := term.ReadLine() + if err == io.EOF { + // Closed + break + } else if err != nil { + logger.Errorf("Terminal reading error: %s", err) + break + } + + err = ratelimit.Count(1) + if err != nil { + user.Send(chat.NewSystemMsg("Message rejected: Rate limiting is in effect.", user)) + continue + } + if len(line) > maxInputLength { + user.Send(chat.NewSystemMsg("Message rejected: Input too long.", user)) + continue + } + if line == "" { + // Silently ignore empty lines. + continue + } + + m := chat.ParseInput(line, user) + + // FIXME: Any reason to use h.room.Send(m) instead? + h.HandleMsg(m) + + cmd := m.Command() + if cmd == "/nick" || cmd == "/theme" { + // Hijack /nick command to update terminal synchronously. Wouldn't + // work if we use h.room.Send(m) above. + // + // FIXME: This is hacky, how do we improve the API to allow for + // this? Chat module shouldn't know about terminals. + term.SetPrompt(GetPrompt(user)) + user.SetHighlight(user.Name()) + } + } + + err = h.Leave(user) + if err != nil { + logger.Errorf("Failed to leave: %s", err) + return + } +} + +// Serve our chat room onto the listener +func (h *Host) Serve() { + terminals := h.listener.ServeTerminal() + + for term := range terminals { + go h.Connect(term) + } +} + +func (h Host) completeName(partial string) string { + names := h.NamesPrefix(partial) + if len(names) == 0 { + // Didn't find anything + return "" + } + + return names[len(names)-1] +} + +func (h Host) completeCommand(partial string) string { + for cmd, _ := range h.commands { + if strings.HasPrefix(cmd, partial) { + return cmd + } + } + return "" +} + +// AutoCompleteFunction returns a callback for terminal autocompletion +func (h *Host) AutoCompleteFunction(u *chat.User) func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { + return func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { + if key != 9 { + return + } + + if strings.HasSuffix(line[:pos], " ") { + // Don't autocomplete spaces. + return + } + + fields := strings.Fields(line[:pos]) + isFirst := len(fields) < 2 + partial := fields[len(fields)-1] + posPartial := pos - len(partial) + + var completed string + if isFirst && strings.HasPrefix(partial, "/") { + // Command + completed = h.completeCommand(partial) + if completed == "/reply" { + replyTo := u.ReplyTo() + if replyTo != nil { + completed = "/msg " + replyTo.Name() + } + } + } else { + // Name + completed = h.completeName(partial) + if completed == "" { + return + } + if isFirst { + completed += ":" + } + } + completed += " " + + // Reposition the cursor + newLine = strings.Replace(line[posPartial:], partial, completed, 1) + newLine = line[:posPartial] + newLine + newPos = pos + (len(completed) - len(partial)) + ok = true + return + } +} + +// GetUser returns a chat.User based on a name. +func (h *Host) GetUser(name string) (*chat.User, bool) { + m, ok := h.MemberById(name) + if !ok { + return nil, false + } + return m.User, true +} + +// InitCommands adds host-specific commands to a Commands container. These will +// override any existing commands. +func (h *Host) InitCommands(c *chat.Commands) { + c.Add(chat.Command{ + Prefix: "/msg", + PrefixHelp: "USER MESSAGE", + Help: "Send MESSAGE to USER.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + args := msg.Args() + switch len(args) { + case 0: + return errors.New("must specify user") + case 1: + return errors.New("must specify message") + } + + target, ok := h.GetUser(args[0]) + if !ok { + return errors.New("user not found") + } + + m := chat.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target) + room.Send(m) + return nil + }, + }) + + c.Add(chat.Command{ + Prefix: "/reply", + PrefixHelp: "MESSAGE", + Help: "Reply with MESSAGE to the previous private message.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + args := msg.Args() + switch len(args) { + case 0: + return errors.New("must specify message") + } + + target := msg.From().ReplyTo() + if target == nil { + return errors.New("no message to reply to") + } + + m := chat.NewPrivateMsg(strings.Join(args, " "), msg.From(), target) + room.Send(m) + return nil + }, + }) + + c.Add(chat.Command{ + Prefix: "/whois", + PrefixHelp: "USER", + Help: "Information about USER.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + args := msg.Args() + if len(args) == 0 { + return errors.New("must specify user") + } + + target, ok := h.GetUser(args[0]) + if !ok { + return errors.New("user not found") + } + + id := target.Identifier.(*Identity) + room.Send(chat.NewSystemMsg(id.Whois(), msg.From())) + + return nil + }, + }) + + // Hidden commands + c.Add(chat.Command{ + Prefix: "/version", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + room.Send(chat.NewSystemMsg(buildCommit, msg.From())) + return nil + }, + }) + + timeStarted := time.Now() + c.Add(chat.Command{ + Prefix: "/uptime", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + room.Send(chat.NewSystemMsg(time.Now().Sub(timeStarted).String(), msg.From())) + return nil + }, + }) + + // Op commands + c.Add(chat.Command{ + Op: true, + Prefix: "/kick", + PrefixHelp: "USER", + Help: "Kick USER from the server.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + if !room.IsOp(msg.From()) { + return errors.New("must be op") + } + + args := msg.Args() + if len(args) == 0 { + return errors.New("must specify user") + } + + target, ok := h.GetUser(args[0]) + if !ok { + return errors.New("user not found") + } + + body := fmt.Sprintf("%s was kicked by %s.", target.Name(), msg.From().Name()) + room.Send(chat.NewAnnounceMsg(body)) + target.Close() + return nil + }, + }) + + c.Add(chat.Command{ + Op: true, + Prefix: "/ban", + PrefixHelp: "USER [DURATION]", + Help: "Ban USER from the server.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + // TODO: Would be nice to specify what to ban. Key? Ip? etc. + if !room.IsOp(msg.From()) { + return errors.New("must be op") + } + + args := msg.Args() + if len(args) == 0 { + return errors.New("must specify user") + } + + target, ok := h.GetUser(args[0]) + if !ok { + return errors.New("user not found") + } + + var until time.Duration = 0 + if len(args) > 1 { + until, _ = time.ParseDuration(args[1]) + } + + id := target.Identifier.(*Identity) + h.auth.Ban(id.PublicKey(), until) + h.auth.BanAddr(id.RemoteAddr(), until) + + body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name()) + room.Send(chat.NewAnnounceMsg(body)) + target.Close() + + logger.Debugf("Banned: \n-> %s", id.Whois()) + + return nil + }, + }) + + c.Add(chat.Command{ + Op: true, + Prefix: "/motd", + PrefixHelp: "MESSAGE", + Help: "Set the MESSAGE of the day.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + if !room.IsOp(msg.From()) { + return errors.New("must be op") + } + + motd := "" + args := msg.Args() + if len(args) > 0 { + motd = strings.Join(args, " ") + } + + h.motd = motd + body := fmt.Sprintf("New message of the day set by %s:", msg.From().Name()) + room.Send(chat.NewAnnounceMsg(body)) + if motd != "" { + room.Send(chat.NewAnnounceMsg(motd)) + } + + return nil + }, + }) + + c.Add(chat.Command{ + Op: true, + Prefix: "/op", + PrefixHelp: "USER [DURATION]", + Help: "Set USER as admin.", + Handler: func(room *chat.Room, msg chat.CommandMsg) error { + if !room.IsOp(msg.From()) { + return errors.New("must be op") + } + + args := msg.Args() + if len(args) == 0 { + return errors.New("must specify user") + } + + var until time.Duration = 0 + if len(args) > 1 { + until, _ = time.ParseDuration(args[1]) + } + + member, ok := room.MemberById(args[0]) + if !ok { + return errors.New("user not found") + } + member.Op = true + id := member.Identifier.(*Identity) + h.auth.Op(id.PublicKey(), until) + + body := fmt.Sprintf("Made op by %s.", msg.From().Name()) + room.Send(chat.NewSystemMsg(body, member.User)) + + return nil + }, + }) +} diff --git a/host_test.go b/host_test.go new file mode 100644 index 0000000..76bbe6b --- /dev/null +++ b/host_test.go @@ -0,0 +1,218 @@ +package main + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "io" + "io/ioutil" + "strings" + "testing" + "time" + + "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" + "golang.org/x/crypto/ssh" +) + +func stripPrompt(s string) string { + pos := strings.LastIndex(s, "\033[K") + if pos < 0 { + return s + } + return s[pos+3:] +} + +func TestHostGetPrompt(t *testing.T) { + var expected, actual string + + u := chat.NewUser(&Identity{nil, "foo"}) + u.SetColorIdx(2) + + actual = GetPrompt(u) + expected = "[foo] " + if actual != expected { + t.Errorf("Got: %q; Expected: %q", actual, expected) + } + + u.Config.Theme = &chat.Themes[0] + actual = GetPrompt(u) + expected = "[\033[38;05;2mfoo\033[0m] " + if actual != expected { + t.Errorf("Got: %q; Expected: %q", actual, expected) + } +} + +func TestHostNameCollision(t *testing.T) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + config := sshd.MakeNoAuth() + config.AddHostKey(key) + + s, err := sshd.ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + host := NewHost(s) + go host.Serve() + + done := make(chan struct{}, 1) + + // First client + go func() { + err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + scanner := bufio.NewScanner(r) + + // Consume the initial buffer + scanner.Scan() + actual := scanner.Text() + if !strings.HasPrefix(actual, "[foo] ") { + t.Errorf("First client failed to get 'foo' name.") + } + + actual = stripPrompt(actual) + expected := " * foo joined. (Connected: 1)" + if actual != expected { + t.Errorf("Got %q; expected %q", actual, expected) + } + + // Ready for second client + done <- struct{}{} + + scanner.Scan() + actual = stripPrompt(scanner.Text()) + expected = " * Guest1 joined. (Connected: 2)" + if actual != expected { + t.Errorf("Got %q; expected %q", actual, expected) + } + + // Wrap it up. + close(done) + }) + if err != nil { + t.Fatal(err) + } + }() + + // Wait for first client + <-done + + // Second client + err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + scanner := bufio.NewScanner(r) + + // Consume the initial buffer + scanner.Scan() + actual := scanner.Text() + if !strings.HasPrefix(actual, "[Guest1] ") { + t.Errorf("Second client did not get Guest1 name.") + } + }) + if err != nil { + t.Fatal(err) + } + + <-done +} + +func TestHostWhitelist(t *testing.T) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + + auth := NewAuth() + config := sshd.MakeAuth(auth) + config.AddHostKey(key) + + s, err := sshd.ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + host := NewHost(s) + host.auth = auth + go host.Serve() + + target := s.Addr().String() + + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + if err != nil { + t.Error(err) + } + + clientkey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err) + } + + clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) + auth.Whitelist(clientpubkey, 0) + + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + if err == nil { + t.Error("Failed to block unwhitelisted connection.") + } +} + +func TestHostKick(t *testing.T) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + + auth := NewAuth() + config := sshd.MakeAuth(auth) + config.AddHostKey(key) + + s, err := sshd.ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + addr := s.Addr().String() + host := NewHost(s) + go host.Serve() + + connected := make(chan struct{}) + done := make(chan struct{}) + + go func() { + // First client + err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) { + // Make op + member, _ := host.Room.MemberById("foo") + member.Op = true + + // Block until second client is here + connected <- struct{}{} + w.Write([]byte("/kick bar\r\n")) + }) + if err != nil { + t.Fatal(err) + } + }() + + go func() { + // Second client + err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) { + <-connected + + // Consume while we're connected. Should break when kicked. + ioutil.ReadAll(r) + }) + if err != nil { + t.Fatal(err) + } + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second * 1): + t.Fatal("Timeout.") + } +} diff --git a/identity.go b/identity.go new file mode 100644 index 0000000..bfd46fa --- /dev/null +++ b/identity.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "net" + + "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" +) + +// Identity is a container for everything that identifies a client. +type Identity struct { + sshd.Connection + id string +} + +// NewIdentity returns a new identity object from an sshd.Connection. +func NewIdentity(conn sshd.Connection) *Identity { + return &Identity{ + Connection: conn, + id: chat.SanitizeName(conn.Name()), + } +} + +func (i Identity) Id() string { + return i.id +} + +func (i *Identity) SetId(id string) { + i.id = id +} + +func (i *Identity) SetName(name string) { + i.SetId(name) +} + +func (i Identity) Name() string { + return i.id +} + +func (i Identity) Whois() string { + ip, _, _ := net.SplitHostPort(i.RemoteAddr().String()) + fingerprint := "(no public key)" + if i.PublicKey() != nil { + fingerprint = sshd.Fingerprint(i.PublicKey()) + } + return fmt.Sprintf("name: %s"+chat.Newline+ + " > ip: %s"+chat.Newline+ + " > fingerprint: %s", i.Name(), ip, fingerprint) +} diff --git a/key.go b/key.go new file mode 100644 index 0000000..0135e1b --- /dev/null +++ b/key.go @@ -0,0 +1,49 @@ +package main + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + + "code.google.com/p/gopass" +) + +// ReadPrivateKey attempts to read your private key and possibly decrypt it if it +// requires a passphrase. +// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`), +// is not set. +func ReadPrivateKey(path string) ([]byte, error) { + privateKey, err := ioutil.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to load identity: %v", err) + } + + block, rest := pem.Decode(privateKey) + if len(rest) > 0 { + return nil, fmt.Errorf("extra data when decoding private key") + } + if !x509.IsEncryptedPEMBlock(block) { + return privateKey, nil + } + + passphrase := os.Getenv("IDENTITY_PASSPHRASE") + if passphrase == "" { + passphrase, err = gopass.GetPass("Enter passphrase: ") + if err != nil { + return nil, fmt.Errorf("couldn't read passphrase: %v", err) + } + } + der, err := x509.DecryptPEMBlock(block, []byte(passphrase)) + if err != nil { + return nil, fmt.Errorf("decrypt failed: %v", err) + } + + privateKey = pem.EncodeToMemory(&pem.Block{ + Type: block.Type, + Bytes: der, + }) + + return privateKey, nil +} diff --git a/logger.go b/logger.go index 8fe9842..4fabd05 100644 --- a/logger.go +++ b/logger.go @@ -1,7 +1,16 @@ package main import ( + "bytes" + + "github.com/alexcesaro/log" "github.com/alexcesaro/log/golog" ) var logger *golog.Logger + +func init() { + // Set a default null logger + var b bytes.Buffer + logger = golog.New(&b, log.Debug) +} diff --git a/server.go b/server.go deleted file mode 100644 index c5f4127..0000000 --- a/server.go +++ /dev/null @@ -1,519 +0,0 @@ -package main - -import ( - "bufio" - "crypto/md5" - "encoding/base64" - "fmt" - "net" - "net/http" - "regexp" - "strings" - "sync" - "syscall" - "time" - - "golang.org/x/crypto/ssh" -) - -const ( - maxNameLength = 32 - historyLength = 20 - systemMessageFormat = "\033[1;90m" - privateMessageFormat = "\033[1m" - highlightFormat = Bold + "\033[48;5;11m\033[38;5;16m" - beep = "\007" -) - -var ( - reStripText = regexp.MustCompile("[^0-9A-Za-z_.-]") -) - -// Clients is a map of clients -type Clients map[string]*Client - -// Server holds all the fields used by a server -type Server struct { - sshConfig *ssh.ServerConfig - done chan struct{} - clients Clients - count int - history *History - motd string - whitelist map[string]struct{} // fingerprint lookup - admins map[string]struct{} // fingerprint lookup - bannedPK map[string]*time.Time // fingerprint lookup - started time.Time - sync.RWMutex -} - -// NewServer constructs a new server -func NewServer(privateKey []byte) (*Server, error) { - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - server := Server{ - done: make(chan struct{}), - clients: Clients{}, - count: 0, - history: NewHistory(historyLength), - motd: "", - whitelist: map[string]struct{}{}, - admins: map[string]struct{}{}, - bannedPK: map[string]*time.Time{}, - started: time.Now(), - } - - config := ssh.ServerConfig{ - NoClientAuth: false, - // Auth-related things should be constant-time to avoid timing attacks. - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - fingerprint := Fingerprint(key) - if server.IsBanned(fingerprint) { - return nil, fmt.Errorf("Banned.") - } - if !server.IsWhitelisted(fingerprint) { - return nil, fmt.Errorf("Not Whitelisted.") - } - perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}} - return perm, nil - }, - KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { - if server.IsBanned("") { - return nil, fmt.Errorf("Interactive login disabled.") - } - if !server.IsWhitelisted("") { - return nil, fmt.Errorf("Not Whitelisted.") - } - return nil, nil - }, - } - config.AddHostKey(signer) - - server.sshConfig = &config - - return &server, nil -} - -// Len returns the number of clients -func (s *Server) Len() int { - return len(s.clients) -} - -// SysMsg broadcasts the given message to everyone -func (s *Server) SysMsg(msg string, args ...interface{}) { - s.Broadcast(ContinuousFormat(systemMessageFormat, " * "+fmt.Sprintf(msg, args...)), nil) -} - -// Broadcast broadcasts the given message to everyone except for the given client -func (s *Server) Broadcast(msg string, except *Client) { - logger.Debugf("Broadcast to %d: %s", s.Len(), msg) - s.history.Add(msg) - - s.RLock() - defer s.RUnlock() - - for _, client := range s.clients { - if except != nil && client == except { - continue - } - - if strings.Contains(msg, client.Name) { - // Turn message red if client's name is mentioned, and send BEL if they have enabled beeping - personalMsg := strings.Replace(msg, client.Name, highlightFormat+client.Name+Reset, -1) - if client.beepMe { - personalMsg += beep - } - client.Send(personalMsg) - } else { - client.Send(msg) - } - } -} - -// Privmsg sends a message to a particular nick, if it exists -func (s *Server) Privmsg(nick, message string, sender *Client) error { - // Get the recipient - target, ok := s.clients[strings.ToLower(nick)] - if !ok { - return fmt.Errorf("no client with that nick") - } - // Send the message - target.Msg <- fmt.Sprintf(beep+"[PM from %v] %s%v%s", sender.ColoredName(), privateMessageFormat, message, Reset) - logger.Debugf("PM from %v to %v: %v", sender.Name, nick, message) - return nil -} - -// SetMotd sets the Message of the Day (MOTD) -func (s *Server) SetMotd(motd string) { - s.motd = motd -} - -// MotdUnicast sends the MOTD as a SysMsg -func (s *Server) MotdUnicast(client *Client) { - if s.motd == "" { - return - } - client.SysMsg(s.motd) -} - -// MotdBroadcast broadcasts the MOTD -func (s *Server) MotdBroadcast(client *Client) { - if s.motd == "" { - return - } - s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * New MOTD set by %s.", client.ColoredName())), client) - s.Broadcast(s.motd, client) -} - -// Add adds the client to the list of clients -func (s *Server) Add(client *Client) { - go func() { - s.MotdUnicast(client) - client.SendLines(s.history.Get(10)) - }() - - s.Lock() - s.count++ - - newName, err := s.proposeName(client.Name) - if err != nil { - client.SysMsg("Your name '%s' is not available, renamed to '%s'. Use /nick to change it.", client.Name, ColorString(client.Color, newName)) - } - - client.Rename(newName) - s.clients[strings.ToLower(client.Name)] = client - num := len(s.clients) - s.Unlock() - - s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * %s joined. (Total connected: %d)", client.Name, num)), client) -} - -// Remove removes the given client from the list of clients -func (s *Server) Remove(client *Client) { - s.Lock() - delete(s.clients, strings.ToLower(client.Name)) - s.Unlock() - - s.SysMsg("%s left.", client.Name) -} - -func (s *Server) proposeName(name string) (string, error) { - // Assumes caller holds lock. - var err error - name = reStripText.ReplaceAllString(name, "") - - if len(name) > maxNameLength { - name = name[:maxNameLength] - } else if len(name) == 0 { - name = fmt.Sprintf("Guest%d", s.count) - } - - _, collision := s.clients[strings.ToLower(name)] - if collision { - err = fmt.Errorf("%s is not available", name) - name = fmt.Sprintf("Guest%d", s.count) - } - - return name, err -} - -// Rename renames the given client (user) -func (s *Server) Rename(client *Client, newName string) { - var oldName string - if strings.ToLower(newName) == strings.ToLower(client.Name) { - oldName = client.Name - client.Rename(newName) - } else { - s.Lock() - newName, err := s.proposeName(newName) - if err != nil { - client.SysMsg("%s", err) - s.Unlock() - return - } - - // TODO: Use a channel/goroutine for adding clients, rather than locks? - delete(s.clients, strings.ToLower(client.Name)) - oldName = client.Name - client.Rename(newName) - s.clients[strings.ToLower(client.Name)] = client - s.Unlock() - } - s.SysMsg("%s is now known as %s.", ColorString(client.Color, oldName), ColorString(client.Color, client.Name)) -} - -// List lists the clients with the given prefix -func (s *Server) List(prefix *string) []string { - r := []string{} - - s.RLock() - defer s.RUnlock() - - for name, client := range s.clients { - if prefix != nil && !strings.HasPrefix(name, strings.ToLower(*prefix)) { - continue - } - r = append(r, client.Name) - } - - return r -} - -// Who returns the client with a given name -func (s *Server) Who(name string) *Client { - return s.clients[strings.ToLower(name)] -} - -// Op adds the given fingerprint to the list of admins -func (s *Server) Op(fingerprint string) { - logger.Infof("Adding admin: %s", fingerprint) - s.Lock() - s.admins[fingerprint] = struct{}{} - s.Unlock() -} - -// Whitelist adds the given fingerprint to the whitelist -func (s *Server) Whitelist(fingerprint string) error { - if fingerprint == "" { - return fmt.Errorf("Invalid fingerprint.") - } - if strings.HasPrefix(fingerprint, "github.com/") { - return s.whitelistIdentityURL(fingerprint) - } - - return s.whitelistFingerprint(fingerprint) -} - -func (s *Server) whitelistIdentityURL(user string) error { - logger.Infof("Adding github account %s to whitelist", user) - - user = strings.Replace(user, "github.com/", "", -1) - keys, err := getGithubPubKeys(user) - if err != nil { - return err - } - if len(keys) == 0 { - return fmt.Errorf("No keys for github user %s", user) - } - for _, key := range keys { - fingerprint := Fingerprint(key) - s.whitelistFingerprint(fingerprint) - } - return nil -} - -func (s *Server) whitelistFingerprint(fingerprint string) error { - logger.Infof("Adding whitelist: %s", fingerprint) - s.Lock() - s.whitelist[fingerprint] = struct{}{} - s.Unlock() - return nil -} - -// Client for getting github pub keys -var client = http.Client{ - Timeout: time.Duration(10 * time.Second), -} - -// Returns an array of public keys for the given github user URL -func getGithubPubKeys(user string) ([]ssh.PublicKey, error) { - resp, err := client.Get("http://github.com/" + user + ".keys") - if err != nil { - return nil, err - } - defer resp.Body.Close() - - pubs := []ssh.PublicKey{} - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - text := scanner.Text() - if text == "Not Found" { - continue - } - - splitKey := strings.SplitN(text, " ", -1) - - // In case of malformated key - if len(splitKey) < 2 { - continue - } - - bodyDecoded, err := base64.StdEncoding.DecodeString(splitKey[1]) - if err != nil { - return nil, err - } - - pub, err := ssh.ParsePublicKey(bodyDecoded) - if err != nil { - return nil, err - } - - pubs = append(pubs, pub) - } - return pubs, nil -} - -// Uptime returns the time since the server was started -func (s *Server) Uptime() string { - return time.Now().Sub(s.started).String() -} - -// IsOp checks if the given client is Op -func (s *Server) IsOp(client *Client) bool { - fingerprint := client.Fingerprint() - if fingerprint == "" { - return false - } - _, r := s.admins[client.Fingerprint()] - return r -} - -// IsWhitelisted checks if the given fingerprint is whitelisted -func (s *Server) IsWhitelisted(fingerprint string) bool { - /* if no whitelist, anyone is welcome */ - if len(s.whitelist) == 0 { - return true - } - - /* otherwise, check for whitelist presence */ - _, r := s.whitelist[fingerprint] - return r -} - -// IsBanned checks if the given fingerprint is banned -func (s *Server) IsBanned(fingerprint string) bool { - ban, hasBan := s.bannedPK[fingerprint] - if !hasBan { - return false - } - if ban == nil { - return true - } - if ban.Before(time.Now()) { - s.Unban(fingerprint) - return false - } - return true -} - -// Ban bans a fingerprint for the given duration -func (s *Server) Ban(fingerprint string, duration *time.Duration) { - var until *time.Time - s.Lock() - if duration != nil { - when := time.Now().Add(*duration) - until = &when - } - s.bannedPK[fingerprint] = until - s.Unlock() -} - -// Unban unbans a banned fingerprint -func (s *Server) Unban(fingerprint string) { - s.Lock() - delete(s.bannedPK, fingerprint) - s.Unlock() -} - -// Start starts the server -func (s *Server) Start(laddr string) error { - // Once a ServerConfig has been configured, connections can be - // accepted. - socket, err := net.Listen("tcp", laddr) - if err != nil { - return err - } - - logger.Infof("Listening on %s", laddr) - - go func() { - defer socket.Close() - for { - conn, err := socket.Accept() - - if err != nil { - logger.Errorf("Failed to accept connection: %v", err) - if err == syscall.EINVAL { - // TODO: Handle shutdown more gracefully? - return - } - } - - // Goroutineify to resume accepting sockets early. - go func() { - // From a standard TCP connection to an encrypted SSH connection - sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) - if err != nil { - logger.Errorf("Failed to handshake: %v", err) - return - } - - version := reStripText.ReplaceAllString(string(sshConn.ClientVersion()), "") - if len(version) > 100 { - version = "Evil Jerk with a superlong string" - } - logger.Infof("Connection #%d from: %s, %s, %s", s.count+1, sshConn.RemoteAddr(), sshConn.User(), version) - - go ssh.DiscardRequests(requests) - - client := NewClient(s, sshConn) - go client.handleChannels(channels) - }() - } - }() - - go func() { - <-s.done - socket.Close() - }() - - return nil -} - -// AutoCompleteFunction handles auto completion of nicks -func (s *Server) AutoCompleteFunction(line string, pos int, key rune) (newLine string, newPos int, ok bool) { - if key == 9 { - shortLine := strings.Split(line[:pos], " ") - partialNick := shortLine[len(shortLine)-1] - - nicks := s.List(&partialNick) - if len(nicks) > 0 { - nick := nicks[len(nicks)-1] - posPartialNick := pos - len(partialNick) - if len(shortLine) < 2 { - nick += ": " - } else { - nick += " " - } - newLine = strings.Replace(line[posPartialNick:], - partialNick, nick, 1) - newLine = line[:posPartialNick] + newLine - newPos = pos + (len(nick) - len(partialNick)) - ok = true - } - } else { - ok = false - } - return -} - -// Stop stops the server -func (s *Server) Stop() { - s.Lock() - for _, client := range s.clients { - client.Conn.Close() - } - s.Unlock() - - close(s.done) -} - -// Fingerprint returns the fingerprint based on a public key -func Fingerprint(k ssh.PublicKey) string { - hash := md5.Sum(k.Marshal()) - r := fmt.Sprintf("% x", hash) - return strings.Replace(r, " ", ":", -1) -} diff --git a/set.go b/set.go new file mode 100644 index 0000000..86afe13 --- /dev/null +++ b/set.go @@ -0,0 +1,70 @@ +package main + +import ( + "sync" + "time" +) + +type expiringValue struct { + time.Time +} + +func (v expiringValue) Bool() bool { + return time.Now().Before(v.Time) +} + +type value struct{} + +func (v value) Bool() bool { + return true +} + +type SetValue interface { + Bool() bool +} + +// Set with expire-able keys +type Set struct { + lookup map[string]SetValue + sync.Mutex +} + +// NewSet creates a new set. +func NewSet() *Set { + return &Set{ + lookup: map[string]SetValue{}, + } +} + +// Len returns the size of the set right now. +func (s *Set) Len() int { + return len(s.lookup) +} + +// In checks if an item exists in this set. +func (s *Set) In(key string) bool { + s.Lock() + v, ok := s.lookup[key] + if ok && !v.Bool() { + ok = false + delete(s.lookup, key) + } + s.Unlock() + return ok +} + +// Add item to this set, replace if it exists. +func (s *Set) Add(key string) { + s.Lock() + s.lookup[key] = value{} + s.Unlock() +} + +// Add item to this set, replace if it exists. +func (s *Set) AddExpiring(key string, d time.Duration) time.Time { + until := time.Now().Add(d) + s.Lock() + s.lookup[key] = expiringValue{until} + s.Unlock() + return until +} diff --git a/set_test.go b/set_test.go new file mode 100644 index 0000000..0a4b9ea --- /dev/null +++ b/set_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "testing" + "time" +) + +func TestSetExpiring(t *testing.T) { + s := NewSet() + if s.In("foo") { + t.Error("Matched before set.") + } + + s.Add("foo") + if !s.In("foo") { + t.Errorf("Not matched after set") + } + if s.Len() != 1 { + t.Error("Not len 1 after set") + } + + v := expiringValue{time.Now().Add(-time.Nanosecond * 1)} + if v.Bool() { + t.Errorf("expiringValue now is not expiring") + } + + v = expiringValue{time.Now().Add(time.Minute * 2)} + if !v.Bool() { + t.Errorf("expiringValue in 2 minutes is expiring now") + } + + until := s.AddExpiring("bar", time.Minute*2) + if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) { + t.Errorf("until is not a minute after %s: %s", time.Now(), until) + } + val, ok := s.lookup["bar"] + if !ok { + t.Errorf("bar not in lookup") + } + if !until.Equal(val.(expiringValue).Time) { + t.Errorf("bar's until is not equal to the expected value") + } + if !val.Bool() { + t.Errorf("bar expired immediately") + } + + if !s.In("bar") { + t.Errorf("Not matched after timed set") + } + if s.Len() != 2 { + t.Error("Not len 2 after set") + } + + s.AddExpiring("bar", time.Nanosecond*1) + if s.In("bar") { + t.Error("Matched after expired timer") + } +} diff --git a/sshd/auth.go b/sshd/auth.go new file mode 100644 index 0000000..163caa0 --- /dev/null +++ b/sshd/auth.go @@ -0,0 +1,72 @@ +package sshd + +import ( + "crypto/sha256" + "encoding/base64" + "errors" + "net" + + "golang.org/x/crypto/ssh" +) + +// Auth is used to authenticate connections based on public keys. +type Auth interface { + // Whether to allow connections without a public key. + AllowAnonymous() bool + // Given address and public key, return if the connection should be permitted. + Check(net.Addr, ssh.PublicKey) (bool, error) +} + +// MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation. +func MakeAuth(auth Auth) *ssh.ServerConfig { + config := ssh.ServerConfig{ + NoClientAuth: false, + // Auth-related things should be constant-time to avoid timing attacks. + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + ok, err := auth.Check(conn.RemoteAddr(), key) + if !ok { + return nil, err + } + perm := &ssh.Permissions{Extensions: map[string]string{ + "pubkey": string(key.Marshal()), + }} + return perm, nil + }, + KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + if !auth.AllowAnonymous() { + return nil, errors.New("public key authentication required") + } + _, err := auth.Check(conn.RemoteAddr(), nil) + return nil, err + }, + } + + return &config +} + +// MakeNoAuth makes a simple ssh.ServerConfig which allows all connections. +// Primarily used for testing. +func MakeNoAuth() *ssh.ServerConfig { + config := ssh.ServerConfig{ + NoClientAuth: false, + // Auth-related things should be constant-time to avoid timing attacks. + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + perm := &ssh.Permissions{Extensions: map[string]string{ + "pubkey": string(key.Marshal()), + }} + return perm, nil + }, + KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + return nil, nil + }, + } + + return &config +} + +// Fingerprint performs a SHA256 BASE64 fingerprint of the PublicKey, similar to OpenSSH. +// See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac +func Fingerprint(k ssh.PublicKey) string { + hash := sha256.Sum256(k.Marshal()) + return base64.StdEncoding.EncodeToString(hash[:]) +} diff --git a/sshd/client.go b/sshd/client.go new file mode 100644 index 0000000..13d5dea --- /dev/null +++ b/sshd/client.go @@ -0,0 +1,72 @@ +package sshd + +import ( + "crypto/rand" + "crypto/rsa" + "io" + + "golang.org/x/crypto/ssh" +) + +// NewRandomSigner generates a random key of a desired bit length. +func NewRandomSigner(bits int) (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial. +func NewClientConfig(name string) *ssh.ClientConfig { + return &ssh.ClientConfig{ + User: name, + Auth: []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return + }), + }, + } +} + +// ConnectShell makes a barebones SSH client session, used for testing. +func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { + config := NewClientConfig(name) + conn, err := ssh.Dial("tcp", host, config) + if err != nil { + return err + } + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + return err + } + defer session.Close() + + in, err := session.StdinPipe() + if err != nil { + return err + } + + out, err := session.StdoutPipe() + if err != nil { + return err + } + + /* FIXME: Do we want to request a PTY? + err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) + if err != nil { + return err + } + */ + + err = session.Shell() + if err != nil { + return err + } + + handler(out, in) + + return nil +} diff --git a/sshd/client_test.go b/sshd/client_test.go new file mode 100644 index 0000000..651c67e --- /dev/null +++ b/sshd/client_test.go @@ -0,0 +1,46 @@ +package sshd + +import ( + "errors" + "net" + "testing" + + "golang.org/x/crypto/ssh" +) + +var errRejectAuth = errors.New("not welcome here") + +type RejectAuth struct{} + +func (a RejectAuth) AllowAnonymous() bool { + return false +} +func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) { + return false, errRejectAuth +} + +func consume(ch <-chan *Terminal) { + for _ = range ch { + } +} + +func TestClientReject(t *testing.T) { + signer, err := NewRandomSigner(512) + config := MakeAuth(RejectAuth{}) + config.AddHostKey(signer) + + s, err := ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + go consume(s.ServeTerminal()) + + conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo")) + if err == nil { + defer conn.Close() + t.Error("Failed to reject conncetion") + } + t.Log(err) +} diff --git a/sshd/doc.go b/sshd/doc.go new file mode 100644 index 0000000..21cd914 --- /dev/null +++ b/sshd/doc.go @@ -0,0 +1,34 @@ +package sshd + +/* + + signer, err := ssh.ParsePrivateKey(privateKey) + + config := MakeNoAuth() + config.AddHostKey(signer) + + s, err := ListenSSH("0.0.0.0:2022", config) + if err != nil { + // Handle opening socket error + } + defer s.Close() + + terminals := s.ServeTerminal() + + for term := range terminals { + go func() { + defer term.Close() + term.SetPrompt("...") + term.AutoCompleteCallback = nil // ... + + for { + line, err := term.ReadLine() + if err != nil { + break + } + term.Write(...) + } + + }() + } +*/ diff --git a/sshd/logger.go b/sshd/logger.go new file mode 100644 index 0000000..9f6998f --- /dev/null +++ b/sshd/logger.go @@ -0,0 +1,22 @@ +package sshd + +import "io" +import stdlog "log" + +var logger *stdlog.Logger + +func SetLogger(w io.Writer) { + flags := stdlog.Flags() + prefix := "[sshd] " + logger = stdlog.New(w, prefix, flags) +} + +type nullWriter struct{} + +func (nullWriter) Write(data []byte) (int, error) { + return len(data), nil +} + +func init() { + SetLogger(nullWriter{}) +} diff --git a/sshd/net.go b/sshd/net.go new file mode 100644 index 0000000..69a30da --- /dev/null +++ b/sshd/net.go @@ -0,0 +1,74 @@ +package sshd + +import ( + "net" + "time" + + "github.com/shazow/rateio" + "golang.org/x/crypto/ssh" +) + +// Container for the connection and ssh-related configuration +type SSHListener struct { + net.Listener + config *ssh.ServerConfig + RateLimit bool +} + +// Make an SSH listener socket +func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { + socket, err := net.Listen("tcp", laddr) + if err != nil { + return nil, err + } + l := SSHListener{Listener: socket, config: config} + return &l, nil +} + +func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { + if l.RateLimit { + // TODO: Configurable Limiter? + conn = ReadLimitConn(conn, rateio.NewGracefulLimiter(1024*10, time.Minute*2, time.Second*3)) + } + + // Upgrade TCP connection to SSH connection + sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config) + if err != nil { + return nil, err + } + + // FIXME: Disconnect if too many faulty requests? (Avoid DoS.) + go ssh.DiscardRequests(requests) + return NewSession(sshConn, channels) +} + +// Accept incoming connections as terminal requests and yield them +func (l *SSHListener) ServeTerminal() <-chan *Terminal { + ch := make(chan *Terminal) + + go func() { + defer l.Close() + defer close(ch) + + for { + conn, err := l.Accept() + + if err != nil { + logger.Printf("Failed to accept connection: %v", err) + return + } + + // Goroutineify to resume accepting sockets early + go func() { + term, err := l.handleConn(conn) + if err != nil { + logger.Printf("Failed to handshake: %v", err) + return + } + ch <- term + }() + } + }() + + return ch +} diff --git a/sshd/net_test.go b/sshd/net_test.go new file mode 100644 index 0000000..c250525 --- /dev/null +++ b/sshd/net_test.go @@ -0,0 +1,81 @@ +package sshd + +import ( + "bytes" + "io" + "testing" +) + +func TestServerInit(t *testing.T) { + config := MakeNoAuth() + s, err := ListenSSH(":badport", config) + if err == nil { + t.Fatal("should fail on bad port") + } + + s, err = ListenSSH(":0", config) + if err != nil { + t.Error(err) + } + + err = s.Close() + if err != nil { + t.Error(err) + } +} + +func TestServeTerminals(t *testing.T) { + signer, err := NewRandomSigner(512) + config := MakeNoAuth() + config.AddHostKey(signer) + + s, err := ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + + terminals := s.ServeTerminal() + + go func() { + // Accept one terminal, read from it, echo back, close. + term := <-terminals + term.SetPrompt("> ") + + line, err := term.ReadLine() + if err != nil { + t.Error(err) + } + _, err = term.Write([]byte("echo: " + line + "\r\n")) + if err != nil { + t.Error(err) + } + + term.Close() + }() + + host := s.Addr().String() + name := "foo" + + err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) { + // Consume if there is anything + buf := new(bytes.Buffer) + w.Write([]byte("hello\r\n")) + + buf.Reset() + _, err := io.Copy(buf, r) + if err != nil { + t.Error(err) + } + + expected := "> hello\r\necho: hello\r\n" + actual := buf.String() + if actual != expected { + t.Errorf("Got %q; expected %q", actual, expected) + } + s.Close() + }) + + if err != nil { + t.Fatal(err) + } +} diff --git a/pty.go b/sshd/pty.go similarity index 94% rename from pty.go rename to sshd/pty.go index e635fba..06d34f0 100644 --- a/pty.go +++ b/sshd/pty.go @@ -1,8 +1,9 @@ -// Borrowed from go.crypto circa 2011 -package main +package sshd import "encoding/binary" +// Helpers below are borrowed from go.crypto circa 2011: + // parsePtyRequest parses the payload of the pty-req message and extracts the // dimensions of the terminal. See RFC 4254, section 6.2. func parsePtyRequest(s []byte) (width, height int, ok bool) { diff --git a/sshd/ratelimit.go b/sshd/ratelimit.go new file mode 100644 index 0000000..c80f0ac --- /dev/null +++ b/sshd/ratelimit.go @@ -0,0 +1,25 @@ +package sshd + +import ( + "io" + "net" + + "github.com/shazow/rateio" +) + +type limitedConn struct { + net.Conn + io.Reader // Our rate-limited io.Reader for net.Conn +} + +func (r *limitedConn) Read(p []byte) (n int, err error) { + return r.Reader.Read(p) +} + +// ReadLimitConn returns a net.Conn whose io.Reader interface is rate-limited by limiter. +func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn { + return &limitedConn{ + Conn: conn, + Reader: rateio.NewReader(conn, limiter), + } +} diff --git a/sshd/terminal.go b/sshd/terminal.go new file mode 100644 index 0000000..3da2d65 --- /dev/null +++ b/sshd/terminal.go @@ -0,0 +1,144 @@ +package sshd + +import ( + "errors" + "fmt" + "net" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +// Connection is an interface with fields necessary to operate an sshd host. +type Connection interface { + PublicKey() ssh.PublicKey + RemoteAddr() net.Addr + Name() string + Close() error +} + +type sshConn struct { + *ssh.ServerConn +} + +func (c sshConn) PublicKey() ssh.PublicKey { + if c.Permissions == nil { + return nil + } + + s, ok := c.Permissions.Extensions["pubkey"] + if !ok { + return nil + } + + key, err := ssh.ParsePublicKey([]byte(s)) + if err != nil { + return nil + } + + return key +} + +func (c sshConn) Name() string { + return c.User() +} + +// Extending ssh/terminal to include a closer interface +type Terminal struct { + terminal.Terminal + Conn Connection + Channel ssh.Channel +} + +// Make new terminal from a session channel +func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { + if ch.ChannelType() != "session" { + return nil, errors.New("terminal requires session channel") + } + channel, requests, err := ch.Accept() + if err != nil { + return nil, err + } + term := Terminal{ + *terminal.NewTerminal(channel, "Connecting..."), + sshConn{conn}, + channel, + } + + go term.listen(requests) + go func() { + // FIXME: Is this necessary? + conn.Wait() + channel.Close() + }() + + return &term, nil +} + +// Find session channel and make a Terminal from it +func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { + for ch := range channels { + if t := ch.ChannelType(); t != "session" { + ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + continue + } + + term, err = NewTerminal(conn, ch) + if err == nil { + break + } + } + + if term != nil { + // Reject the rest. + // FIXME: Do we need this? + go func() { + for ch := range channels { + ch.Reject(ssh.Prohibited, "only one session allowed") + } + }() + } + + return term, err +} + +// Close terminal and ssh connection +func (t *Terminal) Close() error { + return t.Conn.Close() +} + +// Negotiate terminal type and settings +func (t *Terminal) listen(requests <-chan *ssh.Request) { + hasShell := false + + for req := range requests { + var width, height int + var ok bool + + switch req.Type { + case "shell": + if !hasShell { + ok = true + hasShell = true + } + case "pty-req": + width, height, ok = parsePtyRequest(req.Payload) + if ok { + // TODO: Hardcode width to 100000? + err := t.SetSize(width, height) + ok = err == nil + } + case "window-change": + width, height, ok = parseWinchRequest(req.Payload) + if ok { + // TODO: Hardcode width to 100000? + err := t.SetSize(width, height) + ok = err == nil + } + } + + if req.WantReply { + req.Reply(ok, nil) + } + } +}