From 73295e4fa7df7161d8859bc8bcb9407090af7fb6 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Tue, 16 Nov 2021 00:38:04 +0100 Subject: [PATCH] Allow most config options to be reloaded Closes: https://todo.sr.ht/~emersion/soju/42 --- cmd/soju/main.go | 114 +++++++++++++++++++++++++---------------------- conn.go | 4 +- doc/soju.1.scd | 5 ++- downstream.go | 32 +++++++------ server.go | 46 ++++++++++--------- service.go | 2 +- user.go | 6 +-- 7 files changed, 111 insertions(+), 98 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 186573e..42cb811 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -37,19 +37,6 @@ func (v *stringSliceFlag) Set(s string) error { return nil } -func loadMOTD(srv *soju.Server, filename string) error { - if filename == "" { - return nil - } - - b, err := ioutil.ReadFile(filename) - if err != nil { - return err - } - srv.SetMOTD(strings.TrimSuffix(string(b), "\n")) - return nil -} - func bumpOpenedFileLimit() error { var rlimit syscall.Rlimit if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { @@ -62,24 +49,65 @@ func bumpOpenedFileLimit() error { return nil } +var ( + configPath string + debug bool + + tlsCert atomic.Value // *tls.Certificate +) + +func loadConfig() (*config.Server, *soju.Config, error) { + var raw *config.Server + if configPath != "" { + var err error + raw, err = config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load config file: %v", err) + } + } else { + raw = config.Defaults() + } + + var motd string + if raw.MOTDPath != "" { + b, err := ioutil.ReadFile(raw.MOTDPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load MOTD: %v", err) + } + motd = strings.TrimSuffix(string(b), "\n") + } + + if raw.TLS != nil { + cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err) + } + tlsCert.Store(&cert) + } + + cfg := &soju.Config{ + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.LogPath, + HTTPOrigins: raw.HTTPOrigins, + AcceptProxyIPs: raw.AcceptProxyIPs, + MaxUserNetworks: raw.MaxUserNetworks, + Debug: debug, + MOTD: motd, + } + return raw, cfg, nil +} + func main() { var listen []string - var configPath string - var debug bool flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") flag.StringVar(&configPath, "config", "", "path to configuration file") flag.BoolVar(&debug, "debug", false, "enable debug logging") flag.Parse() - var cfg *config.Server - if configPath != "" { - var err error - cfg, err = config.Load(configPath) - if err != nil { - log.Fatalf("failed to load config file: %v", err) - } - } else { - cfg = config.Defaults() + cfg, serverCfg, err := loadConfig() + if err != nil { + log.Fatal(err) } cfg.Listen = append(cfg.Listen, listen...) @@ -97,14 +125,7 @@ func main() { } var tlsCfg *tls.Config - var tlsCert atomic.Value if cfg.TLS != nil { - cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) - if err != nil { - log.Fatalf("failed to load TLS certificate and key: %v", err) - } - tlsCert.Store(&cert) - tlsCfg = &tls.Config{ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return tlsCert.Load().(*tls.Certificate), nil @@ -113,17 +134,7 @@ func main() { } srv := soju.NewServer(db) - srv.Hostname = cfg.Hostname - srv.Title = cfg.Title - srv.LogPath = cfg.LogPath - srv.HTTPOrigins = cfg.HTTPOrigins - srv.AcceptProxyIPs = cfg.AcceptProxyIPs - srv.MaxUserNetworks = cfg.MaxUserNetworks - srv.Debug = debug - - if err := loadMOTD(srv, cfg.MOTDPath); err != nil { - log.Fatalf("failed to load MOTD: %v", err) - } + srv.SetConfig(serverCfg) for _, listen := range cfg.Listen { listenURI := listen @@ -258,17 +269,12 @@ func main() { for sig := range sigCh { switch sig { case syscall.SIGHUP: - log.Print("reloading TLS certificate and MOTD") - if cfg.TLS != nil { - cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) - if err != nil { - log.Printf("failed to reload TLS certificate and key: %v", err) - break - } - tlsCert.Store(&cert) - } - if err := loadMOTD(srv, cfg.MOTDPath); err != nil { - log.Printf("failed to reload MOTD: %v", err) + log.Print("reloading configuration") + _, serverCfg, err := loadConfig() + if err != nil { + log.Printf("failed to reloading configuration: %v", err) + } else { + srv.SetConfig(serverCfg) } case syscall.SIGINT, syscall.SIGTERM: log.Print("shutting down server") @@ -286,7 +292,7 @@ func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener { if !ok { return proxyproto.IGNORE, nil } - if srv.AcceptProxyIPs.Contains(tcpAddr.IP) { + if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) { return proxyproto.USE, nil } return proxyproto.IGNORE, nil diff --git a/conn.go b/conn.go index e25f7da..82ee4eb 100644 --- a/conn.go +++ b/conn.go @@ -195,7 +195,7 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn { <-rl.C } - if c.srv.Debug { + if c.srv.Config().Debug { c.logger.Printf("sent: %v", msg) } c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) @@ -248,7 +248,7 @@ func (c *conn) ReadMessage() (*irc.Message, error) { return nil, err } - if c.srv.Debug { + if c.srv.Config().Debug { c.logger.Printf("received: %v", msg) } diff --git a/doc/soju.1.scd b/doc/soju.1.scd index a21979a..de70b6c 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -44,8 +44,9 @@ soju supports two connection modes: For per-client history to work, clients need to indicate their name. This can be done by adding a "@" suffix to the username. -soju will reload the TLS certificate/key and the MOTD file when it receives the -HUP signal. +soju will reload the configuration file, the TLS certificate/key and the MOTD +file when it receives the HUP signal. The configuration options _listen_, _db_ +and _log_ cannot be reloaded. Administrators can broadcast a message to all bouncer users via _/notice $ _, or via _/notice $\* _ in multi-upstream mode. All diff --git a/downstream.go b/downstream.go index 649fc66..af91df3 100644 --- a/downstream.go +++ b/downstream.go @@ -290,7 +290,10 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { for k, v := range permanentDownstreamCaps { dc.supportedCaps[k] = v } - if srv.LogPath != "" { + // TODO: this is racy, we should only enable chathistory after + // authentication and then check that user.msgStore implements + // chatHistoryMessageStore + if srv.Config().LogPath != "" { dc.supportedCaps["draft/chathistory"] = "" } return dc @@ -996,7 +999,7 @@ func (dc *downstreamConn) updateSupportedCaps() { } } - if dc.srv.LogPath != "" && dc.network != nil { + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { dc.setSupportedCap("draft/event-playback", "") } else { dc.unsetSupportedCap("draft/event-playback") @@ -1175,8 +1178,8 @@ func (dc *downstreamConn) welcome() error { if dc.network != nil { isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) } - if dc.network == nil && dc.srv.Title != "" { - isupport = append(isupport, "NETWORK="+encodeISUPPORT(dc.srv.Title)) + if title := dc.srv.Config().Title; dc.network == nil && title != "" { + isupport = append(isupport, "NETWORK="+encodeISUPPORT(title)) } if dc.network == nil && dc.caps["soju.im/bouncer-networks"] { isupport = append(isupport, "WHOX") @@ -1204,12 +1207,12 @@ func (dc *downstreamConn) welcome() error { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_YOURHOST, - Params: []string{dc.nick, "Your host is " + dc.srv.Hostname}, + Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, }) dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_MYINFO, - Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"}, + Params: []string{dc.nick, dc.srv.Config().Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"}, }) for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { dc.SendMessage(msg) @@ -1229,7 +1232,7 @@ func (dc *downstreamConn) welcome() error { }) } - if motd := dc.user.srv.MOTD(); motd != "" && dc.network == nil { + if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { dc.SendMessage(msg) } @@ -1420,7 +1423,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { if len(msg.Params) > 1 { destination = msg.Params[1] } - if destination != "" && destination != dc.srv.Hostname { + hostname := dc.srv.Config().Hostname + if destination != "" && destination != hostname { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHSERVER, Params: []string{dc.nick, destination, "No such server"}, @@ -1429,7 +1433,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "PONG", - Params: []string{dc.srv.Hostname, source}, + Params: []string{hostname, source}, }) return nil case "PONG": @@ -1946,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Token: whoxToken, Username: dc.user.Username, Hostname: dc.hostname, - Server: dc.srv.Hostname, + Server: dc.srv.Config().Hostname, Nickname: dc.nick, Flags: flags, Account: dc.user.Username, @@ -1965,7 +1969,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Token: whoxToken, Username: servicePrefix.User, Hostname: servicePrefix.Host, - Server: dc.srv.Hostname, + Server: dc.srv.Config().Hostname, Nickname: serviceNick, Flags: "H*", Account: serviceNick, @@ -2025,7 +2029,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISSERVER, - Params: []string{dc.nick, dc.nick, dc.srv.Hostname, "soju"}, + Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "soju"}, }) if dc.user.Admin { dc.SendMessage(&irc.Message{ @@ -2055,7 +2059,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISSERVER, - Params: []string{dc.nick, serviceNick, dc.srv.Hostname, "soju"}, + Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "soju"}, }) dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), @@ -2104,7 +2108,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { tags := copyClientTags(msg.Tags) for _, name := range strings.Split(targetsStr, ",") { - if name == "$"+dc.srv.Hostname || (name == "$*" && dc.network == nil) { + if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) { // "$" means a server mask follows. If it's the bouncer's // hostname, broadcast the message to all bouncer users. if !dc.user.Admin { diff --git a/server.go b/server.go index 1e08885..4505de9 100644 --- a/server.go +++ b/server.go @@ -53,17 +53,22 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) { l.logger.Printf("%v"+format, v...) } -type Server struct { +type Config struct { Hostname string Title string - Logger Logger LogPath string Debug bool HTTPOrigins []string AcceptProxyIPs config.IPSet MaxUserNetworks int - Identd *Identd // can be nil + MOTD string +} +type Server struct { + Logger Logger + Identd *Identd // can be nil + + config atomic.Value // *Config db Database stopWG sync.WaitGroup connCount int64 // atomic @@ -71,24 +76,29 @@ type Server struct { lock sync.Mutex listeners map[net.Listener]struct{} users map[string]*user - - motd atomic.Value // string } func NewServer(db Database) *Server { srv := &Server{ - Logger: log.New(log.Writer(), "", log.LstdFlags), - MaxUserNetworks: -1, - db: db, - listeners: make(map[net.Listener]struct{}), - users: make(map[string]*user), + Logger: log.New(log.Writer(), "", log.LstdFlags), + db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), } - srv.motd.Store("") + srv.config.Store(&Config{Hostname: "localhost", MaxUserNetworks: -1}) return srv } func (s *Server) prefix() *irc.Prefix { - return &irc.Prefix{Name: s.Hostname} + return &irc.Prefix{Name: s.Config().Hostname} +} + +func (s *Server) Config() *Config { + return s.config.Load().(*Config) +} + +func (s *Server) SetConfig(cfg *Config) { + s.config.Store(cfg) } func (s *Server) Start() error { @@ -239,7 +249,7 @@ func (s *Server) Serve(ln net.Listener) error { func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{ Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me - OriginPatterns: s.HTTPOrigins, + OriginPatterns: s.Config().HTTPOrigins, }) if err != nil { s.Logger.Printf("failed to serve HTTP connection: %v", err) @@ -249,7 +259,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { isProxy := false if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if ip := net.ParseIP(host); ip != nil { - isProxy = s.AcceptProxyIPs.Contains(ip) + isProxy = s.Config().AcceptProxyIPs.Contains(ip) } } @@ -293,11 +303,3 @@ func (s *Server) Stats() *ServerStats { stats.Downstreams = atomic.LoadInt64(&s.connCount) return &stats } - -func (s *Server) SetMOTD(motd string) { - s.motd.Store(motd) -} - -func (s *Server) MOTD() string { - return s.motd.Load().(string) -} diff --git a/service.go b/service.go index 4cf4150..72b45f2 100644 --- a/service.go +++ b/service.go @@ -1050,7 +1050,7 @@ func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params [ broadcastMsg := &irc.Message{ Prefix: servicePrefix, Command: "NOTICE", - Params: []string{"$" + dc.srv.Hostname, text}, + Params: []string{"$" + dc.srv.Config().Hostname, text}, } var err error dc.srv.forEachUser(func(u *user) { diff --git a/user.go b/user.go index 276d527..a1e9199 100644 --- a/user.go +++ b/user.go @@ -415,8 +415,8 @@ func newUser(srv *Server, record *User) *user { logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} var msgStore messageStore - if srv.LogPath != "" { - msgStore = newFSMessageStore(srv.LogPath, record.Username) + if logPath := srv.Config().LogPath; logPath != "" { + msgStore = newFSMessageStore(logPath, record.Username) } else { msgStore = newMemoryMessageStore() } @@ -776,7 +776,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er return nil, err } - if u.srv.MaxUserNetworks >= 0 && len(u.networks) >= u.srv.MaxUserNetworks { + if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { return nil, fmt.Errorf("maximum number of networks reached") }