Allow clients to specify an upstream name in their username

This commit is contained in:
Simon Ser 2020-03-04 15:44:13 +01:00
parent d1550a3cdb
commit c22ce793a1
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 85 additions and 17 deletions

View File

@ -58,6 +58,7 @@ type downstreamConn struct {
nick string
username string
realname string
upstream *Upstream
}
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
@ -97,13 +98,38 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
return name
}
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
dc.user.forEachUpstream(func(uc *upstreamConn) {
if dc.upstream != nil && uc.upstream != dc.upstream {
return
}
f(uc)
})
}
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
// TODO: extract network name from channel name
ch, err := dc.user.getChannel(name)
if err != nil {
return nil, "", err
// TODO: extract network name from channel name if dc.upstream == nil
var channel *upstreamChannel
var err error
dc.forEachUpstream(func(uc *upstreamConn) {
if err != nil {
return
}
if ch, ok := uc.channels[name]; ok {
if channel != nil {
err = fmt.Errorf("ambiguous channel name %q", name)
} else {
channel = ch
}
}
})
if channel == nil {
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{name, "No such channel"},
}}
}
return ch.conn, ch.Name, nil
return channel.conn, channel.Name, nil
}
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
@ -274,9 +300,18 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
}
func (dc *downstreamConn) register() error {
u := dc.srv.getUser(strings.TrimPrefix(dc.username, "~"))
username := strings.TrimPrefix(dc.username, "~")
var network string
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
network = username[i+1:]
}
if i := strings.IndexAny(username, "/@"); i >= 0 {
username = username[:i]
}
u := dc.srv.getUser(username)
if u == nil {
dc.logger.Printf("failed authentication: unknown username %q", dc.username)
dc.logger.Printf("failed authentication: unknown username %q", username)
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_PASSWDMISMATCH,
@ -285,6 +320,19 @@ func (dc *downstreamConn) register() error {
return nil
}
if network != "" {
dc.upstream = dc.srv.getUpstream(network)
if dc.upstream == nil {
dc.logger.Printf("failed registration: unknown upstream %q", network)
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", network)},
})
return nil
}
}
dc.registered = true
dc.user = u
@ -319,7 +367,7 @@ func (dc *downstreamConn) register() error {
Params: []string{dc.nick, "No MOTD"},
})
u.forEachUpstream(func(uc *upstreamConn) {
dc.forEachUpstream(func(uc *upstreamConn) {
// TODO: fix races accessing upstream connection data
for _, ch := range uc.channels {
if ch.complete {
@ -327,8 +375,7 @@ func (dc *downstreamConn) register() error {
}
}
// TODO: let clients specify the ring buffer name in their username
historyName := ""
historyName := dc.username
var seqPtr *uint64
if firstDownstream {
@ -376,7 +423,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{dc.nick, "You may not reregister"},
}}
case "NICK":
dc.user.forEachUpstream(func(uc *upstreamConn) {
dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(msg)
})
case "JOIN", "PART":
@ -448,7 +495,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
if modeStr != "" {
dc.user.forEachUpstream(func(uc *upstreamConn) {
dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessage(&irc.Message{
Command: "MODE",
Params: []string{uc.nick, modeStr},

View File

@ -82,13 +82,16 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Unlock()
}
func (u *user) getChannel(name string) (*upstreamChannel, error) {
func (u *user) getChannel(name string, upstream *Upstream) (*upstreamChannel, error) {
var channel *upstreamChannel
var err error
u.forEachUpstream(func(uc *upstreamConn) {
if err != nil {
return
}
if upstream != nil && uc.upstream != upstream {
return
}
if ch, ok := uc.channels[name]; ok {
if channel != nil {
err = fmt.Errorf("ambiguous channel name %q", name)
@ -196,6 +199,15 @@ func (s *Server) getUser(name string) *user {
return u
}
func (s *Server) getUpstream(name string) *Upstream {
for i, upstream := range s.Upstreams {
if upstream.Addr == name {
return &s.Upstreams[i]
}
}
return nil
}
func (s *Server) Serve(ln net.Listener) error {
for {
netConn, err := ln.Accept()

View File

@ -100,6 +100,15 @@ func (uc *upstreamConn) Close() error {
return nil
}
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
uc.user.forEachDownstream(func(dc *downstreamConn) {
if dc.upstream != nil && dc.upstream != uc.upstream {
return
}
f(dc)
})
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := uc.channels[name]
if !ok {
@ -140,7 +149,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
uc.user.forEachDownstream(func(dc *downstreamConn) {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "MODE",
@ -210,7 +219,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
ch.Members[msg.Prefix.Name] = 0
}
uc.user.forEachDownstream(func(dc *downstreamConn) {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "JOIN",
@ -240,7 +249,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
delete(ch.Members, msg.Prefix.Name)
}
uc.user.forEachDownstream(func(dc *downstreamConn) {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
Command: "PART",
@ -326,7 +335,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}
ch.complete = true
uc.user.forEachDownstream(func(dc *downstreamConn) {
uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(dc, ch)
})
case "PRIVMSG":