From 276ce12e7c72d2924a85addfd24aad05d5424918 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Sat, 11 Apr 2020 17:00:40 +0200 Subject: [PATCH] Add network.channels, remove DB.GetChannel Store the list of configured channels in the network data structure. This removes the need for a database lookup and will be useful for detached channels. --- db.go | 19 ------------------- upstream.go | 8 +------- user.go | 34 ++++++++++++++++++++++++++-------- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/db.go b/db.go index c093064..08a1fa0 100644 --- a/db.go +++ b/db.go @@ -48,8 +48,6 @@ type Channel struct { Key string } -var ErrNoSuchChannel = fmt.Errorf("soju: no such channel") - const schema = ` CREATE TABLE User ( username VARCHAR(255) PRIMARY KEY, @@ -371,23 +369,6 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { return channels, nil } -func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) { - db.lock.RLock() - defer db.lock.RUnlock() - - ch := &Channel{Name: name} - - var key *string - row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name) - if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows { - return nil, ErrNoSuchChannel - } else if err != nil { - return nil, err - } - ch.Key = fromStringPtr(key) - return ch, nil -} - func (db *DB) StoreChannel(networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() diff --git a/upstream.go b/upstream.go index 599a9d7..98a05c2 100644 --- a/upstream.go +++ b/upstream.go @@ -421,13 +421,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.registered = true uc.logger.Printf("connection registered") - channels, err := uc.srv.db.ListChannels(uc.network.ID) - if err != nil { - uc.logger.Printf("failed to list channels from database: %v", err) - break - } - - for _, ch := range channels { + for _, ch := range uc.network.channels { params := []string{ch.Name} if ch.Key != "" { params = append(params, ch.Key) diff --git a/user.go b/user.go index dfa0f15..d504796 100644 --- a/user.go +++ b/user.go @@ -56,16 +56,23 @@ type network struct { stopped chan struct{} conn *upstreamConn + channels map[string]*Channel history map[string]*networkHistory // indexed by entity offlineClients map[string]struct{} // indexed by client name lastError error } -func newNetwork(user *user, record *Network) *network { +func newNetwork(user *user, record *Network, channels []Channel) *network { + m := make(map[string]*Channel, len(channels)) + for _, ch := range channels { + m[ch.Name] = &ch + } + return &network{ Network: *record, user: user, stopped: make(chan struct{}), + channels: m, history: make(map[string]*networkHistory), offlineClients: make(map[string]struct{}), } @@ -140,16 +147,22 @@ func (net *network) Stop() { } func (net *network) createUpdateChannel(ch *Channel) error { - if dbCh, err := net.user.srv.db.GetChannel(net.ID, ch.Name); err == nil { - ch.ID = dbCh.ID - } else if err != ErrNoSuchChannel { + if current, ok := net.channels[ch.Name]; ok { + ch.ID = current.ID // update channel if it already exists + } + if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil { return err } - return net.user.srv.db.StoreChannel(net.ID, ch) + net.channels[ch.Name] = ch + return nil } func (net *network) deleteChannel(name string) error { - return net.user.srv.db.DeleteChannel(net.ID, name) + if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil { + return err + } + delete(net.channels, name) + return nil } type user struct { @@ -221,7 +234,12 @@ func (u *user) run() { } for _, record := range networks { - network := newNetwork(u, &record) + channels, err := u.srv.db.ListChannels(record.ID) + if err != nil { + u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) + } + + network := newNetwork(u, &record, channels) u.networks = append(u.networks, network) go network.run() @@ -353,7 +371,7 @@ func (u *user) createNetwork(net *Network) (*network, error) { panic("tried creating an already-existing network") } - network := newNetwork(u, net) + network := newNetwork(u, net, nil) err := u.srv.db.StoreNetwork(u.Username, &network.Network) if err != nil { return nil, err