Introduce conn for common connection logic
This centralizes the common upstream & downstream bits.
This commit is contained in:
parent
8c6328207b
commit
2a0696b6bb
109
conn.go
Normal file
109
conn.go
Normal file
@ -0,0 +1,109 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"gopkg.in/irc.v3"
|
||||
)
|
||||
|
||||
func setKeepAlive(c net.Conn) error {
|
||||
tcpConn, ok := c.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot enable keep-alive on a non-TCP connection")
|
||||
}
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
return tcpConn.SetKeepAlivePeriod(keepAlivePeriod)
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
net net.Conn
|
||||
irc *irc.Conn
|
||||
srv *Server
|
||||
logger Logger
|
||||
outgoing chan<- *irc.Message
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newConn(srv *Server, netConn net.Conn, logger Logger) *conn {
|
||||
setKeepAlive(netConn)
|
||||
|
||||
outgoing := make(chan *irc.Message, 64)
|
||||
c := &conn{
|
||||
net: netConn,
|
||||
irc: irc.NewConn(netConn),
|
||||
srv: srv,
|
||||
outgoing: outgoing,
|
||||
logger: logger,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for msg := range outgoing {
|
||||
if c.srv.Debug {
|
||||
c.logger.Printf("sent: %v", msg)
|
||||
}
|
||||
c.net.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := c.irc.WriteMessage(msg); err != nil {
|
||||
c.logger.Printf("failed to write message: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := c.net.Close(); err != nil {
|
||||
c.logger.Printf("failed to close connection: %v", err)
|
||||
} else {
|
||||
c.logger.Printf("connection closed")
|
||||
}
|
||||
// Drain the outgoing channel to prevent SendMessage from blocking
|
||||
for range outgoing {
|
||||
// This space is intentionally left blank
|
||||
}
|
||||
}()
|
||||
|
||||
c.logger.Printf("new connection")
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *conn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection. It is safe to call from any goroutine.
|
||||
func (c *conn) Close() error {
|
||||
if c.isClosed() {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
close(c.closed)
|
||||
close(c.outgoing)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) ReadMessage() (*irc.Message, error) {
|
||||
msg, err := c.irc.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.srv.Debug {
|
||||
c.logger.Printf("received: %v", msg)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// SendMessage queues a new outgoing message. It is safe to call from any
|
||||
// goroutine.
|
||||
func (c *conn) SendMessage(msg *irc.Message) {
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
c.outgoing <- msg
|
||||
}
|
@ -52,13 +52,9 @@ var errAuthFailed = ircError{&irc.Message{
|
||||
}}
|
||||
|
||||
type downstreamConn struct {
|
||||
id uint64
|
||||
net net.Conn
|
||||
irc *irc.Conn
|
||||
srv *Server
|
||||
logger Logger
|
||||
outgoing chan<- *irc.Message
|
||||
closed chan struct{}
|
||||
conn
|
||||
|
||||
id uint64
|
||||
|
||||
registered bool
|
||||
user *user
|
||||
@ -84,15 +80,10 @@ type downstreamConn struct {
|
||||
}
|
||||
|
||||
func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn {
|
||||
outgoing := make(chan *irc.Message, 64)
|
||||
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}
|
||||
dc := &downstreamConn{
|
||||
conn: *newConn(srv, netConn, logger),
|
||||
id: id,
|
||||
net: netConn,
|
||||
irc: irc.NewConn(netConn),
|
||||
srv: srv,
|
||||
logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
|
||||
outgoing: outgoing,
|
||||
closed: make(chan struct{}),
|
||||
ringConsumers: make(map[*network]*RingConsumer),
|
||||
caps: make(map[string]bool),
|
||||
ourMessages: make(map[*irc.Message]struct{}),
|
||||
@ -101,30 +92,6 @@ func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn
|
||||
if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
|
||||
dc.hostname = host
|
||||
}
|
||||
|
||||
go func() {
|
||||
for msg := range outgoing {
|
||||
if dc.srv.Debug {
|
||||
dc.logger.Printf("sent: %v", msg)
|
||||
}
|
||||
dc.net.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := dc.irc.WriteMessage(msg); err != nil {
|
||||
dc.logger.Printf("failed to write message: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := dc.net.Close(); err != nil {
|
||||
dc.logger.Printf("failed to close connection: %v", err)
|
||||
} else {
|
||||
dc.logger.Printf("connection closed")
|
||||
}
|
||||
// Drain the outgoing channel to prevent SendMessage from blocking
|
||||
for range outgoing {
|
||||
// This space is intentionally left blank
|
||||
}
|
||||
}()
|
||||
|
||||
dc.logger.Printf("new connection")
|
||||
return dc
|
||||
}
|
||||
|
||||
@ -227,56 +194,24 @@ func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) isClosed() bool {
|
||||
select {
|
||||
case <-dc.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) readMessages(ch chan<- event) error {
|
||||
for {
|
||||
msg, err := dc.irc.ReadMessage()
|
||||
msg, err := dc.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to read IRC command: %v", err)
|
||||
}
|
||||
|
||||
if dc.srv.Debug {
|
||||
dc.logger.Printf("received: %v", msg)
|
||||
}
|
||||
|
||||
ch <- eventDownstreamMessage{msg, dc}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) writeMessages() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection. It is safe to call from any goroutine.
|
||||
func (dc *downstreamConn) Close() error {
|
||||
if dc.isClosed() {
|
||||
return fmt.Errorf("downstream connection already closed")
|
||||
}
|
||||
close(dc.closed)
|
||||
close(dc.outgoing)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage queues a new outgoing message. It is safe to call from any
|
||||
// goroutine.
|
||||
func (dc *downstreamConn) SendMessage(msg *irc.Message) {
|
||||
if dc.isClosed() {
|
||||
return
|
||||
}
|
||||
// TODO: strip tags if the client doesn't support them (see runNetwork)
|
||||
dc.outgoing <- msg
|
||||
dc.conn.SendMessage(msg)
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
|
||||
|
13
server.go
13
server.go
@ -16,17 +16,6 @@ var retryConnectMinDelay = time.Minute
|
||||
var connectTimeout = 15 * time.Second
|
||||
var writeTimeout = 10 * time.Second
|
||||
|
||||
func setKeepAlive(c net.Conn) error {
|
||||
tcpConn, ok := c.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot enable keep-alive on a non-TCP connection")
|
||||
}
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
return tcpConn.SetKeepAlivePeriod(keepAlivePeriod)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Print(v ...interface{})
|
||||
Printf(format string, v ...interface{})
|
||||
@ -109,8 +98,6 @@ func (s *Server) Serve(ln net.Listener) error {
|
||||
return fmt.Errorf("failed to accept connection: %v", err)
|
||||
}
|
||||
|
||||
setKeepAlive(netConn)
|
||||
|
||||
dc := newDownstreamConn(s, netConn, nextDownstreamID)
|
||||
nextDownstreamID++
|
||||
go func() {
|
||||
|
75
upstream.go
75
upstream.go
@ -31,14 +31,10 @@ type upstreamChannel struct {
|
||||
}
|
||||
|
||||
type upstreamConn struct {
|
||||
network *network
|
||||
logger Logger
|
||||
net net.Conn
|
||||
irc *irc.Conn
|
||||
srv *Server
|
||||
user *user
|
||||
outgoing chan<- *irc.Message
|
||||
closed chan struct{}
|
||||
conn
|
||||
|
||||
network *network
|
||||
user *user
|
||||
|
||||
serverName string
|
||||
availableUserModes string
|
||||
@ -90,18 +86,10 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||
}
|
||||
|
||||
setKeepAlive(netConn)
|
||||
|
||||
outgoing := make(chan *irc.Message, 64)
|
||||
uc := &upstreamConn{
|
||||
conn: *newConn(network.user.srv, netConn, logger),
|
||||
network: network,
|
||||
logger: logger,
|
||||
net: netConn,
|
||||
irc: irc.NewConn(netConn),
|
||||
srv: network.user.srv,
|
||||
user: network.user,
|
||||
outgoing: outgoing,
|
||||
closed: make(chan struct{}),
|
||||
channels: make(map[string]*upstreamChannel),
|
||||
caps: make(map[string]string),
|
||||
batches: make(map[string]batch),
|
||||
@ -112,50 +100,9 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
logs: make(map[string]entityLog),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for msg := range outgoing {
|
||||
if uc.srv.Debug {
|
||||
uc.logger.Printf("sent: %v", msg)
|
||||
}
|
||||
uc.net.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := uc.irc.WriteMessage(msg); err != nil {
|
||||
uc.logger.Printf("failed to write message: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := uc.net.Close(); err != nil {
|
||||
uc.logger.Printf("failed to close connection: %v", err)
|
||||
} else {
|
||||
uc.logger.Printf("connection closed")
|
||||
}
|
||||
// Drain the outgoing channel to prevent SendMessage from blocking
|
||||
for range outgoing {
|
||||
// This space is intentionally left blank
|
||||
}
|
||||
}()
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) isClosed() bool {
|
||||
select {
|
||||
case <-uc.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection. It is safe to call from any goroutine.
|
||||
func (uc *upstreamConn) Close() error {
|
||||
if uc.isClosed() {
|
||||
return fmt.Errorf("upstream connection already closed")
|
||||
}
|
||||
close(uc.closed)
|
||||
close(uc.outgoing)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
if dc.network != nil && dc.network != uc.network {
|
||||
@ -1409,29 +1356,19 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
|
||||
|
||||
func (uc *upstreamConn) readMessages(ch chan<- event) error {
|
||||
for {
|
||||
msg, err := uc.irc.ReadMessage()
|
||||
msg, err := uc.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to read IRC command: %v", err)
|
||||
}
|
||||
|
||||
if uc.srv.Debug {
|
||||
uc.logger.Printf("received: %v", msg)
|
||||
}
|
||||
|
||||
ch <- eventUpstreamMessage{msg, uc}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage queues a new outgoing message. It is safe to call from any
|
||||
// goroutine.
|
||||
func (uc *upstreamConn) SendMessage(msg *irc.Message) {
|
||||
uc.outgoing <- msg
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
|
||||
if uc.labelsSupported {
|
||||
if msg.Tags == nil {
|
||||
|
Loading…
Reference in New Issue
Block a user