db198335aa
Instead of having one ring buffer per network, each network has one ring buffer per entity (channel or nick). This allows history to be more fair: if there's a lot of activity in a channel, it won't prune activity in other channels. We now track history sequence numbers per client and per network in networkHistory. The overall list of offline clients is still tracked in network.offlineClients. When all clients have received history, the ring buffer can be released. In the future, we should get rid of too-old offline clients to avoid having to maintain history for them forever. We should also add a per-user limit on the number of ring buffers.
396 lines
8.8 KiB
Go
396 lines
8.8 KiB
Go
package soju
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"gopkg.in/irc.v3"
|
|
)
|
|
|
|
type event interface{}
|
|
|
|
type eventUpstreamMessage struct {
|
|
msg *irc.Message
|
|
uc *upstreamConn
|
|
}
|
|
|
|
type eventUpstreamConnectionError struct {
|
|
net *network
|
|
err error
|
|
}
|
|
|
|
type eventUpstreamConnected struct {
|
|
uc *upstreamConn
|
|
}
|
|
|
|
type eventUpstreamDisconnected struct {
|
|
uc *upstreamConn
|
|
}
|
|
|
|
type eventUpstreamError struct {
|
|
uc *upstreamConn
|
|
err error
|
|
}
|
|
|
|
type eventDownstreamMessage struct {
|
|
msg *irc.Message
|
|
dc *downstreamConn
|
|
}
|
|
|
|
type eventDownstreamConnected struct {
|
|
dc *downstreamConn
|
|
}
|
|
|
|
type eventDownstreamDisconnected struct {
|
|
dc *downstreamConn
|
|
}
|
|
|
|
type networkHistory struct {
|
|
offlineClients map[string]uint64 // indexed by client name
|
|
ring *Ring // can be nil if there are no offline clients
|
|
}
|
|
|
|
type network struct {
|
|
Network
|
|
user *user
|
|
stopped chan struct{}
|
|
|
|
conn *upstreamConn
|
|
history map[string]*networkHistory // indexed by entity
|
|
offlineClients map[string]struct{} // indexed by client name
|
|
lastError error
|
|
}
|
|
|
|
func newNetwork(user *user, record *Network) *network {
|
|
return &network{
|
|
Network: *record,
|
|
user: user,
|
|
stopped: make(chan struct{}),
|
|
history: make(map[string]*networkHistory),
|
|
offlineClients: make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
func (net *network) forEachDownstream(f func(*downstreamConn)) {
|
|
net.user.forEachDownstream(func(dc *downstreamConn) {
|
|
if dc.network != nil && dc.network != net {
|
|
return
|
|
}
|
|
f(dc)
|
|
})
|
|
}
|
|
|
|
func (net *network) run() {
|
|
var lastTry time.Time
|
|
for {
|
|
select {
|
|
case <-net.stopped:
|
|
return
|
|
default:
|
|
// This space is intentionally left blank
|
|
}
|
|
|
|
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
|
|
delay := retryConnectMinDelay - dur
|
|
net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
|
|
time.Sleep(delay)
|
|
}
|
|
lastTry = time.Now()
|
|
|
|
uc, err := connectToUpstream(net)
|
|
if err != nil {
|
|
net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
|
|
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
|
|
continue
|
|
}
|
|
|
|
uc.register()
|
|
if err := uc.runUntilRegistered(); err != nil {
|
|
uc.logger.Printf("failed to register: %v", err)
|
|
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
|
|
uc.Close()
|
|
continue
|
|
}
|
|
|
|
net.user.events <- eventUpstreamConnected{uc}
|
|
if err := uc.readMessages(net.user.events); err != nil {
|
|
uc.logger.Printf("failed to handle messages: %v", err)
|
|
net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
|
|
}
|
|
uc.Close()
|
|
net.user.events <- eventUpstreamDisconnected{uc}
|
|
}
|
|
}
|
|
|
|
func (net *network) upstream() *upstreamConn {
|
|
return net.conn
|
|
}
|
|
|
|
func (net *network) Stop() {
|
|
select {
|
|
case <-net.stopped:
|
|
return
|
|
default:
|
|
close(net.stopped)
|
|
}
|
|
|
|
if uc := net.upstream(); uc != nil {
|
|
uc.Close()
|
|
}
|
|
}
|
|
|
|
func (net *network) createUpdateChannel(ch *Channel) error {
|
|
if dbCh, err := net.user.srv.db.GetChannel(net.ID, ch.Name); err == nil {
|
|
ch.ID = dbCh.ID
|
|
} else if err != ErrNoSuchChannel {
|
|
return err
|
|
}
|
|
return net.user.srv.db.StoreChannel(net.ID, ch)
|
|
}
|
|
|
|
func (net *network) deleteChannel(name string) error {
|
|
return net.user.srv.db.DeleteChannel(net.ID, name)
|
|
}
|
|
|
|
type user struct {
|
|
User
|
|
srv *Server
|
|
|
|
events chan event
|
|
|
|
networks []*network
|
|
downstreamConns []*downstreamConn
|
|
|
|
// LIST commands in progress
|
|
pendingLISTs []pendingLIST
|
|
}
|
|
|
|
type pendingLIST struct {
|
|
downstreamID uint64
|
|
// list of per-upstream LIST commands not yet sent or completed
|
|
pendingCommands map[int64]*irc.Message
|
|
}
|
|
|
|
func newUser(srv *Server, record *User) *user {
|
|
return &user{
|
|
User: *record,
|
|
srv: srv,
|
|
events: make(chan event, 64),
|
|
}
|
|
}
|
|
|
|
func (u *user) forEachNetwork(f func(*network)) {
|
|
for _, network := range u.networks {
|
|
f(network)
|
|
}
|
|
}
|
|
|
|
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
|
for _, network := range u.networks {
|
|
uc := network.upstream()
|
|
if uc == nil {
|
|
continue
|
|
}
|
|
f(uc)
|
|
}
|
|
}
|
|
|
|
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
|
for _, dc := range u.downstreamConns {
|
|
f(dc)
|
|
}
|
|
}
|
|
|
|
func (u *user) getNetwork(name string) *network {
|
|
for _, network := range u.networks {
|
|
if network.Addr == name {
|
|
return network
|
|
}
|
|
if network.Name != "" && network.Name == name {
|
|
return network
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *user) run() {
|
|
networks, err := u.srv.db.ListNetworks(u.Username)
|
|
if err != nil {
|
|
u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
|
|
return
|
|
}
|
|
|
|
for _, record := range networks {
|
|
network := newNetwork(u, &record)
|
|
u.networks = append(u.networks, network)
|
|
|
|
go network.run()
|
|
}
|
|
|
|
for e := range u.events {
|
|
switch e := e.(type) {
|
|
case eventUpstreamConnected:
|
|
uc := e.uc
|
|
|
|
uc.network.conn = uc
|
|
|
|
uc.updateAway()
|
|
|
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
|
sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
|
|
})
|
|
uc.network.lastError = nil
|
|
case eventUpstreamDisconnected:
|
|
uc := e.uc
|
|
|
|
uc.network.conn = nil
|
|
|
|
for _, ml := range uc.messageLoggers {
|
|
if err := ml.Close(); err != nil {
|
|
uc.logger.Printf("failed to close message logger: %v", err)
|
|
}
|
|
}
|
|
|
|
uc.endPendingLISTs(true)
|
|
|
|
if uc.network.lastError == nil {
|
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
|
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
|
|
})
|
|
}
|
|
case eventUpstreamConnectionError:
|
|
net := e.net
|
|
|
|
if net.lastError == nil || net.lastError.Error() != e.err.Error() {
|
|
net.forEachDownstream(func(dc *downstreamConn) {
|
|
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
|
|
})
|
|
}
|
|
net.lastError = e.err
|
|
case eventUpstreamError:
|
|
uc := e.uc
|
|
|
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
|
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
|
|
})
|
|
uc.network.lastError = e.err
|
|
case eventUpstreamMessage:
|
|
msg, uc := e.msg, e.uc
|
|
if uc.isClosed() {
|
|
uc.logger.Printf("ignoring message on closed connection: %v", msg)
|
|
break
|
|
}
|
|
if err := uc.handleMessage(msg); err != nil {
|
|
uc.logger.Printf("failed to handle message %q: %v", msg, err)
|
|
}
|
|
case eventDownstreamConnected:
|
|
dc := e.dc
|
|
|
|
if err := dc.welcome(); err != nil {
|
|
dc.logger.Printf("failed to handle new registered connection: %v", err)
|
|
break
|
|
}
|
|
|
|
u.downstreamConns = append(u.downstreamConns, dc)
|
|
|
|
u.forEachUpstream(func(uc *upstreamConn) {
|
|
uc.updateAway()
|
|
})
|
|
case eventDownstreamDisconnected:
|
|
dc := e.dc
|
|
|
|
for i := range u.downstreamConns {
|
|
if u.downstreamConns[i] == dc {
|
|
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Save history if we're the last client with this name
|
|
skipHistory := make(map[*network]bool)
|
|
u.forEachDownstream(func(conn *downstreamConn) {
|
|
if dc.clientName == conn.clientName {
|
|
skipHistory[conn.network] = true
|
|
}
|
|
})
|
|
|
|
dc.forEachNetwork(func(net *network) {
|
|
if skipHistory[net] || skipHistory[nil] {
|
|
return
|
|
}
|
|
|
|
net.offlineClients[dc.clientName] = struct{}{}
|
|
for _, history := range net.history {
|
|
history.offlineClients[dc.clientName] = history.ring.Cur()
|
|
}
|
|
})
|
|
|
|
u.forEachUpstream(func(uc *upstreamConn) {
|
|
uc.updateAway()
|
|
})
|
|
case eventDownstreamMessage:
|
|
msg, dc := e.msg, e.dc
|
|
if dc.isClosed() {
|
|
dc.logger.Printf("ignoring message on closed connection: %v", msg)
|
|
break
|
|
}
|
|
err := dc.handleMessage(msg)
|
|
if ircErr, ok := err.(ircError); ok {
|
|
ircErr.Message.Prefix = dc.srv.prefix()
|
|
dc.SendMessage(ircErr.Message)
|
|
} else if err != nil {
|
|
dc.logger.Printf("failed to handle message %q: %v", msg, err)
|
|
dc.Close()
|
|
}
|
|
default:
|
|
u.srv.Logger.Printf("received unknown event type: %T", e)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *user) createNetwork(net *Network) (*network, error) {
|
|
if net.ID != 0 {
|
|
panic("tried creating an already-existing network")
|
|
}
|
|
|
|
network := newNetwork(u, net)
|
|
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u.networks = append(u.networks, network)
|
|
|
|
go network.run()
|
|
return network, nil
|
|
}
|
|
|
|
func (u *user) deleteNetwork(id int64) error {
|
|
for i, net := range u.networks {
|
|
if net.ID != id {
|
|
continue
|
|
}
|
|
|
|
if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
u.forEachDownstream(func(dc *downstreamConn) {
|
|
if dc.network != nil && dc.network == net {
|
|
dc.Close()
|
|
}
|
|
})
|
|
|
|
net.Stop()
|
|
u.networks = append(u.networks[:i], u.networks[i+1:]...)
|
|
return nil
|
|
}
|
|
|
|
panic("tried deleting a non-existing network")
|
|
}
|
|
|
|
func (u *user) updatePassword(hashed string) error {
|
|
u.User.Password = hashed
|
|
return u.srv.db.UpdatePassword(&u.User)
|
|
}
|