From 2713bcba343db3efb613846eb4f1ed75d6e48970 Mon Sep 17 00:00:00 2001 From: delthas Date: Fri, 20 Jan 2023 15:51:09 +0100 Subject: [PATCH] Add administrative unix listen endpoint This adds support for listening on a Unix socket for administrative connections, that then use a simple protocol for communicating with the service (BouncerServ) as an administrator with a global context. The wire format used by the Unix socket is IRC, but without registration or overheads. Example session: >>> BOUNCERSERV <<< 461 * BOUNCERSERV :Not enough parameters >>> BOUNCERSERV :n s <<< :gensou FAIL BOUNCERSERV :this command must be run as a user >>> BOUNCERSERV :u s <<< :gensou PRIVMSG * :marisa: 2 networks <<< :gensou PRIVMSG * :alice: 1 networks <<< :gensou BOUNCERSERV OK --- cmd/soju/main.go | 21 +++++++++-- doc/soju.1.scd | 4 ++- server.go | 94 +++++++++++++++++++++++++++++++++++++++++++++--- server_test.go | 2 +- 4 files changed, 112 insertions(+), 9 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index e8897ca..0e93c86 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -177,7 +177,7 @@ func main() { ln := tls.NewListener(l, ircsTLSCfg) ln = proxyProtoListener(ln, srv) go func() { - if err := srv.Serve(ln); err != nil { + if err := srv.Serve(ln, srv.Handle); err != nil { log.Printf("serving %q: %v", listen, err) } }() @@ -195,7 +195,7 @@ func main() { } ln = proxyProtoListener(ln, srv) go func() { - if err := srv.Serve(ln); err != nil { + if err := srv.Serve(ln, srv.Handle); err != nil { log.Printf("serving %q: %v", listen, err) } }() @@ -206,7 +206,22 @@ func main() { } ln = proxyProtoListener(ln, srv) go func() { - if err := srv.Serve(ln); err != nil { + if err := srv.Serve(ln, srv.Handle); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "unix+admin": + path := u.Path + if path == "" { + path = soju.DefaultUnixAdminPath + } + ln, err := net.Listen("unix", path) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + ln = proxyProtoListener(ln, srv) + go func() { + if err := srv.Serve(ln, srv.HandleAdmin); err != nil { log.Printf("serving %q: %v", listen, err) } }() diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 665dd55..8bc7d11 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -85,7 +85,7 @@ The following directives are supported: omitted: 6697) - _irc+insecure://[host][:port]_ listens with plain-text over TCP (default port if omitted: 6667) - - _unix:///_ listens on a Unix domain socket + - _unix://_ listens on a Unix domain socket - _wss://[host][:port]_ listens for WebSocket connections over TLS (default port: 443) - _ws+insecure://[host][:port]_ listens for plain-text WebSocket @@ -97,6 +97,8 @@ The following directives are supported: - _http+pprof://localhost:_ listens for plain-text HTTP connections and serves pprof runtime profiling data (host must be "localhost"). For more information, see: . + - _unix+admin://[path]_ listens on a Unix domain socket for administrative + connections, such as sojuctl (default path: /run/soju/admin) If the scheme is omitted, "ircs" is assumed. If multiple *listen* directives are specified, soju will listen on each of them. diff --git a/server.go b/server.go index 3fac69b..41d5c19 100644 --- a/server.go +++ b/server.go @@ -26,6 +26,8 @@ import ( "git.sr.ht/~emersion/soju/identd" ) +var DefaultUnixAdminPath = "/run/soju/admin" + // TODO: make configurable var retryConnectMinDelay = time.Minute var retryConnectMaxDelay = 10 * time.Minute @@ -437,7 +439,7 @@ func (s *Server) addUserLocked(user *database.User) *user { var lastDownstreamID uint64 -func (s *Server) handle(ic ircConn) { +func (s *Server) Handle(ic ircConn) { defer func() { if err := recover(); err != nil { s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, string(debug.Stack())) @@ -471,7 +473,91 @@ func (s *Server) handle(ic ircConn) { s.metrics.downstreams.Add(-1) } -func (s *Server) Serve(ln net.Listener) error { +func (s *Server) HandleAdmin(ic ircConn) { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving admin client %q: %v\n%v", ic.RemoteAddr(), err, string(debug.Stack())) + } + }() + + s.lock.Lock() + shutdown := s.shutdown + s.lock.Unlock() + + ctx := context.TODO() + remoteAddr := ic.RemoteAddr().String() + logger := &prefixLogger{s.Logger, fmt.Sprintf("admin %q: ", remoteAddr)} + c := newConn(s, ic, &connOptions{Logger: logger}) + defer c.Close() + + if shutdown { + c.SendMessage(ctx, &irc.Message{ + Command: "ERROR", + Params: []string{"Server is shutting down"}, + }) + return + } + for { + msg, err := c.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + logger.Printf("failed to read IRC command: %v", err) + break + } + switch msg.Command { + case "BOUNCERSERV": + if len(msg.Params) < 1 { + c.SendMessage(ctx, &irc.Message{ + Command: irc.ERR_NEEDMOREPARAMS, + Params: []string{ + "*", + msg.Command, + "Not enough parameters", + }, + }) + break + } + err := handleServicePRIVMSG(&serviceContext{ + Context: ctx, + srv: s, + admin: true, + print: func(text string) { + c.SendMessage(ctx, &irc.Message{ + Prefix: s.prefix(), + Command: "PRIVMSG", + Params: []string{"*", text}, + }) + }, + }, msg.Params[0]) + if err != nil { + c.SendMessage(ctx, &irc.Message{ + Prefix: s.prefix(), + Command: "FAIL", + Params: []string{msg.Command, err.Error()}, + }) + } else { + c.SendMessage(ctx, &irc.Message{ + Prefix: s.prefix(), + Command: msg.Command, + Params: []string{"OK"}, + }) + } + default: + c.SendMessage(ctx, &irc.Message{ + Prefix: s.prefix(), + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{ + "*", + msg.Command, + "Unknown command", + }, + }) + } + } +} + +func (s *Server) Serve(ln net.Listener, handler func(ircConn)) error { ln = &retryListener{ Listener: ln, Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())}, @@ -499,7 +585,7 @@ func (s *Server) Serve(ln net.Listener) error { return fmt.Errorf("failed to accept connection: %v", err) } - go s.handle(newNetIRCConn(conn)) + go handler(newNetIRCConn(conn)) } } @@ -530,7 +616,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } - s.handle(newWebsocketIRCConn(conn, remoteAddr)) + s.Handle(newWebsocketIRCConn(conn, remoteAddr)) } func parseForwarded(h http.Header) map[string]string { diff --git a/server_test.go b/server_test.go index 74a2e1f..bd53fae 100644 --- a/server_test.go +++ b/server_test.go @@ -61,7 +61,7 @@ func createTestUser(t *testing.T, db database.Database) *database.User { func createTestDownstream(t *testing.T, srv *Server) ircConn { c1, c2 := net.Pipe() - go srv.handle(newNetIRCConn(c1)) + go srv.Handle(newNetIRCConn(c1)) return newNetIRCConn(c2) }