Add msgstore package

This commit is contained in:
Simon Ser 2022-05-09 16:25:57 +02:00
parent b92afa7cca
commit 620a8789b0
5 changed files with 75 additions and 60 deletions

View File

@ -18,6 +18,7 @@ import (
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore"
"git.sr.ht/~emersion/soju/xirc"
)
@ -650,7 +651,7 @@ func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
// ackMsgID acknowledges that a message has been received.
func (dc *downstreamConn) ackMsgID(id string) {
netID, entity, err := parseMsgID(id, nil)
netID, entity, err := msgstore.ParseMsgID(id, nil)
if err != nil {
dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
return
@ -1137,7 +1138,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
dc.unsetSupportedCap("draft/account-registration")
}
if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
if _, ok := dc.user.msgStore.(msgstore.ChatHistoryStore); ok && dc.network != nil {
dc.setSupportedCap("draft/event-playback", "")
} else {
dc.unsetSupportedCap("draft/event-playback")
@ -1665,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
defer cancel()
targetCM := net.casemap(target)
loadOptions := loadMessageOptions{
loadOptions := msgstore.LoadMessageOptions{
Network: &net.Network,
Entity: targetCM,
Limit: backlogLimit,
@ -2786,7 +2787,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil
}
store, ok := dc.user.msgStore.(chatHistoryMessageStore)
store, ok := dc.user.msgStore.(msgstore.ChatHistoryStore)
if !ok {
return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND,
@ -2832,7 +2833,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
eventPlayback := dc.caps.IsEnabled("draft/event-playback")
options := loadMessageOptions{
options := msgstore.LoadMessageOptions{
Network: &network.Network,
Entity: entity,
Limit: limit,
@ -2980,7 +2981,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
})
case "SEARCH":
store, ok := dc.user.msgStore.(searchMessageStore)
store, ok := dc.user.msgStore.(msgstore.SearchStore)
if !ok {
return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND,
@ -2995,7 +2996,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
var uc *upstreamConn
const searchMaxLimit = 100
opts := searchMessageOptions{
opts := msgstore.SearchMessageOptions{
Limit: searchMaxLimit,
}
for name, v := range attrs {

View File

@ -1,4 +1,4 @@
package soju
package msgstore
import (
"bufio"
@ -57,7 +57,7 @@ func (fsMsgID) msgIDType() msgIDType {
func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
var id fsMsgID
netID, entity, err = parseMsgID(s, &id)
netID, entity, err = ParseMsgID(s, &id)
if err != nil {
return 0, "", time.Time{}, 0, err
}
@ -89,11 +89,14 @@ type fsMessageStore struct {
files map[string]*fsMessageStoreFile // indexed by entity
}
var _ messageStore = (*fsMessageStore)(nil)
var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
var _ searchMessageStore = (*fsMessageStore)(nil)
var (
_ Store = (*fsMessageStore)(nil)
_ ChatHistoryStore = (*fsMessageStore)(nil)
_ SearchStore = (*fsMessageStore)(nil)
_ RenameNetworkStore = (*fsMessageStore)(nil)
)
func newFSMessageStore(root string, user *database.User) *fsMessageStore {
func NewFSStore(root string, user *database.User) *fsMessageStore {
return &fsMessageStore{
root: filepath.Join(root, escapeFilename(user.Username)),
user: user,
@ -402,7 +405,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e
return msg, t, nil
}
func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *loadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *LoadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path)
if err != nil {
@ -461,7 +464,7 @@ func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, opti
}
}
func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path)
if err != nil {
@ -496,7 +499,7 @@ func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, optio
return history, nil
}
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() {
start = time.Now()
} else {
@ -531,11 +534,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, en
return messages[remaining:], nil
}
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, start, end, options, nil)
}
func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local)
if end.IsZero() {
end = time.Now()
@ -569,11 +572,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end
return messages, nil
}
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, start, end, options, nil)
}
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) {
var afterTime time.Time
var afterOffset int64
if id != "" {
@ -623,7 +626,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *
return history[remaining:], nil
}
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {
start = start.In(time.Local)
end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
@ -642,7 +645,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
return nil, err
}
var targets []chatHistoryTarget
var targets []ChatHistoryTarget
for _, target := range targetNames {
// target is already escaped here
targetPath := filepath.Join(rootPath, target)
@ -673,7 +676,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
continue
}
targets = append(targets, chatHistoryTarget{
targets = append(targets, ChatHistoryTarget{
Name: target,
LatestMessage: t,
})
@ -702,7 +705,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
return targets, nil
}
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *searchMessageOptions) ([]*irc.Message, error) {
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *SearchMessageOptions) ([]*irc.Message, error) {
text := strings.ToLower(opts.Text)
selector := func(m *irc.Message) bool {
if opts.From != "" && m.User != opts.From {
@ -713,7 +716,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network,
}
return true
}
loadOptions := loadMessageOptions{
loadOptions := LoadMessageOptions{
Network: network,
Entity: opts.In,
Limit: opts.Limit,

View File

@ -1,4 +1,4 @@
package soju
package msgstore
import (
"context"
@ -23,7 +23,7 @@ func (memoryMsgID) msgIDType() msgIDType {
func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
var id memoryMsgID
netID, entity, err = parseMsgID(s, &id)
netID, entity, err = ParseMsgID(s, &id)
if err != nil {
return 0, "", 0, err
}
@ -40,13 +40,18 @@ type ringBufferKey struct {
entity string
}
func IsMemoryStore(store Store) bool {
_, ok := store.(*memoryMessageStore)
return ok
}
type memoryMessageStore struct {
buffers map[ringBufferKey]*messageRingBuffer
}
var _ messageStore = (*memoryMessageStore)(nil)
var _ Store = (*memoryMessageStore)(nil)
func newMemoryMessageStore() *memoryMessageStore {
func NewMemoryStore() *memoryMessageStore {
return &memoryMessageStore{
buffers: make(map[ringBufferKey]*messageRingBuffer),
}
@ -96,7 +101,7 @@ func (ms *memoryMessageStore) Append(network *database.Network, entity string, m
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) {
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) {
if options.Events {
return nil, fmt.Errorf("events are unsupported for memory message store")
}

View File

@ -1,4 +1,4 @@
package soju
package msgstore
import (
"bytes"
@ -13,15 +13,15 @@ import (
"git.sr.ht/~emersion/soju/database"
)
type loadMessageOptions struct {
type LoadMessageOptions struct {
Network *database.Network
Entity string
Limit int
Events bool
}
// messageStore is a per-user store for IRC messages.
type messageStore interface {
// Store is a per-user store for IRC messages.
type Store interface {
Close() error
// LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be
@ -29,38 +29,37 @@ type messageStore interface {
LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error)
LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error)
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
}
type chatHistoryTarget struct {
type ChatHistoryTarget struct {
Name string
LatestMessage time.Time
}
// chatHistoryMessageStore is a message store that supports chat history
// operations.
type chatHistoryMessageStore interface {
messageStore
// ChatHistoryStore is a message store that supports chat history operations.
type ChatHistoryStore interface {
Store
// ListTargets lists channels and nicknames by time of the latest message.
// It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds.
// end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds.
// end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error)
LoadAfterTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error)
}
type searchMessageOptions struct {
type SearchMessageOptions struct {
Start time.Time
End time.Time
Limit int
@ -69,13 +68,20 @@ type searchMessageOptions struct {
Text string
}
// searchMessageStore is a message store that supports server-side search
// operations.
type searchMessageStore interface {
messageStore
// SearchStore is a message store that supports server-side search operations.
type SearchStore interface {
Store
// Search returns messages matching the specified options.
Search(ctx context.Context, network *database.Network, options *searchMessageOptions) ([]*irc.Message, error)
Search(ctx context.Context, network *database.Network, options *SearchMessageOptions) ([]*irc.Message, error)
}
// RenameNetworkStore is a message store which needs to be notified of network
// name changes.
type RenameNetworkStore interface {
Store
RenameNetwork(oldNet, newNet *database.Network) error
}
type msgIDType uint
@ -118,7 +124,7 @@ func formatMsgID(netID int64, target string, body msgIDBody) string {
return base64.RawURLEncoding.EncodeToString(buf.Bytes())
}
func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
func ParseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return 0, "", fmt.Errorf("invalid internal message ID: %v", err)

20
user.go
View File

@ -16,6 +16,7 @@ import (
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore"
)
type event interface{}
@ -454,17 +455,17 @@ type user struct {
networks []*network
downstreamConns []*downstreamConn
msgStore messageStore
msgStore msgstore.Store
}
func newUser(srv *Server, record *database.User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore
var msgStore msgstore.Store
if logPath := srv.Config().LogPath; logPath != "" {
msgStore = newFSMessageStore(logPath, record)
msgStore = msgstore.NewFSStore(logPath, record)
} else {
msgStore = newMemoryMessageStore()
msgStore = msgstore.NewMemoryStore()
}
return &user{
@ -951,10 +952,10 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
// The filesystem message store needs to be notified whenever the network
// is renamed
fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
if isFS && updatedNetwork.GetName() != network.GetName() {
if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
renameNetMsgStore, ok := u.msgStore.(msgstore.RenameNetworkStore)
if ok && updatedNetwork.GetName() != network.GetName() {
if err := renameNetMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
network.logger.Printf("failed to update message store network name to %q: %v", updatedNetwork.GetName(), err)
}
}
@ -1049,8 +1050,7 @@ func (u *user) hasPersistentMsgStore() bool {
if u.msgStore == nil {
return false
}
_, isMem := u.msgStore.(*memoryMessageStore)
return !isMem
return !msgstore.IsMemoryStore(u.msgStore)
}
// localAddrForHost returns the local address to use when connecting to host.