Add functions to translate between upstream and downstream names

This commit is contained in:
Simon Ser 2020-02-19 18:25:19 +01:00
parent ef2d145d1f
commit 70fcef297b
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 153 additions and 68 deletions

View File

@ -9,23 +9,25 @@ func forwardChannel(dc *downstreamConn, ch *upstreamChannel) {
panic("Tried to forward a partial channel") panic("Tried to forward a partial channel")
} }
downstreamName := dc.marshalChannel(ch.conn, ch.Name)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.prefix(), Prefix: dc.prefix(),
Command: "JOIN", Command: "JOIN",
Params: []string{ch.Name}, Params: []string{downstreamName},
}) })
if ch.Topic != "" { if ch.Topic != "" {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_TOPIC, Command: irc.RPL_TOPIC,
Params: []string{dc.nick, ch.Name, ch.Topic}, Params: []string{dc.nick, downstreamName, ch.Topic},
}) })
} else { } else {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_NOTOPIC, Command: irc.RPL_NOTOPIC,
Params: []string{dc.nick, ch.Name, "No topic is set"}, Params: []string{dc.nick, downstreamName, "No topic is set"},
}) })
} }
@ -33,21 +35,21 @@ func forwardChannel(dc *downstreamConn, ch *upstreamChannel) {
// TODO: send multiple members in each message // TODO: send multiple members in each message
for nick, membership := range ch.Members { for nick, membership := range ch.Members {
s := nick s := dc.marshalNick(ch.conn, nick)
if membership != 0 { if membership != 0 {
s = string(membership) + nick s = string(membership) + s
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_NAMREPLY, Command: irc.RPL_NAMREPLY,
Params: []string{dc.nick, string(ch.Status), ch.Name, s}, Params: []string{dc.nick, string(ch.Status), downstreamName, s},
}) })
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFNAMES, Command: irc.RPL_ENDOFNAMES,
Params: []string{dc.nick, ch.Name, "End of /NAMES list"}, Params: []string{dc.nick, downstreamName, "End of /NAMES list"},
}) })
} }

View File

@ -39,14 +39,19 @@ func (err ircError) Error() string {
return err.Message.String() return err.Message.String()
} }
type consumption struct {
consumer *RingConsumer
upstreamConn *upstreamConn
}
type downstreamConn struct { type downstreamConn struct {
net net.Conn net net.Conn
irc *irc.Conn irc *irc.Conn
srv *Server srv *Server
logger Logger logger Logger
messages chan *irc.Message messages chan *irc.Message
consumers chan *RingConsumer consumptions chan consumption
closed chan struct{} closed chan struct{}
registered bool registered bool
user *user user *user
@ -57,13 +62,13 @@ type downstreamConn struct {
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
dc := &downstreamConn{ dc := &downstreamConn{
net: netConn, net: netConn,
irc: irc.NewConn(netConn), irc: irc.NewConn(netConn),
srv: srv, srv: srv,
logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
messages: make(chan *irc.Message, 64), messages: make(chan *irc.Message, 64),
consumers: make(chan *RingConsumer), consumptions: make(chan consumption),
closed: make(chan struct{}), closed: make(chan struct{}),
} }
go func() { go func() {
@ -88,6 +93,33 @@ func (dc *downstreamConn) prefix() *irc.Prefix {
} }
} }
func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
return name
}
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
// TODO: extract network name from channel name
ch, err := dc.user.getChannel(name)
if err != nil {
return nil, "", err
}
return ch.conn, ch.Name, nil
}
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
if nick == uc.nick {
return dc.nick
}
return nick
}
func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
if prefix.Name == uc.nick {
return dc.prefix()
}
return prefix
}
func (dc *downstreamConn) isClosed() bool { func (dc *downstreamConn) isClosed() bool {
select { select {
case <-dc.closed: case <-dc.closed:
@ -138,12 +170,21 @@ func (dc *downstreamConn) writeMessages() error {
dc.logger.Printf("sent: %v", msg) dc.logger.Printf("sent: %v", msg)
} }
err = dc.irc.WriteMessage(msg) err = dc.irc.WriteMessage(msg)
case consumer := <-dc.consumers: case consumption := <-dc.consumptions:
consumer, uc := consumption.consumer, consumption.upstreamConn
for { for {
msg := consumer.Peek() msg := consumer.Peek()
if msg == nil { if msg == nil {
break break
} }
msg = msg.Copy()
switch msg.Command {
case "PRIVMSG":
// TODO: detect whether it's a user or a channel
msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
default:
panic("expected to consume a PRIVMSG message")
}
if dc.srv.Debug { if dc.srv.Debug {
dc.logger.Printf("sent: %v", msg) dc.logger.Printf("sent: %v", msg)
} }
@ -303,7 +344,7 @@ func (dc *downstreamConn) register() error {
var closed bool var closed bool
select { select {
case <-ch: case <-ch:
dc.consumers <- consumer dc.consumptions <- consumption{consumer, uc}
case <-dc.closed: case <-dc.closed:
closed = true closed = true
} }
@ -338,35 +379,30 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.user.forEachUpstream(func(uc *upstreamConn) { dc.user.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(msg) uc.SendMessage(msg)
}) })
case "JOIN": case "JOIN", "PART":
var name string var name string
if err := parseMessageParams(msg, &name); err != nil { if err := parseMessageParams(msg, &name); err != nil {
return err return err
} }
if ch, _ := dc.user.getChannel(name); ch != nil { uc, upstreamName, err := dc.unmarshalChannel(name)
break // already joined
}
// TODO: extract network name from channel name
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, "Channel name ambiguous"},
}}
case "PART":
var name string
if err := parseMessageParams(msg, &name); err != nil {
return err
}
ch, err := dc.user.getChannel(name)
if err != nil { if err != nil {
return err return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, err.Error()},
}}
} }
ch.conn.SendMessage(msg) uc.SendMessage(&irc.Message{
// TODO: remove channel from upstream config Command: msg.Command,
Params: []string{upstreamName},
})
// TODO: add/remove channel from upstream config
case "MODE": case "MODE":
if msg.Prefix == nil {
return fmt.Errorf("missing prefix")
}
var name string var name string
if err := parseMessageParams(msg, &name); err != nil { if err := parseMessageParams(msg, &name); err != nil {
return err return err
@ -378,18 +414,30 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
if msg.Prefix.Name != name { if msg.Prefix.Name != name {
ch, err := dc.user.getChannel(name) uc, upstreamName, err := dc.unmarshalChannel(name)
if err != nil { if err != nil {
return err return err
} }
if modeStr != "" { if modeStr != "" {
ch.conn.SendMessage(msg) uc.SendMessage(&irc.Message{
Prefix: uc.prefix(),
Command: "MODE",
Params: []string{upstreamName, modeStr},
})
} else { } else {
ch, ok := uc.channels[upstreamName]
if !ok {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, "No such channel"},
}}
}
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_CHANNELMODEIS, Command: irc.RPL_CHANNELMODEIS,
Params: []string{ch.Name, string(ch.modes)}, Params: []string{name, string(ch.modes)},
}) })
} }
} else { } else {
@ -402,7 +450,11 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
if modeStr != "" { if modeStr != "" {
dc.user.forEachUpstream(func(uc *upstreamConn) { dc.user.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(msg) uc.SendMessage(&irc.Message{
Prefix: uc.prefix(),
Command: "MODE",
Params: []string{uc.nick, modeStr},
})
}) })
} else { } else {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -419,15 +471,15 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
for _, name := range strings.Split(targetsStr, ",") { for _, name := range strings.Split(targetsStr, ",") {
ch, err := dc.user.getChannel(name) uc, upstreamName, err := dc.unmarshalChannel(name)
if err != nil { if err != nil {
return err return err
} }
ch.conn.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Prefix: msg.Prefix, Prefix: uc.prefix(),
Command: "PRIVMSG", Command: "PRIVMSG",
Params: []string{name, text}, Params: []string{upstreamName, text},
}) })
} }
default: default:

