Ensure all incoming messages have a prefix set

Per the spec:

> If the source is missing from a message, it’s is assumed to have originated
> from the client/server on the other end of the connection the message was
> received on.
This commit is contained in:
Simon Ser 2022-03-21 17:01:15 +01:00
parent 883683c0b7
commit 5defd29509

View File

@ -111,6 +111,7 @@ type upstreamConn struct {
network *network network *network
user *user user *user
serverPrefix *irc.Prefix
serverName string serverName string
availableUserModes string availableUserModes string
availableChannelModes map[byte]channelModeType availableChannelModes map[byte]channelModeType
@ -244,6 +245,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
channels: upstreamChannelCasemapMap{newCasemapMap(0)}, channels: upstreamChannelCasemapMap{newCasemapMap(0)},
caps: newCapRegistry(), caps: newCapRegistry(),
batches: make(map[string]batch), batches: make(map[string]batch),
serverPrefix: &irc.Prefix{Name: "*"},
availableChannelTypes: stdChannelTypes, availableChannelTypes: stdChannelTypes,
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
availableMemberships: stdMemberships, availableMemberships: stdMemberships,
@ -444,6 +446,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
} }
if msg.Prefix == nil {
msg.Prefix = uc.serverPrefix
}
if _, ok := msg.Tags["time"]; !ok { if _, ok := msg.Tags["time"]; !ok {
msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now())) msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now()))
} }
@ -456,10 +462,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
return nil return nil
case "NOTICE", "PRIVMSG", "TAGMSG": case "NOTICE", "PRIVMSG", "TAGMSG":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var entity, text string var entity, text string
if msg.Command != "TAGMSG" { if msg.Command != "TAGMSG" {
if err := parseMessageParams(msg, &entity, &text); err != nil { if err := parseMessageParams(msg, &entity, &text); err != nil {
@ -738,6 +740,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
uc.registered = true uc.registered = true
uc.serverPrefix = msg.Prefix
uc.nickCM = uc.network.casemap(uc.nick) uc.nickCM = uc.network.casemap(uc.nick)
uc.logger.Printf("connection registered with nick %q", uc.nick) uc.logger.Printf("connection registered with nick %q", uc.nick)
@ -891,10 +894,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
} }
case "NICK": case "NICK":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var newNick string var newNick string
if err := parseMessageParams(msg, &newNick); err != nil { if err := parseMessageParams(msg, &newNick); err != nil {
return err return err
@ -929,10 +928,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.updateMonitor() uc.updateMonitor()
} }
case "SETNAME": case "SETNAME":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var newRealname string var newRealname string
if err := parseMessageParams(msg, &newRealname); err != nil { if err := parseMessageParams(msg, &newRealname); err != nil {
return err return err
@ -953,10 +948,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
} }
case "CHGHOST": case "CHGHOST":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var newUsername, newHostname string var newUsername, newHostname string
if err := parseMessageParams(msg, &newUsername, &newHostname); err != nil { if err := parseMessageParams(msg, &newUsername, &newHostname); err != nil {
return err return err
@ -983,10 +974,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
} }
case "JOIN": case "JOIN":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var channels string var channels string
if err := parseMessageParams(msg, &channels); err != nil { if err := parseMessageParams(msg, &channels); err != nil {
return err return err
@ -1021,10 +1008,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.produce(ch, chMsg, nil) uc.produce(ch, chMsg, nil)
} }
case "PART": case "PART":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var channels string var channels string
if err := parseMessageParams(msg, &channels); err != nil { if err := parseMessageParams(msg, &channels); err != nil {
return err return err
@ -1051,10 +1034,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.produce(ch, chMsg, nil) uc.produce(ch, chMsg, nil)
} }
case "KICK": case "KICK":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var channel, user string var channel, user string
if err := parseMessageParams(msg, &channel, &user); err != nil { if err := parseMessageParams(msg, &channel, &user); err != nil {
return err return err
@ -1073,10 +1052,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.produce(channel, msg, nil) uc.produce(channel, msg, nil)
case "QUIT": case "QUIT":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
if uc.isOurNick(msg.Prefix.Name) { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("quit") uc.logger.Printf("quit")
} }
@ -1110,10 +1085,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
ch.Topic = "" ch.Topic = ""
} }
case "TOPIC": case "TOPIC":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
var name string var name string
if err := parseMessageParams(msg, &name); err != nil { if err := parseMessageParams(msg, &name); err != nil {
return err return err
@ -1660,10 +1631,6 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
}) })
case "AWAY", "ACCOUNT": case "AWAY", "ACCOUNT":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),