Stop setting *user in downstreamConn.register
Set it in downstreamConn.welcome instead. Makes it clearer that it must not be accessed before welcome is called (because it can only be accessed from the user goroutine).
This commit is contained in:
parent
c5079f7ac3
commit
f6043e5b98
@ -309,6 +309,8 @@ type downstreamRegistration struct {
|
|||||||
networkID int64
|
networkID int64
|
||||||
|
|
||||||
negotiatingCaps bool
|
negotiatingCaps bool
|
||||||
|
|
||||||
|
authUsername string
|
||||||
}
|
}
|
||||||
|
|
||||||
func serverSASLMechanisms(srv *Server) []string {
|
func serverSASLMechanisms(srv *Server) []string {
|
||||||
@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
|||||||
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
|
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
if username == "" {
|
|
||||||
panic(fmt.Errorf("username unset after SASL authentication"))
|
|
||||||
}
|
|
||||||
err = dc.setUser(username, clientName, networkName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
|
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
|
||||||
dc.endSASL(&irc.Message{
|
dc.endSASL(&irc.Message{
|
||||||
@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if username == "" {
|
||||||
|
panic(fmt.Errorf("username unset after SASL authentication"))
|
||||||
|
}
|
||||||
|
dc.setAuthUsername(username, clientName, networkName)
|
||||||
|
|
||||||
// Technically we should send RPL_LOGGEDIN here. However we use
|
// Technically we should send RPL_LOGGEDIN here. However we use
|
||||||
// RPL_LOGGEDIN to mirror the upstream connection status. Let's
|
// RPL_LOGGEDIN to mirror the upstream connection status. Let's
|
||||||
// see how many clients that breaks. See:
|
// see how many clients that breaks. See:
|
||||||
@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if dc.user == nil {
|
if dc.registration.authUsername == "" {
|
||||||
return ircError{&irc.Message{
|
return ircError{&irc.Message{
|
||||||
Command: "FAIL",
|
Command: "FAIL",
|
||||||
Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
|
Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
|
||||||
@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
|||||||
return username, client, network
|
return username, client, network
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
|
func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) {
|
||||||
dc.user = dc.srv.getUser(username)
|
|
||||||
if dc.user == nil && dc.srv.Config().EnableUsersOnAuth {
|
|
||||||
ctx := context.TODO()
|
|
||||||
if _, err := dc.srv.db.GetUser(ctx, username); err != nil {
|
|
||||||
// Can't find the user in the DB -- try to create it
|
|
||||||
record := database.User{
|
|
||||||
Username: username,
|
|
||||||
Enabled: true,
|
|
||||||
}
|
|
||||||
dc.user, err = dc.srv.createUser(ctx, &record)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dc.user == nil {
|
|
||||||
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
|
|
||||||
}
|
|
||||||
dc.clientName = clientName
|
dc.clientName = clientName
|
||||||
|
dc.registration.authUsername = username
|
||||||
dc.registration.networkName = networkName
|
dc.registration.networkName = networkName
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) register(ctx context.Context) error {
|
func (dc *downstreamConn) register(ctx context.Context) error {
|
||||||
@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
|||||||
|
|
||||||
password := dc.registration.password
|
password := dc.registration.password
|
||||||
dc.registration.password = ""
|
dc.registration.password = ""
|
||||||
if dc.user == nil {
|
if dc.registration.authUsername == "" {
|
||||||
if password == "" {
|
if password == "" {
|
||||||
if dc.caps.IsEnabled("sasl") {
|
if dc.caps.IsEnabled("sasl") {
|
||||||
return ircError{&irc.Message{
|
return ircError{&irc.Message{
|
||||||
@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
|||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dc.setUser(username, clientName, networkName); err != nil {
|
dc.setAuthUsername(username, clientName, networkName)
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
|
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
|
||||||
@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dc.registered = true
|
dc.registered = true
|
||||||
dc.username = dc.user.Username
|
dc.username = dc.registration.authUsername
|
||||||
dc.logger.Printf("registration complete for user %q", dc.user.Username)
|
dc.logger.Printf("registration complete for user %q", dc.username)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) welcome(ctx context.Context) error {
|
func (dc *downstreamConn) welcome(ctx context.Context, user *user) error {
|
||||||
if dc.user == nil || !dc.registered {
|
if !dc.registered {
|
||||||
panic("tried to welcome an unregistered connection")
|
panic("tried to welcome an unregistered connection")
|
||||||
}
|
}
|
||||||
|
if dc.user != nil {
|
||||||
|
panic("tried to welcome the same connection twice")
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.user = user
|
||||||
|
|
||||||
remoteAddr := dc.conn.RemoteAddr().String()
|
remoteAddr := dc.conn.RemoteAddr().String()
|
||||||
dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
|
dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
|
||||||
|
41
server.go
41
server.go
@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dc.user.events <- eventDownstreamConnected{dc}
|
user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername)
|
||||||
if err := dc.readMessages(dc.user.events); err != nil {
|
if err != nil {
|
||||||
|
dc.SendMessage(&irc.Message{
|
||||||
|
Command: "ERROR",
|
||||||
|
Params: []string{"Internal server error"},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user.events <- eventDownstreamConnected{dc}
|
||||||
|
if err := dc.readMessages(user.events); err != nil {
|
||||||
dc.logger.Printf("%v", err)
|
dc.logger.Printf("%v", err)
|
||||||
}
|
}
|
||||||
dc.user.events <- eventDownstreamDisconnected{dc}
|
user.events <- eventDownstreamDisconnected{dc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) {
|
||||||
|
user := s.getUser(username)
|
||||||
|
if user != nil {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.db.GetUser(ctx, username); err == nil {
|
||||||
|
return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.Config().EnableUsersOnAuth {
|
||||||
|
return nil, fmt.Errorf("cannot find user %q in the DB", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't find the user in the DB -- try to create it
|
||||||
|
record := database.User{
|
||||||
|
Username: username,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
user, err := s.createUser(ctx, &record)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) HandleAdmin(ic ircConn) {
|
func (s *Server) HandleAdmin(ic ircConn) {
|
||||||
|
2
user.go
2
user.go
@ -737,7 +737,7 @@ func (u *user) run() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dc.welcome(ctx); err != nil {
|
if err := dc.welcome(ctx, u); err != nil {
|
||||||
if ircErr, ok := err.(ircError); ok {
|
if ircErr, ok := err.(ircError); ok {
|
||||||
msg := ircErr.Message.Copy()
|
msg := ircErr.Message.Copy()
|
||||||
msg.Prefix = dc.srv.prefix()
|
msg.Prefix = dc.srv.prefix()
|
||||||
|
Loading…
Reference in New Issue
Block a user