Channel Member now wrapping User with metadata, new Auth struct.
This commit is contained in:
parent
6874601c0b
commit
6a662bf358
|
@ -13,12 +13,17 @@ const channelBuffer = 10
|
|||
// closed.
|
||||
var ErrChannelClosed = errors.New("channel closed")
|
||||
|
||||
// Member is a User with per-Channel metadata attached to it.
|
||||
type Member struct {
|
||||
*User
|
||||
Op bool
|
||||
}
|
||||
|
||||
// Channel definition, also a Set of User Items
|
||||
type Channel struct {
|
||||
topic string
|
||||
history *History
|
||||
users *Set
|
||||
ops *Set
|
||||
members *Set
|
||||
broadcast chan Message
|
||||
commands Commands
|
||||
closed bool
|
||||
|
@ -32,8 +37,7 @@ func NewChannel() *Channel {
|
|||
return &Channel{
|
||||
broadcast: broadcast,
|
||||
history: NewHistory(historyLen),
|
||||
users: NewSet(),
|
||||
ops: NewSet(),
|
||||
members: NewSet(),
|
||||
commands: *defaultCommands,
|
||||
}
|
||||
}
|
||||
|
@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) {
|
|||
func (ch *Channel) Close() {
|
||||
ch.closeOnce.Do(func() {
|
||||
ch.closed = true
|
||||
ch.users.Each(func(u Item) {
|
||||
ch.members.Each(func(u Item) {
|
||||
u.(*User).Close()
|
||||
})
|
||||
ch.users.Clear()
|
||||
ch.members.Clear()
|
||||
close(ch.broadcast)
|
||||
})
|
||||
}
|
||||
|
@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) {
|
|||
skipUser = fromMsg.From()
|
||||
}
|
||||
|
||||
ch.users.Each(func(u Item) {
|
||||
user := u.(*User)
|
||||
ch.members.Each(func(u Item) {
|
||||
user := u.(*Member).User
|
||||
if skip && skipUser == user {
|
||||
// Skip
|
||||
return
|
||||
|
@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error {
|
|||
if ch.closed {
|
||||
return ErrChannelClosed
|
||||
}
|
||||
err := ch.users.Add(u)
|
||||
err := ch.members.Add(&Member{u, false})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len())
|
||||
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
|
||||
ch.Send(NewAnnounceMsg(s))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Leave the channel as a user, will announce.
|
||||
// Leave the channel as a user, will announce. Mostly used during setup.
|
||||
func (ch *Channel) Leave(u *User) error {
|
||||
err := ch.users.Remove(u)
|
||||
err := ch.members.Remove(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Member returns a corresponding Member object to a User if the Member is
|
||||
// present in this channel.
|
||||
func (ch *Channel) Member(u *User) (*Member, bool) {
|
||||
m, err := ch.members.Get(u.Id())
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
// Check that it's the same user
|
||||
if m.(*Member).User != u {
|
||||
return nil, false
|
||||
}
|
||||
return m.(*Member), true
|
||||
}
|
||||
|
||||
// IsOp returns whether a user is an operator in this channel.
|
||||
func (ch *Channel) IsOp(u *User) bool {
|
||||
m, ok := ch.Member(u)
|
||||
return ok && m.Op
|
||||
}
|
||||
|
||||
// Topic of the channel.
|
||||
func (ch *Channel) Topic() string {
|
||||
return ch.topic
|
||||
|
@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) {
|
|||
// NamesPrefix lists all members' names with a given prefix, used to query
|
||||
// for autocompletion purposes.
|
||||
func (ch *Channel) NamesPrefix(prefix string) []string {
|
||||
users := ch.users.ListPrefix(prefix)
|
||||
names := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
members := ch.members.ListPrefix(prefix)
|
||||
names := make([]string, len(members))
|
||||
for i, u := range members {
|
||||
names[i] = u.(*User).Name()
|
||||
}
|
||||
return names
|
||||
|
|
|
@ -98,9 +98,8 @@ func init() {
|
|||
c.Add(Command{
|
||||
Prefix: "/help",
|
||||
Handler: func(channel *Channel, msg CommandMsg) error {
|
||||
user := msg.From()
|
||||
op := channel.ops.In(user)
|
||||
channel.Send(NewSystemMsg(channel.commands.Help(op), user))
|
||||
op := channel.IsOp(msg.From())
|
||||
channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From()))
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
@ -193,11 +192,12 @@ func init() {
|
|||
})
|
||||
|
||||
c.Add(Command{
|
||||
Op: true,
|
||||
Prefix: "/op",
|
||||
PrefixHelp: "USER",
|
||||
Help: "Mark user as admin.",
|
||||
Handler: func(channel *Channel, msg CommandMsg) error {
|
||||
if !channel.ops.In(msg.From()) {
|
||||
if !channel.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
|
||||
|
@ -206,13 +206,14 @@ func init() {
|
|||
return errors.New("must specify user")
|
||||
}
|
||||
|
||||
// TODO: Add support for fingerprint-based op'ing.
|
||||
user, err := channel.users.Get(Id(args[0]))
|
||||
// TODO: Add support for fingerprint-based op'ing. This will
|
||||
// probably need to live in host land.
|
||||
member, err := channel.members.Get(Id(args[0]))
|
||||
if err != nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
channel.ops.Add(user)
|
||||
member.(*Member).Op = true
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
|
17
cmd.go
17
cmd.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -93,8 +94,8 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO: MakeAuth
|
||||
config := sshd.MakeNoAuth()
|
||||
auth := Auth{}
|
||||
config := sshd.MakeAuth(auth)
|
||||
config.AddHostKey(signer)
|
||||
|
||||
s, err := sshd.ListenSSH(options.Bind, config)
|
||||
|
@ -106,11 +107,10 @@ func main() {
|
|||
defer s.Close()
|
||||
|
||||
host := NewHost(s)
|
||||
go host.Serve()
|
||||
host.auth = &auth
|
||||
|
||||
/* TODO:
|
||||
for _, fingerprint := range options.Admin {
|
||||
server.Op(fingerprint)
|
||||
auth.Op(fingerprint)
|
||||
}
|
||||
|
||||
if options.Whitelist != "" {
|
||||
|
@ -123,7 +123,7 @@ func main() {
|
|||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
server.Whitelist(scanner.Text())
|
||||
auth.Whitelist(scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,9 +137,10 @@ func main() {
|
|||
// hack to normalize line endings into \r\n
|
||||
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
||||
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
||||
server.SetMotd(motdString)
|
||||
host.SetMotd(motdString)
|
||||
}
|
||||
*/
|
||||
|
||||
go host.Serve()
|
||||
|
||||
// Construct interrupt handler
|
||||
sig := make(chan os.Signal, 1)
|
||||
|
|
12
host.go
12
host.go
|
@ -14,9 +14,12 @@ import (
|
|||
type Host struct {
|
||||
listener *sshd.SSHListener
|
||||
channel *chat.Channel
|
||||
|
||||
motd string
|
||||
auth *Auth
|
||||
}
|
||||
|
||||
// NewHost creates a Host on top of an existing listener
|
||||
// NewHost creates a Host on top of an existing listener.
|
||||
func NewHost(listener *sshd.SSHListener) *Host {
|
||||
ch := chat.NewChannel()
|
||||
h := Host{
|
||||
|
@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host {
|
|||
return &h
|
||||
}
|
||||
|
||||
// Connect a specific Terminal to this host and its channel
|
||||
// SetMotd sets the host's message of the day.
|
||||
func (h *Host) SetMotd(motd string) {
|
||||
h.motd = motd
|
||||
}
|
||||
|
||||
// Connect a specific Terminal to this host and its channel.
|
||||
func (h *Host) Connect(term *sshd.Terminal) {
|
||||
name := term.Conn.User()
|
||||
term.AutoCompleteCallback = h.AutoCompleteFunction
|
||||
|
|
29
sshd/auth.go
29
sshd/auth.go
|
@ -9,13 +9,9 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errBanned = errors.New("banned")
|
||||
var errNotWhitelisted = errors.New("not whitelisted")
|
||||
var errNoInteractive = errors.New("public key authentication required")
|
||||
|
||||
type Auth interface {
|
||||
IsBanned(ssh.PublicKey) bool
|
||||
IsWhitelisted(ssh.PublicKey) bool
|
||||
AllowAnonymous() bool
|
||||
Check(string) (bool, error)
|
||||
}
|
||||
|
||||
func MakeAuth(auth Auth) *ssh.ServerConfig {
|
||||
|
@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
|
|||
NoClientAuth: false,
|
||||
// Auth-related things should be constant-time to avoid timing attacks.
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if auth.IsBanned(key) {
|
||||
return nil, errBanned
|
||||
fingerprint := Fingerprint(key)
|
||||
ok, err := auth.Check(fingerprint)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
if !auth.IsWhitelisted(key) {
|
||||
return nil, errNotWhitelisted
|
||||
}
|
||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
|
||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
|
||||
return perm, nil
|
||||
},
|
||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
if auth.IsBanned(nil) {
|
||||
return nil, errNoInteractive
|
||||
}
|
||||
if !auth.IsWhitelisted(nil) {
|
||||
return nil, errNotWhitelisted
|
||||
if !auth.AllowAnonymous() {
|
||||
return nil, errors.New("public key authentication required")
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
|
@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig {
|
|||
NoClientAuth: false,
|
||||
// Auth-related things should be constant-time to avoid timing attacks.
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
return nil, nil
|
||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
|
||||
return perm, nil
|
||||
},
|
||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
return nil, nil
|
||||
|
|
|
@ -11,12 +11,12 @@ import (
|
|||
// Extending ssh/terminal to include a closer interface
|
||||
type Terminal struct {
|
||||
terminal.Terminal
|
||||
Conn ssh.Conn
|
||||
Conn *ssh.ServerConn
|
||||
Channel ssh.Channel
|
||||
}
|
||||
|
||||
// Make new terminal from a session channel
|
||||
func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
||||
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
|
||||
if ch.ChannelType() != "session" {
|
||||
return nil, errors.New("terminal requires session channel")
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
|||
}
|
||||
|
||||
// Find session channel and make a Terminal from it
|
||||
func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||
for ch := range channels {
|
||||
if t := ch.ChannelType(); t != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||
|
|
Loading…
Reference in New Issue