From 45d118dd12182414db7c32914559b0d89d9e5db3 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Fri, 20 Mar 2020 22:48:17 +0100 Subject: [PATCH] Move upstreamConn.history to network --- downstream.go | 12 ++++++------ upstream.go | 5 ----- user.go | 17 ++++++++++++----- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/downstream.go b/downstream.go index 509a8ed..10b38bc 100644 --- a/downstream.go +++ b/downstream.go @@ -708,9 +708,9 @@ func (dc *downstreamConn) register() error { var seqPtr *uint64 if firstDownstream { - uc.lock.Lock() - seq, ok := uc.history[historyName] - uc.lock.Unlock() + uc.network.lock.Lock() + seq, ok := uc.network.history[historyName] + uc.network.lock.Unlock() if ok { seqPtr = &seq } @@ -738,9 +738,9 @@ func (dc *downstreamConn) register() error { dc.user.lock.Unlock() if lastDownstream { - uc.lock.Lock() - uc.history[historyName] = seq - uc.lock.Unlock() + uc.network.lock.Lock() + uc.network.history[historyName] = seq + uc.network.lock.Unlock() } }() }) diff --git a/upstream.go b/upstream.go index f6be1f4..e35574b 100644 --- a/upstream.go +++ b/upstream.go @@ -8,7 +8,6 @@ import ( "net" "strconv" "strings" - "sync" "time" "github.com/emersion/go-sasl" @@ -53,9 +52,6 @@ type upstreamConn struct { saslClient sasl.Client saslStarted bool - - lock sync.Mutex - history map[string]uint64 // TODO: move to network } func connectToUpstream(network *network) (*upstreamConn, error) { @@ -85,7 +81,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) { outgoing: outgoing, ring: NewRing(network.user.srv.RingCap), channels: make(map[string]*upstreamChannel), - history: make(map[string]uint64), caps: make(map[string]string), } diff --git a/user.go b/user.go index 763f4ff..d896fc3 100644 --- a/user.go +++ b/user.go @@ -20,13 +20,17 @@ type downstreamIncomingMessage struct { type network struct { Network user *user - conn *upstreamConn + + lock sync.Mutex + conn *upstreamConn + history map[string]uint64 } func newNetwork(user *user, record *Network) *network { return &network{ Network: *record, user: user, + history: make(map[string]uint64), } } @@ -48,18 +52,18 @@ func (net *network) run() { uc.register() - net.user.lock.Lock() + net.lock.Lock() net.conn = uc - net.user.lock.Unlock() + net.lock.Unlock() if err := uc.readMessages(net.user.upstreamIncoming); err != nil { uc.logger.Printf("failed to handle messages: %v", err) } uc.Close() - net.user.lock.Lock() + net.lock.Lock() net.conn = nil - net.user.lock.Unlock() + net.lock.Unlock() } } @@ -95,7 +99,10 @@ func (u *user) forEachNetwork(f func(*network)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) { u.lock.Lock() for _, network := range u.networks { + network.lock.Lock() uc := network.conn + network.lock.Unlock() + if uc == nil || !uc.registered || uc.closed { continue }