Allow clients to specify an upstream name in their username

This commit is contained in:
Simon Ser 2020-03-04 15:44:13 +01:00
parent d1550a3cdb
commit c22ce793a1
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 85 additions and 17 deletions

View File

@ -58,6 +58,7 @@ type downstreamConn struct {
nick string nick string
username string username string
realname string realname string
upstream *Upstream
} }
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
@ -97,13 +98,38 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
return name return name
} }
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
dc.user.forEachUpstream(func(uc *upstreamConn) {
if dc.upstream != nil && uc.upstream != dc.upstream {
return
}
f(uc)
})
}
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) { func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
// TODO: extract network name from channel name // TODO: extract network name from channel name if dc.upstream == nil
ch, err := dc.user.getChannel(name) var channel *upstreamChannel
if err != nil { var err error
return nil, "", err dc.forEachUpstream(func(uc *upstreamConn) {
if err != nil {
return
}
if ch, ok := uc.channels[name]; ok {
if channel != nil {
err = fmt.Errorf("ambiguous channel name %q", name)
} else {
channel = ch
}
}
})
if channel == nil {
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, "No such channel"},
}}
} }
return ch.conn, ch.Name, nil return channel.conn, channel.Name, nil
} }
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string { func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
@ -274,9 +300,18 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
} }
func (dc *downstreamConn) register() error { func (dc *downstreamConn) register() error {
u := dc.srv.getUser(strings.TrimPrefix(dc.username, "~")) username := strings.TrimPrefix(dc.username, "~")
var network string
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
network = username[i+1:]
}
if i := strings.IndexAny(username, "/@"); i >= 0 {
username = username[:i]
}
u := dc.srv.getUser(username)
if u == nil { if u == nil {
dc.logger.Printf("failed authentication: unknown username %q", dc.username) dc.logger.Printf("failed authentication: unknown username %q", username)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
@ -285,6 +320,19 @@ func (dc *downstreamConn) register() error {
return nil return nil
} }
if network != "" {
dc.upstream = dc.srv.getUpstream(network)
if dc.upstream == nil {
dc.logger.Printf("failed registration: unknown upstream %q", network)
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", network)},
})
return nil
}
}
dc.registered = true dc.registered = true
dc.user = u dc.user = u
@ -319,7 +367,7 @@ func (dc *downstreamConn) register() error {
Params: []string{dc.nick, "No MOTD"}, Params: []string{dc.nick, "No MOTD"},
}) })
u.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
// TODO: fix races accessing upstream connection data // TODO: fix races accessing upstream connection data
for _, ch := range uc.channels { for _, ch := range uc.channels {
if ch.complete { if ch.complete {
@ -327,8 +375,7 @@ func (dc *downstreamConn) register() error {
} }
} }
// TODO: let clients specify the ring buffer name in their username historyName := dc.username
historyName := ""
var seqPtr *uint64 var seqPtr *uint64
if firstDownstream { if firstDownstream {
@ -376,7 +423,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{dc.nick, "You may not reregister"}, Params: []string{dc.nick, "You may not reregister"},
}} }}
case "NICK": case "NICK":
dc.user.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(msg) uc.SendMessage(msg)
}) })
case "JOIN", "PART": case "JOIN", "PART":
@ -448,7 +495,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
if modeStr != "" { if modeStr != "" {
dc.user.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "MODE", Command: "MODE",
Params: []string{uc.nick, modeStr}, Params: []string{uc.nick, modeStr},

View File

@ -82,13 +82,16 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Unlock() u.lock.Unlock()
} }
func (u *user) getChannel(name string) (*upstreamChannel, error) { func (u *user) getChannel(name string, upstream *Upstream) (*upstreamChannel, error) {
var channel *upstreamChannel var channel *upstreamChannel
var err error var err error
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
if err != nil { if err != nil {
return return
} }
if upstream != nil && uc.upstream != upstream {
return
}
if ch, ok := uc.channels[name]; ok { if ch, ok := uc.channels[name]; ok {
if channel != nil { if channel != nil {
err = fmt.Errorf("ambiguous channel name %q", name) err = fmt.Errorf("ambiguous channel name %q", name)
@ -196,6 +199,15 @@ func (s *Server) getUser(name string) *user {
return u return u
} }
func (s *Server) getUpstream(name string) *Upstream {
for i, upstream := range s.Upstreams {
if upstream.Addr == name {
return &s.Upstreams[i]
}
}
return nil
}
func (s *Server) Serve(ln net.Listener) error { func (s *Server) Serve(ln net.Listener) error {
for { for {
netConn, err := ln.Accept() netConn, err := ln.Accept()

View File

@ -100,6 +100,15 @@ func (uc *upstreamConn) Close() error {
return nil return nil
} }
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
uc.user.forEachDownstream(func(dc *downstreamConn) {
if dc.upstream != nil && dc.upstream != uc.upstream {
return
}
f(dc)
})
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := uc.channels[name] ch, ok := uc.channels[name]
if !ok { if !ok {
@ -140,7 +149,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "MODE", Command: "MODE",
@ -210,7 +219,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
ch.Members[msg.Prefix.Name] = 0 ch.Members[msg.Prefix.Name] = 0
} }
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "JOIN", Command: "JOIN",
@ -240,7 +249,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
delete(ch.Members, msg.Prefix.Name) delete(ch.Members, msg.Prefix.Name)
} }
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "PART", Command: "PART",
@ -326,7 +335,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
ch.complete = true ch.complete = true
uc.user.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(dc, ch) forwardChannel(dc, ch)
}) })
case "PRIVMSG": case "PRIVMSG":