Add message store abstraction

Introduce a messageStore type, which will allow for multiple
implementations (e.g. in the DB or in-memory instead of on-disk).

The message store is per-user so that we don't need to deal with locking
and it's easier to implement per-user limits.
This commit is contained in:
Simon Ser 2020-10-25 11:13:51 +01:00
parent af1e578936
commit 05aafb5edf
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
4 changed files with 83 additions and 82 deletions

View File

@ -863,7 +863,7 @@ func (dc *downstreamConn) welcome() error {
continue continue
} }
lastID, err := lastMsgID(net, target, time.Now()) lastID, err := dc.user.msgStore.LastMsgID(net, target, time.Now())
if err != nil { if err != nil {
dc.logger.Printf("failed to get last message ID: %v", err) dc.logger.Printf("failed to get last message ID: %v", err)
continue continue
@ -876,7 +876,7 @@ func (dc *downstreamConn) welcome() error {
} }
func (dc *downstreamConn) sendNetworkHistory(net *network) { func (dc *downstreamConn) sendNetworkHistory(net *network) {
if dc.caps["draft/chathistory"] || dc.srv.LogPath == "" { if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
return return
} }
for target, history := range net.history { for target, history := range net.history {
@ -890,7 +890,7 @@ func (dc *downstreamConn) sendNetworkHistory(net *network) {
} }
limit := 4000 limit := 4000
history, err := loadHistoryLatestID(net, target, lastDelivered, limit) history, err := dc.user.msgStore.LoadLatestID(net, target, lastDelivered, limit)
if err != nil { if err != nil {
dc.logger.Printf("failed to send implicit history for %q: %v", target, err) dc.logger.Printf("failed to send implicit history for %q: %v", target, err)
continue continue
@ -1601,7 +1601,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}} }}
} }
if dc.srv.LogPath == "" { if dc.user.msgStore == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND, Command: irc.ERR_UNKNOWNCOMMAND,
Params: []string{dc.nick, subcommand, "Unknown command"}, Params: []string{dc.nick, subcommand, "Unknown command"},
@ -1641,9 +1641,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
var history []*irc.Message var history []*irc.Message
switch subcommand { switch subcommand {
case "BEFORE": case "BEFORE":
history, err = loadHistoryBeforeTime(uc.network, entity, timestamp, limit) history, err = dc.user.msgStore.LoadBeforeTime(uc.network, entity, timestamp, limit)
case "AFTER": case "AFTER":
history, err = loadHistoryAfterTime(uc.network, entity, timestamp, limit) history, err = dc.user.msgStore.LoadAfterTime(uc.network, entity, timestamp, limit)
default: default:
// TODO: support LATEST, BETWEEN // TODO: support LATEST, BETWEEN
return ircError{&irc.Message{ return ircError{&irc.Message{

View File

@ -12,32 +12,28 @@ import (
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
) )
const messageLoggerMaxTries = 100 const messageStoreMaxTries = 100
type messageLogger struct {
network *network
entity string
path string
file *os.File
}
func newMessageLogger(network *network, entity string) *messageLogger {
return &messageLogger{
network: network,
entity: entity,
}
}
var escapeFilename = strings.NewReplacer("/", "-", "\\", "-") var escapeFilename = strings.NewReplacer("/", "-", "\\", "-")
func logPath(network *network, entity string, t time.Time) string { // messageStore is a per-user store for IRC messages.
user := network.user type messageStore struct {
srv := user.srv root string
files map[string]*os.File // indexed by entity
}
func newMessageStore(root, username string) *messageStore {
return &messageStore{
root: filepath.Join(root, escapeFilename.Replace(username)),
files: make(map[string]*os.File),
}
}
func (ms *messageStore) logPath(network *network, entity string, t time.Time) string {
year, month, day := t.Date() year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(srv.LogPath, escapeFilename.Replace(user.Username), escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename) return filepath.Join(ms.root, escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename)
} }
func parseMsgID(s string) (network, entity string, t time.Time, offset int64, err error) { func parseMsgID(s string) (network, entity string, t time.Time, offset int64, err error) {
@ -64,11 +60,11 @@ func nextMsgID(network *network, entity string, t time.Time, f *os.File) (string
return formatMsgID(network.GetName(), entity, t, offset), nil return formatMsgID(network.GetName(), entity, t, offset), nil
} }
// lastMsgID queries the last message ID for the given network, entity and // LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be // date. The message ID returned may not refer to a valid message, but can be
// used in history queries. // used in history queries.
func lastMsgID(network *network, entity string, t time.Time) (string, error) { func (ms *messageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) {
p := logPath(network, entity, t) p := ms.logPath(network, entity, t)
fi, err := os.Stat(p) fi, err := os.Stat(p)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return formatMsgID(network.GetName(), entity, t, -1), nil return formatMsgID(network.GetName(), entity, t, -1), nil
@ -78,7 +74,7 @@ func lastMsgID(network *network, entity string, t time.Time) (string, error) {
return formatMsgID(network.GetName(), entity, t, fi.Size()-1), nil return formatMsgID(network.GetName(), entity, t, fi.Size()-1), nil
} }
func (ml *messageLogger) Append(msg *irc.Message) (string, error) { func (ms *messageStore) Append(network *network, entity string, msg *irc.Message) (string, error) {
s := formatMessage(msg) s := formatMessage(msg)
if s == "" { if s == "" {
return "", nil return "", nil
@ -97,44 +93,50 @@ func (ml *messageLogger) Append(msg *irc.Message) (string, error) {
} }
// TODO: enforce maximum open file handles (LRU cache of file handles) // TODO: enforce maximum open file handles (LRU cache of file handles)
f := ms.files[entity]
// TODO: handle non-monotonic clock behaviour // TODO: handle non-monotonic clock behaviour
path := logPath(ml.network, ml.entity, t) path := ms.logPath(network, entity, t)
if ml.path != path { if f == nil || f.Name() != path {
if ml.file != nil { if f != nil {
ml.file.Close() f.Close()
} }
dir := filepath.Dir(path) dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil { if err := os.MkdirAll(dir, 0700); err != nil {
return "", fmt.Errorf("failed to create logs directory %q: %v", dir, err) return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
} }
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) var err error
f, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to open log file %q: %v", path, err) return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
} }
ml.path = path ms.files[entity] = f
ml.file = f
} }
msgID, err := nextMsgID(ml.network, ml.entity, t, ml.file) msgID, err := nextMsgID(network, entity, t, f)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to generate message ID: %v", err) return "", fmt.Errorf("failed to generate message ID: %v", err)
} }
_, err = fmt.Fprintf(ml.file, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to log message to %q: %v", ml.path, err) return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
} }
return msgID, nil return msgID, nil
} }
func (ml *messageLogger) Close() error { func (ms *messageStore) Close() error {
if ml.file == nil { var closeErr error
return nil for _, f := range ms.files {
if err := f.Close(); err != nil {
closeErr = fmt.Errorf("failed to close message store: %v", err)
} }
return ml.file.Close() }
return closeErr
} }
// formatMessage formats a message log line. It assumes a well-formed IRC // formatMessage formats a message log line. It assumes a well-formed IRC
@ -233,8 +235,8 @@ func parseMessage(line, entity string, ref time.Time) (*irc.Message, time.Time,
return msg, t, nil return msg, t, nil
} }
func parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) { func (ms *messageStore) parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) {
path := logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -289,8 +291,8 @@ func parseMessagesBefore(network *network, entity string, ref time.Time, limit i
} }
} }
func parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) { func (ms *messageStore) parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) {
path := logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -319,12 +321,12 @@ func parseMessagesAfter(network *network, entity string, ref time.Time, limit in
return history, nil return history, nil
} }
func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { func (ms *messageStore) LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
history := make([]*irc.Message, limit) history := make([]*irc.Message, limit)
remaining := limit remaining := limit
tries := 0 tries := 0
for remaining > 0 && tries < messageLoggerMaxTries { for remaining > 0 && tries < messageStoreMaxTries {
buf, err := parseMessagesBefore(network, entity, t, remaining, -1) buf, err := ms.parseMessagesBefore(network, entity, t, remaining, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,13 +344,13 @@ func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit i
return history[remaining:], nil return history[remaining:], nil
} }
func loadHistoryAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { func (ms *messageStore) LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
var history []*irc.Message var history []*irc.Message
remaining := limit remaining := limit
tries := 0 tries := 0
now := time.Now() now := time.Now()
for remaining > 0 && tries < messageLoggerMaxTries && t.Before(now) { for remaining > 0 && tries < messageStoreMaxTries && t.Before(now) {
buf, err := parseMessagesAfter(network, entity, t, remaining) buf, err := ms.parseMessagesAfter(network, entity, t, remaining)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -370,7 +372,7 @@ func truncateDay(t time.Time) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
} }
func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *messageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -389,13 +391,13 @@ func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc
t := time.Now() t := time.Now()
remaining := limit remaining := limit
tries := 0 tries := 0
for remaining > 0 && tries < messageLoggerMaxTries && !truncateDay(t).Before(afterTime) { for remaining > 0 && tries < messageStoreMaxTries && !truncateDay(t).Before(afterTime) {
var offset int64 = -1 var offset int64 = -1
if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
offset = afterOffset offset = afterOffset
} }
buf, err := parseMessagesBefore(network, entity, t, remaining, offset) buf, err := ms.parseMessagesBefore(network, entity, t, remaining, offset)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -81,8 +81,6 @@ type upstreamConn struct {
// set of LIST commands in progress, per downstream // set of LIST commands in progress, per downstream
pendingLISTDownstreamSet map[uint64]struct{} pendingLISTDownstreamSet map[uint64]struct{}
messageLoggers map[string]*messageLogger
} }
func connectToUpstream(network *network) (*upstreamConn, error) { func connectToUpstream(network *network) (*upstreamConn, error) {
@ -182,7 +180,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
availableMemberships: stdMemberships, availableMemberships: stdMemberships,
pendingLISTDownstreamSet: make(map[uint64]struct{}), pendingLISTDownstreamSet: make(map[uint64]struct{}),
messageLoggers: make(map[string]*messageLogger),
} }
return uc, nil return uc, nil
} }
@ -1611,16 +1608,10 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
} }
func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) { func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
if uc.srv.LogPath == "" { if uc.user.msgStore == nil {
return return
} }
ml, ok := uc.messageLoggers[entity]
if !ok {
ml = newMessageLogger(uc.network, entity)
uc.messageLoggers[entity] = ml
}
detached := false detached := false
if ch, ok := uc.network.channels[entity]; ok { if ch, ok := uc.network.channels[entity]; ok {
detached = ch.Detached detached = ch.Detached
@ -1628,7 +1619,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
history, ok := uc.network.history[entity] history, ok := uc.network.history[entity]
if !ok { if !ok {
lastID, err := lastMsgID(uc.network, entity, time.Now()) lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now())
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
return return
@ -1652,7 +1643,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
} }
} }
msgID, err := ml.Append(msg) msgID, err := uc.user.msgStore.Append(uc.network, entity, msg)
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: %v", err) uc.logger.Printf("failed to log message: %v", err)
return return

