ssh-chat/server.go

282 lines
5.9 KiB
Go
Raw Normal View History

2014-12-07 08:31:23 +01:00
package main
import (
2014-12-12 23:50:14 +01:00
"crypto/md5"
2014-12-07 08:31:23 +01:00
"fmt"
"net"
2014-12-12 07:03:32 +01:00
"regexp"
2014-12-10 00:51:24 +01:00
"strings"
"sync"
2014-12-13 08:08:18 +01:00
"syscall"
2014-12-12 23:50:14 +01:00
"time"
2014-12-10 00:51:24 +01:00
"golang.org/x/crypto/ssh"
2014-12-07 08:31:23 +01:00
)
2014-12-12 07:03:32 +01:00
const MAX_NAME_LENGTH = 32
const HISTORY_LEN = 20
2014-12-13 07:09:34 +01:00
var RE_STRIP_TEXT = regexp.MustCompile("[^0-9A-Za-z_]")
2014-12-12 07:03:32 +01:00
2014-12-11 10:29:18 +01:00
type Clients map[string]*Client
2014-12-07 08:31:23 +01:00
type Server struct {
sshConfig *ssh.ServerConfig
done chan struct{}
2014-12-11 10:29:18 +01:00
clients Clients
2014-12-10 00:51:24 +01:00
lock sync.Mutex
2014-12-11 10:29:18 +01:00
count int
2014-12-12 07:03:32 +01:00
history *History
2014-12-12 23:50:14 +01:00
admins map[string]struct{} // fingerprint lookup
banned map[string]*time.Time // fingerprint lookup
2014-12-07 08:31:23 +01:00
}
func NewServer(privateKey []byte) (*Server, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
2014-12-12 23:50:14 +01:00
server := Server{
done: make(chan struct{}),
clients: Clients{},
count: 0,
history: NewHistory(HISTORY_LEN),
admins: map[string]struct{}{},
banned: map[string]*time.Time{},
}
2014-12-07 08:31:23 +01:00
config := ssh.ServerConfig{
2014-12-10 00:51:24 +01:00
NoClientAuth: false,
2014-12-12 23:50:14 +01:00
// Auth-related things should be constant-time to avoid timing attacks.
2014-12-10 00:51:24 +01:00
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
2014-12-12 23:50:14 +01:00
fingerprint := Fingerprint(key)
if server.IsBanned(fingerprint) {
return nil, fmt.Errorf("Banned.")
}
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
return perm, nil
2014-12-10 00:51:24 +01:00
},
2014-12-07 08:31:23 +01:00
}
config.AddHostKey(signer)
2014-12-12 23:50:14 +01:00
server.sshConfig = &config
2014-12-07 08:31:23 +01:00
return &server, nil
}
2014-12-12 07:10:06 +01:00
func (s *Server) Len() int {
return len(s.clients)
}
2014-12-10 02:01:35 +01:00
func (s *Server) Broadcast(msg string, except *Client) {
logger.Debugf("Broadcast to %d: %s", s.Len(), msg)
2014-12-12 07:03:32 +01:00
s.history.Add(msg)
2014-12-11 10:29:18 +01:00
for _, client := range s.clients {
2014-12-10 04:26:55 +01:00
if except != nil && client == except {
2014-12-10 02:01:35 +01:00
continue
}
2014-12-10 00:51:24 +01:00
client.Msg <- msg
2014-12-07 08:31:23 +01:00
}
}
2014-12-11 10:29:18 +01:00
func (s *Server) Add(client *Client) {
2014-12-12 09:42:02 +01:00
go func() {
client.WriteLines(s.history.Get(10))
client.Write(fmt.Sprintf("-> Welcome to ssh-chat. Enter /help for more."))
2014-12-12 09:42:02 +01:00
}()
2014-12-11 10:29:18 +01:00
s.lock.Lock()
s.count++
2014-12-12 07:03:32 +01:00
newName, err := s.proposeName(client.Name)
if err != nil {
client.Msg <- fmt.Sprintf("-> Your name '%s' is not available, renamed to '%s'. Use /nick <name> to change it.", client.Name, newName)
2014-12-11 10:29:18 +01:00
}
client.Rename(newName)
2014-12-11 10:29:18 +01:00
s.clients[client.Name] = client
num := len(s.clients)
s.lock.Unlock()
s.Broadcast(fmt.Sprintf("* %s joined. (Total connected: %d)", client.Name, num), client)
2014-12-11 10:29:18 +01:00
}
func (s *Server) Remove(client *Client) {
s.lock.Lock()
delete(s.clients, client.Name)
s.lock.Unlock()
s.Broadcast(fmt.Sprintf("* %s left.", client.Name), nil)
2014-12-11 10:29:18 +01:00
}
2014-12-12 07:03:32 +01:00
func (s *Server) proposeName(name string) (string, error) {
// Assumes caller holds lock.
var err error
2014-12-13 07:09:34 +01:00
name = RE_STRIP_TEXT.ReplaceAllString(name, "")
2014-12-12 07:03:32 +01:00
if len(name) > MAX_NAME_LENGTH {
name = name[:MAX_NAME_LENGTH]
} else if len(name) == 0 {
name = fmt.Sprintf("Guest%d", s.count)
}
_, collision := s.clients[name]
if collision {
err = fmt.Errorf("%s is not available.", name)
name = fmt.Sprintf("Guest%d", s.count)
}
return name, err
}
2014-12-11 10:29:18 +01:00
func (s *Server) Rename(client *Client, newName string) {
s.lock.Lock()
2014-12-12 07:03:32 +01:00
newName, err := s.proposeName(newName)
if err != nil {
client.Msg <- fmt.Sprintf("-> %s", err)
2014-12-11 10:29:18 +01:00
s.lock.Unlock()
return
}
2014-12-12 07:03:32 +01:00
2014-12-12 10:15:58 +01:00
// TODO: Use a channel/goroutine for adding clients, rathern than locks?
2014-12-11 10:29:18 +01:00
delete(s.clients, client.Name)
oldName := client.Name
client.Rename(newName)
s.clients[client.Name] = client
s.lock.Unlock()
s.Broadcast(fmt.Sprintf("* %s is now known as %s.", oldName, newName), nil)
2014-12-11 10:29:18 +01:00
}
func (s *Server) List(prefix *string) []string {
r := []string{}
for name, _ := range s.clients {
if prefix != nil && !strings.HasPrefix(name, *prefix) {
continue
}
r = append(r, name)
}
return r
}
2014-12-12 07:03:32 +01:00
func (s *Server) Who(name string) *Client {
return s.clients[name]
}
2014-12-12 23:50:14 +01:00
func (s *Server) Op(fingerprint string) {
logger.Infof("Adding admin: %s", fingerprint)
s.lock.Lock()
s.admins[fingerprint] = struct{}{}
s.lock.Unlock()
}
func (s *Server) IsOp(client *Client) bool {
_, r := s.admins[client.Fingerprint()]
return r
}
func (s *Server) IsBanned(fingerprint string) bool {
ban, hasBan := s.banned[fingerprint]
if !hasBan {
return false
}
if ban == nil {
return true
}
if ban.Before(time.Now()) {
s.Unban(fingerprint)
return false
}
return true
}
func (s *Server) Ban(fingerprint string, duration *time.Duration) {
var until *time.Time
s.lock.Lock()
if duration != nil {
when := time.Now().Add(*duration)
until = &when
}
s.banned[fingerprint] = until
s.lock.Unlock()
}
func (s *Server) Unban(fingerprint string) {
s.lock.Lock()
delete(s.banned, fingerprint)
s.lock.Unlock()
}
2014-12-10 00:51:24 +01:00
func (s *Server) Start(laddr string) error {
2014-12-07 08:31:23 +01:00
// Once a ServerConfig has been configured, connections can be
// accepted.
socket, err := net.Listen("tcp", laddr)
if err != nil {
2014-12-10 00:51:24 +01:00
return err
2014-12-07 08:31:23 +01:00
}
logger.Infof("Listening on %s", laddr)
go func() {
defer socket.Close()
2014-12-07 08:31:23 +01:00
for {
conn, err := socket.Accept()
2014-12-10 00:51:24 +01:00
2014-12-07 08:31:23 +01:00
if err != nil {
2014-12-13 08:08:18 +01:00
logger.Errorf("Failed to accept connection: %v", err)
if err == syscall.EINVAL {
// TODO: Handle shutdown more gracefully?
return
}
2014-12-07 08:31:23 +01:00
}
2014-12-10 00:51:24 +01:00
// Goroutineify to resume accepting sockets early.
go func() {
// From a standard TCP connection to an encrypted SSH connection
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
logger.Errorf("Failed to handshake: %v", err)
return
}
2014-12-13 07:09:34 +01:00
version := RE_STRIP_TEXT.ReplaceAllString(string(sshConn.ClientVersion()), "")
2014-12-13 04:33:30 +01:00
if len(version) > 100 {
2014-12-13 07:09:34 +01:00
version = "Evil Jerk with a superlong string"
2014-12-13 04:33:30 +01:00
}
logger.Infof("Connection #%d from: %s, %s, %s", s.count+1, sshConn.RemoteAddr(), sshConn.User(), version)
2014-12-10 00:51:24 +01:00
go ssh.DiscardRequests(requests)
2014-12-11 10:29:18 +01:00
client := NewClient(s, sshConn)
2014-12-12 10:15:58 +01:00
go client.handleChannels(channels)
2014-12-10 00:51:24 +01:00
}()
2014-12-07 08:31:23 +01:00
}
}()
2014-12-10 00:51:24 +01:00
go func() {
<-s.done
socket.Close()
}()
return nil
2014-12-07 08:31:23 +01:00
}
2014-12-10 00:51:24 +01:00
func (s *Server) Stop() {
2014-12-11 10:29:18 +01:00
for _, client := range s.clients {
2014-12-10 04:26:55 +01:00
client.Conn.Close()
}
2014-12-10 00:51:24 +01:00
close(s.done)
2014-12-07 08:31:23 +01:00
}
2014-12-12 23:50:14 +01:00
func Fingerprint(k ssh.PublicKey) string {
hash := md5.Sum(k.Marshal())
r := fmt.Sprintf("% x", hash)
return strings.Replace(r, " ", ":", -1)
}