Use a dedicated goroutine to write upstream messages

This commit is contained in:
Simon Ser 2020-02-06 22:46:46 +01:00
parent 774872d655
commit 8493f5b255
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48

View File

@ -28,6 +28,7 @@ type upstreamConn struct {
net net.Conn net net.Conn
irc *irc.Conn irc *irc.Conn
srv *Server srv *Server
messages chan<- *irc.Message
serverName string serverName string
availableUserModes string availableUserModes string
@ -35,10 +36,54 @@ type upstreamConn struct {
channelModesWithParam string channelModesWithParam string
registered bool registered bool
closed bool
modes modeSet modes modeSet
channels map[string]*upstreamChannel channels map[string]*upstreamChannel
} }
func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) {
logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
logger.Printf("connecting to server")
netConn, err := tls.Dial("tcp", upstream.Addr, nil)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
}
msgs := make(chan *irc.Message, 64)
conn := &upstreamConn{
upstream: upstream,
logger: logger,
net: netConn,
irc: irc.NewConn(netConn),
srv: s,
messages: msgs,
channels: make(map[string]*upstreamChannel),
}
go func() {
for msg := range msgs {
if err := conn.irc.WriteMessage(msg); err != nil {
conn.logger.Printf("failed to write message: %v", err)
}
}
}()
return conn, nil
}
func (c *upstreamConn) Close() error {
if c.closed {
return fmt.Errorf("upstream connection already closed")
}
if err := c.net.Close(); err != nil {
return err
}
close(c.messages)
c.closed = true
return nil
}
func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := c.channels[name] ch, ok := c.channels[name]
if !ok { if !ok {
@ -51,10 +96,11 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
switch msg.Command { switch msg.Command {
case "PING": case "PING":
// TODO: handle params // TODO: handle params
return c.irc.WriteMessage(&irc.Message{ c.messages <- &irc.Message{
Command: "PONG", Command: "PONG",
Params: []string{c.srv.Hostname}, Params: []string{c.srv.Hostname},
}) }
return nil
case "MODE": case "MODE":
if len(msg.Params) < 2 { if len(msg.Params) < 2 {
return newNeedMoreParamsError(msg.Command) return newNeedMoreParamsError(msg.Command)
@ -70,12 +116,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
c.logger.Printf("connection registered") c.logger.Printf("connection registered")
for _, ch := range c.upstream.Channels { for _, ch := range c.upstream.Channels {
err := c.irc.WriteMessage(&irc.Message{ c.messages <- &irc.Message{
Command: "JOIN", Command: "JOIN",
Params: []string{ch}, Params: []string{ch},
})
if err != nil {
return err
} }
} }
case irc.RPL_MYINFO: case irc.RPL_MYINFO:
@ -191,22 +234,16 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
} }
func (c *upstreamConn) readMessages() error { func (c *upstreamConn) readMessages() error {
defer c.net.Close() defer c.Close()
err := c.irc.WriteMessage(&irc.Message{ c.messages <- &irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{c.upstream.Nick}, Params: []string{c.upstream.Nick},
})
if err != nil {
return err
} }
err = c.irc.WriteMessage(&irc.Message{ c.messages <- &irc.Message{
Command: "USER", Command: "USER",
Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname}, Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},
})
if err != nil {
return err
} }
for { for {
@ -222,24 +259,5 @@ func (c *upstreamConn) readMessages() error {
} }
} }
return c.net.Close() return c.Close()
}
func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) {
logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
logger.Printf("connecting to server")
netConn, err := tls.Dial("tcp", upstream.Addr, nil)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
}
return &upstreamConn{
upstream: upstream,
logger: logger,
net: netConn,
irc: irc.NewConn(netConn),
srv: s,
channels: make(map[string]*upstreamChannel),
}, nil
} }