From 36c404c50c04b790bca29740bc8b7eb2fbf17a96 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 6 Feb 2020 21:20:22 +0100 Subject: [PATCH] Allow Server to have access to upstreamConn --- downstream.go | 14 +++++++------- server.go | 11 ++++++++--- upstream.go | 47 +++++++++++++++++++++++++---------------------- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/downstream.go b/downstream.go index 828595a..0686e97 100644 --- a/downstream.go +++ b/downstream.go @@ -39,10 +39,10 @@ func (err ircError) Error() string { } type downstreamConn struct { - net net.Conn - irc *irc.Conn - srv *Server - logger Logger + net net.Conn + irc *irc.Conn + srv *Server + logger Logger registered bool closed bool @@ -53,9 +53,9 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { return &downstreamConn{ - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, } } diff --git a/server.go b/server.go index 986ac65..92c7820 100644 --- a/server.go +++ b/server.go @@ -54,8 +54,13 @@ func (s *Server) Run() { upstream := &s.Upstreams[i] // TODO: retry connecting go func() { - if err := connect(s, upstream); err != nil { - s.Logger.Printf("Failed to connect to upstream server %q: %v", upstream.Addr, err) + conn, err := connectToUpstream(s, upstream) + if err != nil { + s.Logger.Printf("failed to connect to upstream server %q: %v", upstream.Addr, err) + return + } + if err := conn.readMessages(); err != nil { + conn.logger.Printf("failed to handle messages: %v", err) } }() } @@ -72,7 +77,7 @@ func (s *Server) Serve(ln net.Listener) error { s.downstreamConns = append(s.downstreamConns, conn) go func() { if err := conn.readMessages(); err != nil { - conn.logger.Printf("Error handling messages: %v", err) + conn.logger.Printf("failed to handle messages: %v", err) } }() } diff --git a/upstream.go b/upstream.go index e56d242..10e0f2b 100644 --- a/upstream.go +++ b/upstream.go @@ -175,28 +175,12 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { return nil } -func connect(s *Server, upstream *Upstream) error { - logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)} - logger.Printf("connecting to server") +func (c *upstreamConn) readMessages() error { + defer c.net.Close() - netConn, err := tls.Dial("tcp", upstream.Addr, nil) - if err != nil { - return fmt.Errorf("failed to dial %q: %v", upstream.Addr, err) - } - - c := upstreamConn{ - upstream: upstream, - logger: logger, - net: netConn, - irc: irc.NewConn(netConn), - srv: s, - channels: make(map[string]*upstreamChannel), - } - defer netConn.Close() - - err = c.irc.WriteMessage(&irc.Message{ + err := c.irc.WriteMessage(&irc.Message{ Command: "NICK", - Params: []string{upstream.Nick}, + Params: []string{c.upstream.Nick}, }) if err != nil { return err @@ -204,7 +188,7 @@ func connect(s *Server, upstream *Upstream) error { err = c.irc.WriteMessage(&irc.Message{ Command: "USER", - Params: []string{upstream.Username, "0", "*", upstream.Realname}, + Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname}, }) if err != nil { return err @@ -223,5 +207,24 @@ func connect(s *Server, upstream *Upstream) error { } } - return netConn.Close() + return c.net.Close() +} + +func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) { + logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)} + logger.Printf("connecting to server") + + netConn, err := tls.Dial("tcp", upstream.Addr, nil) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err) + } + + return &upstreamConn{ + upstream: upstream, + logger: logger, + net: netConn, + irc: irc.NewConn(netConn), + srv: s, + channels: make(map[string]*upstreamChannel), + }, nil }