Stop setting *user in downstreamConn.register
Set it in downstreamConn.welcome instead. Makes it clearer that it must not be accessed before welcome is called (because it can only be accessed from the user goroutine).
This commit is contained in:
parent
c5079f7ac3
commit
f6043e5b98
@ -309,6 +309,8 @@ type downstreamRegistration struct {
|
||||
networkID int64
|
||||
|
||||
negotiatingCaps bool
|
||||
|
||||
authUsername string
|
||||
}
|
||||
|
||||
func serverSASLMechanisms(srv *Server) []string {
|
||||
@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
||||
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if username == "" {
|
||||
panic(fmt.Errorf("username unset after SASL authentication"))
|
||||
}
|
||||
err = dc.setUser(username, clientName, networkName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
|
||||
dc.endSASL(&irc.Message{
|
||||
@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
||||
break
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
panic(fmt.Errorf("username unset after SASL authentication"))
|
||||
}
|
||||
dc.setAuthUsername(username, clientName, networkName)
|
||||
|
||||
// Technically we should send RPL_LOGGEDIN here. However we use
|
||||
// RPL_LOGGEDIN to mirror the upstream connection status. Let's
|
||||
// see how many clients that breaks. See:
|
||||
@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
||||
return err
|
||||
}
|
||||
|
||||
if dc.user == nil {
|
||||
if dc.registration.authUsername == "" {
|
||||
return ircError{&irc.Message{
|
||||
Command: "FAIL",
|
||||
Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
|
||||
@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
||||
return username, client, network
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
|
||||
dc.user = dc.srv.getUser(username)
|
||||
if dc.user == nil && dc.srv.Config().EnableUsersOnAuth {
|
||||
ctx := context.TODO()
|
||||
if _, err := dc.srv.db.GetUser(ctx, username); err != nil {
|
||||
// Can't find the user in the DB -- try to create it
|
||||
record := database.User{
|
||||
Username: username,
|
||||
Enabled: true,
|
||||
}
|
||||
dc.user, err = dc.srv.createUser(ctx, &record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if dc.user == nil {
|
||||
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
|
||||
}
|
||||
func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) {
|
||||
dc.clientName = clientName
|
||||
dc.registration.authUsername = username
|
||||
dc.registration.networkName = networkName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) register(ctx context.Context) error {
|
||||
@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
||||
|
||||
password := dc.registration.password
|
||||
dc.registration.password = ""
|
||||
if dc.user == nil {
|
||||
if dc.registration.authUsername == "" {
|
||||
if password == "" {
|
||||
if dc.caps.IsEnabled("sasl") {
|
||||
return ircError{&irc.Message{
|
||||
@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
||||
}}
|
||||
}
|
||||
|
||||
if err := dc.setUser(username, clientName, networkName); err != nil {
|
||||
return err
|
||||
}
|
||||
dc.setAuthUsername(username, clientName, networkName)
|
||||
}
|
||||
|
||||
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
|
||||
@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
||||
}
|
||||
|
||||
dc.registered = true
|
||||
dc.username = dc.user.Username
|
||||
dc.logger.Printf("registration complete for user %q", dc.user.Username)
|
||||
dc.username = dc.registration.authUsername
|
||||
dc.logger.Printf("registration complete for user %q", dc.username)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) welcome(ctx context.Context) error {
|
||||
if dc.user == nil || !dc.registered {
|
||||
func (dc *downstreamConn) welcome(ctx context.Context, user *user) error {
|
||||
if !dc.registered {
|
||||
panic("tried to welcome an unregistered connection")
|
||||
}
|
||||
if dc.user != nil {
|
||||
panic("tried to welcome the same connection twice")
|
||||
}
|
||||
|
||||
dc.user = user
|
||||
|
||||
remoteAddr := dc.conn.RemoteAddr().String()
|
||||
dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
|
||||
|
41
server.go
41
server.go
@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) {
|
||||
return
|
||||
}
|
||||
|
||||
dc.user.events <- eventDownstreamConnected{dc}
|
||||
if err := dc.readMessages(dc.user.events); err != nil {
|
||||
user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername)
|
||||
if err != nil {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Command: "ERROR",
|
||||
Params: []string{"Internal server error"},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user.events <- eventDownstreamConnected{dc}
|
||||
if err := dc.readMessages(user.events); err != nil {
|
||||
dc.logger.Printf("%v", err)
|
||||
}
|
||||
dc.user.events <- eventDownstreamDisconnected{dc}
|
||||
user.events <- eventDownstreamDisconnected{dc}
|
||||
}
|
||||
|
||||
func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) {
|
||||
user := s.getUser(username)
|
||||
if user != nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
if _, err := s.db.GetUser(ctx, username); err == nil {
|
||||
return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username)
|
||||
}
|
||||
|
||||
if !s.Config().EnableUsersOnAuth {
|
||||
return nil, fmt.Errorf("cannot find user %q in the DB", username)
|
||||
}
|
||||
|
||||
// Can't find the user in the DB -- try to create it
|
||||
record := database.User{
|
||||
Username: username,
|
||||
Enabled: true,
|
||||
}
|
||||
user, err := s.createUser(ctx, &record)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Server) HandleAdmin(ic ircConn) {
|
||||
|
Loading…
Reference in New Issue
Block a user