Channel Member now wrapping User with metadata, new Auth struct.

This commit is contained in:
Andrey Petrov 2015-01-01 18:40:10 -08:00
parent 6874601c0b
commit 6a662bf358
6 changed files with 80 additions and 53 deletions

View File

@ -13,12 +13,17 @@ const channelBuffer = 10
// closed.
var ErrChannelClosed = errors.New("channel closed")
// Member is a User with per-Channel metadata attached to it.
type Member struct {
*User
Op bool
}
// Channel definition, also a Set of User Items
type Channel struct {
topic string
history *History
users *Set
ops *Set
members *Set
broadcast chan Message
commands Commands
closed bool
@ -32,8 +37,7 @@ func NewChannel() *Channel {
return &Channel{
broadcast: broadcast,
history: NewHistory(historyLen),
users: NewSet(),
ops: NewSet(),
members: NewSet(),
commands: *defaultCommands,
}
}
@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) {
func (ch *Channel) Close() {
ch.closeOnce.Do(func() {
ch.closed = true
ch.users.Each(func(u Item) {
ch.members.Each(func(u Item) {
u.(*User).Close()
})
ch.users.Clear()
ch.members.Clear()
close(ch.broadcast)
})
}
@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) {
skipUser = fromMsg.From()
}
ch.users.Each(func(u Item) {
user := u.(*User)
ch.members.Each(func(u Item) {
user := u.(*Member).User
if skip && skipUser == user {
// Skip
return
@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error {
if ch.closed {
return ErrChannelClosed
}
err := ch.users.Add(u)
err := ch.members.Add(&Member{u, false})
if err != nil {
return err
}
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len())
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
ch.Send(NewAnnounceMsg(s))
return nil
}
// Leave the channel as a user, will announce.
// Leave the channel as a user, will announce. Mostly used during setup.
func (ch *Channel) Leave(u *User) error {
err := ch.users.Remove(u)
err := ch.members.Remove(u)
if err != nil {
return err
}
@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error {
return nil
}
// Member returns a corresponding Member object to a User if the Member is
// present in this channel.
func (ch *Channel) Member(u *User) (*Member, bool) {
m, err := ch.members.Get(u.Id())
if err != nil {
return nil, false
}
// Check that it's the same user
if m.(*Member).User != u {
return nil, false
}
return m.(*Member), true
}
// IsOp returns whether a user is an operator in this channel.
func (ch *Channel) IsOp(u *User) bool {
m, ok := ch.Member(u)
return ok && m.Op
}
// Topic of the channel.
func (ch *Channel) Topic() string {
return ch.topic
@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) {
// NamesPrefix lists all members' names with a given prefix, used to query
// for autocompletion purposes.
func (ch *Channel) NamesPrefix(prefix string) []string {
users := ch.users.ListPrefix(prefix)
names := make([]string, len(users))
for i, u := range users {
members := ch.members.ListPrefix(prefix)
names := make([]string, len(members))
for i, u := range members {
names[i] = u.(*User).Name()
}
return names

View File

@ -98,9 +98,8 @@ func init() {
c.Add(Command{
Prefix: "/help",
Handler: func(channel *Channel, msg CommandMsg) error {
user := msg.From()
op := channel.ops.In(user)
channel.Send(NewSystemMsg(channel.commands.Help(op), user))
op := channel.IsOp(msg.From())
channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From()))
return nil
},
})
@ -193,11 +192,12 @@ func init() {
})
c.Add(Command{
Op: true,
Prefix: "/op",
PrefixHelp: "USER",
Help: "Mark user as admin.",
Handler: func(channel *Channel, msg CommandMsg) error {
if !channel.ops.In(msg.From()) {
if !channel.IsOp(msg.From()) {
return errors.New("must be op")
}
@ -206,13 +206,14 @@ func init() {
return errors.New("must specify user")
}
// TODO: Add support for fingerprint-based op'ing.
user, err := channel.users.Get(Id(args[0]))
// TODO: Add support for fingerprint-based op'ing. This will
// probably need to live in host land.
member, err := channel.members.Get(Id(args[0]))
if err != nil {
return errors.New("user not found")
}
channel.ops.Add(user)
member.(*Member).Op = true
return nil
},
})

17
cmd.go
View File

@ -1,6 +1,7 @@
package main
import (
"bufio"
"fmt"
"io/ioutil"
"net/http"
@ -93,8 +94,8 @@ func main() {
return
}
// TODO: MakeAuth
config := sshd.MakeNoAuth()
auth := Auth{}
config := sshd.MakeAuth(auth)
config.AddHostKey(signer)
s, err := sshd.ListenSSH(options.Bind, config)
@ -106,11 +107,10 @@ func main() {
defer s.Close()
host := NewHost(s)
go host.Serve()
host.auth = &auth
/* TODO:
for _, fingerprint := range options.Admin {
server.Op(fingerprint)
auth.Op(fingerprint)
}
if options.Whitelist != "" {
@ -123,7 +123,7 @@ func main() {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
server.Whitelist(scanner.Text())
auth.Whitelist(scanner.Text())
}
}
@ -137,9 +137,10 @@ func main() {
// hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
server.SetMotd(motdString)
host.SetMotd(motdString)
}
*/
go host.Serve()
// Construct interrupt handler
sig := make(chan os.Signal, 1)

12
host.go
View File

@ -14,9 +14,12 @@ import (
type Host struct {
listener *sshd.SSHListener
channel *chat.Channel
motd string
auth *Auth
}
// NewHost creates a Host on top of an existing listener
// NewHost creates a Host on top of an existing listener.
func NewHost(listener *sshd.SSHListener) *Host {
ch := chat.NewChannel()
h := Host{
@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host {
return &h
}
// Connect a specific Terminal to this host and its channel
// SetMotd sets the host's message of the day.
func (h *Host) SetMotd(motd string) {
h.motd = motd
}
// Connect a specific Terminal to this host and its channel.
func (h *Host) Connect(term *sshd.Terminal) {
name := term.Conn.User()
term.AutoCompleteCallback = h.AutoCompleteFunction

View File

@ -9,13 +9,9 @@ import (
"golang.org/x/crypto/ssh"
)
var errBanned = errors.New("banned")
var errNotWhitelisted = errors.New("not whitelisted")
var errNoInteractive = errors.New("public key authentication required")
type Auth interface {
IsBanned(ssh.PublicKey) bool
IsWhitelisted(ssh.PublicKey) bool
AllowAnonymous() bool
Check(string) (bool, error)
}
func MakeAuth(auth Auth) *ssh.ServerConfig {
@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if auth.IsBanned(key) {
return nil, errBanned
fingerprint := Fingerprint(key)
ok, err := auth.Check(fingerprint)
if !ok {
return nil, err
}
if !auth.IsWhitelisted(key) {
return nil, errNotWhitelisted
}
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
return perm, nil
},
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
if auth.IsBanned(nil) {
return nil, errNoInteractive
}
if !auth.IsWhitelisted(nil) {
return nil, errNotWhitelisted
if !auth.AllowAnonymous() {
return nil, errors.New("public key authentication required")
}
return nil, nil
},
@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig {
NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
return nil, nil
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
return perm, nil
},
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return nil, nil

View File

@ -11,12 +11,12 @@ import (
// Extending ssh/terminal to include a closer interface
type Terminal struct {
terminal.Terminal
Conn ssh.Conn
Conn *ssh.ServerConn
Channel ssh.Channel
}
// Make new terminal from a session channel
func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" {
return nil, errors.New("terminal requires session channel")
}
@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
}
// Find session channel and make a Terminal from it
func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
for ch := range channels {
if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))