From 08b1010939f05d4ef6d55f28b77bb397d1c54d77 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Tue, 9 Feb 2021 17:34:46 +0100 Subject: [PATCH] Add support for graceful shutdown Closes: https://todo.sr.ht/~emersion/soju/45 --- cmd/soju/main.go | 35 +++++++++++++++++++++++++------ server.go | 54 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 9d74e0f..5c5115f 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -7,7 +7,10 @@ import ( "net" "net/http" "net/url" + "os" + "os/signal" "strings" + "syscall" "github.com/pires/go-proxyproto" @@ -89,7 +92,9 @@ func main() { } ln = proxyProtoListener(ln, srv) go func() { - log.Fatal(srv.Serve(ln)) + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } }() case "irc+insecure": host := u.Host @@ -102,7 +107,9 @@ func main() { } ln = proxyProtoListener(ln, srv) go func() { - log.Fatal(srv.Serve(ln)) + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } }() case "wss": addr := u.Host @@ -115,7 +122,9 @@ func main() { Handler: srv, } go func() { - log.Fatal(httpSrv.ListenAndServeTLS("", "")) + if err := httpSrv.ListenAndServeTLS("", ""); err != nil { + log.Fatalf("serving %q: %v", listen, err) + } }() case "ws+insecure": addr := u.Host @@ -127,7 +136,9 @@ func main() { Handler: srv, } go func() { - log.Fatal(httpSrv.ListenAndServe()) + if err := httpSrv.ListenAndServe(); err != nil { + log.Fatalf("serving %q: %v", listen, err) + } }() case "ident": if srv.Identd == nil { @@ -144,7 +155,9 @@ func main() { } ln = proxyProtoListener(ln, srv) go func() { - log.Fatal(srv.Identd.Serve(ln)) + if err := srv.Identd.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } }() default: log.Fatalf("failed to listen on %q: unsupported scheme", listen) @@ -152,7 +165,17 @@ func main() { log.Printf("server listening on %q", listen) } - log.Fatal(srv.Run()) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + if err := srv.Start(); err != nil { + log.Fatal(err) + } + + <-sigCh + log.Print("shutting down server") + srv.Shutdown() } func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener { diff --git a/server.go b/server.go index b04f385..aacaf87 100644 --- a/server.go +++ b/server.go @@ -5,6 +5,7 @@ import ( "log" "net" "net/http" + "strings" "sync" "sync/atomic" "time" @@ -54,18 +55,21 @@ type Server struct { AcceptProxyIPs config.IPSet Identd *Identd // can be nil - db *DB + db *DB + stopWG sync.WaitGroup - lock sync.Mutex - users map[string]*user + lock sync.Mutex + listeners map[net.Listener]struct{} + users map[string]*user } func NewServer(db *DB) *Server { return &Server{ Logger: log.New(log.Writer(), "", log.LstdFlags), HistoryLimit: 1000, - users: make(map[string]*user), db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), } } @@ -73,7 +77,7 @@ func (s *Server) prefix() *irc.Prefix { return &irc.Prefix{Name: s.Hostname} } -func (s *Server) Run() error { +func (s *Server) Start() error { users, err := s.db.ListUsers() if err != nil { return err @@ -85,7 +89,22 @@ func (s *Server) Run() error { } s.lock.Unlock() - select {} + return nil +} + +func (s *Server) Shutdown() { + s.lock.Lock() + for ln := range s.listeners { + if err := ln.Close(); err != nil { + s.Logger.Printf("failed to stop listener: %v", err) + } + } + for _, u := range s.users { + u.events <- eventStop{} + } + s.lock.Unlock() + + s.stopWG.Wait() } func (s *Server) createUser(user *User) (*user, error) { @@ -116,12 +135,16 @@ func (s *Server) addUserLocked(user *User) *user { u := newUser(s, user) s.users[u.Username] = u + s.stopWG.Add(1) + go func() { u.run() s.lock.Lock() delete(s.users, u.Username) s.lock.Unlock() + + s.stopWG.Done() }() return u @@ -145,9 +168,26 @@ func (s *Server) handle(ic ircConn) { } func (s *Server) Serve(ln net.Listener) error { + s.lock.Lock() + s.listeners[ln] = struct{}{} + s.lock.Unlock() + + s.stopWG.Add(1) + + defer func() { + s.lock.Lock() + delete(s.listeners, ln) + s.lock.Unlock() + + s.stopWG.Done() + }() + for { conn, err := ln.Accept() - if err != nil { + // TODO: use net.ErrClosed when available + if err != nil && strings.Contains(err.Error(), "use of closed network connection") { + return nil + } else if err != nil { return fmt.Errorf("failed to accept connection: %v", err) }