Allow clients to specify an upstream name in their username
This commit is contained in:
parent
d1550a3cdb
commit
c22ce793a1
@ -58,6 +58,7 @@ type downstreamConn struct {
|
||||
nick string
|
||||
username string
|
||||
realname string
|
||||
upstream *Upstream
|
||||
}
|
||||
|
||||
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
|
||||
@ -97,13 +98,38 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
|
||||
// TODO: extract network name from channel name
|
||||
ch, err := dc.user.getChannel(name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
|
||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
||||
if dc.upstream != nil && uc.upstream != dc.upstream {
|
||||
return
|
||||
}
|
||||
return ch.conn, ch.Name, nil
|
||||
f(uc)
|
||||
})
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
|
||||
// TODO: extract network name from channel name if dc.upstream == nil
|
||||
var channel *upstreamChannel
|
||||
var err error
|
||||
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ch, ok := uc.channels[name]; ok {
|
||||
if channel != nil {
|
||||
err = fmt.Errorf("ambiguous channel name %q", name)
|
||||
} else {
|
||||
channel = ch
|
||||
}
|
||||
}
|
||||
})
|
||||
if channel == nil {
|
||||
return nil, "", ircError{&irc.Message{
|
||||
Command: irc.ERR_NOSUCHCHANNEL,
|
||||
Params: []string{name, "No such channel"},
|
||||
}}
|
||||
}
|
||||
return channel.conn, channel.Name, nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
|
||||
@ -274,9 +300,18 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) register() error {
|
||||
u := dc.srv.getUser(strings.TrimPrefix(dc.username, "~"))
|
||||
username := strings.TrimPrefix(dc.username, "~")
|
||||
var network string
|
||||
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
|
||||
network = username[i+1:]
|
||||
}
|
||||
if i := strings.IndexAny(username, "/@"); i >= 0 {
|
||||
username = username[:i]
|
||||
}
|
||||
|
||||
u := dc.srv.getUser(username)
|
||||
if u == nil {
|
||||
dc.logger.Printf("failed authentication: unknown username %q", dc.username)
|
||||
dc.logger.Printf("failed authentication: unknown username %q", username)
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.ERR_PASSWDMISMATCH,
|
||||
@ -285,6 +320,19 @@ func (dc *downstreamConn) register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if network != "" {
|
||||
dc.upstream = dc.srv.getUpstream(network)
|
||||
if dc.upstream == nil {
|
||||
dc.logger.Printf("failed registration: unknown upstream %q", network)
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.ERR_PASSWDMISMATCH,
|
||||
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", network)},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
dc.registered = true
|
||||
dc.user = u
|
||||
|
||||
@ -319,7 +367,7 @@ func (dc *downstreamConn) register() error {
|
||||
Params: []string{dc.nick, "No MOTD"},
|
||||
})
|
||||
|
||||
u.forEachUpstream(func(uc *upstreamConn) {
|
||||
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||
// TODO: fix races accessing upstream connection data
|
||||
for _, ch := range uc.channels {
|
||||
if ch.complete {
|
||||
@ -327,8 +375,7 @@ func (dc *downstreamConn) register() error {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: let clients specify the ring buffer name in their username
|
||||
historyName := ""
|
||||
historyName := dc.username
|
||||
|
||||
var seqPtr *uint64
|
||||
if firstDownstream {
|
||||
@ -376,7 +423,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
Params: []string{dc.nick, "You may not reregister"},
|
||||
}}
|
||||
case "NICK":
|
||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
||||
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||
uc.SendMessage(msg)
|
||||
})
|
||||
case "JOIN", "PART":
|
||||
@ -448,7 +495,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
|
||||
if modeStr != "" {
|
||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
||||
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "MODE",
|
||||
Params: []string{uc.nick, modeStr},
|
||||
|
14
server.go
14
server.go
@ -82,13 +82,16 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
||||
u.lock.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) getChannel(name string) (*upstreamChannel, error) {
|
||||
func (u *user) getChannel(name string, upstream *Upstream) (*upstreamChannel, error) {
|
||||
var channel *upstreamChannel
|
||||
var err error
|
||||
u.forEachUpstream(func(uc *upstreamConn) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if upstream != nil && uc.upstream != upstream {
|
||||
return
|
||||
}
|
||||
if ch, ok := uc.channels[name]; ok {
|
||||
if channel != nil {
|
||||
err = fmt.Errorf("ambiguous channel name %q", name)
|
||||
@ -196,6 +199,15 @@ func (s *Server) getUser(name string) *user {
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *Server) getUpstream(name string) *Upstream {
|
||||
for i, upstream := range s.Upstreams {
|
||||
if upstream.Addr == name {
|
||||
return &s.Upstreams[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ln net.Listener) error {
|
||||
for {
|
||||
netConn, err := ln.Accept()
|
||||
|
17
upstream.go
17
upstream.go
@ -100,6 +100,15 @@ func (uc *upstreamConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
if dc.upstream != nil && dc.upstream != uc.upstream {
|
||||
return
|
||||
}
|
||||
f(dc)
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
||||
ch, ok := uc.channels[name]
|
||||
if !ok {
|
||||
@ -140,7 +149,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
return err
|
||||
}
|
||||
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||
Command: "MODE",
|
||||
@ -210,7 +219,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
ch.Members[msg.Prefix.Name] = 0
|
||||
}
|
||||
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||
Command: "JOIN",
|
||||
@ -240,7 +249,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
delete(ch.Members, msg.Prefix.Name)
|
||||
}
|
||||
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
|
||||
Command: "PART",
|
||||
@ -326,7 +335,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
}
|
||||
ch.complete = true
|
||||
|
||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
forwardChannel(dc, ch)
|
||||
})
|
||||
case "PRIVMSG":
|
||||
|
Loading…
Reference in New Issue
Block a user