Add network update command
The user.updateNetwork function is a bit involved because we need to make sure that the upstream connection is closed before re-connecting (would otherwise cause "Nick already used" errors) and that the downstream connections' state is kept in sync. References: https://todo.sr.ht/~emersion/soju/17
This commit is contained in:
parent
bee2001e29
commit
c709ebfc91
152
service.go
152
service.go
@ -118,7 +118,7 @@ func init() {
|
||||
"network": {
|
||||
children: serviceCommandSet{
|
||||
"create": {
|
||||
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
|
||||
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
|
||||
desc: "add a new network",
|
||||
handle: handleServiceCreateNetwork,
|
||||
},
|
||||
@ -126,6 +126,11 @@ func init() {
|
||||
desc: "show a list of saved networks and their current status",
|
||||
handle: handleServiceNetworkStatus,
|
||||
},
|
||||
"update": {
|
||||
usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
|
||||
desc: "update a network",
|
||||
handle: handleServiceNetworkUpdate,
|
||||
},
|
||||
"delete": {
|
||||
usage: "<name>",
|
||||
desc: "delete a network",
|
||||
@ -338,36 +343,57 @@ func newFlagSet() *flag.FlagSet {
|
||||
return fs
|
||||
}
|
||||
|
||||
type stringSliceVar []string
|
||||
type stringSliceFlag []string
|
||||
|
||||
func (v *stringSliceVar) String() string {
|
||||
func (v *stringSliceFlag) String() string {
|
||||
return fmt.Sprint([]string(*v))
|
||||
}
|
||||
|
||||
func (v *stringSliceVar) Set(s string) error {
|
||||
func (v *stringSliceFlag) Set(s string) error {
|
||||
*v = append(*v, s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
|
||||
fs := newFlagSet()
|
||||
addr := fs.String("addr", "", "")
|
||||
name := fs.String("name", "", "")
|
||||
username := fs.String("username", "", "")
|
||||
pass := fs.String("pass", "", "")
|
||||
realname := fs.String("realname", "", "")
|
||||
nick := fs.String("nick", "", "")
|
||||
var connectCommands stringSliceVar
|
||||
fs.Var(&connectCommands, "connect-command", "")
|
||||
// stringPtrFlag is a flag value populating a string pointer. This allows to
|
||||
// disambiguate between a flag that hasn't been set and a flag that has been
|
||||
// set to an empty string.
|
||||
type stringPtrFlag struct {
|
||||
ptr **string
|
||||
}
|
||||
|
||||
if err := fs.Parse(params); err != nil {
|
||||
return err
|
||||
}
|
||||
if *addr == "" {
|
||||
return fmt.Errorf("flag -addr is required")
|
||||
func (f stringPtrFlag) String() string {
|
||||
if *f.ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return **f.ptr
|
||||
}
|
||||
|
||||
if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 {
|
||||
func (f stringPtrFlag) Set(s string) error {
|
||||
*f.ptr = &s
|
||||
return nil
|
||||
}
|
||||
|
||||
type networkFlagSet struct {
|
||||
*flag.FlagSet
|
||||
Addr, Name, Nick, Username, Pass, Realname *string
|
||||
ConnectCommands []string
|
||||
}
|
||||
|
||||
func newNetworkFlagSet() *networkFlagSet {
|
||||
fs := &networkFlagSet{FlagSet: newFlagSet()}
|
||||
fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
|
||||
fs.Var(stringPtrFlag{&fs.Name}, "name", "")
|
||||
fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
|
||||
fs.Var(stringPtrFlag{&fs.Username}, "username", "")
|
||||
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
|
||||
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
|
||||
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
|
||||
return fs
|
||||
}
|
||||
|
||||
func (fs *networkFlagSet) update(network *Network) error {
|
||||
if fs.Addr != nil {
|
||||
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
|
||||
scheme := addrParts[0]
|
||||
switch scheme {
|
||||
case "ircs", "irc+insecure":
|
||||
@ -375,28 +401,57 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
|
||||
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
|
||||
}
|
||||
}
|
||||
|
||||
for _, command := range connectCommands {
|
||||
network.Addr = *fs.Addr
|
||||
}
|
||||
if fs.Name != nil {
|
||||
network.Name = *fs.Name
|
||||
}
|
||||
if fs.Nick != nil {
|
||||
network.Nick = *fs.Nick
|
||||
}
|
||||
if fs.Username != nil {
|
||||
network.Username = *fs.Username
|
||||
}
|
||||
if fs.Pass != nil {
|
||||
network.Pass = *fs.Pass
|
||||
}
|
||||
if fs.Realname != nil {
|
||||
network.Realname = *fs.Realname
|
||||
}
|
||||
if fs.ConnectCommands != nil {
|
||||
if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
|
||||
network.ConnectCommands = nil
|
||||
} else {
|
||||
for _, command := range fs.ConnectCommands {
|
||||
_, err := irc.ParseMessage(command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
|
||||
}
|
||||
}
|
||||
network.ConnectCommands = fs.ConnectCommands
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if *nick == "" {
|
||||
*nick = dc.nick
|
||||
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
|
||||
fs := newNetworkFlagSet()
|
||||
if err := fs.Parse(params); err != nil {
|
||||
return err
|
||||
}
|
||||
if fs.Addr == nil {
|
||||
return fmt.Errorf("flag -addr is required")
|
||||
}
|
||||
|
||||
var err error
|
||||
network, err := dc.user.createNetwork(&Network{
|
||||
Addr: *addr,
|
||||
Name: *name,
|
||||
Username: *username,
|
||||
Pass: *pass,
|
||||
Realname: *realname,
|
||||
Nick: *nick,
|
||||
ConnectCommands: connectCommands,
|
||||
})
|
||||
record := &Network{
|
||||
Addr: *fs.Addr,
|
||||
Nick: dc.nick,
|
||||
}
|
||||
if err := fs.update(record); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
network, err := dc.user.createNetwork(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create network: %v", err)
|
||||
}
|
||||
@ -441,6 +496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
|
||||
if len(params) < 1 {
|
||||
return fmt.Errorf("expected exactly one argument")
|
||||
}
|
||||
|
||||
fs := newNetworkFlagSet()
|
||||
if err := fs.Parse(params[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
net := dc.user.getNetwork(params[0])
|
||||
if net == nil {
|
||||
return fmt.Errorf("unknown network %q", params[0])
|
||||
}
|
||||
|
||||
record := net.Network // copy network record because we'll mutate it
|
||||
if err := fs.update(&record); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
network, err := dc.user.updateNetwork(&record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not update network: %v", err)
|
||||
}
|
||||
|
||||
sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
|
||||
if len(params) != 1 {
|
||||
return fmt.Errorf("expected exactly one argument")
|
||||
|
157
user.go
157
user.go
@ -272,6 +272,15 @@ func (u *user) getNetwork(name string) *network {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) getNetworkByID(id int64) *network {
|
||||
for _, net := range u.networks {
|
||||
if net.ID == id {
|
||||
return net
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) run() {
|
||||
networks, err := u.srv.db.ListNetworks(u.Username)
|
||||
if err != nil {
|
||||
@ -309,31 +318,18 @@ func (u *user) run() {
|
||||
})
|
||||
uc.network.lastError = nil
|
||||
case eventUpstreamDisconnected:
|
||||
uc := e.uc
|
||||
|
||||
uc.network.conn = nil
|
||||
|
||||
for _, ml := range uc.messageLoggers {
|
||||
if err := ml.Close(); err != nil {
|
||||
uc.logger.Printf("failed to close message logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
uc.endPendingLISTs(true)
|
||||
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.updateSupportedCaps()
|
||||
})
|
||||
|
||||
if uc.network.lastError == nil {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
|
||||
})
|
||||
}
|
||||
u.handleUpstreamDisconnected(e.uc)
|
||||
case eventUpstreamConnectionError:
|
||||
net := e.net
|
||||
|
||||
if net.lastError == nil || net.lastError.Error() != e.err.Error() {
|
||||
stopped := false
|
||||
select {
|
||||
case <-net.stopped:
|
||||
stopped = true
|
||||
default:
|
||||
}
|
||||
|
||||
if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
|
||||
net.forEachDownstream(func(dc *downstreamConn) {
|
||||
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
|
||||
})
|
||||
@ -425,45 +421,128 @@ func (u *user) run() {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *user) createNetwork(net *Network) (*network, error) {
|
||||
if net.ID != 0 {
|
||||
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
|
||||
uc.network.conn = nil
|
||||
|
||||
for _, ml := range uc.messageLoggers {
|
||||
if err := ml.Close(); err != nil {
|
||||
uc.logger.Printf("failed to close message logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
uc.endPendingLISTs(true)
|
||||
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.updateSupportedCaps()
|
||||
})
|
||||
|
||||
if uc.network.lastError == nil {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (u *user) addNetwork(network *network) {
|
||||
u.networks = append(u.networks, network)
|
||||
go network.run()
|
||||
}
|
||||
|
||||
func (u *user) removeNetwork(network *network) {
|
||||
network.stop()
|
||||
|
||||
u.forEachDownstream(func(dc *downstreamConn) {
|
||||
if dc.network != nil && dc.network == network {
|
||||
dc.Close()
|
||||
}
|
||||
})
|
||||
|
||||
for i, net := range u.networks {
|
||||
if net == network {
|
||||
u.networks = append(u.networks[:i], u.networks[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("tried to remove a non-existing network")
|
||||
}
|
||||
|
||||
func (u *user) createNetwork(record *Network) (*network, error) {
|
||||
if record.ID != 0 {
|
||||
panic("tried creating an already-existing network")
|
||||
}
|
||||
|
||||
network := newNetwork(u, net, nil)
|
||||
network := newNetwork(u, record, nil)
|
||||
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.networks = append(u.networks, network)
|
||||
u.addNetwork(network)
|
||||
|
||||
go network.run()
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (u *user) deleteNetwork(id int64) error {
|
||||
for i, net := range u.networks {
|
||||
if net.ID != id {
|
||||
continue
|
||||
func (u *user) updateNetwork(record *Network) (*network, error) {
|
||||
if record.ID == 0 {
|
||||
panic("tried updating a new network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
|
||||
return err
|
||||
network := u.getNetworkByID(record.ID)
|
||||
if network == nil {
|
||||
panic("tried updating a non-existing network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Most network changes require us to re-connect to the upstream server
|
||||
|
||||
channels := make([]Channel, 0, len(network.channels))
|
||||
for _, ch := range network.channels {
|
||||
channels = append(channels, *ch)
|
||||
}
|
||||
|
||||
updatedNetwork := newNetwork(u, record, channels)
|
||||
|
||||
// If we're currently connected, disconnect and perform the necessary
|
||||
// bookkeeping
|
||||
if network.conn != nil {
|
||||
network.stop()
|
||||
// Note: this will set network.conn to nil
|
||||
u.handleUpstreamDisconnected(network.conn)
|
||||
}
|
||||
|
||||
// Patch downstream connections to use our fresh updated network
|
||||
u.forEachDownstream(func(dc *downstreamConn) {
|
||||
if dc.network != nil && dc.network == net {
|
||||
dc.Close()
|
||||
if dc.network != nil && dc.network == network {
|
||||
dc.network = updatedNetwork
|
||||
}
|
||||
})
|
||||
|
||||
net.stop()
|
||||
u.networks = append(u.networks[:i], u.networks[i+1:]...)
|
||||
return nil
|
||||
// We need to remove the network after patching downstream connections,
|
||||
// otherwise they'll get closed
|
||||
u.removeNetwork(network)
|
||||
|
||||
// This will re-connect to the upstream server
|
||||
u.addNetwork(updatedNetwork)
|
||||
|
||||
return updatedNetwork, nil
|
||||
}
|
||||
|
||||
func (u *user) deleteNetwork(id int64) error {
|
||||
network := u.getNetworkByID(id)
|
||||
if network == nil {
|
||||
panic("tried deleting a non-existing network")
|
||||
}
|
||||
|
||||
panic("tried deleting a non-existing network")
|
||||
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.removeNetwork(network)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) updatePassword(hashed string) error {
|
||||
|
Loading…
Reference in New Issue
Block a user