View File

@ -91,6 +91,14 @@ func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) {
return uc, nil return uc, nil
} }
func (uc *upstreamConn) prefix() *irc.Prefix {
return &irc.Prefix{
Name: uc.nick,
User: uc.upstream.Username,
// TODO: fill the host?
}
}
func (uc *upstreamConn) Close() error { func (uc *upstreamConn) Close() error {
if uc.closed { if uc.closed {
return fmt.Errorf("upstream connection already closed") return fmt.Errorf("upstream connection already closed")
@ -117,6 +125,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}) })
return nil return nil
case "MODE": case "MODE":
if msg.Prefix == nil {
return fmt.Errorf("missing prefix")
}
var name, modeStr string var name, modeStr string
if err := parseMessageParams(msg, &name, &modeStr); err != nil { if err := parseMessageParams(msg, &name, &modeStr); err != nil {
return err return err
@ -135,11 +147,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err := ch.modes.Apply(modeStr); err != nil { if err := ch.modes.Apply(modeStr); err != nil {
return err return err
} }
}
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.user.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg) dc.SendMessage(&irc.Message{
}) Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "MODE",
Params: []string{dc.marshalChannel(uc, name), modeStr},
})
})
}
case "NOTICE": case "NOTICE":
uc.logger.Print(msg) uc.logger.Print(msg)
case irc.RPL_WELCOME: case irc.RPL_WELCOME:
@ -176,11 +192,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
ch.Members[newNick] = membership ch.Members[newNick] = membership
} }
} }
uc.user.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg)
})
case "JOIN": case "JOIN":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var channels string var channels string
if err := parseMessageParams(msg, &channels); err != nil { if err := parseMessageParams(msg, &channels); err != nil {
return err return err
@ -201,12 +217,20 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
ch.Members[msg.Prefix.Name] = 0 ch.Members[msg.Prefix.Name] = 0
} }
uc.user.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "JOIN",
Params: []string{dc.marshalChannel(uc, ch)},
})
})
}
case "PART":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
} }
uc.user.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg)
})
case "PART":
var channels string var channels string
if err := parseMessageParams(msg, &channels); err != nil { if err := parseMessageParams(msg, &channels); err != nil {
return err return err
@ -223,11 +247,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
delete(ch.Members, msg.Prefix.Name) delete(ch.Members, msg.Prefix.Name)
} }
}
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.user.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg) dc.SendMessage(&irc.Message{
}) Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "PART",
Params: []string{dc.marshalChannel(uc, ch)},
})
})
}
case irc.RPL_TOPIC, irc.RPL_NOTOPIC: case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
var name, topic string var name, topic string
if err := parseMessageParams(msg, nil, &name, &topic); err != nil { if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
@ -310,6 +338,9 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
forwardChannel(dc, ch) forwardChannel(dc, ch)
}) })
case "PRIVMSG": case "PRIVMSG":
if err := parseMessageParams(msg, nil, nil); err != nil {
return err
}
uc.ring.Produce(msg) uc.ring.Produce(msg)
case irc.RPL_YOURHOST, irc.RPL_CREATED: case irc.RPL_YOURHOST, irc.RPL_CREATED:
// Ignore // Ignore
@ -331,7 +362,7 @@ func (uc *upstreamConn) register() {
uc.nick = uc.upstream.Nick uc.nick = uc.upstream.Nick
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{uc.upstream.Nick}, Params: []string{uc.nick},
}) })
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "USER", Command: "USER",