Add support for channel keys

This commit is contained in:
Simon Ser 2020-03-25 11:52:24 +01:00
parent 146906ef6b
commit 33dacc4fb0
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 72 additions and 22 deletions

11
db.go
View File

@ -35,6 +35,7 @@ type Network struct {
type Channel struct { type Channel struct {
ID int64 ID int64
Name string Name string
Key string
} }
type DB struct { type DB struct {
@ -193,7 +194,7 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID) rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -202,9 +203,11 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
var channels []Channel var channels []Channel
for rows.Next() { for rows.Next() {
var ch Channel var ch Channel
if err := rows.Scan(&ch.ID, &ch.Name); err != nil { var key *string
if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
return nil, err return nil, err
} }
ch.Key = fromStringPtr(key)
channels = append(channels, ch) channels = append(channels, ch)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -218,7 +221,9 @@ func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
_, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name) key := toStringPtr(ch.Key)
_, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key)
VALUES (?, ?, ?)`, networkID, ch.Name, key)
return err return err
} }

View File

@ -832,13 +832,18 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(msg) uc.SendMessage(msg)
}) })
case "JOIN", "PART": case "JOIN":
var names string var namesStr string
if err := parseMessageParams(msg, &names); err != nil { if err := parseMessageParams(msg, &namesStr); err != nil {
return err return err
} }
for _, name := range strings.Split(names, ",") { var keys []string
if len(msg.Params) > 1 {
keys = strings.Split(msg.Params[1], ",")
}
for i, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name) uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil { if err != nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
@ -847,23 +852,59 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}} }}
} }
var key string
if len(keys) > i {
key = keys[i]
}
params := []string{upstreamName}
if key != "" {
params = append(params, key)
}
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: msg.Command, Command: "JOIN",
Params: []string{upstreamName}, Params: params,
}) })
switch msg.Command { err = dc.srv.db.StoreChannel(uc.network.ID, &Channel{
case "JOIN": Name: upstreamName,
err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{ Key: key,
Name: upstreamName, })
}) if err != nil {
if err != nil { dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err) }
} }
case "PART": case "PART":
if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil { var namesStr string
dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err) if err := parseMessageParams(msg, &namesStr); err != nil {
} return err
}
var reason string
if len(msg.Params) > 1 {
reason = msg.Params[1]
}
for _, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, err.Error()},
}}
}
params := []string{upstreamName}
if reason != "" {
params = append(params, reason)
}
uc.SendMessage(&irc.Message{
Command: "PART",
Params: params,
})
if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
} }
} }
case "MODE": case "MODE":

View File

@ -317,9 +317,13 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
for _, ch := range channels { for _, ch := range channels {
params := []string{ch.Name}
if ch.Key != "" {
params = append(params, ch.Key)
}
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "JOIN", Command: "JOIN",
Params: []string{ch.Name}, Params: params,
}) })
} }
case irc.RPL_MYINFO: case irc.RPL_MYINFO: