Add SQLite database

Closes: https://todo.sr.ht/~emersion/jounce/9
This commit is contained in:
Simon Ser 2020-03-04 18:22:58 +01:00
parent eacd4e6f0c
commit 84fe3ae255
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
9 changed files with 321 additions and 102 deletions

View File

@ -33,6 +33,11 @@ func main() {
cfg.Addr = addr cfg.Addr = addr
} }
db, err := jounce.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
var ln net.Listener var ln net.Listener
if cfg.TLS != nil { if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
@ -53,19 +58,16 @@ func main() {
} }
} }
srv := jounce.NewServer() srv := jounce.NewServer(db)
// TODO: load from config/DB // TODO: load from config/DB
srv.Hostname = cfg.Hostname srv.Hostname = cfg.Hostname
srv.Debug = debug srv.Debug = debug
srv.Upstreams = []jounce.Upstream{{
Addr: "chat.freenode.net:6697",
Nick: "jounce",
Username: "jounce",
Realname: "jounce",
Channels: []string{"#jounce"},
}}
log.Printf("server listening on %q", cfg.Addr) log.Printf("server listening on %q", cfg.Addr)
go srv.Run() go func() {
if err := srv.Run(); err != nil {
log.Fatal(err)
}
}()
log.Fatal(srv.Serve(ln)) log.Fatal(srv.Serve(ln))
} }

View File

@ -17,6 +17,8 @@ type Server struct {
Addr string Addr string
Hostname string Hostname string
TLS *TLS TLS *TLS
SQLDriver string
SQLSource string
} }
func Defaults() *Server { func Defaults() *Server {
@ -27,6 +29,8 @@ func Defaults() *Server {
return &Server{ return &Server{
Addr: ":6667", Addr: ":6667",
Hostname: hostname, Hostname: hostname,
SQLDriver: "sqlite3",
SQLSource: "jounce.db",
} }
} }
@ -64,6 +68,10 @@ func Parse(r io.Reader) (*Server, error) {
return nil, err return nil, err
} }
srv.TLS = tls srv.TLS = tls
case "sql":
if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
return nil, err
}
default: default:
return nil, fmt.Errorf("unknown directive %q", d.Name) return nil, fmt.Errorf("unknown directive %q", d.Name)
} }

134
db.go Normal file
View File

