From c709ebfc912cfca9b9c412bc27bd811d5115ba51 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Tue, 2 Jun 2020 11:39:53 +0200 Subject: [PATCH] Add network update command The user.updateNetwork function is a bit involved because we need to make sure that the upstream connection is closed before re-connecting (would otherwise cause "Nick already used" errors) and that the downstream connections' state is kept in sync. References: https://todo.sr.ht/~emersion/soju/17 --- service.go | 170 +++++++++++++++++++++++++++++++++++++++------------- user.go | 173 ++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 253 insertions(+), 90 deletions(-) diff --git a/service.go b/service.go index 3be9683..2245f53 100644 --- a/service.go +++ b/service.go @@ -118,7 +118,7 @@ func init() { "network": { children: serviceCommandSet{ "create": { - usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]", + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...", desc: "add a new network", handle: handleServiceCreateNetwork, }, @@ -126,6 +126,11 @@ func init() { desc: "show a list of saved networks and their current status", handle: handleServiceNetworkStatus, }, + "update": { + usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...", + desc: "update a network", + handle: handleServiceNetworkUpdate, + }, "delete": { usage: "", desc: "delete a network", @@ -338,65 +343,115 @@ func newFlagSet() *flag.FlagSet { return fs } -type stringSliceVar []string +type stringSliceFlag []string -func (v *stringSliceVar) String() string { +func (v *stringSliceFlag) String() string { return fmt.Sprint([]string(*v)) } -func (v *stringSliceVar) Set(s string) error { +func (v *stringSliceFlag) Set(s string) error { *v = append(*v, s) return nil } -func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { - fs := newFlagSet() - addr := fs.String("addr", "", "") - name := fs.String("name", "", "") - username := fs.String("username", "", "") - pass := fs.String("pass", "", "") - realname := fs.String("realname", "", "") - nick := fs.String("nick", "", "") - var connectCommands stringSliceVar - fs.Var(&connectCommands, "connect-command", "") +// stringPtrFlag is a flag value populating a string pointer. This allows to +// disambiguate between a flag that hasn't been set and a flag that has been +// set to an empty string. +type stringPtrFlag struct { + ptr **string +} +func (f stringPtrFlag) String() string { + if *f.ptr == nil { + return "" + } + return **f.ptr +} + +func (f stringPtrFlag) Set(s string) error { + *f.ptr = &s + return nil +} + +type networkFlagSet struct { + *flag.FlagSet + Addr, Name, Nick, Username, Pass, Realname *string + ConnectCommands []string +} + +func newNetworkFlagSet() *networkFlagSet { + fs := &networkFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.Addr}, "addr", "") + fs.Var(stringPtrFlag{&fs.Name}, "name", "") + fs.Var(stringPtrFlag{&fs.Nick}, "nick", "") + fs.Var(stringPtrFlag{&fs.Username}, "username", "") + fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") + fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") + return fs +} + +func (fs *networkFlagSet) update(network *Network) error { + if fs.Addr != nil { + if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { + scheme := addrParts[0] + switch scheme { + case "ircs", "irc+insecure": + default: + return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme) + } + } + network.Addr = *fs.Addr + } + if fs.Name != nil { + network.Name = *fs.Name + } + if fs.Nick != nil { + network.Nick = *fs.Nick + } + if fs.Username != nil { + network.Username = *fs.Username + } + if fs.Pass != nil { + network.Pass = *fs.Pass + } + if fs.Realname != nil { + network.Realname = *fs.Realname + } + if fs.ConnectCommands != nil { + if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" { + network.ConnectCommands = nil + } else { + for _, command := range fs.ConnectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + network.ConnectCommands = fs.ConnectCommands + } + } + return nil +} + +func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { + fs := newNetworkFlagSet() if err := fs.Parse(params); err != nil { return err } - if *addr == "" { + if fs.Addr == nil { return fmt.Errorf("flag -addr is required") } - if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 { - scheme := addrParts[0] - switch scheme { - case "ircs", "irc+insecure": - default: - return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme) - } + record := &Network{ + Addr: *fs.Addr, + Nick: dc.nick, + } + if err := fs.update(record); err != nil { + return err } - for _, command := range connectCommands { - _, err := irc.ParseMessage(command) - if err != nil { - return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) - } - } - - if *nick == "" { - *nick = dc.nick - } - - var err error - network, err := dc.user.createNetwork(&Network{ - Addr: *addr, - Name: *name, - Username: *username, - Pass: *pass, - Realname: *realname, - Nick: *nick, - ConnectCommands: connectCommands, - }) + network, err := dc.user.createNetwork(record) if err != nil { return fmt.Errorf("could not create network: %v", err) } @@ -441,6 +496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error { return nil } +func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { + if len(params) < 1 { + return fmt.Errorf("expected exactly one argument") + } + + fs := newNetworkFlagSet() + if err := fs.Parse(params[1:]); err != nil { + return err + } + + net := dc.user.getNetwork(params[0]) + if net == nil { + return fmt.Errorf("unknown network %q", params[0]) + } + + record := net.Network // copy network record because we'll mutate it + if err := fs.update(&record); err != nil { + return err + } + + network, err := dc.user.updateNetwork(&record) + if err != nil { + return fmt.Errorf("could not update network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName())) + return nil +} + func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") diff --git a/user.go b/user.go index ca9d801..6d08caa 100644 --- a/user.go +++ b/user.go @@ -272,6 +272,15 @@ func (u *user) getNetwork(name string) *network { return nil } +func (u *user) getNetworkByID(id int64) *network { + for _, net := range u.networks { + if net.ID == id { + return net + } + } + return nil +} + func (u *user) run() { networks, err := u.srv.db.ListNetworks(u.Username) if err != nil { @@ -309,31 +318,18 @@ func (u *user) run() { }) uc.network.lastError = nil case eventUpstreamDisconnected: - uc := e.uc - - uc.network.conn = nil - - for _, ml := range uc.messageLoggers { - if err := ml.Close(); err != nil { - uc.logger.Printf("failed to close message logger: %v", err) - } - } - - uc.endPendingLISTs(true) - - uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateSupportedCaps() - }) - - if uc.network.lastError == nil { - uc.forEachDownstream(func(dc *downstreamConn) { - sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) - }) - } + u.handleUpstreamDisconnected(e.uc) case eventUpstreamConnectionError: net := e.net - if net.lastError == nil || net.lastError.Error() != e.err.Error() { + stopped := false + select { + case <-net.stopped: + stopped = true + default: + } + + if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) { net.forEachDownstream(func(dc *downstreamConn) { sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) }) @@ -425,45 +421,128 @@ func (u *user) run() { } } -func (u *user) createNetwork(net *Network) (*network, error) { - if net.ID != 0 { +func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { + uc.network.conn = nil + + for _, ml := range uc.messageLoggers { + if err := ml.Close(); err != nil { + uc.logger.Printf("failed to close message logger: %v", err) + } + } + + uc.endPendingLISTs(true) + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + + if uc.network.lastError == nil { + uc.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) + }) + } +} + +func (u *user) addNetwork(network *network) { + u.networks = append(u.networks, network) + go network.run() +} + +func (u *user) removeNetwork(network *network) { + network.stop() + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.Close() + } + }) + + for i, net := range u.networks { + if net == network { + u.networks = append(u.networks[:i], u.networks[i+1:]...) + return + } + } + + panic("tried to remove a non-existing network") +} + +func (u *user) createNetwork(record *Network) (*network, error) { + if record.ID != 0 { panic("tried creating an already-existing network") } - network := newNetwork(u, net, nil) + network := newNetwork(u, record, nil) err := u.srv.db.StoreNetwork(u.Username, &network.Network) if err != nil { return nil, err } - u.networks = append(u.networks, network) + u.addNetwork(network) - go network.run() return network, nil } -func (u *user) deleteNetwork(id int64) error { - for i, net := range u.networks { - if net.ID != id { - continue - } - - if err := u.srv.db.DeleteNetwork(net.ID); err != nil { - return err - } - - u.forEachDownstream(func(dc *downstreamConn) { - if dc.network != nil && dc.network == net { - dc.Close() - } - }) - - net.stop() - u.networks = append(u.networks[:i], u.networks[i+1:]...) - return nil +func (u *user) updateNetwork(record *Network) (*network, error) { + if record.ID == 0 { + panic("tried updating a new network") } - panic("tried deleting a non-existing network") + network := u.getNetworkByID(record.ID) + if network == nil { + panic("tried updating a non-existing network") + } + + if err := u.srv.db.StoreNetwork(u.Username, record); err != nil { + return nil, err + } + + // Most network changes require us to re-connect to the upstream server + + channels := make([]Channel, 0, len(network.channels)) + for _, ch := range network.channels { + channels = append(channels, *ch) + } + + updatedNetwork := newNetwork(u, record, channels) + + // If we're currently connected, disconnect and perform the necessary + // bookkeeping + if network.conn != nil { + network.stop() + // Note: this will set network.conn to nil + u.handleUpstreamDisconnected(network.conn) + } + + // Patch downstream connections to use our fresh updated network + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.network = updatedNetwork + } + }) + + // We need to remove the network after patching downstream connections, + // otherwise they'll get closed + u.removeNetwork(network) + + // This will re-connect to the upstream server + u.addNetwork(updatedNetwork) + + return updatedNetwork, nil +} + +func (u *user) deleteNetwork(id int64) error { + network := u.getNetworkByID(id) + if network == nil { + panic("tried deleting a non-existing network") + } + + if err := u.srv.db.DeleteNetwork(network.ID); err != nil { + return err + } + + u.removeNetwork(network) + return nil } func (u *user) updatePassword(hashed string) error {