Allow clients to specify an upstream name in their username
This commit is contained in:
parent
d1550a3cdb
commit
c22ce793a1
@ -58,6 +58,7 @@ type downstreamConn struct {
|
|||||||
nick string
|
nick string
|
||||||
username string
|
username string
|
||||||
realname string
|
realname string
|
||||||
|
upstream *Upstream
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
|
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
|
||||||
@ -97,13 +98,38 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
|
||||||
|
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
||||||
|
if dc.upstream != nil && uc.upstream != dc.upstream {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f(uc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
|
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
|
||||||
// TODO: extract network name from channel name
|
// TODO: extract network name from channel name if dc.upstream == nil
|
||||||
ch, err := dc.user.getChannel(name)
|
var channel *upstreamChannel
|
||||||
if err != nil {
|
var err error
|
||||||
return nil, "", err
|
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ch, ok := uc.channels[name]; ok {
|
||||||
|
if channel != nil {
|
||||||
|
err = fmt.Errorf("ambiguous channel name %q", name)
|
||||||
|
} else {
|
||||||
|
channel = ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if channel == nil {
|
||||||
|
return nil, "", ircError{&irc.Message{
|
||||||
|
Command: irc.ERR_NOSUCHCHANNEL,
|
||||||
|
Params: []string{name, "No such channel"},
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
return ch.conn, ch.Name, nil
|
return channel.conn, channel.Name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
|
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
|
||||||
@ -274,9 +300,18 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) register() error {
|
func (dc *downstreamConn) register() error {
|
||||||
u := dc.srv.getUser(strings.TrimPrefix(dc.username, "~"))
|
username := strings.TrimPrefix(dc.username, "~")
|
||||||
|
var network string
|
||||||
|
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
|
||||||
|
network = username[i+1:]
|
||||||
|
}
|
||||||
|
if i := strings.IndexAny(username, "/@"); i >= 0 {
|
||||||
|
username = username[:i]
|
||||||
|
}
|
||||||
|
|
||||||
|
u := dc.srv.getUser(username)
|
||||||
if u == nil {
|
if u == nil {
|
||||||
dc.logger.Printf("failed authentication: unknown username %q", dc.username)
|
dc.logger.Printf("failed authentication: unknown username %q", username)
|
||||||
dc.SendMessage(&irc.Message{
|
dc.SendMessage(&irc.Message{
|
||||||
Prefix: dc.srv.prefix(),
|
Prefix: dc.srv.prefix(),
|
||||||
Command: irc.ERR_PASSWDMISMATCH,
|
Command: irc.ERR_PASSWDMISMATCH,
|
||||||
@ -285,6 +320,19 @@ func (dc *downstreamConn) register() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if network != "" {
|
||||||
|
dc.upstream = dc.srv.getUpstream(network)
|
||||||
|
if dc.upstream == nil {
|
||||||
|
dc.logger.Printf("failed registration: unknown upstream %q", network)
|
||||||
|
dc.SendMessage(&irc.Message{
|
||||||
|
Prefix: dc.srv.prefix(),
|
||||||
|
Command: irc.ERR_PASSWDMISMATCH,
|
||||||
|
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", network)},
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dc.registered = true
|
dc.registered = true
|
||||||
dc.user = u
|
dc.user = u
|
||||||
|
|
||||||
@ -319,7 +367,7 @@ func (dc *downstreamConn) register() error {
|
|||||||
Params: []string{dc.nick, "No MOTD"},
|
Params: []string{dc.nick, "No MOTD"},
|
||||||
})
|
})
|
||||||
|
|
||||||
u.forEachUpstream(func(uc *upstreamConn) {
|
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||||
// TODO: fix races accessing upstream connection data
|
// TODO: fix races accessing upstream connection data
|
||||||
for _, ch := range uc.channels {
|
for _, ch := range uc.channels {
|
||||||
if ch.complete {
|
if ch.complete {
|
||||||
@ -327,8 +375,7 @@ func (dc *downstreamConn) register() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: let clients specify the ring buffer name in their username
|
historyName := dc.username
|
||||||
historyName := ""
|
|
||||||
|
|
||||||
var seqPtr *uint64
|
var seqPtr *uint64
|
||||||
if firstDownstream {
|
if firstDownstream {
|
||||||
@ -376,7 +423,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
|||||||
Params: []string{dc.nick, "You may not reregister"},
|
Params: []string{dc.nick, "You may not reregister"},
|
||||||
}}
|
}}
|
||||||
case "NICK":
|
case "NICK":
|
||||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||||
uc.SendMessage(msg)
|
uc.SendMessage(msg)
|
||||||
})
|
})
|
||||||
case "JOIN", "PART":
|
case "JOIN", "PART":
|
||||||
@ -448,7 +495,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if modeStr != "" {
|
if modeStr != "" {
|
||||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||||
uc.SendMessage(&irc.Message{
|
uc.SendMessage(&irc.Message{
|
||||||
Command: "MODE",
|
Command: "MODE",
|
||||||
Params: []string{uc.nick, modeStr},
|
Params: []string{uc.nick, modeStr},
|
||||||
|
14
server.go
14
server.go
@ -82,13 +82,16 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
|||||||
u.lock.Unlock()
|
u.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) getChannel(name string) (*upstreamChannel, error) {
|
func (u *user) getChannel(name string, upstream *Upstream) (*upstreamChannel, error) {
|
||||||
var channel *upstreamChannel
|
var channel *upstreamChannel
|
||||||
var err error
|
var err error
|
||||||
u.forEachUpstream(func(uc *upstreamConn) {
|
u.forEachUpstream(func(uc *upstreamConn) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if upstream != nil && uc.upstream != upstream {
|
||||||
|
return
|
||||||
|
}
|
||||||
if ch, ok := uc.channels[name]; ok {
|
if ch, ok := uc.channels[name]; ok {
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
err = fmt.Errorf("ambiguous channel name %q", name)
|
err = fmt.Errorf("ambiguous channel name %q", name)
|
||||||
@ -196,6 +199,15 @@ func (s *Server) getUser(name string) *user {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) getUpstream(name string) *Upstream {
|
||||||
|
for i, upstream := range s.Upstreams {
|
||||||
|
if upstream.Addr == name {
|
||||||
|
return &s.Upstreams[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) Serve(ln net.Listener) error {
|
func (s *Server) Serve(ln net.Listener) error {
|
||||||
for {
|
for {
|
||||||
netConn, err := ln.Accept()
|
netConn, err := ln.Accept()
|
||||||
|
17
upstream.go
17
upstream.go
@ -100,6 +100,15 @@ func (uc *upstreamConn) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||||
|
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
|
if dc.upstream != nil && dc.upstream != uc.upstream {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f(dc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
||||||
ch, ok := uc.channels[name]
|
ch, ok := uc.channels[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -140,7 +149,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
dc.SendMessage(&irc.Message{
|
dc.SendMessage(&irc.Message{
|
||||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||||
Command: "MODE",
|
Command: "MODE",
|
||||||
@ -210,7 +219,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
ch.Members[msg.Prefix.Name] = 0
|
ch.Members[msg.Prefix.Name] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
dc.SendMessage(&irc.Message{
|
dc.SendMessage(&irc.Message{
|
||||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||||
Command: "JOIN",
|
Command: "JOIN",
|
||||||
@ -240,7 +249,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
delete(ch.Members, msg.Prefix.Name)
|
delete(ch.Members, msg.Prefix.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
dc.SendMessage(&irc.Message{
|
dc.SendMessage(&irc.Message{
|
||||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||||
Command: "PART",
|
Command: "PART",
|
||||||
@ -326,7 +335,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
ch.complete = true
|
ch.complete = true
|
||||||
|
|
||||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
forwardChannel(dc, ch)
|
forwardChannel(dc, ch)
|
||||||
})
|
})
|
||||||
case "PRIVMSG":
|
case "PRIVMSG":
|
||||||
|
Loading…
Reference in New Issue
Block a user