@ -0,0 +1,134 @@
package jounce
import (
"database/sql"
"errors"
"sync"
_ "github.com/mattn/go-sqlite3"
)
var ErrNoSuchUser = errors.New("jounce: no such user")
type User struct {
Username string
Password string
}
type Network struct {
ID int64
Addr string
Nick string
Username string
Realname string
}
type Channel struct {
ID int64
Name string
}
type DB struct {
lock sync.Mutex
db *sql.DB
}
func OpenSQLDB(driver, source string) (*DB, error) {
db, err := sql.Open(driver, source)
if err != nil {
return nil, err
}
return &DB{db: db}, nil
}
func (db *DB) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
return db.Close()
}
func (db *DB) ListUsers() ([]User, error) {
db.lock.Lock()
defer db.lock.Unlock()
rows, err := db.db.Query("SELECT username, password FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
var password *string
if err := rows.Scan(&user.Username, &password); err != nil {
return nil, err
}
if password != nil {
user.Password = *password
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (db *DB) ListNetworks(username string) ([]Network, error) {
db.lock.Lock()
defer db.lock.Unlock()
rows, err := db.db.Query("SELECT id, addr, nick, username, realname FROM Network WHERE user = ?", username)
if err != nil {
return nil, err
}
defer rows.Close()
var networks []Network
for rows.Next() {
var net Network
var username, realname *string
if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname); err != nil {
return nil, err
}
if username != nil {
net.Username = *username
}
if realname != nil {
net.Realname = *realname
}
networks = append(networks, net)
}
if err := rows.Err(); err != nil {
return nil, err
}
return networks, nil
}
func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
db.lock.Lock()
defer db.lock.Unlock()
rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []Channel
for rows.Next() {
var ch Channel
if err := rows.Scan(&ch.ID, &ch.Name); err != nil {
return nil, err
}
channels = append(channels, ch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return channels, nil
}

View File

@ -58,7 +58,7 @@ type downstreamConn struct {
nick string nick string
username string username string
realname string realname string
upstream *Upstream network *network // can be nil
} }
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
@ -100,7 +100,7 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
dc.user.forEachUpstream(func(uc *upstreamConn) { dc.user.forEachUpstream(func(uc *upstreamConn) {
if dc.upstream != nil && uc.upstream != dc.upstream { if dc.network != nil && uc.network != dc.network {
return return
} }
f(uc) f(uc)
@ -301,9 +301,9 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
func (dc *downstreamConn) register() error { func (dc *downstreamConn) register() error {
username := strings.TrimPrefix(dc.username, "~") username := strings.TrimPrefix(dc.username, "~")
var upstreamName string var networkName string
if i := strings.LastIndexAny(username, "/@"); i >= 0 { if i := strings.LastIndexAny(username, "/@"); i >= 0 {
upstreamName = username[i+1:] networkName = username[i+1:]
} }
if i := strings.IndexAny(username, "/@"); i >= 0 { if i := strings.IndexAny(username, "/@"); i >= 0 {
username = username[:i] username = username[:i]
@ -320,14 +320,14 @@ func (dc *downstreamConn) register() error {
return nil return nil
} }
if upstreamName != "" { if networkName != "" {
dc.upstream = dc.user.getUpstream(upstreamName) dc.network = dc.user.getNetwork(networkName)
if dc.upstream == nil { if dc.network == nil {
dc.logger.Printf("failed registration: unknown upstream %q", upstreamName) dc.logger.Printf("failed registration: unknown network %q", networkName)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", upstreamName)}, Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)},
}) })
return nil return nil
} }

5
go.mod
View File

@ -2,4 +2,7 @@ module git.sr.ht/~emersion/jounce
go 1.13 go 1.13
require gopkg.in/irc.v3 v3.1.0 require (
github.com/mattn/go-sqlite3 v2.0.3+incompatible
gopkg.in/irc.v3 v3.1.0
)

3
go.sum
View File

@ -1,2 +1,5 @@
github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
gopkg.in/irc.v3 v3.1.0 h1:AeDaEhQ/78gHfpbj/3mSi8FfiNIsFiVrWEgLzOwHWnU= gopkg.in/irc.v3 v3.1.0 h1:AeDaEhQ/78gHfpbj/3mSi8FfiNIsFiVrWEgLzOwHWnU=
gopkg.in/irc.v3 v3.1.0/go.mod h1:qE0DWv0j8Z8wCbFhA9783JBO0bufi3rttcV1Sjin8io= gopkg.in/irc.v3 v3.1.0/go.mod h1:qE0DWv0j8Z8wCbFhA9783JBO0bufi3rttcV1Sjin8io=

21
schema.sql Normal file
View File

@ -0,0 +1,21 @@
CREATE TABLE User (
username VARCHAR(255) PRIMARY KEY,
password VARCHAR(255)
);
CREATE TABLE Network (
id INTEGER PRIMARY KEY,
user VARCHAR(255) NOT NULL,
addr VARCHAR(255) NOT NULL,
nick VARCHAR(255) NOT NULL,
username VARCHAR(255),
realname VARCHAR(255),
FOREIGN KEY(user) REFERENCES User(username)
);
CREATE TABLE Channel (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
name VARCHAR(255) NOT NULL,
FOREIGN KEY(network) REFERENCES Network(id)
);

148
server.go
View File

@ -47,26 +47,73 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
l.logger.Printf("%v"+format, v...) l.logger.Printf("%v"+format, v...)
} }
type network struct {
Network
user *user
conn *upstreamConn
}
func newNetwork(user *user, record *Network) *network {
return &network{
Network: *record,
user: user,
}
}
func (net *network) run() {
var lastTry time.Time
for {
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
delay := retryConnectMinDelay - dur
net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
time.Sleep(delay)
}
lastTry = time.Now()
uc, err := connectToUpstream(net)
if err != nil {
net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
continue
}
uc.register()
net.user.lock.Lock()
net.conn = uc
net.user.lock.Unlock()
if err := uc.readMessages(); err != nil {
uc.logger.Printf("failed to handle messages: %v", err)
}
uc.Close()
net.user.lock.Lock()
net.conn = nil
net.user.lock.Unlock()
}
}
type user struct { type user struct {
username string User
srv *Server srv *Server
lock sync.Mutex lock sync.Mutex
upstreamConns []*upstreamConn networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
} }
func newUser(srv *Server, username string) *user { func newUser(srv *Server, record *User) *user {
return &user{ return &user{
username: username, User: *record,
srv: srv, srv: srv,
} }
} }
func (u *user) forEachUpstream(f func(uc *upstreamConn)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
u.lock.Lock() u.lock.Lock()
for _, uc := range u.upstreamConns { for _, network := range u.networks {
if !uc.registered || uc.closed { uc := network.conn
if uc == nil || !uc.registered || uc.closed {
continue continue
} }
f(uc) f(uc)
@ -82,21 +129,30 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Unlock() u.lock.Unlock()
} }
func (u *user) getUpstream(name string) *Upstream { func (u *user) getNetwork(name string) *network {
for i, upstream := range u.srv.Upstreams { for _, network := range u.networks {
if upstream.Addr == name { if network.Addr == name {
return &u.srv.Upstreams[i] return network
} }
} }
return nil return nil
} }
type Upstream struct { func (u *user) run() {
Addr string networks, err := u.srv.db.ListNetworks(u.Username)
Nick string if err != nil {
Username string u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
Realname string return
Channels []string }
u.lock.Lock()
for _, record := range networks {
network := newNetwork(u, &record)
u.networks = append(u.networks, network)
go network.run()
}
u.lock.Unlock()
} }
type Server struct { type Server struct {
@ -104,18 +160,20 @@ type Server struct {
Logger Logger Logger Logger
RingCap int RingCap int
Debug bool Debug bool
Upstreams []Upstream // TODO: per-user
db *DB
lock sync.Mutex lock sync.Mutex
users map[string]*user users map[string]*user
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
} }
func NewServer() *Server { func NewServer(db *DB) *Server {
return &Server{ return &Server{
Logger: log.New(log.Writer(), "", log.LstdFlags), Logger: log.New(log.Writer(), "", log.LstdFlags),
RingCap: 4096, RingCap: 4096,
users: make(map[string]*user), users: make(map[string]*user),
db: db,
} }
} }
@ -123,55 +181,23 @@ func (s *Server) prefix() *irc.Prefix {
return &irc.Prefix{Name: s.Hostname} return &irc.Prefix{Name: s.Hostname}
} }
func (s *Server) runUpstream(u *user, upstream *Upstream) { func (s *Server) Run() error {
var lastTry time.Time users, err := s.db.ListUsers()
for {
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
delay := retryConnectMinDelay - dur
s.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), upstream.Addr)
time.Sleep(delay)
}
lastTry = time.Now()
uc, err := connectToUpstream(u, upstream)
if err != nil { if err != nil {
s.Logger.Printf("failed to connect to upstream server %q: %v", upstream.Addr, err) return err
continue
} }
uc.register()
u.lock.Lock()
u.upstreamConns = append(u.upstreamConns, uc)
u.lock.Unlock()
if err := uc.readMessages(); err != nil {
uc.logger.Printf("failed to handle messages: %v", err)
}
uc.Close()
u.lock.Lock()
for i := range u.upstreamConns {
if u.upstreamConns[i] == uc {
u.upstreamConns = append(u.upstreamConns[:i], u.upstreamConns[i+1:]...)
break
}
}
u.lock.Unlock()
}
}
func (s *Server) Run() {
// TODO: multi-user
u := newUser(s, "jounce")
s.lock.Lock() s.lock.Lock()
s.users[u.username] = u for _, record := range users {
s.Logger.Printf("starting bouncer for user %q", record.Username)
u := newUser(s, &record)
s.users[u.Username] = u
go u.run()
}
s.lock.Unlock() s.lock.Unlock()
for i := range s.Upstreams { select {}
go s.runUpstream(u, &s.Upstreams[i])
}
} }
func (s *Server) getUser(name string) *user { func (s *Server) getUser(name string) *user {

View File

@ -25,7 +25,7 @@ type upstreamChannel struct {
} }
type upstreamConn struct { type upstreamConn struct {
upstream *Upstream network *network
logger Logger logger Logger
net net.Conn net net.Conn
irc *irc.Conn irc *irc.Conn
@ -41,33 +41,40 @@ type upstreamConn struct {
registered bool registered bool
nick string nick string
username string
realname string
closed bool closed bool
modes modeSet modes modeSet
channels map[string]*upstreamChannel channels map[string]*upstreamChannel
history map[string]uint64 history map[string]uint64
} }
func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) { func connectToUpstream(network *network) (*upstreamConn, error) {
logger := &prefixLogger{u.srv.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)} logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
logger.Printf("connecting to server")
netConn, err := tls.Dial("tcp", upstream.Addr, nil) addr := network.Addr
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
logger.Printf("connecting to TLS server at address %q", addr)
netConn, err := tls.Dial("tcp", addr, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err) return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
} }
setKeepAlive(netConn) setKeepAlive(netConn)
msgs := make(chan *irc.Message, 64) msgs := make(chan *irc.Message, 64)
uc := &upstreamConn{ uc := &upstreamConn{
upstream: upstream, network: network,
logger: logger, logger: logger,
net: netConn, net: netConn,
irc: irc.NewConn(netConn), irc: irc.NewConn(netConn),
srv: u.srv, srv: network.user.srv,
user: u, user: network.user,
messages: msgs, messages: msgs,
ring: NewRing(u.srv.RingCap), ring: NewRing(network.user.srv.RingCap),
channels: make(map[string]*upstreamChannel), channels: make(map[string]*upstreamChannel),
history: make(map[string]uint64), history: make(map[string]uint64),
} }
@ -102,7 +109,7 @@ func (uc *upstreamConn) Close() error {
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.user.forEachDownstream(func(dc *downstreamConn) {
if dc.upstream != nil && dc.upstream != uc.upstream { if dc.network != nil && dc.network != uc.network {
return return
} }
f(dc) f(dc)
@ -163,10 +170,16 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true uc.registered = true
uc.logger.Printf("connection registered") uc.logger.Printf("connection registered")
for _, ch := range uc.upstream.Channels { channels, err := uc.srv.db.ListChannels(uc.network.ID)
if err != nil {
uc.logger.Printf("failed to list channels from database: %v", err)
break
}
for _, ch := range channels {
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "JOIN", Command: "JOIN",
Params: []string{ch}, Params: []string{ch.Name},
}) })
} }
case irc.RPL_MYINFO: case irc.RPL_MYINFO:
@ -371,14 +384,23 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
func (uc *upstreamConn) register() { func (uc *upstreamConn) register() {
uc.nick = uc.upstream.Nick uc.nick = uc.network.Nick
uc.username = uc.network.Username
if uc.username == "" {
uc.username = uc.nick
}
uc.realname = uc.network.Realname
if uc.realname == "" {
uc.realname = uc.nick
}
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{uc.nick}, Params: []string{uc.nick},
}) })
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "USER", Command: "USER",
Params: []string{uc.upstream.Username, "0", "*", uc.upstream.Realname}, Params: []string{uc.username, "0", "*", uc.realname},
}) })
} }