diff --git a/db.go b/db.go index 9ba4d7b..e5ab4ef 100644 --- a/db.go +++ b/db.go @@ -46,6 +46,8 @@ type Channel struct { Key string } +var ErrNoSuchChannel = fmt.Errorf("soju: no such channel") + type DB struct { lock sync.RWMutex db *sql.DB @@ -265,6 +267,23 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { return channels, nil } +func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ch := &Channel{Name: name} + + var key *string + row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name) + if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows { + return nil, ErrNoSuchChannel + } else if err != nil { + return nil, err + } + ch.Key = fromStringPtr(key) + return ch, nil +} + func (db *DB) StoreChannel(networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() diff --git a/downstream.go b/downstream.go index 96ee7b3..dc75f4c 100644 --- a/downstream.go +++ b/downstream.go @@ -896,12 +896,17 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: params, }) - err = dc.srv.db.StoreChannel(uc.network.ID, &Channel{ - Name: upstreamName, - Key: key, - }) - if err != nil { - dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err) + ch, err := dc.srv.db.GetChannel(uc.network.ID, upstreamName) + if err == ErrNoSuchChannel { + ch = &Channel{Name: upstreamName} + } else if err != nil { + return err + } + + ch.Key = key + + if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { + return err } } case "PART":