From c22ce793a176ab972f6642d0c0c3a34fcb70faae Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 4 Mar 2020 15:44:13 +0100 Subject: [PATCH] Allow clients to specify an upstream name in their username --- downstream.go | 71 ++++++++++++++++++++++++++++++++++++++++++--------- server.go | 14 +++++++++- upstream.go | 17 +++++++++--- 3 files changed, 85 insertions(+), 17 deletions(-) diff --git a/downstream.go b/downstream.go index aaf482e..bddeaa2 100644 --- a/downstream.go +++ b/downstream.go @@ -58,6 +58,7 @@ type downstreamConn struct { nick string username string realname string + upstream *Upstream } func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { @@ -97,13 +98,38 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string { return name } +func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { + dc.user.forEachUpstream(func(uc *upstreamConn) { + if dc.upstream != nil && uc.upstream != dc.upstream { + return + } + f(uc) + }) +} + func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) { - // TODO: extract network name from channel name - ch, err := dc.user.getChannel(name) - if err != nil { - return nil, "", err + // TODO: extract network name from channel name if dc.upstream == nil + var channel *upstreamChannel + var err error + dc.forEachUpstream(func(uc *upstreamConn) { + if err != nil { + return + } + if ch, ok := uc.channels[name]; ok { + if channel != nil { + err = fmt.Errorf("ambiguous channel name %q", name) + } else { + channel = ch + } + } + }) + if channel == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "No such channel"}, + }} } - return ch.conn, ch.Name, nil + return channel.conn, channel.Name, nil } func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string { @@ -274,9 +300,18 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { } func (dc *downstreamConn) register() error { - u := dc.srv.getUser(strings.TrimPrefix(dc.username, "~")) + username := strings.TrimPrefix(dc.username, "~") + var network string + if i := strings.LastIndexAny(username, "/@"); i >= 0 { + network = username[i+1:] + } + if i := strings.IndexAny(username, "/@"); i >= 0 { + username = username[:i] + } + + u := dc.srv.getUser(username) if u == nil { - dc.logger.Printf("failed authentication: unknown username %q", dc.username) + dc.logger.Printf("failed authentication: unknown username %q", username) dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_PASSWDMISMATCH, @@ -285,6 +320,19 @@ func (dc *downstreamConn) register() error { return nil } + if network != "" { + dc.upstream = dc.srv.getUpstream(network) + if dc.upstream == nil { + dc.logger.Printf("failed registration: unknown upstream %q", network) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", network)}, + }) + return nil + } + } + dc.registered = true dc.user = u @@ -319,7 +367,7 @@ func (dc *downstreamConn) register() error { Params: []string{dc.nick, "No MOTD"}, }) - u.forEachUpstream(func(uc *upstreamConn) { + dc.forEachUpstream(func(uc *upstreamConn) { // TODO: fix races accessing upstream connection data for _, ch := range uc.channels { if ch.complete { @@ -327,8 +375,7 @@ func (dc *downstreamConn) register() error { } } - // TODO: let clients specify the ring buffer name in their username - historyName := "" + historyName := dc.username var seqPtr *uint64 if firstDownstream { @@ -376,7 +423,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: []string{dc.nick, "You may not reregister"}, }} case "NICK": - dc.user.forEachUpstream(func(uc *upstreamConn) { + dc.forEachUpstream(func(uc *upstreamConn) { uc.SendMessage(msg) }) case "JOIN", "PART": @@ -448,7 +495,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } if modeStr != "" { - dc.user.forEachUpstream(func(uc *upstreamConn) { + dc.forEachUpstream(func(uc *upstreamConn) { uc.SendMessage(&irc.Message{ Command: "MODE", Params: []string{uc.nick, modeStr}, diff --git a/server.go b/server.go index fde8ca5..4d09240 100644 --- a/server.go +++ b/server.go @@ -82,13 +82,16 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) { u.lock.Unlock() } -func (u *user) getChannel(name string) (*upstreamChannel, error) { +func (u *user) getChannel(name string, upstream *Upstream) (*upstreamChannel, error) { var channel *upstreamChannel var err error u.forEachUpstream(func(uc *upstreamConn) { if err != nil { return } + if upstream != nil && uc.upstream != upstream { + return + } if ch, ok := uc.channels[name]; ok { if channel != nil { err = fmt.Errorf("ambiguous channel name %q", name) @@ -196,6 +199,15 @@ func (s *Server) getUser(name string) *user { return u } +func (s *Server) getUpstream(name string) *Upstream { + for i, upstream := range s.Upstreams { + if upstream.Addr == name { + return &s.Upstreams[i] + } + } + return nil +} + func (s *Server) Serve(ln net.Listener) error { for { netConn, err := ln.Accept() diff --git a/upstream.go b/upstream.go index 0ce0bcb..95cea44 100644 --- a/upstream.go +++ b/upstream.go @@ -100,6 +100,15 @@ func (uc *upstreamConn) Close() error { return nil } +func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { + uc.user.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream != nil && dc.upstream != uc.upstream { + return + } + f(dc) + }) +} + func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { ch, ok := uc.channels[name] if !ok { @@ -140,7 +149,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - uc.user.forEachDownstream(func(dc *downstreamConn) { + uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(&irc.Message{ Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Command: "MODE", @@ -210,7 +219,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { ch.Members[msg.Prefix.Name] = 0 } - uc.user.forEachDownstream(func(dc *downstreamConn) { + uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(&irc.Message{ Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Command: "JOIN", @@ -240,7 +249,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { delete(ch.Members, msg.Prefix.Name) } - uc.user.forEachDownstream(func(dc *downstreamConn) { + uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(&irc.Message{ Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Command: "PART", @@ -326,7 +335,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } ch.complete = true - uc.user.forEachDownstream(func(dc *downstreamConn) { + uc.forEachDownstream(func(dc *downstreamConn) { forwardChannel(dc, ch) }) case "PRIVMSG":