Introduce conn for common connection logic

This centralizes the common upstream & downstream bits.
This commit is contained in:
Simon Ser 2020-04-03 16:34:11 +02:00
parent 8c6328207b
commit 2a0696b6bb
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
4 changed files with 122 additions and 154 deletions

109
conn.go Normal file
View File

@ -0,0 +1,109 @@
package soju
import (
"fmt"
"net"
"time"
"gopkg.in/irc.v3"
)
func setKeepAlive(c net.Conn) error {
tcpConn, ok := c.(*net.TCPConn)
if !ok {
return fmt.Errorf("cannot enable keep-alive on a non-TCP connection")
}
if err := tcpConn.SetKeepAlive(true); err != nil {
return err
}
return tcpConn.SetKeepAlivePeriod(keepAlivePeriod)
}
type conn struct {
net net.Conn
irc *irc.Conn
srv *Server
logger Logger
outgoing chan<- *irc.Message
closed chan struct{}
}
func newConn(srv *Server, netConn net.Conn, logger Logger) *conn {
setKeepAlive(netConn)
outgoing := make(chan *irc.Message, 64)
c := &conn{
net: netConn,
irc: irc.NewConn(netConn),
srv: srv,
outgoing: outgoing,
logger: logger,
closed: make(chan struct{}),
}
go func() {
for msg := range outgoing {
if c.srv.Debug {
c.logger.Printf("sent: %v", msg)
}
c.net.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := c.irc.WriteMessage(msg); err != nil {
c.logger.Printf("failed to write message: %v", err)
break
}
}
if err := c.net.Close(); err != nil {
c.logger.Printf("failed to close connection: %v", err)
} else {
c.logger.Printf("connection closed")
}
// Drain the outgoing channel to prevent SendMessage from blocking
for range outgoing {
// This space is intentionally left blank
}
}()
c.logger.Printf("new connection")
return c
}
func (c *conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// Close closes the connection. It is safe to call from any goroutine.
func (c *conn) Close() error {
if c.isClosed() {
return fmt.Errorf("connection already closed")
}
close(c.closed)
close(c.outgoing)
return nil
}
func (c *conn) ReadMessage() (*irc.Message, error) {
msg, err := c.irc.ReadMessage()
if err != nil {
return nil, err
}
if c.srv.Debug {
c.logger.Printf("received: %v", msg)
}
return msg, nil
}
// SendMessage queues a new outgoing message. It is safe to call from any
// goroutine.
func (c *conn) SendMessage(msg *irc.Message) {
if c.isClosed() {
return
}
c.outgoing <- msg
}

View File

@ -52,13 +52,9 @@ var errAuthFailed = ircError{&irc.Message{
}}
type downstreamConn struct {
conn
id uint64
net net.Conn
irc *irc.Conn
srv *Server
logger Logger
outgoing chan<- *irc.Message
closed chan struct{}
registered bool
user *user
@ -84,15 +80,10 @@ type downstreamConn struct {
}
func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn {
outgoing := make(chan *irc.Message, 64)
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}
dc := &downstreamConn{
conn: *newConn(srv, netConn, logger),
id: id,
net: netConn,
irc: irc.NewConn(netConn),
srv: srv,
logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
outgoing: outgoing,
closed: make(chan struct{}),
ringConsumers: make(map[*network]*RingConsumer),
caps: make(map[string]bool),
ourMessages: make(map[*irc.Message]struct{}),
@ -101,30 +92,6 @@ func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn
if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
dc.hostname = host
}
go func() {
for msg := range outgoing {
if dc.srv.Debug {
dc.logger.Printf("sent: %v", msg)
}
dc.net.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := dc.irc.WriteMessage(msg); err != nil {
dc.logger.Printf("failed to write message: %v", err)
break
}
}
if err := dc.net.Close(); err != nil {
dc.logger.Printf("failed to close connection: %v", err)
} else {
dc.logger.Printf("connection closed")
}
// Drain the outgoing channel to prevent SendMessage from blocking
for range outgoing {
// This space is intentionally left blank
}
}()
dc.logger.Printf("new connection")
return dc
}
@ -227,56 +194,24 @@ func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix
}
}
func (dc *downstreamConn) isClosed() bool {
select {
case <-dc.closed:
return true
default:
return false
}
}
func (dc *downstreamConn) readMessages(ch chan<- event) error {
for {
msg, err := dc.irc.ReadMessage()
msg, err := dc.ReadMessage()
if err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("failed to read IRC command: %v", err)
}
if dc.srv.Debug {
dc.logger.Printf("received: %v", msg)
}
ch <- eventDownstreamMessage{msg, dc}
}
return nil
}
func (dc *downstreamConn) writeMessages() error {
return nil
}
// Close closes the connection. It is safe to call from any goroutine.
func (dc *downstreamConn) Close() error {
if dc.isClosed() {
return fmt.Errorf("downstream connection already closed")
}
close(dc.closed)
close(dc.outgoing)
return nil
}
// SendMessage queues a new outgoing message. It is safe to call from any
// goroutine.
func (dc *downstreamConn) SendMessage(msg *irc.Message) {
if dc.isClosed() {
return
}
// TODO: strip tags if the client doesn't support them (see runNetwork)
dc.outgoing <- msg
dc.conn.SendMessage(msg)
}
func (dc *downstreamConn) handleMessage(msg *irc.Message) error {

View File

@ -16,17 +16,6 @@ var retryConnectMinDelay = time.Minute
var connectTimeout = 15 * time.Second
var writeTimeout = 10 * time.Second
func setKeepAlive(c net.Conn) error {
tcpConn, ok := c.(*net.TCPConn)
if !ok {
return fmt.Errorf("cannot enable keep-alive on a non-TCP connection")
}
if err := tcpConn.SetKeepAlive(true); err != nil {
return err
}
return tcpConn.SetKeepAlivePeriod(keepAlivePeriod)
}
type Logger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
@ -109,8 +98,6 @@ func (s *Server) Serve(ln net.Listener) error {
return fmt.Errorf("failed to accept connection: %v", err)
}
setKeepAlive(netConn)
dc := newDownstreamConn(s, netConn, nextDownstreamID)
nextDownstreamID++
go func() {

View File

@ -31,14 +31,10 @@ type upstreamChannel struct {
}
type upstreamConn struct {
conn
network *network
logger Logger
net net.Conn
irc *irc.Conn
srv *Server
user *user
outgoing chan<- *irc.Message
closed chan struct{}
serverName string
availableUserModes string
@ -90,18 +86,10 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
setKeepAlive(netConn)
outgoing := make(chan *irc.Message, 64)
uc := &upstreamConn{
conn: *newConn(network.user.srv, netConn, logger),
network: network,
logger: logger,
net: netConn,
irc: irc.NewConn(netConn),
srv: network.user.srv,
user: network.user,
outgoing: outgoing,
closed: make(chan struct{}),
channels: make(map[string]*upstreamChannel),
caps: make(map[string]string),
batches: make(map[string]batch),
@ -112,50 +100,9 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
logs: make(map[string]entityLog),
}
go func() {
for msg := range outgoing {
if uc.srv.Debug {
uc.logger.Printf("sent: %v", msg)
}
uc.net.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := uc.irc.WriteMessage(msg); err != nil {
uc.logger.Printf("failed to write message: %v", err)
break
}
}
if err := uc.net.Close(); err != nil {
uc.logger.Printf("failed to close connection: %v", err)
} else {
uc.logger.Printf("connection closed")
}
// Drain the outgoing channel to prevent SendMessage from blocking
for range outgoing {
// This space is intentionally left blank
}
}()
return uc, nil
}
func (uc *upstreamConn) isClosed() bool {
select {
case <-uc.closed:
return true
default:
return false
}
}
// Close closes the connection. It is safe to call from any goroutine.
func (uc *upstreamConn) Close() error {
if uc.isClosed() {
return fmt.Errorf("upstream connection already closed")
}
close(uc.closed)
close(uc.outgoing)
return nil
}
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
uc.user.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network != uc.network {
@ -1409,29 +1356,19 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
func (uc *upstreamConn) readMessages(ch chan<- event) error {
for {
msg, err := uc.irc.ReadMessage()
msg, err := uc.ReadMessage()
if err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("failed to read IRC command: %v", err)
}
if uc.srv.Debug {
uc.logger.Printf("received: %v", msg)
}
ch <- eventUpstreamMessage{msg, uc}
}
return nil
}
// SendMessage queues a new outgoing message. It is safe to call from any
// goroutine.
func (uc *upstreamConn) SendMessage(msg *irc.Message) {
uc.outgoing <- msg
}
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
if uc.labelsSupported {
if msg.Tags == nil {