Per-user dispatcher goroutine
This allows message handlers to read upstream/downstream connection information without causing any race condition. References: https://todo.sr.ht/~emersion/soju/1
This commit is contained in:
parent
cdab0dc825
commit
3919ee2036
@ -191,7 +191,7 @@ func (dc *downstreamConn) isClosed() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) readMessages() error {
|
func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
|
||||||
dc.logger.Printf("new connection")
|
dc.logger.Printf("new connection")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -206,17 +206,7 @@ func (dc *downstreamConn) readMessages() error {
|
|||||||
dc.logger.Printf("received: %v", msg)
|
dc.logger.Printf("received: %v", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dc.handleMessage(msg)
|
ch <- downstreamIncomingMessage{msg, dc}
|
||||||
if ircErr, ok := err.(ircError); ok {
|
|
||||||
ircErr.Message.Prefix = dc.srv.prefix()
|
|
||||||
dc.SendMessage(ircErr.Message)
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dc.isClosed() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -484,6 +474,27 @@ func (dc *downstreamConn) register() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dc *downstreamConn) runUntilRegistered() error {
|
||||||
|
for !dc.registered {
|
||||||
|
msg, err := dc.irc.ReadMessage()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to read IRC command: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dc.handleMessage(msg)
|
||||||
|
if ircErr, ok := err.(ircError); ok {
|
||||||
|
ircErr.Message.Prefix = dc.srv.prefix()
|
||||||
|
dc.SendMessage(ircErr.Message)
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case "USER":
|
case "USER":
|
||||||
|
@ -114,8 +114,12 @@ func (s *Server) Serve(ln net.Listener) error {
|
|||||||
s.downstreamConns = append(s.downstreamConns, dc)
|
s.downstreamConns = append(s.downstreamConns, dc)
|
||||||
s.lock.Unlock()
|
s.lock.Unlock()
|
||||||
|
|
||||||
if err := dc.readMessages(); err != nil {
|
if err := dc.runUntilRegistered(); err != nil {
|
||||||
dc.logger.Printf("failed to handle messages: %v", err)
|
dc.logger.Print(err)
|
||||||
|
} else {
|
||||||
|
if err := dc.readMessages(dc.user.downstreamIncoming); err != nil {
|
||||||
|
dc.logger.Print(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dc.Close()
|
dc.Close()
|
||||||
|
|
||||||
|
@ -659,7 +659,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *upstreamConn) readMessages() error {
|
func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
|
||||||
for {
|
for {
|
||||||
msg, err := uc.irc.ReadMessage()
|
msg, err := uc.irc.ReadMessage()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@ -672,9 +672,7 @@ func (uc *upstreamConn) readMessages() error {
|
|||||||
uc.logger.Printf("received: %v", msg)
|
uc.logger.Printf("received: %v", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := uc.handleMessage(msg); err != nil {
|
ch <- upstreamIncomingMessage{msg, uc}
|
||||||
uc.logger.Printf("failed to handle message %q: %v", msg, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
43
user.go
43
user.go
@ -3,8 +3,20 @@ package soju
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/irc.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type upstreamIncomingMessage struct {
|
||||||
|
msg *irc.Message
|
||||||
|
uc *upstreamConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type downstreamIncomingMessage struct {
|
||||||
|
msg *irc.Message
|
||||||
|
dc *downstreamConn
|
||||||
|
}
|
||||||
|
|
||||||
type network struct {
|
type network struct {
|
||||||
Network
|
Network
|
||||||
user *user
|
user *user
|
||||||
@ -40,7 +52,7 @@ func (net *network) run() {
|
|||||||
net.conn = uc
|
net.conn = uc
|
||||||
net.user.lock.Unlock()
|
net.user.lock.Unlock()
|
||||||
|
|
||||||
if err := uc.readMessages(); err != nil {
|
if err := uc.readMessages(net.user.upstreamIncoming); err != nil {
|
||||||
uc.logger.Printf("failed to handle messages: %v", err)
|
uc.logger.Printf("failed to handle messages: %v", err)
|
||||||
}
|
}
|
||||||
uc.Close()
|
uc.Close()
|
||||||
@ -55,6 +67,9 @@ type user struct {
|
|||||||
User
|
User
|
||||||
srv *Server
|
srv *Server
|
||||||
|
|
||||||
|
upstreamIncoming chan upstreamIncomingMessage
|
||||||
|
downstreamIncoming chan downstreamIncomingMessage
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
networks []*network
|
networks []*network
|
||||||
downstreamConns []*downstreamConn
|
downstreamConns []*downstreamConn
|
||||||
@ -62,8 +77,10 @@ type user struct {
|
|||||||
|
|
||||||
func newUser(srv *Server, record *User) *user {
|
func newUser(srv *Server, record *User) *user {
|
||||||
return &user{
|
return &user{
|
||||||
User: *record,
|
User: *record,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
|
upstreamIncoming: make(chan upstreamIncomingMessage, 64),
|
||||||
|
downstreamIncoming: make(chan downstreamIncomingMessage, 64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,6 +136,26 @@ func (u *user) run() {
|
|||||||
go network.run()
|
go network.run()
|
||||||
}
|
}
|
||||||
u.lock.Unlock()
|
u.lock.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case upstreamMsg := <-u.upstreamIncoming:
|
||||||
|
msg, uc := upstreamMsg.msg, upstreamMsg.uc
|
||||||
|
if err := uc.handleMessage(msg); err != nil {
|
||||||
|
uc.logger.Printf("failed to handle message %q: %v", msg, err)
|
||||||
|
}
|
||||||
|
case downstreamMsg := <-u.downstreamIncoming:
|
||||||
|
msg, dc := downstreamMsg.msg, downstreamMsg.dc
|
||||||
|
err := dc.handleMessage(msg)
|
||||||
|
if ircErr, ok := err.(ircError); ok {
|
||||||
|
ircErr.Message.Prefix = dc.srv.prefix()
|
||||||
|
dc.SendMessage(ircErr.Message)
|
||||||
|
} else if err != nil {
|
||||||
|
dc.logger.Printf("failed to handle message %q: %v", msg, err)
|
||||||
|
dc.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) createNetwork(addr, nick string) (*network, error) {
|
func (u *user) createNetwork(addr, nick string) (*network, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user