Use a dedicated goroutine to write upstream messages
This commit is contained in:
parent
774872d655
commit
8493f5b255
88
upstream.go
88
upstream.go
@ -28,6 +28,7 @@ type upstreamConn struct {
|
|||||||
net net.Conn
|
net net.Conn
|
||||||
irc *irc.Conn
|
irc *irc.Conn
|
||||||
srv *Server
|
srv *Server
|
||||||
|
messages chan<- *irc.Message
|
||||||
|
|
||||||
serverName string
|
serverName string
|
||||||
availableUserModes string
|
availableUserModes string
|
||||||
@ -35,10 +36,54 @@ type upstreamConn struct {
|
|||||||
channelModesWithParam string
|
channelModesWithParam string
|
||||||
|
|
||||||
registered bool
|
registered bool
|
||||||
|
closed bool
|
||||||
modes modeSet
|
modes modeSet
|
||||||
channels map[string]*upstreamChannel
|
channels map[string]*upstreamChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) {
|
||||||
|
logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
|
||||||
|
logger.Printf("connecting to server")
|
||||||
|
|
||||||
|
netConn, err := tls.Dial("tcp", upstream.Addr, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := make(chan *irc.Message, 64)
|
||||||
|
conn := &upstreamConn{
|
||||||
|
upstream: upstream,
|
||||||
|
logger: logger,
|
||||||
|
net: netConn,
|
||||||
|
irc: irc.NewConn(netConn),
|
||||||
|
srv: s,
|
||||||
|
messages: msgs,
|
||||||
|
channels: make(map[string]*upstreamChannel),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for msg := range msgs {
|
||||||
|
if err := conn.irc.WriteMessage(msg); err != nil {
|
||||||
|
conn.logger.Printf("failed to write message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *upstreamConn) Close() error {
|
||||||
|
if c.closed {
|
||||||
|
return fmt.Errorf("upstream connection already closed")
|
||||||
|
}
|
||||||
|
if err := c.net.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
close(c.messages)
|
||||||
|
c.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
||||||
ch, ok := c.channels[name]
|
ch, ok := c.channels[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -51,10 +96,11 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case "PING":
|
case "PING":
|
||||||
// TODO: handle params
|
// TODO: handle params
|
||||||
return c.irc.WriteMessage(&irc.Message{
|
c.messages <- &irc.Message{
|
||||||
Command: "PONG",
|
Command: "PONG",
|
||||||
Params: []string{c.srv.Hostname},
|
Params: []string{c.srv.Hostname},
|
||||||
})
|
}
|
||||||
|
return nil
|
||||||
case "MODE":
|
case "MODE":
|
||||||
if len(msg.Params) < 2 {
|
if len(msg.Params) < 2 {
|
||||||
return newNeedMoreParamsError(msg.Command)
|
return newNeedMoreParamsError(msg.Command)
|
||||||
@ -70,12 +116,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
c.logger.Printf("connection registered")
|
c.logger.Printf("connection registered")
|
||||||
|
|
||||||
for _, ch := range c.upstream.Channels {
|
for _, ch := range c.upstream.Channels {
|
||||||
err := c.irc.WriteMessage(&irc.Message{
|
c.messages <- &irc.Message{
|
||||||
Command: "JOIN",
|
Command: "JOIN",
|
||||||
Params: []string{ch},
|
Params: []string{ch},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case irc.RPL_MYINFO:
|
case irc.RPL_MYINFO:
|
||||||
@ -191,22 +234,16 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *upstreamConn) readMessages() error {
|
func (c *upstreamConn) readMessages() error {
|
||||||
defer c.net.Close()
|
defer c.Close()
|
||||||
|
|
||||||
err := c.irc.WriteMessage(&irc.Message{
|
c.messages <- &irc.Message{
|
||||||
Command: "NICK",
|
Command: "NICK",
|
||||||
Params: []string{c.upstream.Nick},
|
Params: []string{c.upstream.Nick},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.irc.WriteMessage(&irc.Message{
|
c.messages <- &irc.Message{
|
||||||
Command: "USER",
|
Command: "USER",
|
||||||
Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},
|
Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -222,24 +259,5 @@ func (c *upstreamConn) readMessages() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.net.Close()
|
return c.Close()
|
||||||
}
|
|
||||||
|
|
||||||
func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) {
|
|
||||||
logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
|
|
||||||
logger.Printf("connecting to server")
|
|
||||||
|
|
||||||
netConn, err := tls.Dial("tcp", upstream.Addr, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &upstreamConn{
|
|
||||||
upstream: upstream,
|
|
||||||
logger: logger,
|
|
||||||
net: netConn,
|
|
||||||
irc: irc.NewConn(netConn),
|
|
||||||
srv: s,
|
|
||||||
channels: make(map[string]*upstreamChannel),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user