diff --git a/chat/channel.go b/chat/channel.go index 3bdca40..b4de26e 100644 --- a/chat/channel.go +++ b/chat/channel.go @@ -13,12 +13,17 @@ const channelBuffer = 10 // closed. var ErrChannelClosed = errors.New("channel closed") +// Member is a User with per-Channel metadata attached to it. +type Member struct { + *User + Op bool +} + // Channel definition, also a Set of User Items type Channel struct { topic string history *History - users *Set - ops *Set + members *Set broadcast chan Message commands Commands closed bool @@ -32,8 +37,7 @@ func NewChannel() *Channel { return &Channel{ broadcast: broadcast, history: NewHistory(historyLen), - users: NewSet(), - ops: NewSet(), + members: NewSet(), commands: *defaultCommands, } } @@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) { func (ch *Channel) Close() { ch.closeOnce.Do(func() { ch.closed = true - ch.users.Each(func(u Item) { + ch.members.Each(func(u Item) { u.(*User).Close() }) - ch.users.Clear() + ch.members.Clear() close(ch.broadcast) }) } @@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) { skipUser = fromMsg.From() } - ch.users.Each(func(u Item) { - user := u.(*User) + ch.members.Each(func(u Item) { + user := u.(*Member).User if skip && skipUser == user { // Skip return @@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error { if ch.closed { return ErrChannelClosed } - err := ch.users.Add(u) + err := ch.members.Add(&Member{u, false}) if err != nil { return err } - s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len()) + s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len()) ch.Send(NewAnnounceMsg(s)) return nil } -// Leave the channel as a user, will announce. +// Leave the channel as a user, will announce. Mostly used during setup. func (ch *Channel) Leave(u *User) error { - err := ch.users.Remove(u) + err := ch.members.Remove(u) if err != nil { return err } @@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error { return nil } +// Member returns a corresponding Member object to a User if the Member is +// present in this channel. +func (ch *Channel) Member(u *User) (*Member, bool) { + m, err := ch.members.Get(u.Id()) + if err != nil { + return nil, false + } + // Check that it's the same user + if m.(*Member).User != u { + return nil, false + } + return m.(*Member), true +} + +// IsOp returns whether a user is an operator in this channel. +func (ch *Channel) IsOp(u *User) bool { + m, ok := ch.Member(u) + return ok && m.Op +} + // Topic of the channel. func (ch *Channel) Topic() string { return ch.topic @@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) { // NamesPrefix lists all members' names with a given prefix, used to query // for autocompletion purposes. func (ch *Channel) NamesPrefix(prefix string) []string { - users := ch.users.ListPrefix(prefix) - names := make([]string, len(users)) - for i, u := range users { + members := ch.members.ListPrefix(prefix) + names := make([]string, len(members)) + for i, u := range members { names[i] = u.(*User).Name() } return names diff --git a/chat/command.go b/chat/command.go index 50d50e5..431ecb9 100644 --- a/chat/command.go +++ b/chat/command.go @@ -98,9 +98,8 @@ func init() { c.Add(Command{ Prefix: "/help", Handler: func(channel *Channel, msg CommandMsg) error { - user := msg.From() - op := channel.ops.In(user) - channel.Send(NewSystemMsg(channel.commands.Help(op), user)) + op := channel.IsOp(msg.From()) + channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From())) return nil }, }) @@ -193,11 +192,12 @@ func init() { }) c.Add(Command{ + Op: true, Prefix: "/op", PrefixHelp: "USER", Help: "Mark user as admin.", Handler: func(channel *Channel, msg CommandMsg) error { - if !channel.ops.In(msg.From()) { + if !channel.IsOp(msg.From()) { return errors.New("must be op") } @@ -206,13 +206,14 @@ func init() { return errors.New("must specify user") } - // TODO: Add support for fingerprint-based op'ing. - user, err := channel.users.Get(Id(args[0])) + // TODO: Add support for fingerprint-based op'ing. This will + // probably need to live in host land. + member, err := channel.members.Get(Id(args[0])) if err != nil { return errors.New("user not found") } - channel.ops.Add(user) + member.(*Member).Op = true return nil }, }) diff --git a/cmd.go b/cmd.go index 8c8a576..e7bfc68 100644 --- a/cmd.go +++ b/cmd.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "fmt" "io/ioutil" "net/http" @@ -93,8 +94,8 @@ func main() { return } - // TODO: MakeAuth - config := sshd.MakeNoAuth() + auth := Auth{} + config := sshd.MakeAuth(auth) config.AddHostKey(signer) s, err := sshd.ListenSSH(options.Bind, config) @@ -106,11 +107,10 @@ func main() { defer s.Close() host := NewHost(s) - go host.Serve() + host.auth = &auth - /* TODO: for _, fingerprint := range options.Admin { - server.Op(fingerprint) + auth.Op(fingerprint) } if options.Whitelist != "" { @@ -123,7 +123,7 @@ func main() { scanner := bufio.NewScanner(file) for scanner.Scan() { - server.Whitelist(scanner.Text()) + auth.Whitelist(scanner.Text()) } } @@ -137,9 +137,10 @@ func main() { // hack to normalize line endings into \r\n motdString = strings.Replace(motdString, "\r\n", "\n", -1) motdString = strings.Replace(motdString, "\n", "\r\n", -1) - server.SetMotd(motdString) + host.SetMotd(motdString) } - */ + + go host.Serve() // Construct interrupt handler sig := make(chan os.Signal, 1) diff --git a/host.go b/host.go index 680d553..c6c8de7 100644 --- a/host.go +++ b/host.go @@ -14,9 +14,12 @@ import ( type Host struct { listener *sshd.SSHListener channel *chat.Channel + + motd string + auth *Auth } -// NewHost creates a Host on top of an existing listener +// NewHost creates a Host on top of an existing listener. func NewHost(listener *sshd.SSHListener) *Host { ch := chat.NewChannel() h := Host{ @@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host { return &h } -// Connect a specific Terminal to this host and its channel +// SetMotd sets the host's message of the day. +func (h *Host) SetMotd(motd string) { + h.motd = motd +} + +// Connect a specific Terminal to this host and its channel. func (h *Host) Connect(term *sshd.Terminal) { name := term.Conn.User() term.AutoCompleteCallback = h.AutoCompleteFunction diff --git a/sshd/auth.go b/sshd/auth.go index d271a85..90134e5 100644 --- a/sshd/auth.go +++ b/sshd/auth.go @@ -9,13 +9,9 @@ import ( "golang.org/x/crypto/ssh" ) -var errBanned = errors.New("banned") -var errNotWhitelisted = errors.New("not whitelisted") -var errNoInteractive = errors.New("public key authentication required") - type Auth interface { - IsBanned(ssh.PublicKey) bool - IsWhitelisted(ssh.PublicKey) bool + AllowAnonymous() bool + Check(string) (bool, error) } func MakeAuth(auth Auth) *ssh.ServerConfig { @@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig { NoClientAuth: false, // Auth-related things should be constant-time to avoid timing attacks. PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if auth.IsBanned(key) { - return nil, errBanned + fingerprint := Fingerprint(key) + ok, err := auth.Check(fingerprint) + if !ok { + return nil, err } - if !auth.IsWhitelisted(key) { - return nil, errNotWhitelisted - } - perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}} + perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}} return perm, nil }, KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { - if auth.IsBanned(nil) { - return nil, errNoInteractive - } - if !auth.IsWhitelisted(nil) { - return nil, errNotWhitelisted + if !auth.AllowAnonymous() { + return nil, errors.New("public key authentication required") } return nil, nil }, @@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig { NoClientAuth: false, // Auth-related things should be constant-time to avoid timing attacks. PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - return nil, nil + perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}} + return perm, nil }, KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { return nil, nil diff --git a/sshd/terminal.go b/sshd/terminal.go index 14318b3..196b9be 100644 --- a/sshd/terminal.go +++ b/sshd/terminal.go @@ -11,12 +11,12 @@ import ( // Extending ssh/terminal to include a closer interface type Terminal struct { terminal.Terminal - Conn ssh.Conn + Conn *ssh.ServerConn Channel ssh.Channel } // Make new terminal from a session channel -func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { +func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { if ch.ChannelType() != "session" { return nil, errors.New("terminal requires session channel") } @@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { } // Find session channel and make a Terminal from it -func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { +func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { for ch := range channels { if t := ch.ChannelType(); t != "session" { ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))