msgstore: add loadMessageOptions

A struct containing common parameters for all messageStore.Load*
functions returning messages.
This commit is contained in:
Simon Ser 2022-05-09 15:36:39 +02:00
parent 3a7dee8128
commit f508d36c38
4 changed files with 76 additions and 42 deletions

View File

@ -1664,7 +1664,12 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
defer cancel()
targetCM := net.casemap(target)
history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit)
loadOptions := loadMessageOptions{
Network: &net.Network,
Entity: targetCM,
Limit: backlogLimit,
}
history, err := dc.user.msgStore.LoadLatestID(ctx, msgID, &loadOptions)
if err != nil {
dc.logger.Printf("failed to send backlog for %q: %v", target, err)
return
@ -2826,17 +2831,24 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
eventPlayback := dc.caps.IsEnabled("draft/event-playback")
options := loadMessageOptions{
Network: &network.Network,
Entity: entity,
Limit: limit,
Events: eventPlayback,
}
var history []*irc.Message
switch subcommand {
case "BEFORE", "LATEST":
history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
history, err = store.LoadBeforeTime(ctx, bounds[0], time.Time{}, &options)
case "AFTER":
history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
history, err = store.LoadAfterTime(ctx, bounds[0], time.Now(), &options)
case "BETWEEN":
if bounds[0].Before(bounds[1]) {
history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
history, err = store.LoadAfterTime(ctx, bounds[0], bounds[1], &options)
} else {
history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
history, err = store.LoadBeforeTime(ctx, bounds[0], bounds[1], &options)
}
case "TARGETS":
// TODO: support TARGETS in multi-upstream mode

View File

@ -13,6 +13,13 @@ import (
"git.sr.ht/~emersion/soju/database"
)
type loadMessageOptions struct {
Network *database.Network
Entity string
Limit int
Events bool
}
// messageStore is a per-user store for IRC messages.
type messageStore interface {
Close() error
@ -22,7 +29,7 @@ type messageStore interface {
LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error)
LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error)
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
}
@ -45,12 +52,12 @@ type chatHistoryMessageStore interface {
// returned messages must be between and excluding the provided bounds.
// end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
LoadBeforeTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds.
// end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
LoadAfterTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
}
type searchOptions struct {

View File

@ -401,8 +401,8 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e
return msg, t, nil
}
func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *loadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
@ -412,7 +412,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
}
defer f.Close()
historyRing := make([]*irc.Message, limit)
historyRing := make([]*irc.Message, options.Limit)
cur := 0
sc := bufio.NewScanner(f)
@ -425,7 +425,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
}
for sc.Scan() {
msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events)
if err != nil {
return nil, err
} else if msg == nil || !t.After(end) {
@ -437,20 +437,20 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
continue
}
historyRing[cur%limit] = msg
historyRing[cur%options.Limit] = msg
cur++
}
if sc.Err() != nil {
return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
}
n := limit
if cur < limit {
n := options.Limit
if cur < options.Limit {
n = cur
}
start := (cur - n + limit) % limit
start := (cur - n + options.Limit) % options.Limit
if start+n <= limit { // ring doesnt wrap
if start+n <= options.Limit { // ring doesnt wrap
return historyRing[start : start+n], nil
} else { // ring wraps
history := make([]*irc.Message, n)
@ -460,8 +460,8 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity
}
}
func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
@ -473,8 +473,8 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s
var history []*irc.Message
sc := bufio.NewScanner(f)
for sc.Scan() && len(history) < limit {
msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
for sc.Scan() && len(history) < options.Limit {
msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events)
if err != nil {
return nil, err
} else if msg == nil || !t.After(ref) {
@ -495,18 +495,20 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s
return history, nil
}
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() {
start = time.Now()
} else {
start = start.In(time.Local)
}
end = end.In(time.Local)
messages := make([]*irc.Message, limit)
remaining := limit
messages := make([]*irc.Message, options.Limit)
remaining := options.Limit
tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1, selector)
parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesBefore(start, end, &parseOptions, -1, selector)
if err != nil {
return nil, err
}
@ -528,11 +530,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.N
return messages[remaining:], nil
}
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, start, end, options, nil)
}
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local)
if end.IsZero() {
end = time.Now()
@ -540,10 +542,12 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne
end = end.In(time.Local)
}
var messages []*irc.Message
remaining := limit
remaining := options.Limit
tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining, selector)
parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesAfter(start, end, &parseOptions, selector)
if err != nil {
return nil, err
}
@ -564,11 +568,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne
return messages, nil
}
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, start, end, options, nil)
}
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
var afterTime time.Time
var afterOffset int64
if id != "" {
@ -579,14 +583,14 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne
if err != nil {
return nil, err
}
if idNet != network.ID || idEntity != entity {
if idNet != options.Network.ID || idEntity != options.Entity {
return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
}
}
history := make([]*irc.Message, limit)
history := make([]*irc.Message, options.Limit)
t := time.Now()
remaining := limit
remaining := options.Limit
tries := 0
for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
var offset int64 = -1
@ -594,7 +598,9 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne
offset = afterOffset
}
buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset, nil)
parseOptions := *options
parseOptions.Limit = remaining
buf, err := ms.parseMessagesBefore(t, time.Time{}, &parseOptions, offset, nil)
if err != nil {
return nil, err
}
@ -706,10 +712,15 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network,
}
return true
}
loadOptions := loadMessageOptions{
Network: network,
Entity: opts.in,
Limit: opts.limit,
}
if !opts.start.IsZero() {
return ms.getAfterTime(ctx, network, opts.in, opts.start, opts.end, opts.limit, false, selector)
return ms.getAfterTime(ctx, opts.start, opts.end, &loadOptions, selector)
} else {
return ms.getBeforeTime(ctx, network, opts.in, opts.end, opts.start, opts.limit, false, selector)
return ms.getBeforeTime(ctx, opts.end, opts.start, &loadOptions, selector)
}
}

View File

@ -96,19 +96,23 @@ func (ms *memoryMessageStore) Append(network *database.Network, entity string, m
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
if options.Events {
return nil, fmt.Errorf("events are unsupported for memory message store")
}
_, _, seq, err := parseMemoryMsgID(id)
if err != nil {
return nil, err
}
k := ringBufferKey{networkID: network.ID, entity: entity}
k := ringBufferKey{networkID: options.Network.ID, entity: options.Entity}
rb, ok := ms.buffers[k]
if !ok {
return nil, nil
}
return rb.LoadLatestSeq(seq, limit)
return rb.LoadLatestSeq(seq, options.Limit)
}
type messageRingBuffer struct {