Allow most config options to be reloaded

Closes: https://todo.sr.ht/~emersion/soju/42
This commit is contained in:
Simon Ser 2021-11-16 00:38:04 +01:00
parent e44f4b2eee
commit 73295e4fa7
7 changed files with 111 additions and 98 deletions

View File

@ -37,19 +37,6 @@ func (v *stringSliceFlag) Set(s string) error {
return nil return nil
} }
func loadMOTD(srv *soju.Server, filename string) error {
if filename == "" {
return nil
}
b, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
srv.SetMOTD(strings.TrimSuffix(string(b), "\n"))
return nil
}
func bumpOpenedFileLimit() error { func bumpOpenedFileLimit() error {
var rlimit syscall.Rlimit var rlimit syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
@ -62,24 +49,65 @@ func bumpOpenedFileLimit() error {
return nil return nil
} }
var (
configPath string
debug bool
tlsCert atomic.Value // *tls.Certificate
)
func loadConfig() (*config.Server, *soju.Config, error) {
var raw *config.Server
if configPath != "" {
var err error
raw, err = config.Load(configPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load config file: %v", err)
}
} else {
raw = config.Defaults()
}
var motd string
if raw.MOTDPath != "" {
b, err := ioutil.ReadFile(raw.MOTDPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
}
motd = strings.TrimSuffix(string(b), "\n")
}
if raw.TLS != nil {
cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
}
tlsCert.Store(&cert)
}
cfg := &soju.Config{
Hostname: raw.Hostname,
Title: raw.Title,
LogPath: raw.LogPath,
HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs,
MaxUserNetworks: raw.MaxUserNetworks,
Debug: debug,
MOTD: motd,
}
return raw, cfg, nil
}
func main() { func main() {
var listen []string var listen []string
var configPath string
var debug bool
flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
flag.StringVar(&configPath, "config", "", "path to configuration file") flag.StringVar(&configPath, "config", "", "path to configuration file")
flag.BoolVar(&debug, "debug", false, "enable debug logging") flag.BoolVar(&debug, "debug", false, "enable debug logging")
flag.Parse() flag.Parse()
var cfg *config.Server cfg, serverCfg, err := loadConfig()
if configPath != "" { if err != nil {
var err error log.Fatal(err)
cfg, err = config.Load(configPath)
if err != nil {
log.Fatalf("failed to load config file: %v", err)
}
} else {
cfg = config.Defaults()
} }
cfg.Listen = append(cfg.Listen, listen...) cfg.Listen = append(cfg.Listen, listen...)
@ -97,14 +125,7 @@ func main() {
} }
var tlsCfg *tls.Config var tlsCfg *tls.Config
var tlsCert atomic.Value
if cfg.TLS != nil { if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
if err != nil {
log.Fatalf("failed to load TLS certificate and key: %v", err)
}
tlsCert.Store(&cert)
tlsCfg = &tls.Config{ tlsCfg = &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return tlsCert.Load().(*tls.Certificate), nil return tlsCert.Load().(*tls.Certificate), nil
@ -113,17 +134,7 @@ func main() {
} }
srv := soju.NewServer(db) srv := soju.NewServer(db)
srv.Hostname = cfg.Hostname srv.SetConfig(serverCfg)
srv.Title = cfg.Title
srv.LogPath = cfg.LogPath
srv.HTTPOrigins = cfg.HTTPOrigins
srv.AcceptProxyIPs = cfg.AcceptProxyIPs
srv.MaxUserNetworks = cfg.MaxUserNetworks
srv.Debug = debug
if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
log.Fatalf("failed to load MOTD: %v", err)
}
for _, listen := range cfg.Listen { for _, listen := range cfg.Listen {
listenURI := listen listenURI := listen
@ -258,17 +269,12 @@ func main() {
for sig := range sigCh { for sig := range sigCh {
switch sig { switch sig {
case syscall.SIGHUP: case syscall.SIGHUP:
log.Print("reloading TLS certificate and MOTD") log.Print("reloading configuration")
if cfg.TLS != nil { _, serverCfg, err := loadConfig()
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) if err != nil {
if err != nil { log.Printf("failed to reloading configuration: %v", err)
log.Printf("failed to reload TLS certificate and key: %v", err) } else {
break srv.SetConfig(serverCfg)
}
tlsCert.Store(&cert)
}
if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
log.Printf("failed to reload MOTD: %v", err)
} }
case syscall.SIGINT, syscall.SIGTERM: case syscall.SIGINT, syscall.SIGTERM:
log.Print("shutting down server") log.Print("shutting down server")
@ -286,7 +292,7 @@ func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
if !ok { if !ok {
return proxyproto.IGNORE, nil return proxyproto.IGNORE, nil
} }
if srv.AcceptProxyIPs.Contains(tcpAddr.IP) { if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) {
return proxyproto.USE, nil return proxyproto.USE, nil
} }
return proxyproto.IGNORE, nil return proxyproto.IGNORE, nil

View File

@ -195,7 +195,7 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
<-rl.C <-rl.C
} }
if c.srv.Debug { if c.srv.Config().Debug {
c.logger.Printf("sent: %v", msg) c.logger.Printf("sent: %v", msg)
} }
c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
@ -248,7 +248,7 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
return nil, err return nil, err
} }
if c.srv.Debug { if c.srv.Config().Debug {
c.logger.Printf("received: %v", msg) c.logger.Printf("received: %v", msg)
} }

View File

@ -44,8 +44,9 @@ soju supports two connection modes:
For per-client history to work, clients need to indicate their name. This can For per-client history to work, clients need to indicate their name. This can
be done by adding a "@<client>" suffix to the username. be done by adding a "@<client>" suffix to the username.
soju will reload the TLS certificate/key and the MOTD file when it receives the soju will reload the configuration file, the TLS certificate/key and the MOTD
HUP signal. file when it receives the HUP signal. The configuration options _listen_, _db_
and _log_ cannot be reloaded.
Administrators can broadcast a message to all bouncer users via _/notice Administrators can broadcast a message to all bouncer users via _/notice
$<hostname> <text>_, or via _/notice $\* <text>_ in multi-upstream mode. All $<hostname> <text>_, or via _/notice $\* <text>_ in multi-upstream mode. All

View File

@ -290,7 +290,10 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
for k, v := range permanentDownstreamCaps { for k, v := range permanentDownstreamCaps {
dc.supportedCaps[k] = v dc.supportedCaps[k] = v
} }
if srv.LogPath != "" { // TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements
// chatHistoryMessageStore
if srv.Config().LogPath != "" {
dc.supportedCaps["draft/chathistory"] = "" dc.supportedCaps["draft/chathistory"] = ""
} }
return dc return dc
@ -996,7 +999,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
} }
} }
if dc.srv.LogPath != "" && dc.network != nil { if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
dc.setSupportedCap("draft/event-playback", "") dc.setSupportedCap("draft/event-playback", "")
} else { } else {
dc.unsetSupportedCap("draft/event-playback") dc.unsetSupportedCap("draft/event-playback")
@ -1175,8 +1178,8 @@ func (dc *downstreamConn) welcome() error {
if dc.network != nil { if dc.network != nil {
isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
} }
if dc.network == nil && dc.srv.Title != "" { if title := dc.srv.Config().Title; dc.network == nil && title != "" {
isupport = append(isupport, "NETWORK="+encodeISUPPORT(dc.srv.Title)) isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
} }
if dc.network == nil && dc.caps["soju.im/bouncer-networks"] { if dc.network == nil && dc.caps["soju.im/bouncer-networks"] {
isupport = append(isupport, "WHOX") isupport = append(isupport, "WHOX")
@ -1204,12 +1207,12 @@ func (dc *downstreamConn) welcome() error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_YOURHOST, Command: irc.RPL_YOURHOST,
Params: []string{dc.nick, "Your host is " + dc.srv.Hostname}, Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
}) })
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_MYINFO, Command: irc.RPL_MYINFO,
Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"}, Params: []string{dc.nick, dc.srv.Config().Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
}) })
for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
dc.SendMessage(msg) dc.SendMessage(msg)
@ -1229,7 +1232,7 @@ func (dc *downstreamConn) welcome() error {
}) })
} }
if motd := dc.user.srv.MOTD(); motd != "" && dc.network == nil { if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
dc.SendMessage(msg) dc.SendMessage(msg)
} }
@ -1420,7 +1423,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
if len(msg.Params) > 1 { if len(msg.Params) > 1 {
destination = msg.Params[1] destination = msg.Params[1]
} }
if destination != "" && destination != dc.srv.Hostname { hostname := dc.srv.Config().Hostname
if destination != "" && destination != hostname {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHSERVER, Command: irc.ERR_NOSUCHSERVER,
Params: []string{dc.nick, destination, "No such server"}, Params: []string{dc.nick, destination, "No such server"},
@ -1429,7 +1433,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "PONG", Command: "PONG",
Params: []string{dc.srv.Hostname, source}, Params: []string{hostname, source},
}) })
return nil return nil
case "PONG": case "PONG":
@ -1946,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Token: whoxToken, Token: whoxToken,
Username: dc.user.Username, Username: dc.user.Username,
Hostname: dc.hostname, Hostname: dc.hostname,
Server: dc.srv.Hostname, Server: dc.srv.Config().Hostname,
Nickname: dc.nick, Nickname: dc.nick,
Flags: flags, Flags: flags,
Account: dc.user.Username, Account: dc.user.Username,
@ -1965,7 +1969,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Token: whoxToken, Token: whoxToken,
Username: servicePrefix.User, Username: servicePrefix.User,
Hostname: servicePrefix.Host, Hostname: servicePrefix.Host,
Server: dc.srv.Hostname, Server: dc.srv.Config().Hostname,
Nickname: serviceNick, Nickname: serviceNick,
Flags: "H*", Flags: "H*",
Account: serviceNick, Account: serviceNick,
@ -2025,7 +2029,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISSERVER, Command: irc.RPL_WHOISSERVER,
Params: []string{dc.nick, dc.nick, dc.srv.Hostname, "soju"}, Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "soju"},
}) })
if dc.user.Admin { if dc.user.Admin {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -2055,7 +2059,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISSERVER, Command: irc.RPL_WHOISSERVER,
Params: []string{dc.nick, serviceNick, dc.srv.Hostname, "soju"}, Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "soju"},
}) })
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -2104,7 +2108,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
tags := copyClientTags(msg.Tags) tags := copyClientTags(msg.Tags)
for _, name := range strings.Split(targetsStr, ",") { for _, name := range strings.Split(targetsStr, ",") {
if name == "$"+dc.srv.Hostname || (name == "$*" && dc.network == nil) { if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
// "$" means a server mask follows. If it's the bouncer's // "$" means a server mask follows. If it's the bouncer's
// hostname, broadcast the message to all bouncer users. // hostname, broadcast the message to all bouncer users.
if !dc.user.Admin { if !dc.user.Admin {

View File

@ -53,17 +53,22 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
l.logger.Printf("%v"+format, v...) l.logger.Printf("%v"+format, v...)
} }
type Server struct { type Config struct {
Hostname string Hostname string
Title string Title string
Logger Logger
LogPath string LogPath string
Debug bool Debug bool
HTTPOrigins []string HTTPOrigins []string
AcceptProxyIPs config.IPSet AcceptProxyIPs config.IPSet
MaxUserNetworks int MaxUserNetworks int
Identd *Identd // can be nil MOTD string
}
type Server struct {
Logger Logger
Identd *Identd // can be nil
config atomic.Value // *Config
db Database db Database
stopWG sync.WaitGroup stopWG sync.WaitGroup
connCount int64 // atomic connCount int64 // atomic
@ -71,24 +76,29 @@ type Server struct {
lock sync.Mutex lock sync.Mutex
listeners map[net.Listener]struct{} listeners map[net.Listener]struct{}
users map[string]*user users map[string]*user
motd atomic.Value // string
} }
func NewServer(db Database) *Server { func NewServer(db Database) *Server {
srv := &Server{ srv := &Server{
Logger: log.New(log.Writer(), "", log.LstdFlags), Logger: log.New(log.Writer(), "", log.LstdFlags),
MaxUserNetworks: -1, db: db,
db: db, listeners: make(map[net.Listener]struct{}),
listeners: make(map[net.Listener]struct{}), users: make(map[string]*user),
users: make(map[string]*user),
} }
srv.motd.Store("") srv.config.Store(&Config{Hostname: "localhost", MaxUserNetworks: -1})
return srv return srv
} }
func (s *Server) prefix() *irc.Prefix { func (s *Server) prefix() *irc.Prefix {
return &irc.Prefix{Name: s.Hostname} return &irc.Prefix{Name: s.Config().Hostname}
}
func (s *Server) Config() *Config {
return s.config.Load().(*Config)
}
func (s *Server) SetConfig(cfg *Config) {
s.config.Store(cfg)
} }
func (s *Server) Start() error { func (s *Server) Start() error {
@ -239,7 +249,7 @@ func (s *Server) Serve(ln net.Listener) error {
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{ conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
OriginPatterns: s.HTTPOrigins, OriginPatterns: s.Config().HTTPOrigins,
}) })
if err != nil { if err != nil {
s.Logger.Printf("failed to serve HTTP connection: %v", err) s.Logger.Printf("failed to serve HTTP connection: %v", err)
@ -249,7 +259,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
isProxy := false isProxy := false
if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
isProxy = s.AcceptProxyIPs.Contains(ip) isProxy = s.Config().AcceptProxyIPs.Contains(ip)
} }
} }
@ -293,11 +303,3 @@ func (s *Server) Stats() *ServerStats {
stats.Downstreams = atomic.LoadInt64(&s.connCount) stats.Downstreams = atomic.LoadInt64(&s.connCount)
return &stats return &stats
} }
func (s *Server) SetMOTD(motd string) {
s.motd.Store(motd)
}
func (s *Server) MOTD() string {
return s.motd.Load().(string)
}

View File

@ -1050,7 +1050,7 @@ func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params [
broadcastMsg := &irc.Message{ broadcastMsg := &irc.Message{
Prefix: servicePrefix, Prefix: servicePrefix,
Command: "NOTICE", Command: "NOTICE",
Params: []string{"$" + dc.srv.Hostname, text}, Params: []string{"$" + dc.srv.Config().Hostname, text},
} }
var err error var err error
dc.srv.forEachUser(func(u *user) { dc.srv.forEachUser(func(u *user) {

View File

@ -415,8 +415,8 @@ func newUser(srv *Server, record *User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore var msgStore messageStore
if srv.LogPath != "" { if logPath := srv.Config().LogPath; logPath != "" {
msgStore = newFSMessageStore(srv.LogPath, record.Username) msgStore = newFSMessageStore(logPath, record.Username)
} else { } else {
msgStore = newMemoryMessageStore() msgStore = newMemoryMessageStore()
} }
@ -776,7 +776,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
return nil, err return nil, err
} }
if u.srv.MaxUserNetworks >= 0 && len(u.networks) >= u.srv.MaxUserNetworks { if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
return nil, fmt.Errorf("maximum number of networks reached") return nil, fmt.Errorf("maximum number of networks reached")
} }