contrib/znc-import: new utility
Allows populating the soju database from a ZNC config file.
This commit is contained in:
parent
7ebe47ad4a
commit
1ac895430a
469
contrib/znc-import.go
Normal file
469
contrib/znc-import.go
Normal file
@ -0,0 +1,469 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"git.sr.ht/~emersion/soju"
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
)
|
||||
|
||||
const usage = `usage: znc-import [options...] <znc config path>
|
||||
|
||||
Imports configuration from a ZNC file. Users and networks are merged if they
|
||||
already exist in the soju database. ZNC settings overwrite existing soju
|
||||
settings.
|
||||
|
||||
Options:
|
||||
|
||||
-help Show this help message
|
||||
-config <path> Path to soju config file
|
||||
-user <username> Limit import to username (may be specified multiple times)
|
||||
-network <name> Limit import to network (may be specified multiple times)
|
||||
`
|
||||
|
||||
func init() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), usage)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
var configPath string
|
||||
users := make(map[string]bool)
|
||||
networks := make(map[string]bool)
|
||||
flag.StringVar(&configPath, "config", "", "path to configuration file")
|
||||
flag.Var((*stringSetFlag)(&users), "user", "")
|
||||
flag.Var((*stringSetFlag)(&networks), "network", "")
|
||||
flag.Parse()
|
||||
|
||||
zncPath := flag.Arg(0)
|
||||
if zncPath == "" {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var cfg *config.Server
|
||||
if configPath != "" {
|
||||
var err error
|
||||
cfg, err = config.Load(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load config file: %v", err)
|
||||
}
|
||||
} else {
|
||||
cfg = config.Defaults()
|
||||
}
|
||||
|
||||
db, err := soju.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
f, err := os.Open(zncPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open ZNC configuration file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
zp := zncParser{bufio.NewReader(f), 1}
|
||||
root, err := zp.sectionBody("", "")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
|
||||
}
|
||||
|
||||
l, err := db.ListUsers()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list users in DB: %v", err)
|
||||
}
|
||||
existingUsers := make(map[string]*soju.User, len(l))
|
||||
for i, u := range l {
|
||||
existingUsers[u.Username] = &l[i]
|
||||
}
|
||||
|
||||
usersCreated := 0
|
||||
usersImported := 0
|
||||
networksImported := 0
|
||||
channelsImported := 0
|
||||
root.ForEach("User", func(section *zncSection) {
|
||||
username := section.Name
|
||||
if len(users) > 0 && !users[username] {
|
||||
return
|
||||
}
|
||||
usersImported++
|
||||
|
||||
u, ok := existingUsers[username]
|
||||
if ok {
|
||||
log.Printf("user %q: updating existing user", username)
|
||||
} else {
|
||||
// "!!" is an invalid crypt format, thus disables password auth
|
||||
u = &soju.User{Username: username, Password: "!!"}
|
||||
usersCreated++
|
||||
log.Printf("user %q: creating new user", username)
|
||||
}
|
||||
|
||||
u.Admin = section.Values.Get("Admin") == "true"
|
||||
|
||||
if err := db.StoreUser(u); err != nil {
|
||||
log.Fatalf("failed to store user %q: %v", username, err)
|
||||
}
|
||||
|
||||
l, err := db.ListNetworks(username)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
||||
}
|
||||
existingNetworks := make(map[string]*soju.Network, len(l))
|
||||
for i, n := range l {
|
||||
existingNetworks[n.GetName()] = &l[i]
|
||||
}
|
||||
|
||||
nick := section.Values.Get("Nick")
|
||||
realname := section.Values.Get("RealName")
|
||||
ident := section.Values.Get("Ident")
|
||||
|
||||
section.ForEach("Network", func(section *zncSection) {
|
||||
netName := section.Name
|
||||
if len(networks) > 0 && !networks[netName] {
|
||||
return
|
||||
}
|
||||
networksImported++
|
||||
|
||||
logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
|
||||
logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
|
||||
|
||||
netNick := section.Values.Get("Nick")
|
||||
if netNick == "" {
|
||||
netNick = nick
|
||||
}
|
||||
netRealname := section.Values.Get("RealName")
|
||||
if netRealname == "" {
|
||||
netRealname = realname
|
||||
}
|
||||
netIdent := section.Values.Get("Ident")
|
||||
if netIdent == "" {
|
||||
netIdent = ident
|
||||
}
|
||||
|
||||
for _, name := range section.Values["LoadModule"] {
|
||||
switch name {
|
||||
case "sasl":
|
||||
logger.Printf("warning: SASL credentials not imported")
|
||||
case "nickserv":
|
||||
logger.Printf("warning: NickServ credentials not imported")
|
||||
case "perform":
|
||||
logger.Printf("warning: \"perform\" plugin commands not imported")
|
||||
}
|
||||
}
|
||||
|
||||
u, pass, err := importNetworkServer(section.Values.Get("Server"))
|
||||
if err != nil {
|
||||
logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
|
||||
}
|
||||
|
||||
n, ok := existingNetworks[netName]
|
||||
if ok {
|
||||
logger.Printf("updating existing network")
|
||||
} else {
|
||||
n = &soju.Network{Name: netName}
|
||||
logger.Printf("creating new network")
|
||||
}
|
||||
|
||||
n.Addr = u.String()
|
||||
n.Nick = netNick
|
||||
n.Username = netIdent
|
||||
n.Realname = netRealname
|
||||
n.Pass = pass
|
||||
|
||||
if err := db.StoreNetwork(username, n); err != nil {
|
||||
logger.Fatalf("failed to store network: %v", err)
|
||||
}
|
||||
|
||||
l, err := db.ListChannels(n.ID)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed to list channels: %v", err)
|
||||
}
|
||||
existingChannels := make(map[string]*soju.Channel, len(l))
|
||||
for i, ch := range l {
|
||||
existingChannels[ch.Name] = &l[i]
|
||||
}
|
||||
|
||||
section.ForEach("Chan", func(section *zncSection) {
|
||||
chName := section.Name
|
||||
|
||||
if section.Values.Get("Disabled") == "true" {
|
||||
logger.Printf("skipping import of disabled channel %q", chName)
|
||||
return
|
||||
}
|
||||
|
||||
channelsImported++
|
||||
|
||||
ch, ok := existingChannels[chName]
|
||||
if ok {
|
||||
logger.Printf("channel %q: updating existing channel", chName)
|
||||
} else {
|
||||
ch = &soju.Channel{Name: chName}
|
||||
logger.Printf("channel %q: creating new channel", chName)
|
||||
}
|
||||
|
||||
ch.Key = section.Values.Get("Key")
|
||||
ch.Detached = section.Values.Get("Detached") == "true"
|
||||
|
||||
if err := db.StoreChannel(n.ID, ch); err != nil {
|
||||
logger.Printf("channel %q: failed to store channel: %v", chName, err)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close database: %v", err)
|
||||
}
|
||||
|
||||
if usersCreated > 0 {
|
||||
log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
|
||||
}
|
||||
|
||||
log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
|
||||
}
|
||||
|
||||
func importNetworkServer(s string) (u *url.URL, pass string, err error) {
|
||||
parts := strings.Fields(s)
|
||||
if len(parts) < 2 {
|
||||
return nil, "", fmt.Errorf("expected space-separated host and port")
|
||||
}
|
||||
|
||||
scheme := "irc+insecure"
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
if strings.HasPrefix(port, "+") {
|
||||
port = port[1:]
|
||||
scheme = "ircs"
|
||||
}
|
||||
|
||||
if len(parts) > 2 {
|
||||
pass = parts[2]
|
||||
}
|
||||
|
||||
u = &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host + ":" + port,
|
||||
}
|
||||
return u, pass, nil
|
||||
}
|
||||
|
||||
type zncSection struct {
|
||||
Type string
|
||||
Name string
|
||||
Values zncValues
|
||||
Children []zncSection
|
||||
}
|
||||
|
||||
func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
|
||||
for _, section := range s.Children {
|
||||
if section.Type == typ {
|
||||
f(§ion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type zncValues map[string][]string
|
||||
|
||||
func (zv zncValues) Get(k string) string {
|
||||
if len(zv[k]) == 0 {
|
||||
return ""
|
||||
}
|
||||
return zv[k][0]
|
||||
}
|
||||
|
||||
type zncParser struct {
|
||||
br *bufio.Reader
|
||||
line int
|
||||
}
|
||||
|
||||
func (zp *zncParser) readByte() (byte, error) {
|
||||
b, err := zp.br.ReadByte()
|
||||
if b == '\n' {
|
||||
zp.line++
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (zp *zncParser) readRune() (rune, int, error) {
|
||||
r, n, err := zp.br.ReadRune()
|
||||
if r == '\n' {
|
||||
zp.line++
|
||||
}
|
||||
return r, n, err
|
||||
}
|
||||
|
||||
func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
|
||||
section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
if err := zp.skipSpace(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := zp.br.Peek(2)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
case '<':
|
||||
if b[1] == '/' {
|
||||
break Loop
|
||||
} else {
|
||||
childType, childName, err := zp.sectionHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
child, err := zp.sectionBody(childType, childName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if footerType, err := zp.sectionFooter(); err != nil {
|
||||
return nil, err
|
||||
} else if footerType != childType {
|
||||
return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
|
||||
}
|
||||
section.Children = append(section.Children, *child)
|
||||
}
|
||||
case '/':
|
||||
if b[1] == '/' {
|
||||
if err := zp.skipComment(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
k, v, err := zp.keyValuePair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
section.Values[k] = append(section.Values[k], v)
|
||||
}
|
||||
}
|
||||
|
||||
return section, nil
|
||||
}
|
||||
|
||||
func (zp *zncParser) skipSpace() error {
|
||||
for {
|
||||
r, _, err := zp.readRune()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !unicode.IsSpace(r) {
|
||||
zp.br.UnreadRune()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (zp *zncParser) skipComment() error {
|
||||
if err := zp.expectRune('/'); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := zp.expectRune('/'); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
b, err := zp.readByte()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b == '\n' {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (zp *zncParser) sectionHeader() (string, string, error) {
|
||||
if err := zp.expectRune('<'); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
typ, err := zp.readWord(' ')
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
name, err := zp.readWord('>')
|
||||
return typ, name, err
|
||||
}
|
||||
|
||||
func (zp *zncParser) sectionFooter() (string, error) {
|
||||
if err := zp.expectRune('<'); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := zp.expectRune('/'); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return zp.readWord('>')
|
||||
}
|
||||
|
||||
func (zp *zncParser) keyValuePair() (string, string, error) {
|
||||
k, err := zp.readWord('=')
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
v, err := zp.readWord('\n')
|
||||
return strings.TrimSpace(k), strings.TrimSpace(v), err
|
||||
}
|
||||
|
||||
func (zp *zncParser) expectRune(expected rune) error {
|
||||
r, _, err := zp.readRune()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r != expected {
|
||||
return fmt.Errorf("expected %q, got %q", expected, r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (zp *zncParser) readWord(delim byte) (string, error) {
|
||||
var sb strings.Builder
|
||||
for {
|
||||
b, err := zp.readByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if b == delim {
|
||||
return sb.String(), nil
|
||||
}
|
||||
if b == '\n' {
|
||||
return "", fmt.Errorf("expected %q before newline", delim)
|
||||
}
|
||||
|
||||
sb.WriteByte(b)
|
||||
}
|
||||
}
|
||||
|
||||
type stringSetFlag map[string]bool
|
||||
|
||||
func (v *stringSetFlag) String() string {
|
||||
return fmt.Sprint(map[string]bool(*v))
|
||||
}
|
||||
|
||||
func (v *stringSetFlag) Set(s string) error {
|
||||
(*v)[s] = true
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user