66aea1b4a2
This avoids blocking on upstream message rate limiting for too long.
268 lines
6.1 KiB
Go
268 lines
6.1 KiB
Go
package soju
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode"
|
|
|
|
"golang.org/x/time/rate"
|
|
"gopkg.in/irc.v3"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
|
|
// reading and writing IRC messages.
|
|
type ircConn interface {
|
|
ReadMessage() (*irc.Message, error)
|
|
WriteMessage(*irc.Message) error
|
|
Close() error
|
|
SetReadDeadline(time.Time) error
|
|
SetWriteDeadline(time.Time) error
|
|
RemoteAddr() net.Addr
|
|
LocalAddr() net.Addr
|
|
}
|
|
|
|
func newNetIRCConn(c net.Conn) ircConn {
|
|
type netConn net.Conn
|
|
return struct {
|
|
*irc.Conn
|
|
netConn
|
|
}{irc.NewConn(c), c}
|
|
}
|
|
|
|
type websocketIRCConn struct {
|
|
conn *websocket.Conn
|
|
readDeadline, writeDeadline time.Time
|
|
remoteAddr string
|
|
}
|
|
|
|
func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
|
|
return &websocketIRCConn{conn: c, remoteAddr: remoteAddr}
|
|
}
|
|
|
|
func (wic *websocketIRCConn) ReadMessage() (*irc.Message, error) {
|
|
ctx := context.Background()
|
|
if !wic.readDeadline.IsZero() {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
|
|
defer cancel()
|
|
}
|
|
_, b, err := wic.conn.Read(ctx)
|
|
if err != nil {
|
|
switch websocket.CloseStatus(err) {
|
|
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
|
|
return nil, io.EOF
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
return irc.ParseMessage(string(b))
|
|
}
|
|
|
|
func (wic *websocketIRCConn) WriteMessage(msg *irc.Message) error {
|
|
b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
|
|
ctx := context.Background()
|
|
if !wic.writeDeadline.IsZero() {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
|
|
defer cancel()
|
|
}
|
|
return wic.conn.Write(ctx, websocket.MessageText, b)
|
|
}
|
|
|
|
func isErrWebSocketClosed(err error) bool {
|
|
return err != nil && strings.HasSuffix(err.Error(), "failed to close WebSocket: already wrote close")
|
|
}
|
|
|
|
func (wic *websocketIRCConn) Close() error {
|
|
err := wic.conn.Close(websocket.StatusNormalClosure, "")
|
|
// TODO: remove once this PR is merged:
|
|
// https://github.com/nhooyr/websocket/pull/303
|
|
if isErrWebSocketClosed(err) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (wic *websocketIRCConn) SetReadDeadline(t time.Time) error {
|
|
wic.readDeadline = t
|
|
return nil
|
|
}
|
|
|
|
func (wic *websocketIRCConn) SetWriteDeadline(t time.Time) error {
|
|
wic.writeDeadline = t
|
|
return nil
|
|
}
|
|
|
|
func (wic *websocketIRCConn) RemoteAddr() net.Addr {
|
|
return websocketAddr(wic.remoteAddr)
|
|
}
|
|
|
|
func (wic *websocketIRCConn) LocalAddr() net.Addr {
|
|
// Behind a reverse HTTP proxy, we don't have access to the real listening
|
|
// address
|
|
return websocketAddr("")
|
|
}
|
|
|
|
type websocketAddr string
|
|
|
|
func (websocketAddr) Network() string {
|
|
return "ws"
|
|
}
|
|
|
|
func (wa websocketAddr) String() string {
|
|
return string(wa)
|
|
}
|
|
|
|
type connOptions struct {
|
|
Logger Logger
|
|
RateLimitDelay time.Duration
|
|
RateLimitBurst int
|
|
}
|
|
|
|
type conn struct {
|
|
conn ircConn
|
|
srv *Server
|
|
logger Logger
|
|
|
|
lock sync.Mutex
|
|
outgoing chan<- *irc.Message
|
|
closed bool
|
|
closedCh chan struct{}
|
|
}
|
|
|
|
func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
|
|
outgoing := make(chan *irc.Message, 64)
|
|
c := &conn{
|
|
conn: ic,
|
|
srv: srv,
|
|
outgoing: outgoing,
|
|
logger: options.Logger,
|
|
closedCh: make(chan struct{}),
|
|
}
|
|
|
|
go func() {
|
|
ctx, cancel := c.NewContext(context.Background())
|
|
defer cancel()
|
|
|
|
rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
|
|
for msg := range outgoing {
|
|
if err := rl.Wait(ctx); err != nil {
|
|
break
|
|
}
|
|
|
|
c.logger.Debugf("sent: %v", msg)
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
|
if err := c.conn.WriteMessage(msg); err != nil {
|
|
c.logger.Printf("failed to write message: %v", err)
|
|
break
|
|
}
|
|
}
|
|
if err := c.conn.Close(); err != nil && !isErrClosed(err) {
|
|
c.logger.Printf("failed to close connection: %v", err)
|
|
} else {
|
|
c.logger.Debugf("connection closed")
|
|
}
|
|
// Drain the outgoing channel to prevent SendMessage from blocking
|
|
for range outgoing {
|
|
// This space is intentionally left blank
|
|
}
|
|
}()
|
|
|
|
c.logger.Debugf("new connection")
|
|
return c
|
|
}
|
|
|
|
func (c *conn) isClosed() bool {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
return c.closed
|
|
}
|
|
|
|
// Close closes the connection. It is safe to call from any goroutine.
|
|
func (c *conn) Close() error {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
if c.closed {
|
|
return fmt.Errorf("connection already closed")
|
|
}
|
|
|
|
err := c.conn.Close()
|
|
c.closed = true
|
|
close(c.outgoing)
|
|
close(c.closedCh)
|
|
return err
|
|
}
|
|
|
|
func (c *conn) ReadMessage() (*irc.Message, error) {
|
|
msg, err := c.conn.ReadMessage()
|
|
if isErrClosed(err) {
|
|
return nil, io.EOF
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.logger.Debugf("received: %v", msg)
|
|
return msg, nil
|
|
}
|
|
|
|
// SendMessage queues a new outgoing message. It is safe to call from any
|
|
// goroutine.
|
|
//
|
|
// If the connection is closed before the message is sent, SendMessage silently
|
|
// drops the message.
|
|
func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
if c.closed {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case c.outgoing <- msg:
|
|
// Success
|
|
case <-ctx.Done():
|
|
c.logger.Printf("failed to send message: %v", ctx.Err())
|
|
}
|
|
}
|
|
|
|
func (c *conn) RemoteAddr() net.Addr {
|
|
return c.conn.RemoteAddr()
|
|
}
|
|
|
|
func (c *conn) LocalAddr() net.Addr {
|
|
return c.conn.LocalAddr()
|
|
}
|
|
|
|
// NewContext returns a copy of the parent context with a new Done channel. The
|
|
// returned context's Done channel is closed when the connection is closed,
|
|
// when the returned cancel function is called, or when the parent context's
|
|
// Done channel is closed, whichever happens first.
|
|
//
|
|
// Canceling this context releases resources associated with it, so code should
|
|
// call cancel as soon as the operations running in this Context complete.
|
|
func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
|
|
ctx, cancel := context.WithCancel(parent)
|
|
|
|
go func() {
|
|
defer cancel()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
// The parent context has been cancelled, or the caller has called
|
|
// cancel()
|
|
case <-c.closedCh:
|
|
// The connection has been closed
|
|
}
|
|
}()
|
|
|
|
return ctx, cancel
|
|
}
|