Add message store abstraction

Introduce a messageStore type, which will allow for multiple
implementations (e.g. in the DB or in-memory instead of on-disk).

The message store is per-user so that we don't need to deal with locking
and it's easier to implement per-user limits.
This commit is contained in:
Simon Ser 2020-10-25 11:13:51 +01:00
parent af1e578936
commit 05aafb5edf
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
4 changed files with 83 additions and 82 deletions

View File

@ -863,7 +863,7 @@ func (dc *downstreamConn) welcome() error {
continue
}
lastID, err := lastMsgID(net, target, time.Now())
lastID, err := dc.user.msgStore.LastMsgID(net, target, time.Now())
if err != nil {
dc.logger.Printf("failed to get last message ID: %v", err)
continue
@ -876,7 +876,7 @@ func (dc *downstreamConn) welcome() error {
}
func (dc *downstreamConn) sendNetworkHistory(net *network) {
if dc.caps["draft/chathistory"] || dc.srv.LogPath == "" {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
return
}
for target, history := range net.history {
@ -890,7 +890,7 @@ func (dc *downstreamConn) sendNetworkHistory(net *network) {
}
limit := 4000
history, err := loadHistoryLatestID(net, target, lastDelivered, limit)
history, err := dc.user.msgStore.LoadLatestID(net, target, lastDelivered, limit)
if err != nil {
dc.logger.Printf("failed to send implicit history for %q: %v", target, err)
continue
@ -1601,7 +1601,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}}
}
if dc.srv.LogPath == "" {
if dc.user.msgStore == nil {
return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND,
Params: []string{dc.nick, subcommand, "Unknown command"},
@ -1641,9 +1641,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
var history []*irc.Message
switch subcommand {
case "BEFORE":
history, err = loadHistoryBeforeTime(uc.network, entity, timestamp, limit)
history, err = dc.user.msgStore.LoadBeforeTime(uc.network, entity, timestamp, limit)
case "AFTER":
history, err = loadHistoryAfterTime(uc.network, entity, timestamp, limit)
history, err = dc.user.msgStore.LoadAfterTime(uc.network, entity, timestamp, limit)
default:
// TODO: support LATEST, BETWEEN
return ircError{&irc.Message{

View File

@ -12,32 +12,28 @@ import (
"gopkg.in/irc.v3"
)
const messageLoggerMaxTries = 100
type messageLogger struct {
network *network
entity string
path string
file *os.File
}
func newMessageLogger(network *network, entity string) *messageLogger {
return &messageLogger{
network: network,
entity: entity,
}
}
const messageStoreMaxTries = 100
var escapeFilename = strings.NewReplacer("/", "-", "\\", "-")
func logPath(network *network, entity string, t time.Time) string {
user := network.user
srv := user.srv
// messageStore is a per-user store for IRC messages.
type messageStore struct {
root string
files map[string]*os.File // indexed by entity
}
func newMessageStore(root, username string) *messageStore {
return &messageStore{
root: filepath.Join(root, escapeFilename.Replace(username)),
files: make(map[string]*os.File),
}
}
func (ms *messageStore) logPath(network *network, entity string, t time.Time) string {
year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(srv.LogPath, escapeFilename.Replace(user.Username), escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename)
return filepath.Join(ms.root, escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename)
}
func parseMsgID(s string) (network, entity string, t time.Time, offset int64, err error) {
@ -64,11 +60,11 @@ func nextMsgID(network *network, entity string, t time.Time, f *os.File) (string
return formatMsgID(network.GetName(), entity, t, offset), nil
}
// lastMsgID queries the last message ID for the given network, entity and
// LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be
// used in history queries.
func lastMsgID(network *network, entity string, t time.Time) (string, error) {
p := logPath(network, entity, t)
func (ms *messageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) {
p := ms.logPath(network, entity, t)
fi, err := os.Stat(p)
if os.IsNotExist(err) {
return formatMsgID(network.GetName(), entity, t, -1), nil
@ -78,7 +74,7 @@ func lastMsgID(network *network, entity string, t time.Time) (string, error) {
return formatMsgID(network.GetName(), entity, t, fi.Size()-1), nil
}
func (ml *messageLogger) Append(msg *irc.Message) (string, error) {
func (ms *messageStore) Append(network *network, entity string, msg *irc.Message) (string, error) {
s := formatMessage(msg)
if s == "" {
return "", nil
@ -97,44 +93,50 @@ func (ml *messageLogger) Append(msg *irc.Message) (string, error) {
}
// TODO: enforce maximum open file handles (LRU cache of file handles)
f := ms.files[entity]
// TODO: handle non-monotonic clock behaviour
path := logPath(ml.network, ml.entity, t)
if ml.path != path {
if ml.file != nil {
ml.file.Close()
path := ms.logPath(network, entity, t)
if f == nil || f.Name() != path {
if f != nil {
f.Close()
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return "", fmt.Errorf("failed to create logs directory %q: %v", dir, err)
return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
}
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
var err error
f, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
if err != nil {
return "", fmt.Errorf("failed to open log file %q: %v", path, err)
return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
}
ml.path = path
ml.file = f
ms.files[entity] = f
}
msgID, err := nextMsgID(ml.network, ml.entity, t, ml.file)
msgID, err := nextMsgID(network, entity, t, f)
if err != nil {
return "", fmt.Errorf("failed to generate message ID: %v", err)
}
_, err = fmt.Fprintf(ml.file, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
_, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
if err != nil {
return "", fmt.Errorf("failed to log message to %q: %v", ml.path, err)
return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
}
return msgID, nil
}
func (ml *messageLogger) Close() error {
if ml.file == nil {
return nil
func (ms *messageStore) Close() error {
var closeErr error
for _, f := range ms.files {
if err := f.Close(); err != nil {
closeErr = fmt.Errorf("failed to close message store: %v", err)
}
return ml.file.Close()
}
return closeErr
}
// formatMessage formats a message log line. It assumes a well-formed IRC
@ -233,8 +235,8 @@ func parseMessage(line, entity string, ref time.Time) (*irc.Message, time.Time,
return msg, t, nil
}
func parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) {
path := logPath(network, entity, ref)
func (ms *messageStore) parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
@ -289,8 +291,8 @@ func parseMessagesBefore(network *network, entity string, ref time.Time, limit i
}
}
func parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) {
path := logPath(network, entity, ref)
func (ms *messageStore) parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
@ -319,12 +321,12 @@ func parseMessagesAfter(network *network, entity string, ref time.Time, limit in
return history, nil
}
func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
history := make([]*irc.Message, limit)
remaining := limit
tries := 0
for remaining > 0 && tries < messageLoggerMaxTries {
buf, err := parseMessagesBefore(network, entity, t, remaining, -1)
for remaining > 0 && tries < messageStoreMaxTries {
buf, err := ms.parseMessagesBefore(network, entity, t, remaining, -1)
if err != nil {
return nil, err
}
@ -342,13 +344,13 @@ func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit i
return history[remaining:], nil
}
func loadHistoryAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
var history []*irc.Message
remaining := limit
tries := 0
now := time.Now()
for remaining > 0 && tries < messageLoggerMaxTries && t.Before(now) {
buf, err := parseMessagesAfter(network, entity, t, remaining)
for remaining > 0 && tries < messageStoreMaxTries && t.Before(now) {
buf, err := ms.parseMessagesAfter(network, entity, t, remaining)
if err != nil {
return nil, err
}
@ -370,7 +372,7 @@ func truncateDay(t time.Time) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
}
func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time
var afterOffset int64
if id != "" {
@ -389,13 +391,13 @@ func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc
t := time.Now()
remaining := limit
tries := 0
for remaining > 0 && tries < messageLoggerMaxTries && !truncateDay(t).Before(afterTime) {
for remaining > 0 && tries < messageStoreMaxTries && !truncateDay(t).Before(afterTime) {
var offset int64 = -1
if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
offset = afterOffset
}
buf, err := parseMessagesBefore(network, entity, t, remaining, offset)
buf, err := ms.parseMessagesBefore(network, entity, t, remaining, offset)
if err != nil {
return nil, err
}

View File

@ -81,8 +81,6 @@ type upstreamConn struct {
// set of LIST commands in progress, per downstream
pendingLISTDownstreamSet map[uint64]struct{}
messageLoggers map[string]*messageLogger
}
func connectToUpstream(network *network) (*upstreamConn, error) {
@ -182,7 +180,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
availableChannelModes: stdChannelModes,
availableMemberships: stdMemberships,
pendingLISTDownstreamSet: make(map[uint64]struct{}),
messageLoggers: make(map[string]*messageLogger),
}
return uc, nil
}
@ -1611,16 +1608,10 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
}
func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
if uc.srv.LogPath == "" {
if uc.user.msgStore == nil {
return
}
ml, ok := uc.messageLoggers[entity]
if !ok {
ml = newMessageLogger(uc.network, entity)
uc.messageLoggers[entity] = ml
}
detached := false
if ch, ok := uc.network.channels[entity]; ok {
detached = ch.Detached
@ -1628,7 +1619,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
history, ok := uc.network.history[entity]
if !ok {
lastID, err := lastMsgID(uc.network, entity, time.Now())
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now())
if err != nil {
uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
return
@ -1652,7 +1643,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
}
}
msgID, err := ml.Append(msg)
msgID, err := uc.user.msgStore.Append(uc.network, entity, msg)
if err != nil {
uc.logger.Printf("failed to log message: %v", err)
return

22
user.go
View File

@ -249,6 +249,7 @@ type user struct {
networks []*network
downstreamConns []*downstreamConn
msgStore *messageStore
// LIST commands in progress
pendingLISTs []pendingLIST
@ -261,11 +262,17 @@ type pendingLIST struct {
}
func newUser(srv *Server, record *User) *user {
var msgStore *messageStore
if srv.LogPath != "" {
msgStore = newMessageStore(srv.LogPath, record.Username)
}
return &user{
User: *record,
srv: srv,
events: make(chan event, 64),
done: make(chan struct{}),
msgStore: msgStore,
}
}
@ -312,7 +319,14 @@ func (u *user) getNetworkByID(id int64) *network {
}
func (u *user) run() {
defer close(u.done)
defer func() {
if u.msgStore != nil {
if err := u.msgStore.Close(); err != nil {
u.srv.Logger.Printf("failed to close message store for user %q: %v", u.Username, err)
}
}
close(u.done)
}()
networks, err := u.srv.db.ListNetworks(u.ID)
if err != nil {
@ -459,12 +473,6 @@ func (u *user) run() {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {