Fixed /kick command to actually close target

This commit is contained in:
Andrey Petrov 2016-07-15 16:22:25 -04:00
parent 0fdeda8b75
commit 9bf1f53445
7 changed files with 59 additions and 44 deletions

View File

@ -3,7 +3,6 @@ package sshchat
import ( import (
"errors" "errors"
"net" "net"
"sync"
"time" "time"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
@ -36,7 +35,6 @@ func newAuthAddr(addr net.Addr) string {
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface. // Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
type Auth struct { type Auth struct {
sync.RWMutex
bannedAddr *Set bannedAddr *Set
banned *Set banned *Set
whitelist *Set whitelist *Set

View File

@ -48,21 +48,21 @@ func NewMsg(body string) *Msg {
} }
// Render message based on a theme. // Render message based on a theme.
func (m *Msg) Render(t *Theme) string { func (m Msg) Render(t *Theme) string {
// TODO: Render based on theme // TODO: Render based on theme
// TODO: Cache based on theme // TODO: Cache based on theme
return m.String() return m.String()
} }
func (m *Msg) String() string { func (m Msg) String() string {
return m.body return m.body
} }
func (m *Msg) Command() string { func (m Msg) Command() string {
return "" return ""
} }
func (m *Msg) Timestamp() time.Time { func (m Msg) Timestamp() time.Time {
return m.timestamp return m.timestamp
} }
@ -72,8 +72,8 @@ type PublicMsg struct {
from *User from *User
} }
func NewPublicMsg(body string, from *User) *PublicMsg { func NewPublicMsg(body string, from *User) PublicMsg {
return &PublicMsg{ return PublicMsg{
Msg: Msg{ Msg: Msg{
body: body, body: body,
timestamp: time.Now(), timestamp: time.Now(),
@ -82,11 +82,11 @@ func NewPublicMsg(body string, from *User) *PublicMsg {
} }
} }
func (m *PublicMsg) From() *User { func (m PublicMsg) From() *User {
return m.from return m.from
} }
func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) { func (m PublicMsg) ParseCommand() (*CommandMsg, bool) {
// Check if the message is a command // Check if the message is a command
if !strings.HasPrefix(m.body, "/") { if !strings.HasPrefix(m.body, "/") {
return nil, false return nil, false
@ -104,7 +104,7 @@ func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) {
return &msg, true return &msg, true
} }
func (m *PublicMsg) Render(t *Theme) string { func (m PublicMsg) Render(t *Theme) string {
if t == nil { if t == nil {
return m.String() return m.String()
} }
@ -112,7 +112,7 @@ func (m *PublicMsg) Render(t *Theme) string {
return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body) return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body)
} }
func (m *PublicMsg) RenderFor(cfg UserConfig) string { func (m PublicMsg) RenderFor(cfg UserConfig) string {
if cfg.Highlight == nil || cfg.Theme == nil { if cfg.Highlight == nil || cfg.Theme == nil {
return m.Render(cfg.Theme) return m.Render(cfg.Theme)
} }
@ -128,7 +128,7 @@ func (m *PublicMsg) RenderFor(cfg UserConfig) string {
return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body) return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body)
} }
func (m *PublicMsg) String() string { func (m PublicMsg) String() string {
return fmt.Sprintf("%s: %s", m.from.Name(), m.body) return fmt.Sprintf("%s: %s", m.from.Name(), m.body)
} }
@ -164,9 +164,9 @@ type PrivateMsg struct {
to *User to *User
} }
func NewPrivateMsg(body string, from *User, to *User) *PrivateMsg { func NewPrivateMsg(body string, from *User, to *User) PrivateMsg {
return &PrivateMsg{ return PrivateMsg{
PublicMsg: *NewPublicMsg(body, from), PublicMsg: NewPublicMsg(body, from),
to: to, to: to,
} }
} }
@ -242,19 +242,19 @@ func (m *AnnounceMsg) String() string {
} }
type CommandMsg struct { type CommandMsg struct {
*PublicMsg PublicMsg
command string command string
args []string args []string
} }
func (m *CommandMsg) Command() string { func (m CommandMsg) Command() string {
return m.command return m.command
} }
func (m *CommandMsg) Args() []string { func (m CommandMsg) Args() []string {
return m.args return m.args
} }
func (m *CommandMsg) Body() string { func (m CommandMsg) Body() string {
return m.body return m.body
} }

View File

@ -18,14 +18,16 @@ var ErrUserClosed = errors.New("user closed")
// User definition, implemented set Item interface and io.Writer // User definition, implemented set Item interface and io.Writer
type User struct { type User struct {
Identifier Identifier
Config UserConfig Config UserConfig
colorIdx int colorIdx int
joined time.Time joined time.Time
msg chan Message msg chan Message
done chan struct{} done chan struct{}
replyTo *User // Set when user gets a /msg, for replying.
closed bool mu sync.Mutex
closeOnce sync.Once replyTo *User // Set when user gets a /msg, for replying.
screen io.Closer
closed bool
} }
func NewUser(identity Identifier) *User { func NewUser(identity Identifier) *User {
@ -41,8 +43,9 @@ func NewUser(identity Identifier) *User {
return &u return &u
} }
func NewUserScreen(identity Identifier, screen io.Writer) *User { func NewUserScreen(identity Identifier, screen io.WriteCloser) *User {
u := NewUser(identity) u := NewUser(identity)
u.screen = screen
go u.Consume(screen) go u.Consume(screen)
return u return u
@ -82,11 +85,20 @@ func (u *User) Wait() {
// Disconnect user, stop accepting messages // Disconnect user, stop accepting messages
func (u *User) Close() { func (u *User) Close() {
u.closeOnce.Do(func() { u.mu.Lock()
u.closed = true defer u.mu.Unlock()
close(u.done)
close(u.msg) if u.closed {
}) return
}
u.closed = true
close(u.done)
close(u.msg)
if u.screen != nil {
u.screen.Close()
}
} }
// Consume message buffer into an io.Writer. Will block, should be called in a // Consume message buffer into an io.Writer. Will block, should be called in a
@ -136,6 +148,9 @@ func (u *User) HandleMsg(m Message, out io.Writer) {
// Add message to consume by user // Add message to consume by user
func (u *User) Send(m Message) error { func (u *User) Send(m Message) error {
u.mu.Lock()
defer u.mu.Unlock()
if u.closed { if u.closed {
return ErrUserClosed return ErrUserClosed
} }

View File

@ -20,8 +20,8 @@ type identified interface {
// Set with string lookup. // Set with string lookup.
// TODO: Add trie for efficient prefix lookup? // TODO: Add trie for efficient prefix lookup?
type idSet struct { type idSet struct {
lookup map[string]identified
sync.RWMutex sync.RWMutex
lookup map[string]identified
} }
// newIdSet creates a new set. // newIdSet creates a new set.
@ -42,6 +42,8 @@ func (s *idSet) Clear() int {
// Len returns the size of the set right now. // Len returns the size of the set right now.
func (s *idSet) Len() int { func (s *idSet) Len() int {
s.RLock()
defer s.RUnlock()
return len(s.lookup) return len(s.lookup)
} }

12
host.go
View File

@ -90,12 +90,10 @@ func (h *Host) Connect(term *sshd.Terminal) {
id := NewIdentity(term.Conn) id := NewIdentity(term.Conn)
user := message.NewUserScreen(id, term) user := message.NewUserScreen(id, term)
user.Config.Theme = &h.theme user.Config.Theme = &h.theme
go func() {
// Close term once user is closed. // Close term once user is closed.
user.Wait()
term.Close()
}()
defer user.Close() defer user.Close()
defer term.Close()
h.mu.Lock() h.mu.Lock()
motd := h.motd motd := h.motd
@ -285,7 +283,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
} }
m := message.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target) m := message.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target)
room.Send(m) room.Send(&m)
return nil return nil
}, },
}) })
@ -307,7 +305,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
} }
m := message.NewPrivateMsg(strings.Join(args, " "), msg.From(), target) m := message.NewPrivateMsg(strings.Join(args, " "), msg.From(), target)
room.Send(m) room.Send(&m)
return nil return nil
}, },
}) })

View File

@ -206,7 +206,7 @@ func TestHostKick(t *testing.T) {
<-connected <-connected
// Consume while we're connected. Should break when kicked. // Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r) ioutil.ReadAll(r) // XXX?
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

4
set.go
View File

@ -25,8 +25,8 @@ type setValue interface {
// Set with expire-able keys // Set with expire-able keys
type Set struct { type Set struct {
lookup map[string]setValue
sync.Mutex sync.Mutex
lookup map[string]setValue
} }
// NewSet creates a new set. // NewSet creates a new set.
@ -38,6 +38,8 @@ func NewSet() *Set {
// Len returns the size of the set right now. // Len returns the size of the set right now.
func (s *Set) Len() int { func (s *Set) Len() int {
s.Lock()
defer s.Unlock()
return len(s.lookup) return len(s.lookup)
} }