Introduce permanentUpstreamCaps

This commit is contained in:
Simon Ser 2020-04-30 16:10:39 +02:00
parent 2a569c3b27
commit bbb5e79f59
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48

View File

@ -15,6 +15,16 @@ import (
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
) )
// permanentUpstreamCaps is the static list of upstream capabilities always
// requested when supported.
var permanentUpstreamCaps = map[string]bool{
"away-notify": true,
"batch": true,
"labeled-response": true,
"message-tags": true,
"server-time": true,
}
type upstreamChannel struct { type upstreamChannel struct {
Name string Name string
conn *upstreamConn conn *upstreamConn
@ -1209,7 +1219,7 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
func (uc *upstreamConn) requestCaps() { func (uc *upstreamConn) requestCaps() {
var requestCaps []string var requestCaps []string
for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} { for c := range permanentUpstreamCaps {
if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
requestCaps = append(requestCaps, c) requestCaps = append(requestCaps, c)
} }
@ -1219,12 +1229,72 @@ func (uc *upstreamConn) requestCaps() {
requestCaps = append(requestCaps, "sasl") requestCaps = append(requestCaps, "sasl")
} }
if len(requestCaps) > 0 { if len(requestCaps) == 0 {
return
}
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{"REQ", strings.Join(requestCaps, " ")}, Params: []string{"REQ", strings.Join(requestCaps, " ")},
}) })
}
func (uc *upstreamConn) requestSASL() bool {
if uc.network.SASL.Mechanism == "" {
return false
} }
v, ok := uc.supportedCaps["sasl"]
if !ok {
return false
}
if v != "" {
mechanisms := strings.Split(v, ",")
found := false
for _, mech := range mechanisms {
if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
uc.caps[name] = ok
switch name {
case "sasl":
if !ok {
uc.logger.Printf("server refused to acknowledge the SASL capability")
return nil
}
auth := &uc.network.SASL
switch auth.Mechanism {
case "PLAIN":
uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
default:
return fmt.Errorf("unsupported SASL mechanism %q", name)
}
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{auth.Mechanism},
})
default:
if permanentUpstreamCaps[name] {
break
}
uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
}
return nil
} }
func splitSpace(s string) []string { func splitSpace(s string) []string {
@ -1290,63 +1360,6 @@ func (uc *upstreamConn) runUntilRegistered() error {
return nil return nil
} }
func (uc *upstreamConn) requestSASL() bool {
if uc.network.SASL.Mechanism == "" {
return false
}
v, ok := uc.supportedCaps["sasl"]
if !ok {
return false
}
if v != "" {
mechanisms := strings.Split(v, ",")
found := false
for _, mech := range mechanisms {
if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
uc.caps[name] = ok
switch name {
case "sasl":
if !ok {
uc.logger.Printf("server refused to acknowledge the SASL capability")
return nil
}
auth := &uc.network.SASL
switch auth.Mechanism {
case "PLAIN":
uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
default:
return fmt.Errorf("unsupported SASL mechanism %q", name)
}
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{auth.Mechanism},
})
case "message-tags", "labeled-response", "away-notify", "batch", "server-time":
// Nothing to do
default:
uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
}
return nil
}
func (uc *upstreamConn) readMessages(ch chan<- event) error { func (uc *upstreamConn) readMessages(ch chan<- event) error {
for { for {
msg, err := uc.ReadMessage() msg, err := uc.ReadMessage()