Add context arg to downstreamConn.welcome()

This commit is contained in:
Simon Ser 2021-11-17 12:33:30 +01:00
parent e459dcdb76
commit e28332a5aa
2 changed files with 10 additions and 10 deletions

View File

@ -1126,7 +1126,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
return nil return nil
} }
func (dc *downstreamConn) loadNetwork() error { func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
if dc.networkName == "" { if dc.networkName == "" {
return nil return nil
} }
@ -1139,7 +1139,7 @@ func (dc *downstreamConn) loadNetwork() error {
} }
dc.logger.Printf("trying to connect to new network %q", addr) dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(context.TODO(), addr); err != nil { if err := sanityCheckServer(ctx, addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err) dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
@ -1154,7 +1154,7 @@ func (dc *downstreamConn) loadNetwork() error {
dc.logger.Printf("auto-saving network %q", dc.networkName) dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error var err error
network, err = dc.user.createNetwork(context.TODO(), &Network{ network, err = dc.user.createNetwork(ctx, &Network{
Addr: dc.networkName, Addr: dc.networkName,
Nick: nick, Nick: nick,
Enabled: true, Enabled: true,
@ -1168,7 +1168,7 @@ func (dc *downstreamConn) loadNetwork() error {
return nil return nil
} }
func (dc *downstreamConn) welcome() error { func (dc *downstreamConn) welcome(ctx context.Context) error {
if dc.user == nil || !dc.registered { if dc.user == nil || !dc.registered {
panic("tried to welcome an unregistered connection") panic("tried to welcome an unregistered connection")
} }
@ -1176,7 +1176,7 @@ func (dc *downstreamConn) welcome() error {
// TODO: doing this might take some time. We should do it in dc.register // TODO: doing this might take some time. We should do it in dc.register
// instead, but we'll potentially be adding a new network and this must be // instead, but we'll potentially be adding a new network and this must be
// done in the user goroutine. // done in the user goroutine.
if err := dc.loadNetwork(); err != nil { if err := dc.loadNetwork(ctx); err != nil {
return err return err
} }
@ -1322,7 +1322,7 @@ func (dc *downstreamConn) welcome() error {
return return
} }
dc.sendTargetBacklog(net, target, lastDelivered) dc.sendTargetBacklog(ctx, net, target, lastDelivered)
// Fast-forward history to last message // Fast-forward history to last message
targetCM := net.casemap(target) targetCM := net.casemap(target)
@ -1352,14 +1352,14 @@ func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
return false return false
} }
func (dc *downstreamConn) sendTargetBacklog(net *network, target, msgID string) { func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
return return
} }
ch := net.channels.Value(target) ch := net.channels.Value(target)
ctx, cancel := context.WithTimeout(context.TODO(), backlogTimeout) ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
defer cancel() defer cancel()
targetCM := net.casemap(target) targetCM := net.casemap(target)

View File

@ -314,7 +314,7 @@ func (net *network) attach(ch *Channel) {
} }
if detachedMsgID != "" { if detachedMsgID != "" {
dc.sendTargetBacklog(net, ch.Name, detachedMsgID) dc.sendTargetBacklog(context.TODO(), net, ch.Name, detachedMsgID)
} }
}) })
} }
@ -597,7 +597,7 @@ func (u *user) run() {
dc.monitored.SetCasemapping(dc.network.casemap) dc.monitored.SetCasemapping(dc.network.casemap)
} }
if err := dc.welcome(); err != nil { if err := dc.welcome(context.TODO()); err != nil {
dc.logger.Printf("failed to handle new registered connection: %v", err) dc.logger.Printf("failed to handle new registered connection: %v", err)
break break
} }