From 7f74055380d17440319e2e24a88c55f4adec62bd Mon Sep 17 00:00:00 2001 From: delthas Date: Thu, 16 Apr 2020 01:40:50 +0200 Subject: [PATCH] Add support for custom network on-connect commands Some servers use custom IRC bots with custom commands for registering to specific services after connection. This adds support for setting custom raw IRC messages, that will be sent after registering to a network. It also adds support for a custom flag.Value type for string slice flags (flags taking several string values). --- db.go | 40 ++++++++++++++++++++++++---------------- service.go | 35 ++++++++++++++++++++++++++++------- upstream.go | 9 +++++++++ 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/db.go b/db.go index 9bb08ab..c093064 100644 --- a/db.go +++ b/db.go @@ -3,6 +3,7 @@ package soju import ( "database/sql" "fmt" + "strings" "sync" _ "github.com/mattn/go-sqlite3" @@ -23,14 +24,15 @@ type SASL struct { } type Network struct { - ID int64 - Name string - Addr string - Nick string - Username string - Realname string - Pass string - SASL SASL + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL } func (net *Network) GetName() string { @@ -63,6 +65,7 @@ CREATE TABLE Network ( username VARCHAR(255), realname VARCHAR(255), pass VARCHAR(255), + connect_commands VARCHAR(1023), sasl_mechanism VARCHAR(255), sasl_plain_username VARCHAR(255), sasl_plain_password VARCHAR(255), @@ -82,6 +85,7 @@ CREATE TABLE Channel ( var migrations = []string{ "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", } type DB struct { @@ -233,7 +237,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { defer db.lock.RUnlock() rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass, - sasl_mechanism, sasl_plain_username, sasl_plain_password + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password FROM Network WHERE user = ?`, username) @@ -245,10 +249,10 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { var networks []Network for rows.Next() { var net Network - var name, username, realname, pass *string + var name, username, realname, pass, connectCommands *string var saslMechanism, saslPlainUsername, saslPlainPassword *string err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname, - &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword) + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword) if err != nil { return nil, err } @@ -256,6 +260,9 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { net.Username = fromStringPtr(username) net.Realname = fromStringPtr(realname) net.Pass = fromStringPtr(pass) + if connectCommands != nil { + net.ConnectCommands = strings.Split(*connectCommands, "\r\n") + } net.SASL.Mechanism = fromStringPtr(saslMechanism) net.SASL.Plain.Username = fromStringPtr(saslPlainUsername) net.SASL.Plain.Password = fromStringPtr(saslPlainPassword) @@ -276,6 +283,7 @@ func (db *DB) StoreNetwork(username string, network *Network) error { netUsername := toStringPtr(network.Username) realname := toStringPtr(network.Realname) pass := toStringPtr(network.Pass) + connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n")) var saslMechanism, saslPlainUsername, saslPlainPassword *string if network.SASL.Mechanism != "" { @@ -292,18 +300,18 @@ func (db *DB) StoreNetwork(username string, network *Network) error { var err error if network.ID != 0 { _, err = db.db.Exec(`UPDATE Network - SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, + SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?, sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ? WHERE id = ?`, - netName, network.Addr, network.Nick, netUsername, realname, pass, + netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword, network.ID) } else { var res sql.Result res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username, - realname, pass, sasl_mechanism, sasl_plain_username, + realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - username, netName, network.Addr, network.Nick, netUsername, realname, pass, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword) if err != nil { return err diff --git a/service.go b/service.go index 4e73873..ecbd6aa 100644 --- a/service.go +++ b/service.go @@ -104,7 +104,7 @@ func init() { "network": { children: serviceCommandSet{ "create": { - usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick]", + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]", desc: "add a new network", handle: handleServiceCreateNetwork, }, @@ -174,6 +174,17 @@ func newFlagSet() *flag.FlagSet { return fs } +type stringSliceVar []string + +func (v *stringSliceVar) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceVar) Set(s string) error { + *v = append(*v, s) + return nil +} + func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { fs := newFlagSet() addr := fs.String("addr", "", "") @@ -182,6 +193,8 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { pass := fs.String("pass", "", "") realname := fs.String("realname", "", "") nick := fs.String("nick", "", "") + var connectCommands stringSliceVar + fs.Var(&connectCommands, "connect-command", "") if err := fs.Parse(params); err != nil { return err @@ -190,18 +203,26 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { return fmt.Errorf("flag -addr is required") } + for _, command := range connectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + if *nick == "" { *nick = dc.nick } var err error network, err := dc.user.createNetwork(&Network{ - Addr: *addr, - Name: *name, - Username: *username, - Pass: *pass, - Realname: *realname, - Nick: *nick, + Addr: *addr, + Name: *name, + Username: *username, + Pass: *pass, + Realname: *realname, + Nick: *nick, + ConnectCommands: connectCommands, }) if err != nil { return fmt.Errorf("could not create network: %v", err) diff --git a/upstream.go b/upstream.go index e563955..599a9d7 100644 --- a/upstream.go +++ b/upstream.go @@ -1189,6 +1189,15 @@ func (uc *upstreamConn) runUntilRegistered() error { } } + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(m) + } + } + return nil }