Add network update command
The user.updateNetwork function is a bit involved because we need to make sure that the upstream connection is closed before re-connecting (would otherwise cause "Nick already used" errors) and that the downstream connections' state is kept in sync. References: https://todo.sr.ht/~emersion/soju/17
This commit is contained in:
parent
bee2001e29
commit
c709ebfc91
170
service.go
170
service.go
@ -118,7 +118,7 @@ func init() {
|
|||||||
"network": {
|
"network": {
|
||||||
children: serviceCommandSet{
|
children: serviceCommandSet{
|
||||||
"create": {
|
"create": {
|
||||||
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
|
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
|
||||||
desc: "add a new network",
|
desc: "add a new network",
|
||||||
handle: handleServiceCreateNetwork,
|
handle: handleServiceCreateNetwork,
|
||||||
},
|
},
|
||||||
@ -126,6 +126,11 @@ func init() {
|
|||||||
desc: "show a list of saved networks and their current status",
|
desc: "show a list of saved networks and their current status",
|
||||||
handle: handleServiceNetworkStatus,
|
handle: handleServiceNetworkStatus,
|
||||||
},
|
},
|
||||||
|
"update": {
|
||||||
|
usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
|
||||||
|
desc: "update a network",
|
||||||
|
handle: handleServiceNetworkUpdate,
|
||||||
|
},
|
||||||
"delete": {
|
"delete": {
|
||||||
usage: "<name>",
|
usage: "<name>",
|
||||||
desc: "delete a network",
|
desc: "delete a network",
|
||||||
@ -338,65 +343,115 @@ func newFlagSet() *flag.FlagSet {
|
|||||||
return fs
|
return fs
|
||||||
}
|
}
|
||||||
|
|
||||||
type stringSliceVar []string
|
type stringSliceFlag []string
|
||||||
|
|
||||||
func (v *stringSliceVar) String() string {
|
func (v *stringSliceFlag) String() string {
|
||||||
return fmt.Sprint([]string(*v))
|
return fmt.Sprint([]string(*v))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *stringSliceVar) Set(s string) error {
|
func (v *stringSliceFlag) Set(s string) error {
|
||||||
*v = append(*v, s)
|
*v = append(*v, s)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
|
// stringPtrFlag is a flag value populating a string pointer. This allows to
|
||||||
fs := newFlagSet()
|
// disambiguate between a flag that hasn't been set and a flag that has been
|
||||||
addr := fs.String("addr", "", "")
|
// set to an empty string.
|
||||||
name := fs.String("name", "", "")
|
type stringPtrFlag struct {
|
||||||
username := fs.String("username", "", "")
|
ptr **string
|
||||||
pass := fs.String("pass", "", "")
|
}
|
||||||
realname := fs.String("realname", "", "")
|
|
||||||
nick := fs.String("nick", "", "")
|
|
||||||
var connectCommands stringSliceVar
|
|
||||||
fs.Var(&connectCommands, "connect-command", "")
|
|
||||||
|
|
||||||
|
func (f stringPtrFlag) String() string {
|
||||||
|
if *f.ptr == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return **f.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f stringPtrFlag) Set(s string) error {
|
||||||
|
*f.ptr = &s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type networkFlagSet struct {
|
||||||
|
*flag.FlagSet
|
||||||
|
Addr, Name, Nick, Username, Pass, Realname *string
|
||||||
|
ConnectCommands []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNetworkFlagSet() *networkFlagSet {
|
||||||
|
fs := &networkFlagSet{FlagSet: newFlagSet()}
|
||||||
|
fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
|
||||||
|
fs.Var(stringPtrFlag{&fs.Name}, "name", "")
|
||||||
|
fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
|
||||||
|
fs.Var(stringPtrFlag{&fs.Username}, "username", "")
|
||||||
|
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
|
||||||
|
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
|
||||||
|
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *networkFlagSet) update(network *Network) error {
|
||||||
|
if fs.Addr != nil {
|
||||||
|
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
|
||||||
|
scheme := addrParts[0]
|
||||||
|
switch scheme {
|
||||||
|
case "ircs", "irc+insecure":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
network.Addr = *fs.Addr
|
||||||
|
}
|
||||||
|
if fs.Name != nil {
|
||||||
|
network.Name = *fs.Name
|
||||||
|
}
|
||||||
|
if fs.Nick != nil {
|
||||||
|
network.Nick = *fs.Nick
|
||||||
|
}
|
||||||
|
if fs.Username != nil {
|
||||||
|
network.Username = *fs.Username
|
||||||
|
}
|
||||||
|
if fs.Pass != nil {
|
||||||
|
network.Pass = *fs.Pass
|
||||||
|
}
|
||||||
|
if fs.Realname != nil {
|
||||||
|
network.Realname = *fs.Realname
|
||||||
|
}
|
||||||
|
if fs.ConnectCommands != nil {
|
||||||
|
if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
|
||||||
|
network.ConnectCommands = nil
|
||||||
|
} else {
|
||||||
|
for _, command := range fs.ConnectCommands {
|
||||||
|
_, err := irc.ParseMessage(command)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
network.ConnectCommands = fs.ConnectCommands
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
|
||||||
|
fs := newNetworkFlagSet()
|
||||||
if err := fs.Parse(params); err != nil {
|
if err := fs.Parse(params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if *addr == "" {
|
if fs.Addr == nil {
|
||||||
return fmt.Errorf("flag -addr is required")
|
return fmt.Errorf("flag -addr is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 {
|
record := &Network{
|
||||||
scheme := addrParts[0]
|
Addr: *fs.Addr,
|
||||||
switch scheme {
|
Nick: dc.nick,
|
||||||
case "ircs", "irc+insecure":
|
}
|
||||||
default:
|
if err := fs.update(record); err != nil {
|
||||||
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, command := range connectCommands {
|
network, err := dc.user.createNetwork(record)
|
||||||
_, err := irc.ParseMessage(command)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if *nick == "" {
|
|
||||||
*nick = dc.nick
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
network, err := dc.user.createNetwork(&Network{
|
|
||||||
Addr: *addr,
|
|
||||||
Name: *name,
|
|
||||||
Username: *username,
|
|
||||||
Pass: *pass,
|
|
||||||
Realname: *realname,
|
|
||||||
Nick: *nick,
|
|
||||||
ConnectCommands: connectCommands,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create network: %v", err)
|
return fmt.Errorf("could not create network: %v", err)
|
||||||
}
|
}
|
||||||
@ -441,6 +496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
|
||||||
|
if len(params) < 1 {
|
||||||
|
return fmt.Errorf("expected exactly one argument")
|
||||||
|
}
|
||||||
|
|
||||||
|
fs := newNetworkFlagSet()
|
||||||
|
if err := fs.Parse(params[1:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
net := dc.user.getNetwork(params[0])
|
||||||
|
if net == nil {
|
||||||
|
return fmt.Errorf("unknown network %q", params[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
record := net.Network // copy network record because we'll mutate it
|
||||||
|
if err := fs.update(&record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
network, err := dc.user.updateNetwork(&record)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not update network: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
|
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
|
||||||
if len(params) != 1 {
|
if len(params) != 1 {
|
||||||
return fmt.Errorf("expected exactly one argument")
|
return fmt.Errorf("expected exactly one argument")
|
||||||
|
173
user.go
173
user.go
@ -272,6 +272,15 @@ func (u *user) getNetwork(name string) *network {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *user) getNetworkByID(id int64) *network {
|
||||||
|
for _, net := range u.networks {
|
||||||
|
if net.ID == id {
|
||||||
|
return net
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *user) run() {
|
func (u *user) run() {
|
||||||
networks, err := u.srv.db.ListNetworks(u.Username)
|
networks, err := u.srv.db.ListNetworks(u.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -309,31 +318,18 @@ func (u *user) run() {
|
|||||||
})
|
})
|
||||||
uc.network.lastError = nil
|
uc.network.lastError = nil
|
||||||
case eventUpstreamDisconnected:
|
case eventUpstreamDisconnected:
|
||||||
uc := e.uc
|
u.handleUpstreamDisconnected(e.uc)
|
||||||
|
|
||||||
uc.network.conn = nil
|
|
||||||
|
|
||||||
for _, ml := range uc.messageLoggers {
|
|
||||||
if err := ml.Close(); err != nil {
|
|
||||||
uc.logger.Printf("failed to close message logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uc.endPendingLISTs(true)
|
|
||||||
|
|
||||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
|
||||||
dc.updateSupportedCaps()
|
|
||||||
})
|
|
||||||
|
|
||||||
if uc.network.lastError == nil {
|
|
||||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
|
||||||
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case eventUpstreamConnectionError:
|
case eventUpstreamConnectionError:
|
||||||
net := e.net
|
net := e.net
|
||||||
|
|
||||||
if net.lastError == nil || net.lastError.Error() != e.err.Error() {
|
stopped := false
|
||||||
|
select {
|
||||||
|
case <-net.stopped:
|
||||||
|
stopped = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
|
||||||
net.forEachDownstream(func(dc *downstreamConn) {
|
net.forEachDownstream(func(dc *downstreamConn) {
|
||||||
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
|
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
|
||||||
})
|
})
|
||||||
@ -425,45 +421,128 @@ func (u *user) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) createNetwork(net *Network) (*network, error) {
|
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
|
||||||
if net.ID != 0 {
|
uc.network.conn = nil
|
||||||
|
|
||||||
|
for _, ml := range uc.messageLoggers {
|
||||||
|
if err := ml.Close(); err != nil {
|
||||||
|
uc.logger.Printf("failed to close message logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uc.endPendingLISTs(true)
|
||||||
|
|
||||||
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
|
dc.updateSupportedCaps()
|
||||||
|
})
|
||||||
|
|
||||||
|
if uc.network.lastError == nil {
|
||||||
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
|
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *user) addNetwork(network *network) {
|
||||||
|
u.networks = append(u.networks, network)
|
||||||
|
go network.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *user) removeNetwork(network *network) {
|
||||||
|
network.stop()
|
||||||
|
|
||||||
|
u.forEachDownstream(func(dc *downstreamConn) {
|
||||||
|
if dc.network != nil && dc.network == network {
|
||||||
|
dc.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, net := range u.networks {
|
||||||
|
if net == network {
|
||||||
|
u.networks = append(u.networks[:i], u.networks[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("tried to remove a non-existing network")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *user) createNetwork(record *Network) (*network, error) {
|
||||||
|
if record.ID != 0 {
|
||||||
panic("tried creating an already-existing network")
|
panic("tried creating an already-existing network")
|
||||||
}
|
}
|
||||||
|
|
||||||
network := newNetwork(u, net, nil)
|
network := newNetwork(u, record, nil)
|
||||||
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
|
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.networks = append(u.networks, network)
|
u.addNetwork(network)
|
||||||
|
|
||||||
go network.run()
|
|
||||||
return network, nil
|
return network, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) deleteNetwork(id int64) error {
|
func (u *user) updateNetwork(record *Network) (*network, error) {
|
||||||
for i, net := range u.networks {
|
if record.ID == 0 {
|
||||||
if net.ID != id {
|
panic("tried updating a new network")
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
u.forEachDownstream(func(dc *downstreamConn) {
|
|
||||||
if dc.network != nil && dc.network == net {
|
|
||||||
dc.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
net.stop()
|
|
||||||
u.networks = append(u.networks[:i], u.networks[i+1:]...)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
panic("tried deleting a non-existing network")
|
network := u.getNetworkByID(record.ID)
|
||||||
|
if network == nil {
|
||||||
|
panic("tried updating a non-existing network")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most network changes require us to re-connect to the upstream server
|
||||||
|
|
||||||
|
channels := make([]Channel, 0, len(network.channels))
|
||||||
|
for _, ch := range network.channels {
|
||||||
|
channels = append(channels, *ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedNetwork := newNetwork(u, record, channels)
|
||||||
|
|
||||||
|
// If we're currently connected, disconnect and perform the necessary
|
||||||
|
// bookkeeping
|
||||||
|
if network.conn != nil {
|
||||||
|
network.stop()
|
||||||
|
// Note: this will set network.conn to nil
|
||||||
|
u.handleUpstreamDisconnected(network.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch downstream connections to use our fresh updated network
|
||||||
|
u.forEachDownstream(func(dc *downstreamConn) {
|
||||||
|
if dc.network != nil && dc.network == network {
|
||||||
|
dc.network = updatedNetwork
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// We need to remove the network after patching downstream connections,
|
||||||
|
// otherwise they'll get closed
|
||||||
|
u.removeNetwork(network)
|
||||||
|
|
||||||
|
// This will re-connect to the upstream server
|
||||||
|
u.addNetwork(updatedNetwork)
|
||||||
|
|
||||||
|
return updatedNetwork, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *user) deleteNetwork(id int64) error {
|
||||||
|
network := u.getNetworkByID(id)
|
||||||
|
if network == nil {
|
||||||
|
panic("tried deleting a non-existing network")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.removeNetwork(network)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) updatePassword(hashed string) error {
|
func (u *user) updatePassword(hashed string) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user