Add support for custom network on-connect commands

Some servers use custom IRC bots with custom commands for registering to
specific services after connection.

This adds support for setting custom raw IRC messages, that will be
sent after registering to a network.

It also adds support for a custom flag.Value type for string
slice flags (flags taking several string values).
This commit is contained in:
delthas 2020-04-16 01:40:50 +02:00 committed by Simon Ser
parent 9c463b61ec
commit 7f74055380
3 changed files with 61 additions and 23 deletions

40
db.go
View File

@ -3,6 +3,7 @@ package soju
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"sync" "sync"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -23,14 +24,15 @@ type SASL struct {
} }
type Network struct { type Network struct {
ID int64 ID int64
Name string Name string
Addr string Addr string
Nick string Nick string
Username string Username string
Realname string Realname string
Pass string Pass string
SASL SASL ConnectCommands []string
SASL SASL
} }
func (net *Network) GetName() string { func (net *Network) GetName() string {
@ -63,6 +65,7 @@ CREATE TABLE Network (
username VARCHAR(255), username VARCHAR(255),
realname VARCHAR(255), realname VARCHAR(255),
pass VARCHAR(255), pass VARCHAR(255),
connect_commands VARCHAR(1023),
sasl_mechanism VARCHAR(255), sasl_mechanism VARCHAR(255),
sasl_plain_username VARCHAR(255), sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255), sasl_plain_password VARCHAR(255),
@ -82,6 +85,7 @@ CREATE TABLE Channel (
var migrations = []string{ var migrations = []string{
"", // migration #0 is reserved for schema initialization "", // migration #0 is reserved for schema initialization
"ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
} }
type DB struct { type DB struct {
@ -233,7 +237,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
defer db.lock.RUnlock() defer db.lock.RUnlock()
rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass, rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
sasl_mechanism, sasl_plain_username, sasl_plain_password connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
FROM Network FROM Network
WHERE user = ?`, WHERE user = ?`,
username) username)
@ -245,10 +249,10 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
var networks []Network var networks []Network
for rows.Next() { for rows.Next() {
var net Network var net Network
var name, username, realname, pass *string var name, username, realname, pass, connectCommands *string
var saslMechanism, saslPlainUsername, saslPlainPassword *string var saslMechanism, saslPlainUsername, saslPlainPassword *string
err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname, err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
&pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword) &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,6 +260,9 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
net.Username = fromStringPtr(username) net.Username = fromStringPtr(username)
net.Realname = fromStringPtr(realname) net.Realname = fromStringPtr(realname)
net.Pass = fromStringPtr(pass) net.Pass = fromStringPtr(pass)
if connectCommands != nil {
net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
}
net.SASL.Mechanism = fromStringPtr(saslMechanism) net.SASL.Mechanism = fromStringPtr(saslMechanism)
net.SASL.Plain.Username = fromStringPtr(saslPlainUsername) net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
net.SASL.Plain.Password = fromStringPtr(saslPlainPassword) net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
@ -276,6 +283,7 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
netUsername := toStringPtr(network.Username) netUsername := toStringPtr(network.Username)
realname := toStringPtr(network.Realname) realname := toStringPtr(network.Realname)
pass := toStringPtr(network.Pass) pass := toStringPtr(network.Pass)
connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
var saslMechanism, saslPlainUsername, saslPlainPassword *string var saslMechanism, saslPlainUsername, saslPlainPassword *string
if network.SASL.Mechanism != "" { if network.SASL.Mechanism != "" {
@ -292,18 +300,18 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
var err error var err error
if network.ID != 0 { if network.ID != 0 {
_, err = db.db.Exec(`UPDATE Network _, err = db.db.Exec(`UPDATE Network
SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ? sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
WHERE id = ?`, WHERE id = ?`,
netName, network.Addr, network.Nick, netUsername, realname, pass, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.ID) saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
} else { } else {
var res sql.Result var res sql.Result
res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username, res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
realname, pass, sasl_mechanism, sasl_plain_username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
sasl_plain_password) sasl_plain_password)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
username, netName, network.Addr, network.Nick, netUsername, realname, pass, username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword) saslMechanism, saslPlainUsername, saslPlainPassword)
if err != nil { if err != nil {
return err return err

View File

@ -104,7 +104,7 @@ func init() {
"network": { "network": {
children: serviceCommandSet{ children: serviceCommandSet{
"create": { "create": {
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick]", usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
desc: "add a new network", desc: "add a new network",
handle: handleServiceCreateNetwork, handle: handleServiceCreateNetwork,
}, },
@ -174,6 +174,17 @@ func newFlagSet() *flag.FlagSet {
return fs return fs
} }
type stringSliceVar []string
func (v *stringSliceVar) String() string {
return fmt.Sprint([]string(*v))
}
func (v *stringSliceVar) Set(s string) error {
*v = append(*v, s)
return nil
}
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
fs := newFlagSet() fs := newFlagSet()
addr := fs.String("addr", "", "") addr := fs.String("addr", "", "")
@ -182,6 +193,8 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
pass := fs.String("pass", "", "") pass := fs.String("pass", "", "")
realname := fs.String("realname", "", "") realname := fs.String("realname", "", "")
nick := fs.String("nick", "", "") nick := fs.String("nick", "", "")
var connectCommands stringSliceVar
fs.Var(&connectCommands, "connect-command", "")
if err := fs.Parse(params); err != nil { if err := fs.Parse(params); err != nil {
return err return err
@ -190,18 +203,26 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
return fmt.Errorf("flag -addr is required") return fmt.Errorf("flag -addr is required")
} }
for _, command := range connectCommands {
_, err := irc.ParseMessage(command)
if err != nil {
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
}
}
if *nick == "" { if *nick == "" {
*nick = dc.nick *nick = dc.nick
} }
var err error var err error
network, err := dc.user.createNetwork(&Network{ network, err := dc.user.createNetwork(&Network{
Addr: *addr, Addr: *addr,
Name: *name, Name: *name,
Username: *username, Username: *username,
Pass: *pass, Pass: *pass,
Realname: *realname, Realname: *realname,
Nick: *nick, Nick: *nick,
ConnectCommands: connectCommands,
}) })
if err != nil { if err != nil {
return fmt.Errorf("could not create network: %v", err) return fmt.Errorf("could not create network: %v", err)

View File

@ -1189,6 +1189,15 @@ func (uc *upstreamConn) runUntilRegistered() error {
} }
} }
for _, command := range uc.network.ConnectCommands {
m, err := irc.ParseMessage(command)
if err != nil {
uc.logger.Printf("failed to parse connect command %q: %v", command, err)
} else {
uc.SendMessage(m)
}
}
return nil return nil
} }