From 0c4e9b539c7fcf1c0031ee16482bf4b949300681 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 12 Mar 2020 18:33:03 +0100 Subject: [PATCH] Update DB on JOIN and PART --- db.go | 30 ++++++++++++++++++------------ downstream.go | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/db.go b/db.go index 3111441..763be7d 100644 --- a/db.go +++ b/db.go @@ -77,22 +77,12 @@ func (db *DB) CreateUser(user *User) error { db.lock.Lock() defer db.lock.Unlock() - tx, err := db.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - var password *string if user.Password != "" { password = &user.Password } - _, err = tx.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password) - if err != nil { - return err - } - - return tx.Commit() + _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password) + return err } func (db *DB) ListNetworks(username string) ([]Network, error) { @@ -151,3 +141,19 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { return channels, nil } + +func (db *DB) StoreChannel(networkID int64, ch *Channel) error { + db.lock.Lock() + defer db.lock.Unlock() + + _, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name) + return err +} + +func (db *DB) DeleteChannel(networkID int64, name string) error { + db.lock.Lock() + defer db.lock.Unlock() + + _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name) + return err +} diff --git a/downstream.go b/downstream.go index e1020b5..8a70e33 100644 --- a/downstream.go +++ b/downstream.go @@ -114,7 +114,25 @@ func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { }) } +// upstream returns the upstream connection, if any. If there are zero or if +// there are multiple upstream connections, it returns nil. +func (dc *downstreamConn) upstream() *upstreamConn { + if dc.network == nil { + return nil + } + + var upstream *upstreamConn + dc.forEachUpstream(func(uc *upstreamConn) { + upstream = uc + }) + return upstream +} + func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) { + if uc := dc.upstream(); uc != nil { + return uc, name, nil + } + // TODO: extract network name from channel name if dc.upstream == nil var channel *upstreamChannel var err error @@ -461,7 +479,20 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Command: msg.Command, Params: []string{upstreamName}, }) - // TODO: add/remove channel from upstream config + + switch msg.Command { + case "JOIN": + err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{ + Name: upstreamName, + }) + if err != nil { + dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err) + } + case "PART": + if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err) + } + } case "MODE": if msg.Prefix == nil { return fmt.Errorf("missing prefix")