22
user.go
View File

@ -249,6 +249,7 @@ type user struct {
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
msgStore *messageStore
// LIST commands in progress // LIST commands in progress
pendingLISTs []pendingLIST pendingLISTs []pendingLIST
@ -261,11 +262,17 @@ type pendingLIST struct {
} }
func newUser(srv *Server, record *User) *user { func newUser(srv *Server, record *User) *user {
var msgStore *messageStore
if srv.LogPath != "" {
msgStore = newMessageStore(srv.LogPath, record.Username)
}
return &user{ return &user{
User: *record, User: *record,
srv: srv, srv: srv,
events: make(chan event, 64), events: make(chan event, 64),
done: make(chan struct{}), done: make(chan struct{}),
msgStore: msgStore,
} }
} }
@ -312,7 +319,14 @@ func (u *user) getNetworkByID(id int64) *network {
} }
func (u *user) run() { func (u *user) run() {
defer close(u.done) defer func() {
if u.msgStore != nil {
if err := u.msgStore.Close(); err != nil {
u.srv.Logger.Printf("failed to close message store for user %q: %v", u.Username, err)
}
}
close(u.done)
}()
networks, err := u.srv.db.ListNetworks(u.ID) networks, err := u.srv.db.ListNetworks(u.ID)
if err != nil { if err != nil {
@ -459,12 +473,6 @@ func (u *user) run() {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.network.conn = nil uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true) uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {