diff --git a/contrib/znc-import.go b/contrib/znc-import.go new file mode 100644 index 0000000..53ba5b3 --- /dev/null +++ b/contrib/znc-import.go @@ -0,0 +1,469 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "unicode" + + "git.sr.ht/~emersion/soju" + "git.sr.ht/~emersion/soju/config" +) + +const usage = `usage: znc-import [options...] + +Imports configuration from a ZNC file. Users and networks are merged if they +already exist in the soju database. ZNC settings overwrite existing soju +settings. + +Options: + + -help Show this help message + -config Path to soju config file + -user Limit import to username (may be specified multiple times) + -network Limit import to network (may be specified multiple times) +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + users := make(map[string]bool) + networks := make(map[string]bool) + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Var((*stringSetFlag)(&users), "user", "") + flag.Var((*stringSetFlag)(&networks), "network", "") + flag.Parse() + + zncPath := flag.Arg(0) + if zncPath == "" { + flag.Usage() + os.Exit(1) + } + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + db, err := soju.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + f, err := os.Open(zncPath) + if err != nil { + log.Fatalf("failed to open ZNC configuration file: %v", err) + } + defer f.Close() + + zp := zncParser{bufio.NewReader(f), 1} + root, err := zp.sectionBody("", "") + if err != nil { + log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) + } + + l, err := db.ListUsers() + if err != nil { + log.Fatalf("failed to list users in DB: %v", err) + } + existingUsers := make(map[string]*soju.User, len(l)) + for i, u := range l { + existingUsers[u.Username] = &l[i] + } + + usersCreated := 0 + usersImported := 0 + networksImported := 0 + channelsImported := 0 + root.ForEach("User", func(section *zncSection) { + username := section.Name + if len(users) > 0 && !users[username] { + return + } + usersImported++ + + u, ok := existingUsers[username] + if ok { + log.Printf("user %q: updating existing user", username) + } else { + // "!!" is an invalid crypt format, thus disables password auth + u = &soju.User{Username: username, Password: "!!"} + usersCreated++ + log.Printf("user %q: creating new user", username) + } + + u.Admin = section.Values.Get("Admin") == "true" + + if err := db.StoreUser(u); err != nil { + log.Fatalf("failed to store user %q: %v", username, err) + } + + l, err := db.ListNetworks(username) + if err != nil { + log.Fatalf("failed to list networks for user %q: %v", username, err) + } + existingNetworks := make(map[string]*soju.Network, len(l)) + for i, n := range l { + existingNetworks[n.GetName()] = &l[i] + } + + nick := section.Values.Get("Nick") + realname := section.Values.Get("RealName") + ident := section.Values.Get("Ident") + + section.ForEach("Network", func(section *zncSection) { + netName := section.Name + if len(networks) > 0 && !networks[netName] { + return + } + networksImported++ + + logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName) + logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix) + + netNick := section.Values.Get("Nick") + if netNick == "" { + netNick = nick + } + netRealname := section.Values.Get("RealName") + if netRealname == "" { + netRealname = realname + } + netIdent := section.Values.Get("Ident") + if netIdent == "" { + netIdent = ident + } + + for _, name := range section.Values["LoadModule"] { + switch name { + case "sasl": + logger.Printf("warning: SASL credentials not imported") + case "nickserv": + logger.Printf("warning: NickServ credentials not imported") + case "perform": + logger.Printf("warning: \"perform\" plugin commands not imported") + } + } + + u, pass, err := importNetworkServer(section.Values.Get("Server")) + if err != nil { + logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err) + } + + n, ok := existingNetworks[netName] + if ok { + logger.Printf("updating existing network") + } else { + n = &soju.Network{Name: netName} + logger.Printf("creating new network") + } + + n.Addr = u.String() + n.Nick = netNick + n.Username = netIdent + n.Realname = netRealname + n.Pass = pass + + if err := db.StoreNetwork(username, n); err != nil { + logger.Fatalf("failed to store network: %v", err) + } + + l, err := db.ListChannels(n.ID) + if err != nil { + logger.Fatalf("failed to list channels: %v", err) + } + existingChannels := make(map[string]*soju.Channel, len(l)) + for i, ch := range l { + existingChannels[ch.Name] = &l[i] + } + + section.ForEach("Chan", func(section *zncSection) { + chName := section.Name + + if section.Values.Get("Disabled") == "true" { + logger.Printf("skipping import of disabled channel %q", chName) + return + } + + channelsImported++ + + ch, ok := existingChannels[chName] + if ok { + logger.Printf("channel %q: updating existing channel", chName) + } else { + ch = &soju.Channel{Name: chName} + logger.Printf("channel %q: creating new channel", chName) + } + + ch.Key = section.Values.Get("Key") + ch.Detached = section.Values.Get("Detached") == "true" + + if err := db.StoreChannel(n.ID, ch); err != nil { + logger.Printf("channel %q: failed to store channel: %v", chName, err) + } + }) + }) + }) + + if err := db.Close(); err != nil { + log.Printf("failed to close database: %v", err) + } + + if usersCreated > 0 { + log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password `") + } + + log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported) +} + +func importNetworkServer(s string) (u *url.URL, pass string, err error) { + parts := strings.Fields(s) + if len(parts) < 2 { + return nil, "", fmt.Errorf("expected space-separated host and port") + } + + scheme := "irc+insecure" + host := parts[0] + port := parts[1] + if strings.HasPrefix(port, "+") { + port = port[1:] + scheme = "ircs" + } + + if len(parts) > 2 { + pass = parts[2] + } + + u = &url.URL{ + Scheme: scheme, + Host: host + ":" + port, + } + return u, pass, nil +} + +type zncSection struct { + Type string + Name string + Values zncValues + Children []zncSection +} + +func (s *zncSection) ForEach(typ string, f func(*zncSection)) { + for _, section := range s.Children { + if section.Type == typ { + f(§ion) + } + } +} + +type zncValues map[string][]string + +func (zv zncValues) Get(k string) string { + if len(zv[k]) == 0 { + return "" + } + return zv[k][0] +} + +type zncParser struct { + br *bufio.Reader + line int +} + +func (zp *zncParser) readByte() (byte, error) { + b, err := zp.br.ReadByte() + if b == '\n' { + zp.line++ + } + return b, err +} + +func (zp *zncParser) readRune() (rune, int, error) { + r, n, err := zp.br.ReadRune() + if r == '\n' { + zp.line++ + } + return r, n, err +} + +func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) { + section := &zncSection{Type: typ, Name: name, Values: make(zncValues)} + +Loop: + for { + if err := zp.skipSpace(); err != nil { + return nil, err + } + + b, err := zp.br.Peek(2) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch b[0] { + case '<': + if b[1] == '/' { + break Loop + } else { + childType, childName, err := zp.sectionHeader() + if err != nil { + return nil, err + } + child, err := zp.sectionBody(childType, childName) + if err != nil { + return nil, err + } + if footerType, err := zp.sectionFooter(); err != nil { + return nil, err + } else if footerType != childType { + return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType) + } + section.Children = append(section.Children, *child) + } + case '/': + if b[1] == '/' { + if err := zp.skipComment(); err != nil { + return nil, err + } + break + } + fallthrough + default: + k, v, err := zp.keyValuePair() + if err != nil { + return nil, err + } + section.Values[k] = append(section.Values[k], v) + } + } + + return section, nil +} + +func (zp *zncParser) skipSpace() error { + for { + r, _, err := zp.readRune() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if !unicode.IsSpace(r) { + zp.br.UnreadRune() + return nil + } + } +} + +func (zp *zncParser) skipComment() error { + if err := zp.expectRune('/'); err != nil { + return err + } + if err := zp.expectRune('/'); err != nil { + return err + } + + for { + b, err := zp.readByte() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if b == '\n' { + return nil + } + } +} + +func (zp *zncParser) sectionHeader() (string, string, error) { + if err := zp.expectRune('<'); err != nil { + return "", "", err + } + typ, err := zp.readWord(' ') + if err != nil { + return "", "", err + } + name, err := zp.readWord('>') + return typ, name, err +} + +func (zp *zncParser) sectionFooter() (string, error) { + if err := zp.expectRune('<'); err != nil { + return "", err + } + if err := zp.expectRune('/'); err != nil { + return "", err + } + return zp.readWord('>') +} + +func (zp *zncParser) keyValuePair() (string, string, error) { + k, err := zp.readWord('=') + if err != nil { + return "", "", err + } + v, err := zp.readWord('\n') + return strings.TrimSpace(k), strings.TrimSpace(v), err +} + +func (zp *zncParser) expectRune(expected rune) error { + r, _, err := zp.readRune() + if err != nil { + return err + } else if r != expected { + return fmt.Errorf("expected %q, got %q", expected, r) + } + return nil +} + +func (zp *zncParser) readWord(delim byte) (string, error) { + var sb strings.Builder + for { + b, err := zp.readByte() + if err != nil { + return "", err + } + + if b == delim { + return sb.String(), nil + } + if b == '\n' { + return "", fmt.Errorf("expected %q before newline", delim) + } + + sb.WriteByte(b) + } +} + +type stringSetFlag map[string]bool + +func (v *stringSetFlag) String() string { + return fmt.Sprint(map[string]bool(*v)) +} + +func (v *stringSetFlag) Set(s string) error { + (*v)[s] = true + return nil +}