msgstore: add loadMessageOptions

A struct containing common parameters for all messageStore.Load*
functions returning messages.
This commit is contained in:
Simon Ser 2022-05-09 15:36:39 +02:00
parent 3a7dee8128
commit f508d36c38
4 changed files with 76 additions and 42 deletions

View File

@ -1664,7 +1664,12 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
defer cancel() defer cancel()
targetCM := net.casemap(target) targetCM := net.casemap(target)
history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) loadOptions := loadMessageOptions{
Network: &net.Network,
Entity: targetCM,
Limit: backlogLimit,
}
history, err := dc.user.msgStore.LoadLatestID(ctx, msgID, &loadOptions)
if err != nil { if err != nil {
dc.logger.Printf("failed to send backlog for %q: %v", target, err) dc.logger.Printf("failed to send backlog for %q: %v", target, err)
return return
@ -2826,17 +2831,24 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
eventPlayback := dc.caps.IsEnabled("draft/event-playback") eventPlayback := dc.caps.IsEnabled("draft/event-playback")
options := loadMessageOptions{
Network: &network.Network,
Entity: entity,
Limit: limit,
Events: eventPlayback,
}
var history []*irc.Message var history []*irc.Message
switch subcommand { switch subcommand {
case "BEFORE", "LATEST": case "BEFORE", "LATEST":
history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) history, err = store.LoadBeforeTime(ctx, bounds[0], time.Time{}, &options)
case "AFTER": case "AFTER":
history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) history, err = store.LoadAfterTime(ctx, bounds[0], time.Now(), &options)
case "BETWEEN": case "BETWEEN":
if bounds[0].Before(bounds[1]) { if bounds[0].Before(bounds[1]) {
history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadAfterTime(ctx, bounds[0], bounds[1], &options)
} else { } else {
history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadBeforeTime(ctx, bounds[0], bounds[1], &options)
} }
case "TARGETS": case "TARGETS":
// TODO: support TARGETS in multi-upstream mode // TODO: support TARGETS in multi-upstream mode

View File

@ -13,6 +13,13 @@ import (
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
) )
type loadMessageOptions struct {
Network *database.Network
Entity string
Limit int
Events bool
}
// messageStore is a per-user store for IRC messages. // messageStore is a per-user store for IRC messages.
type messageStore interface { type messageStore interface {
Close() error Close() error
@ -22,7 +29,7 @@ type messageStore interface {
LastMsgID(network *database.Network, entity string, t time.Time) (string, error) LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network, // LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest. // entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error)
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error) Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
} }
@ -45,12 +52,12 @@ type chatHistoryMessageStore interface {
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadBeforeTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadAfterTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
} }
type searchOptions struct { type searchOptions struct {

View File

@ -401,8 +401,8 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e
return msg, t, nil return msg, t, nil
} }
func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *loadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -412,7 +412,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
} }
defer f.Close() defer f.Close()
historyRing := make([]*irc.Message, limit) historyRing := make([]*irc.Message, options.Limit)
cur := 0 cur := 0
sc := bufio.NewScanner(f) sc := bufio.NewScanner(f)
@ -425,7 +425,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
} }
for sc.Scan() { for sc.Scan() {
msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events)
if err != nil { if err != nil {
return nil, err return nil, err
} else if msg == nil || !t.After(end) { } else if msg == nil || !t.After(end) {
@ -437,20 +437,20 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
continue continue
} }
historyRing[cur%limit] = msg historyRing[cur%options.Limit] = msg
cur++ cur++
} }
if sc.Err() != nil { if sc.Err() != nil {
return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
} }
n := limit n := options.Limit
if cur < limit { if cur < options.Limit {
n = cur n = cur
} }
start := (cur - n + limit) % limit start := (cur - n + options.Limit) % options.Limit
if start+n <= limit { // ring doesnt wrap if start+n <= options.Limit { // ring doesnt wrap
return historyRing[start : start+n], nil return historyRing[start : start+n], nil
} else { // ring wraps } else { // ring wraps
history := make([]*irc.Message, n) history := make([]*irc.Message, n)
@ -460,8 +460,8 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
} }
} }
func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -473,8 +473,8 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s
var history []*irc.Message var history []*irc.Message
sc := bufio.NewScanner(f) sc := bufio.NewScanner(f)
for sc.Scan() && len(history) < limit { for sc.Scan() && len(history) < options.Limit {
msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events)
if err != nil { if err != nil {
return nil, err return nil, err
} else if msg == nil || !t.After(ref) { } else if msg == nil || !t.After(ref) {
@ -495,18 +495,20 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s
return history, nil return history, nil
} }
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() { if start.IsZero() {
start = time.Now() start = time.Now()
} else { } else {
start = start.In(time.Local) start = start.In(time.Local)
} }
end = end.In(time.Local) end = end.In(time.Local)
messages := make([]*irc.Message, limit) messages := make([]*irc.Message, options.Limit)
remaining := limit remaining := options.Limit
tries := 0 tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1, selector) parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesBefore(start, end, &parseOptions, -1, selector)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -528,11 +530,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.N
return messages[remaining:], nil return messages[remaining:], nil
} }
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil) return ms.getBeforeTime(ctx, start, end, options, nil)
} }
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
if end.IsZero() { if end.IsZero() {
end = time.Now() end = time.Now()
@ -540,10 +542,12 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne
end = end.In(time.Local) end = end.In(time.Local)
} }
var messages []*irc.Message var messages []*irc.Message
remaining := limit remaining := options.Limit
tries := 0 tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining, selector) parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesAfter(start, end, &parseOptions, selector)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -564,11 +568,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne
return messages, nil return messages, nil
} }
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil) return ms.getAfterTime(ctx, start, end, options, nil)
} }
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -579,14 +583,14 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne
if err != nil { if err != nil {
return nil, err return nil, err
} }
if idNet != network.ID || idEntity != entity { if idNet != options.Network.ID || idEntity != options.Entity {
return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
} }
} }
history := make([]*irc.Message, limit) history := make([]*irc.Message, options.Limit)
t := time.Now() t := time.Now()
remaining := limit remaining := options.Limit
tries := 0 tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
var offset int64 = -1 var offset int64 = -1
@ -594,7 +598,9 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne
offset = afterOffset offset = afterOffset
} }
buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset, nil) parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesBefore(t, time.Time{}, &parseOptions, offset, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -706,10 +712,15 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network,
} }
return true return true
} }
loadOptions := loadMessageOptions{
Network: network,
Entity: opts.in,
Limit: opts.limit,
}
if !opts.start.IsZero() { if !opts.start.IsZero() {
return ms.getAfterTime(ctx, network, opts.in, opts.start, opts.end, opts.limit, false, selector) return ms.getAfterTime(ctx, opts.start, opts.end, &loadOptions, selector)
} else { } else {
return ms.getBeforeTime(ctx, network, opts.in, opts.end, opts.start, opts.limit, false, selector) return ms.getBeforeTime(ctx, opts.end, opts.start, &loadOptions, selector)
} }
} }

View File

@ -96,19 +96,23 @@ func (ms *memoryMessageStore) Append(network *database.Network, entity string, m
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
if options.Events {
return nil, fmt.Errorf("events are unsupported for memory message store")
}
_, _, seq, err := parseMemoryMsgID(id) _, _, seq, err := parseMemoryMsgID(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
k := ringBufferKey{networkID: network.ID, entity: entity} k := ringBufferKey{networkID: options.Network.ID, entity: options.Entity}
rb, ok := ms.buffers[k] rb, ok := ms.buffers[k]
if !ok { if !ok {
return nil, nil return nil, nil
} }
return rb.LoadLatestSeq(seq, limit) return rb.LoadLatestSeq(seq, options.Limit)
} }
type messageRingBuffer struct { type messageRingBuffer struct {