Stop setting *user in downstreamConn.register

Set it in downstreamConn.welcome instead. Makes it clearer that it
must not be accessed before welcome is called (because it can only
be accessed from the user goroutine).
This commit is contained in:
Simon Ser 2023-04-05 16:54:55 +02:00
parent c5079f7ac3
commit f6043e5b98
3 changed files with 60 additions and 40 deletions

View File

@ -309,6 +309,8 @@ type downstreamRegistration struct {
networkID int64
negotiatingCaps bool
authUsername string
}
func serverSASLMechanisms(srv *Server) []string {
@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
}
if err == nil {
if username == "" {
panic(fmt.Errorf("username unset after SASL authentication"))
}
err = dc.setUser(username, clientName, networkName)
}
if err != nil {
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
dc.endSASL(&irc.Message{
@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
break
}
if username == "" {
panic(fmt.Errorf("username unset after SASL authentication"))
}
dc.setAuthUsername(username, clientName, networkName)
// Technically we should send RPL_LOGGEDIN here. However we use
// RPL_LOGGEDIN to mirror the upstream connection status. Let's
// see how many clients that breaks. See:
@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
return err
}
if dc.user == nil {
if dc.registration.authUsername == "" {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
return username, client, network
}
func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
dc.user = dc.srv.getUser(username)
if dc.user == nil && dc.srv.Config().EnableUsersOnAuth {
ctx := context.TODO()
if _, err := dc.srv.db.GetUser(ctx, username); err != nil {
// Can't find the user in the DB -- try to create it
record := database.User{
Username: username,
Enabled: true,
}
dc.user, err = dc.srv.createUser(ctx, &record)
if err != nil {
return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
}
}
}
if dc.user == nil {
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
}
func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) {
dc.clientName = clientName
dc.registration.authUsername = username
dc.registration.networkName = networkName
return nil
}
func (dc *downstreamConn) register(ctx context.Context) error {
@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
password := dc.registration.password
dc.registration.password = ""
if dc.user == nil {
if dc.registration.authUsername == "" {
if password == "" {
if dc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{
@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
}}
}
if err := dc.setUser(username, clientName, networkName); err != nil {
return err
}
dc.setAuthUsername(username, clientName, networkName)
}
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
}
dc.registered = true
dc.username = dc.user.Username
dc.logger.Printf("registration complete for user %q", dc.user.Username)
dc.username = dc.registration.authUsername
dc.logger.Printf("registration complete for user %q", dc.username)
return nil
}
@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
return nil
}
func (dc *downstreamConn) welcome(ctx context.Context) error {
if dc.user == nil || !dc.registered {
func (dc *downstreamConn) welcome(ctx context.Context, user *user) error {
if !dc.registered {
panic("tried to welcome an unregistered connection")
}
if dc.user != nil {
panic("tried to welcome the same connection twice")
}
dc.user = user
remoteAddr := dc.conn.RemoteAddr().String()
dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}

View File

@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) {
return
}
dc.user.events <- eventDownstreamConnected{dc}
if err := dc.readMessages(dc.user.events); err != nil {
user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername)
if err != nil {
dc.SendMessage(&irc.Message{
Command: "ERROR",
Params: []string{"Internal server error"},
})
return
}
user.events <- eventDownstreamConnected{dc}
if err := dc.readMessages(user.events); err != nil {
dc.logger.Printf("%v", err)
}
dc.user.events <- eventDownstreamDisconnected{dc}
user.events <- eventDownstreamDisconnected{dc}
}
func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) {
user := s.getUser(username)
if user != nil {
return user, nil
}
if _, err := s.db.GetUser(ctx, username); err == nil {
return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username)
}
if !s.Config().EnableUsersOnAuth {
return nil, fmt.Errorf("cannot find user %q in the DB", username)
}
// Can't find the user in the DB -- try to create it
record := database.User{
Username: username,
Enabled: true,
}
user, err := s.createUser(ctx, &record)
if err != nil {
return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
}
return user, nil
}
func (s *Server) HandleAdmin(ic ircConn) {

View File

@ -737,7 +737,7 @@ func (u *user) run() {
break
}
if err := dc.welcome(ctx); err != nil {
if err := dc.welcome(ctx, u); err != nil {
if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy()
msg.Prefix = dc.srv.prefix()