diff --git a/cmd/jounce/main.go b/cmd/jounce/main.go index e7b19cd..86ce126 100644 --- a/cmd/jounce/main.go +++ b/cmd/jounce/main.go @@ -33,6 +33,11 @@ func main() { cfg.Addr = addr } + db, err := jounce.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + var ln net.Listener if cfg.TLS != nil { cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) @@ -53,19 +58,16 @@ func main() { } } - srv := jounce.NewServer() + srv := jounce.NewServer(db) // TODO: load from config/DB srv.Hostname = cfg.Hostname srv.Debug = debug - srv.Upstreams = []jounce.Upstream{{ - Addr: "chat.freenode.net:6697", - Nick: "jounce", - Username: "jounce", - Realname: "jounce", - Channels: []string{"#jounce"}, - }} log.Printf("server listening on %q", cfg.Addr) - go srv.Run() + go func() { + if err := srv.Run(); err != nil { + log.Fatal(err) + } + }() log.Fatal(srv.Serve(ln)) } diff --git a/config/config.go b/config/config.go index 87b2cb1..a8e958a 100644 --- a/config/config.go +++ b/config/config.go @@ -14,9 +14,11 @@ type TLS struct { } type Server struct { - Addr string - Hostname string - TLS *TLS + Addr string + Hostname string + TLS *TLS + SQLDriver string + SQLSource string } func Defaults() *Server { @@ -25,8 +27,10 @@ func Defaults() *Server { hostname = "localhost" } return &Server{ - Addr: ":6667", - Hostname: hostname, + Addr: ":6667", + Hostname: hostname, + SQLDriver: "sqlite3", + SQLSource: "jounce.db", } } @@ -64,6 +68,10 @@ func Parse(r io.Reader) (*Server, error) { return nil, err } srv.TLS = tls + case "sql": + if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil { + return nil, err + } default: return nil, fmt.Errorf("unknown directive %q", d.Name) } diff --git a/db.go b/db.go new file mode 100644 index 0000000..24c3053 --- /dev/null +++ b/db.go @@ -0,0 +1,134 @@ +package jounce + +import ( + "database/sql" + "errors" + "sync" + + _ "github.com/mattn/go-sqlite3" +) + +var ErrNoSuchUser = errors.New("jounce: no such user") + +type User struct { + Username string + Password string +} + +type Network struct { + ID int64 + Addr string + Nick string + Username string + Realname string +} + +type Channel struct { + ID int64 + Name string +} + +type DB struct { + lock sync.Mutex + db *sql.DB +} + +func OpenSQLDB(driver, source string) (*DB, error) { + db, err := sql.Open(driver, source) + if err != nil { + return nil, err + } + return &DB{db: db}, nil +} + +func (db *DB) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + return db.Close() +} + +func (db *DB) ListUsers() ([]User, error) { + db.lock.Lock() + defer db.lock.Unlock() + + rows, err := db.db.Query("SELECT username, password FROM User") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password *string + if err := rows.Scan(&user.Username, &password); err != nil { + return nil, err + } + if password != nil { + user.Password = *password + } + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *DB) ListNetworks(username string) ([]Network, error) { + db.lock.Lock() + defer db.lock.Unlock() + + rows, err := db.db.Query("SELECT id, addr, nick, username, realname FROM Network WHERE user = ?", username) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var username, realname *string + if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname); err != nil { + return nil, err + } + if username != nil { + net.Username = *username + } + if realname != nil { + net.Realname = *realname + } + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *DB) ListChannels(networkID int64) ([]Channel, error) { + db.lock.Lock() + defer db.lock.Unlock() + + rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + if err := rows.Scan(&ch.ID, &ch.Name); err != nil { + return nil, err + } + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} diff --git a/downstream.go b/downstream.go index 629e9f4..41636df 100644 --- a/downstream.go +++ b/downstream.go @@ -58,7 +58,7 @@ type downstreamConn struct { nick string username string realname string - upstream *Upstream + network *network // can be nil } func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { @@ -100,7 +100,7 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string { func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { dc.user.forEachUpstream(func(uc *upstreamConn) { - if dc.upstream != nil && uc.upstream != dc.upstream { + if dc.network != nil && uc.network != dc.network { return } f(uc) @@ -301,9 +301,9 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { func (dc *downstreamConn) register() error { username := strings.TrimPrefix(dc.username, "~") - var upstreamName string + var networkName string if i := strings.LastIndexAny(username, "/@"); i >= 0 { - upstreamName = username[i+1:] + networkName = username[i+1:] } if i := strings.IndexAny(username, "/@"); i >= 0 { username = username[:i] @@ -320,14 +320,14 @@ func (dc *downstreamConn) register() error { return nil } - if upstreamName != "" { - dc.upstream = dc.user.getUpstream(upstreamName) - if dc.upstream == nil { - dc.logger.Printf("failed registration: unknown upstream %q", upstreamName) + if networkName != "" { + dc.network = dc.user.getNetwork(networkName) + if dc.network == nil { + dc.logger.Printf("failed registration: unknown network %q", networkName) dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_PASSWDMISMATCH, - Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", upstreamName)}, + Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)}, }) return nil } diff --git a/go.mod b/go.mod index ffe4e12..1f534b3 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module git.sr.ht/~emersion/jounce go 1.13 -require gopkg.in/irc.v3 v3.1.0 +require ( + github.com/mattn/go-sqlite3 v2.0.3+incompatible + gopkg.in/irc.v3 v3.1.0 +) diff --git a/go.sum b/go.sum index 9d45687..a109788 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,5 @@ +github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c= +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= gopkg.in/irc.v3 v3.1.0 h1:AeDaEhQ/78gHfpbj/3mSi8FfiNIsFiVrWEgLzOwHWnU= gopkg.in/irc.v3 v3.1.0/go.mod h1:qE0DWv0j8Z8wCbFhA9783JBO0bufi3rttcV1Sjin8io= diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..ce26d02 --- /dev/null +++ b/schema.sql @@ -0,0 +1,21 @@ +CREATE TABLE User ( + username VARCHAR(255) PRIMARY KEY, + password VARCHAR(255) +); + +CREATE TABLE Network ( + id INTEGER PRIMARY KEY, + user VARCHAR(255) NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + FOREIGN KEY(user) REFERENCES User(username) +); + +CREATE TABLE Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id) +); diff --git a/server.go b/server.go index bfcb5b9..a095ecf 100644 --- a/server.go +++ b/server.go @@ -47,26 +47,73 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) { l.logger.Printf("%v"+format, v...) } +type network struct { + Network + user *user + conn *upstreamConn +} + +func newNetwork(user *user, record *Network) *network { + return &network{ + Network: *record, + user: user, + } +} + +func (net *network) run() { + var lastTry time.Time + for { + if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay { + delay := retryConnectMinDelay - dur + net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) + time.Sleep(delay) + } + lastTry = time.Now() + + uc, err := connectToUpstream(net) + if err != nil { + net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) + continue + } + + uc.register() + + net.user.lock.Lock() + net.conn = uc + net.user.lock.Unlock() + + if err := uc.readMessages(); err != nil { + uc.logger.Printf("failed to handle messages: %v", err) + } + uc.Close() + + net.user.lock.Lock() + net.conn = nil + net.user.lock.Unlock() + } +} + type user struct { - username string + User srv *Server lock sync.Mutex - upstreamConns []*upstreamConn + networks []*network downstreamConns []*downstreamConn } -func newUser(srv *Server, username string) *user { +func newUser(srv *Server, record *User) *user { return &user{ - username: username, - srv: srv, + User: *record, + srv: srv, } } func (u *user) forEachUpstream(f func(uc *upstreamConn)) { u.lock.Lock() - for _, uc := range u.upstreamConns { - if !uc.registered || uc.closed { + for _, network := range u.networks { + uc := network.conn + if uc == nil || !uc.registered || uc.closed { continue } f(uc) @@ -82,21 +129,30 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) { u.lock.Unlock() } -func (u *user) getUpstream(name string) *Upstream { - for i, upstream := range u.srv.Upstreams { - if upstream.Addr == name { - return &u.srv.Upstreams[i] +func (u *user) getNetwork(name string) *network { + for _, network := range u.networks { + if network.Addr == name { + return network } } return nil } -type Upstream struct { - Addr string - Nick string - Username string - Realname string - Channels []string +func (u *user) run() { + networks, err := u.srv.db.ListNetworks(u.Username) + if err != nil { + u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return + } + + u.lock.Lock() + for _, record := range networks { + network := newNetwork(u, &record) + u.networks = append(u.networks, network) + + go network.run() + } + u.lock.Unlock() } type Server struct { @@ -104,18 +160,20 @@ type Server struct { Logger Logger RingCap int Debug bool - Upstreams []Upstream // TODO: per-user + + db *DB lock sync.Mutex users map[string]*user downstreamConns []*downstreamConn } -func NewServer() *Server { +func NewServer(db *DB) *Server { return &Server{ Logger: log.New(log.Writer(), "", log.LstdFlags), RingCap: 4096, users: make(map[string]*user), + db: db, } } @@ -123,55 +181,23 @@ func (s *Server) prefix() *irc.Prefix { return &irc.Prefix{Name: s.Hostname} } -func (s *Server) runUpstream(u *user, upstream *Upstream) { - var lastTry time.Time - for { - if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay { - delay := retryConnectMinDelay - dur - s.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), upstream.Addr) - time.Sleep(delay) - } - lastTry = time.Now() - - uc, err := connectToUpstream(u, upstream) - if err != nil { - s.Logger.Printf("failed to connect to upstream server %q: %v", upstream.Addr, err) - continue - } - - uc.register() - - u.lock.Lock() - u.upstreamConns = append(u.upstreamConns, uc) - u.lock.Unlock() - - if err := uc.readMessages(); err != nil { - uc.logger.Printf("failed to handle messages: %v", err) - } - uc.Close() - - u.lock.Lock() - for i := range u.upstreamConns { - if u.upstreamConns[i] == uc { - u.upstreamConns = append(u.upstreamConns[:i], u.upstreamConns[i+1:]...) - break - } - } - u.lock.Unlock() +func (s *Server) Run() error { + users, err := s.db.ListUsers() + if err != nil { + return err } -} - -func (s *Server) Run() { - // TODO: multi-user - u := newUser(s, "jounce") s.lock.Lock() - s.users[u.username] = u + for _, record := range users { + s.Logger.Printf("starting bouncer for user %q", record.Username) + u := newUser(s, &record) + s.users[u.Username] = u + + go u.run() + } s.lock.Unlock() - for i := range s.Upstreams { - go s.runUpstream(u, &s.Upstreams[i]) - } + select {} } func (s *Server) getUser(name string) *user { diff --git a/upstream.go b/upstream.go index c627ad8..e07e4c9 100644 --- a/upstream.go +++ b/upstream.go @@ -25,7 +25,7 @@ type upstreamChannel struct { } type upstreamConn struct { - upstream *Upstream + network *network logger Logger net net.Conn irc *irc.Conn @@ -41,33 +41,40 @@ type upstreamConn struct { registered bool nick string + username string + realname string closed bool modes modeSet channels map[string]*upstreamChannel history map[string]uint64 } -func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) { - logger := &prefixLogger{u.srv.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)} - logger.Printf("connecting to server") +func connectToUpstream(network *network) (*upstreamConn, error) { + logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)} - netConn, err := tls.Dial("tcp", upstream.Addr, nil) + addr := network.Addr + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + logger.Printf("connecting to TLS server at address %q", addr) + netConn, err := tls.Dial("tcp", addr, nil) if err != nil { - return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err) + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) } setKeepAlive(netConn) msgs := make(chan *irc.Message, 64) uc := &upstreamConn{ - upstream: upstream, + network: network, logger: logger, net: netConn, irc: irc.NewConn(netConn), - srv: u.srv, - user: u, + srv: network.user.srv, + user: network.user, messages: msgs, - ring: NewRing(u.srv.RingCap), + ring: NewRing(network.user.srv.RingCap), channels: make(map[string]*upstreamChannel), history: make(map[string]uint64), } @@ -102,7 +109,7 @@ func (uc *upstreamConn) Close() error { func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { uc.user.forEachDownstream(func(dc *downstreamConn) { - if dc.upstream != nil && dc.upstream != uc.upstream { + if dc.network != nil && dc.network != uc.network { return } f(dc) @@ -163,10 +170,16 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.registered = true uc.logger.Printf("connection registered") - for _, ch := range uc.upstream.Channels { + channels, err := uc.srv.db.ListChannels(uc.network.ID) + if err != nil { + uc.logger.Printf("failed to list channels from database: %v", err) + break + } + + for _, ch := range channels { uc.SendMessage(&irc.Message{ Command: "JOIN", - Params: []string{ch}, + Params: []string{ch.Name}, }) } case irc.RPL_MYINFO: @@ -371,14 +384,23 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } func (uc *upstreamConn) register() { - uc.nick = uc.upstream.Nick + uc.nick = uc.network.Nick + uc.username = uc.network.Username + if uc.username == "" { + uc.username = uc.nick + } + uc.realname = uc.network.Realname + if uc.realname == "" { + uc.realname = uc.nick + } + uc.SendMessage(&irc.Message{ Command: "NICK", Params: []string{uc.nick}, }) uc.SendMessage(&irc.Message{ Command: "USER", - Params: []string{uc.upstream.Username, "0", "*", uc.upstream.Realname}, + Params: []string{uc.username, "0", "*", uc.realname}, }) }