Add network update command

The user.updateNetwork function is a bit involved because we need to
make sure that the upstream connection is closed before re-connecting
(would otherwise cause "Nick already used" errors) and that the
downstream connections' state is kept in sync.

References: https://todo.sr.ht/~emersion/soju/17
This commit is contained in:
Simon Ser 2020-06-02 11:39:53 +02:00
parent bee2001e29
commit c709ebfc91
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 253 additions and 90 deletions

View File

@ -118,7 +118,7 @@ func init() {
"network": {
children: serviceCommandSet{
"create": {
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
desc: "add a new network",
handle: handleServiceCreateNetwork,
},
@ -126,6 +126,11 @@ func init() {
desc: "show a list of saved networks and their current status",
handle: handleServiceNetworkStatus,
},
"update": {
usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
desc: "update a network",
handle: handleServiceNetworkUpdate,
},
"delete": {
usage: "<name>",
desc: "delete a network",
@ -338,65 +343,115 @@ func newFlagSet() *flag.FlagSet {
return fs
}
type stringSliceVar []string
type stringSliceFlag []string
func (v *stringSliceVar) String() string {
func (v *stringSliceFlag) String() string {
return fmt.Sprint([]string(*v))
}
func (v *stringSliceVar) Set(s string) error {
func (v *stringSliceFlag) Set(s string) error {
*v = append(*v, s)
return nil
}
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
fs := newFlagSet()
addr := fs.String("addr", "", "")
name := fs.String("name", "", "")
username := fs.String("username", "", "")
pass := fs.String("pass", "", "")
realname := fs.String("realname", "", "")
nick := fs.String("nick", "", "")
var connectCommands stringSliceVar
fs.Var(&connectCommands, "connect-command", "")
// stringPtrFlag is a flag value populating a string pointer. This allows to
// disambiguate between a flag that hasn't been set and a flag that has been
// set to an empty string.
type stringPtrFlag struct {
ptr **string
}
func (f stringPtrFlag) String() string {
if *f.ptr == nil {
return ""
}
return **f.ptr
}
func (f stringPtrFlag) Set(s string) error {
*f.ptr = &s
return nil
}
type networkFlagSet struct {
*flag.FlagSet
Addr, Name, Nick, Username, Pass, Realname *string
ConnectCommands []string
}
func newNetworkFlagSet() *networkFlagSet {
fs := &networkFlagSet{FlagSet: newFlagSet()}
fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
fs.Var(stringPtrFlag{&fs.Name}, "name", "")
fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
fs.Var(stringPtrFlag{&fs.Username}, "username", "")
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
return fs
}
func (fs *networkFlagSet) update(network *Network) error {
if fs.Addr != nil {
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
scheme := addrParts[0]
switch scheme {
case "ircs", "irc+insecure":
default:
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
}
}
network.Addr = *fs.Addr
}
if fs.Name != nil {
network.Name = *fs.Name
}
if fs.Nick != nil {
network.Nick = *fs.Nick
}
if fs.Username != nil {
network.Username = *fs.Username
}
if fs.Pass != nil {
network.Pass = *fs.Pass
}
if fs.Realname != nil {
network.Realname = *fs.Realname
}
if fs.ConnectCommands != nil {
if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
network.ConnectCommands = nil
} else {
for _, command := range fs.ConnectCommands {
_, err := irc.ParseMessage(command)
if err != nil {
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
}
}
network.ConnectCommands = fs.ConnectCommands
}
}
return nil
}
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
fs := newNetworkFlagSet()
if err := fs.Parse(params); err != nil {
return err
}
if *addr == "" {
if fs.Addr == nil {
return fmt.Errorf("flag -addr is required")
}
if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 {
scheme := addrParts[0]
switch scheme {
case "ircs", "irc+insecure":
default:
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
}
record := &Network{
Addr: *fs.Addr,
Nick: dc.nick,
}
if err := fs.update(record); err != nil {
return err
}
for _, command := range connectCommands {
_, err := irc.ParseMessage(command)
if err != nil {
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
}
}
if *nick == "" {
*nick = dc.nick
}
var err error
network, err := dc.user.createNetwork(&Network{
Addr: *addr,
Name: *name,
Username: *username,
Pass: *pass,
Realname: *realname,
Nick: *nick,
ConnectCommands: connectCommands,
})
network, err := dc.user.createNetwork(record)
if err != nil {
return fmt.Errorf("could not create network: %v", err)
}
@ -441,6 +496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
return nil
}
func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
if len(params) < 1 {
return fmt.Errorf("expected exactly one argument")
}
fs := newNetworkFlagSet()
if err := fs.Parse(params[1:]); err != nil {
return err
}
net := dc.user.getNetwork(params[0])
if net == nil {
return fmt.Errorf("unknown network %q", params[0])
}
record := net.Network // copy network record because we'll mutate it
if err := fs.update(&record); err != nil {
return err
}
network, err := dc.user.updateNetwork(&record)
if err != nil {
return fmt.Errorf("could not update network: %v", err)
}
sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
return nil
}
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
if len(params) != 1 {
return fmt.Errorf("expected exactly one argument")

173
user.go
View File

@ -272,6 +272,15 @@ func (u *user) getNetwork(name string) *network {
return nil
}
func (u *user) getNetworkByID(id int64) *network {
for _, net := range u.networks {
if net.ID == id {
return net
}
}
return nil
}
func (u *user) run() {
networks, err := u.srv.db.ListNetworks(u.Username)
if err != nil {
@ -309,31 +318,18 @@ func (u *user) run() {
})
uc.network.lastError = nil
case eventUpstreamDisconnected:
uc := e.uc
uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
})
}
u.handleUpstreamDisconnected(e.uc)
case eventUpstreamConnectionError:
net := e.net
if net.lastError == nil || net.lastError.Error() != e.err.Error() {
stopped := false
select {
case <-net.stopped:
stopped = true
default:
}
if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
net.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
})
@ -425,45 +421,128 @@ func (u *user) run() {
}
}
func (u *user) createNetwork(net *Network) (*network, error) {
if net.ID != 0 {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
})
}
}
func (u *user) addNetwork(network *network) {
u.networks = append(u.networks, network)
go network.run()
}
func (u *user) removeNetwork(network *network) {
network.stop()
u.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network == network {
dc.Close()
}
})
for i, net := range u.networks {
if net == network {
u.networks = append(u.networks[:i], u.networks[i+1:]...)
return
}
}
panic("tried to remove a non-existing network")
}
func (u *user) createNetwork(record *Network) (*network, error) {
if record.ID != 0 {
panic("tried creating an already-existing network")
}
network := newNetwork(u, net, nil)
network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
if err != nil {
return nil, err
}
u.networks = append(u.networks, network)
u.addNetwork(network)
go network.run()
return network, nil
}
func (u *user) deleteNetwork(id int64) error {
for i, net := range u.networks {
if net.ID != id {
continue
}
if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
return err
}
u.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network == net {
dc.Close()
}
})
net.stop()
u.networks = append(u.networks[:i], u.networks[i+1:]...)
return nil
func (u *user) updateNetwork(record *Network) (*network, error) {
if record.ID == 0 {
panic("tried updating a new network")
}
panic("tried deleting a non-existing network")
network := u.getNetworkByID(record.ID)
if network == nil {
panic("tried updating a non-existing network")
}
if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
return nil, err
}
// Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, len(network.channels))
for _, ch := range network.channels {
channels = append(channels, *ch)
}
updatedNetwork := newNetwork(u, record, channels)
// If we're currently connected, disconnect and perform the necessary
// bookkeeping
if network.conn != nil {
network.stop()
// Note: this will set network.conn to nil
u.handleUpstreamDisconnected(network.conn)
}
// Patch downstream connections to use our fresh updated network
u.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network == network {
dc.network = updatedNetwork
}
})
// We need to remove the network after patching downstream connections,
// otherwise they'll get closed
u.removeNetwork(network)
// This will re-connect to the upstream server
u.addNetwork(updatedNetwork)
return updatedNetwork, nil
}
func (u *user) deleteNetwork(id int64) error {
network := u.getNetworkByID(id)
if network == nil {
panic("tried deleting a non-existing network")
}
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
return err
}
u.removeNetwork(network)
return nil
}
func (u *user) updatePassword(hashed string) error {