msgstore: add context to messageStore methods

This allows setting a hard timeout.
This commit is contained in:
Simon Ser 2021-11-03 18:18:04 +01:00
parent 2b4f0a870f
commit ce69f00e3f
5 changed files with 25 additions and 15 deletions

View File

@ -1328,9 +1328,12 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target, msgID string)
ch := net.channels.Value(target) ch := net.channels.Value(target)
ctx, cancel := context.WithTimeout(context.TODO(), messageStoreTimeout)
defer cancel()
limit := 4000 limit := 4000
targetCM := net.casemap(target) targetCM := net.casemap(target)
history, err := dc.user.msgStore.LoadLatestID(&net.Network, targetCM, msgID, limit) history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, limit)
if err != nil { if err != nil {
dc.logger.Printf("failed to send backlog for %q: %v", target, err) dc.logger.Printf("failed to send backlog for %q: %v", target, err)
return return
@ -2334,21 +2337,24 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
eventPlayback := dc.caps["draft/event-playback"] eventPlayback := dc.caps["draft/event-playback"]
ctx, cancel := context.WithTimeout(context.TODO(), messageStoreTimeout)
defer cancel()
var history []*irc.Message var history []*irc.Message
switch subcommand { switch subcommand {
case "BEFORE": case "BEFORE":
history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
case "AFTER": case "AFTER":
history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
case "BETWEEN": case "BETWEEN":
if bounds[0].Before(bounds[1]) { if bounds[0].Before(bounds[1]) {
history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
} else { } else {
history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
} }
case "TARGETS": case "TARGETS":
// TODO: support TARGETS in multi-upstream mode // TODO: support TARGETS in multi-upstream mode
targets, err := store.ListTargets(&network.Network, bounds[0], bounds[1], limit, eventPlayback) targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback)
if err != nil { if err != nil {
dc.logger.Printf("failed fetching targets for chathistory: %v", err) dc.logger.Printf("failed fetching targets for chathistory: %v", err)
return ircError{&irc.Message{ return ircError{&irc.Message{

View File

@ -2,6 +2,7 @@ package soju
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"time" "time"
@ -19,7 +20,7 @@ type messageStore interface {
LastMsgID(network *Network, entity string, t time.Time) (string, error) LastMsgID(network *Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network, // LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest. // entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
Append(network *Network, entity string, msg *irc.Message) (id string, err error) Append(network *Network, entity string, msg *irc.Message) (id string, err error)
} }
@ -37,17 +38,17 @@ type chatHistoryMessageStore interface {
// It returns up to limit targets, starting from start and ending on end, // It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start. // both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The // LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
} }
type msgIDType uint type msgIDType uint

View File

@ -2,6 +2,7 @@ package soju
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -476,7 +477,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re
return history, nil return history, nil
} }
func (ms *fsMessageStore) LoadBeforeTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
history := make([]*irc.Message, limit) history := make([]*irc.Message, limit)
@ -501,7 +502,7 @@ func (ms *fsMessageStore) LoadBeforeTime(network *Network, entity string, start
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) LoadAfterTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
var history []*irc.Message var history []*irc.Message
@ -525,7 +526,7 @@ func (ms *fsMessageStore) LoadAfterTime(network *Network, entity string, start t
return history, nil return history, nil
} }
func (ms *fsMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -569,7 +570,7 @@ func (ms *fsMessageStore) LoadLatestID(network *Network, entity, id string, limi
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))

View File

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@ -91,7 +92,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id) _, _, seq, err := parseMemoryMsgID(id)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -25,6 +25,7 @@ var connectTimeout = 15 * time.Second
var writeTimeout = 10 * time.Second var writeTimeout = 10 * time.Second
var upstreamMessageDelay = 2 * time.Second var upstreamMessageDelay = 2 * time.Second
var upstreamMessageBurst = 10 var upstreamMessageBurst = 10
var messageStoreTimeout = 10 * time.Second
type Logger interface { type Logger interface {
Print(v ...interface{}) Print(v ...interface{})