From 33dacc4fb0e4bc3b490eb25e20a63b9b1ef1b869 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 25 Mar 2020 11:52:24 +0100 Subject: [PATCH] Add support for channel keys --- db.go | 11 ++++++-- downstream.go | 77 +++++++++++++++++++++++++++++++++++++++------------ upstream.go | 6 +++- 3 files changed, 72 insertions(+), 22 deletions(-) diff --git a/db.go b/db.go index 49a4fd6..e511c89 100644 --- a/db.go +++ b/db.go @@ -35,6 +35,7 @@ type Network struct { type Channel struct { ID int64 Name string + Key string } type DB struct { @@ -193,7 +194,7 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID) + rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID) if err != nil { return nil, err } @@ -202,9 +203,11 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { var channels []Channel for rows.Next() { var ch Channel - if err := rows.Scan(&ch.ID, &ch.Name); err != nil { + var key *string + if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil { return nil, err } + ch.Key = fromStringPtr(key) channels = append(channels, ch) } if err := rows.Err(); err != nil { @@ -218,7 +221,9 @@ func (db *DB) StoreChannel(networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() - _, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name) + key := toStringPtr(ch.Key) + _, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key) + VALUES (?, ?, ?)`, networkID, ch.Name, key) return err } diff --git a/downstream.go b/downstream.go index 8db13d5..9636456 100644 --- a/downstream.go +++ b/downstream.go @@ -832,13 +832,18 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { dc.forEachUpstream(func(uc *upstreamConn) { uc.SendMessage(msg) }) - case "JOIN", "PART": - var names string - if err := parseMessageParams(msg, &names); err != nil { + case "JOIN": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { return err } - for _, name := range strings.Split(names, ",") { + var keys []string + if len(msg.Params) > 1 { + keys = strings.Split(msg.Params[1], ",") + } + + for i, name := range strings.Split(namesStr, ",") { uc, upstreamName, err := dc.unmarshalEntity(name) if err != nil { return ircError{&irc.Message{ @@ -847,23 +852,59 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }} } + var key string + if len(keys) > i { + key = keys[i] + } + + params := []string{upstreamName} + if key != "" { + params = append(params, key) + } uc.SendMessage(&irc.Message{ - Command: msg.Command, - Params: []string{upstreamName}, + Command: "JOIN", + Params: params, }) - switch msg.Command { - case "JOIN": - err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{ - Name: upstreamName, - }) - if err != nil { - dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err) - } - case "PART": - if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil { - dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err) - } + err = dc.srv.db.StoreChannel(uc.network.ID, &Channel{ + Name: upstreamName, + Key: key, + }) + if err != nil { + dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err) + } + } + case "PART": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + for _, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, err.Error()}, + }} + } + + params := []string{upstreamName} + if reason != "" { + params = append(params, reason) + } + uc.SendMessage(&irc.Message{ + Command: "PART", + Params: params, + }) + + if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err) } } case "MODE": diff --git a/upstream.go b/upstream.go index 0dff3c3..58f0727 100644 --- a/upstream.go +++ b/upstream.go @@ -317,9 +317,13 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } for _, ch := range channels { + params := []string{ch.Name} + if ch.Key != "" { + params = append(params, ch.Key) + } uc.SendMessage(&irc.Message{ Command: "JOIN", - Params: []string{ch.Name}, + Params: params, }) } case irc.RPL_MYINFO: