Add support for graceful shutdown

Closes: https://todo.sr.ht/~emersion/soju/45
This commit is contained in:
Simon Ser 2021-02-09 17:34:46 +01:00
parent 5aa15d5628
commit 08b1010939
2 changed files with 76 additions and 13 deletions

View File

@ -7,7 +7,10 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"os/signal"
"strings" "strings"
"syscall"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
@ -89,7 +92,9 @@ func main() {
} }
ln = proxyProtoListener(ln, srv) ln = proxyProtoListener(ln, srv)
go func() { go func() {
log.Fatal(srv.Serve(ln)) if err := srv.Serve(ln); err != nil {
log.Printf("serving %q: %v", listen, err)
}
}() }()
case "irc+insecure": case "irc+insecure":
host := u.Host host := u.Host
@ -102,7 +107,9 @@ func main() {
} }
ln = proxyProtoListener(ln, srv) ln = proxyProtoListener(ln, srv)
go func() { go func() {
log.Fatal(srv.Serve(ln)) if err := srv.Serve(ln); err != nil {
log.Printf("serving %q: %v", listen, err)
}
}() }()
case "wss": case "wss":
addr := u.Host addr := u.Host
@ -115,7 +122,9 @@ func main() {
Handler: srv, Handler: srv,
} }
go func() { go func() {
log.Fatal(httpSrv.ListenAndServeTLS("", "")) if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
log.Fatalf("serving %q: %v", listen, err)
}
}() }()
case "ws+insecure": case "ws+insecure":
addr := u.Host addr := u.Host
@ -127,7 +136,9 @@ func main() {
Handler: srv, Handler: srv,
} }
go func() { go func() {
log.Fatal(httpSrv.ListenAndServe()) if err := httpSrv.ListenAndServe(); err != nil {
log.Fatalf("serving %q: %v", listen, err)
}
}() }()
case "ident": case "ident":
if srv.Identd == nil { if srv.Identd == nil {
@ -144,7 +155,9 @@ func main() {
} }
ln = proxyProtoListener(ln, srv) ln = proxyProtoListener(ln, srv)
go func() { go func() {
log.Fatal(srv.Identd.Serve(ln)) if err := srv.Identd.Serve(ln); err != nil {
log.Printf("serving %q: %v", listen, err)
}
}() }()
default: default:
log.Fatalf("failed to listen on %q: unsupported scheme", listen) log.Fatalf("failed to listen on %q: unsupported scheme", listen)
@ -152,7 +165,17 @@ func main() {
log.Printf("server listening on %q", listen) log.Printf("server listening on %q", listen)
} }
log.Fatal(srv.Run())
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
if err := srv.Start(); err != nil {
log.Fatal(err)
}
<-sigCh
log.Print("shutting down server")
srv.Shutdown()
} }
func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener { func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -54,18 +55,21 @@ type Server struct {
AcceptProxyIPs config.IPSet AcceptProxyIPs config.IPSet
Identd *Identd // can be nil Identd *Identd // can be nil
db *DB db *DB
stopWG sync.WaitGroup
lock sync.Mutex lock sync.Mutex
users map[string]*user listeners map[net.Listener]struct{}
users map[string]*user
} }
func NewServer(db *DB) *Server { func NewServer(db *DB) *Server {
return &Server{ return &Server{
Logger: log.New(log.Writer(), "", log.LstdFlags), Logger: log.New(log.Writer(), "", log.LstdFlags),
HistoryLimit: 1000, HistoryLimit: 1000,
users: make(map[string]*user),
db: db, db: db,
listeners: make(map[net.Listener]struct{}),
users: make(map[string]*user),
} }
} }
@ -73,7 +77,7 @@ func (s *Server) prefix() *irc.Prefix {
return &irc.Prefix{Name: s.Hostname} return &irc.Prefix{Name: s.Hostname}
} }
func (s *Server) Run() error { func (s *Server) Start() error {
users, err := s.db.ListUsers() users, err := s.db.ListUsers()
if err != nil { if err != nil {
return err return err
@ -85,7 +89,22 @@ func (s *Server) Run() error {
} }
s.lock.Unlock() s.lock.Unlock()
select {} return nil
}
func (s *Server) Shutdown() {
s.lock.Lock()
for ln := range s.listeners {
if err := ln.Close(); err != nil {
s.Logger.Printf("failed to stop listener: %v", err)
}
}
for _, u := range s.users {
u.events <- eventStop{}
}
s.lock.Unlock()
s.stopWG.Wait()
} }
func (s *Server) createUser(user *User) (*user, error) { func (s *Server) createUser(user *User) (*user, error) {
@ -116,12 +135,16 @@ func (s *Server) addUserLocked(user *User) *user {
u := newUser(s, user) u := newUser(s, user)
s.users[u.Username] = u s.users[u.Username] = u
s.stopWG.Add(1)
go func() { go func() {
u.run() u.run()
s.lock.Lock() s.lock.Lock()
delete(s.users, u.Username) delete(s.users, u.Username)
s.lock.Unlock() s.lock.Unlock()
s.stopWG.Done()
}() }()
return u return u
@ -145,9 +168,26 @@ func (s *Server) handle(ic ircConn) {
} }
func (s *Server) Serve(ln net.Listener) error { func (s *Server) Serve(ln net.Listener) error {
s.lock.Lock()
s.listeners[ln] = struct{}{}
s.lock.Unlock()
s.stopWG.Add(1)
defer func() {
s.lock.Lock()
delete(s.listeners, ln)
s.lock.Unlock()
s.stopWG.Done()
}()
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { // TODO: use net.ErrClosed when available
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
return nil
} else if err != nil {
return fmt.Errorf("failed to accept connection: %v", err) return fmt.Errorf("failed to accept connection: %v", err)
} }