soju/database/database.go
Simon Ser d7d9d45b45 Add a flag to disable users
Add a new flag to disable users. This can be useful to temporarily
deactivate an account without erasing data.

The user goroutine is kept alive for simplicity's sake. Most of the
infrastructure assumes that each user always has a running goroutine.
A disabled user's goroutine is responsible for sending back an error
to downstream connections, and listening for potential events to
re-enable the account.
2023-01-26 18:33:55 +01:00

249 lines
5.4 KiB
Go

package database
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
)
type Database interface {
Close() error
Stats(ctx context.Context) (*DatabaseStats, error)
ListUsers(ctx context.Context) ([]User, error)
GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error
DeleteNetwork(ctx context.Context, id int64) error
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
DeleteChannel(ctx context.Context, id int64) error
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error
ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error
}
type MetricsCollectorDatabase interface {
Database
RegisterMetrics(r prometheus.Registerer) error
}
func Open(driver, source string) (Database, error) {
switch driver {
case "sqlite3":
return OpenSqliteDB(source)
case "postgres":
return OpenPostgresDB(source)
default:
return nil, fmt.Errorf("unsupported database driver: %q", driver)
}
}
type DatabaseStats struct {
Users int64
Networks int64
Channels int64
}
type User struct {
ID int64
Username string
Password string // hashed
Nick string
Realname string
Admin bool
Enabled bool
}
func (u *User) CheckPassword(password string) (upgraded bool, err error) {
// Password auth disabled
if u.Password == "" {
return false, fmt.Errorf("password auth disabled")
}
err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
if err != nil {
return false, fmt.Errorf("wrong password: %v", err)
}
passCost, err := bcrypt.Cost([]byte(u.Password))
if err != nil {
return false, fmt.Errorf("invalid password cost: %v", err)
}
if passCost < bcrypt.DefaultCost {
return true, u.SetPassword(password)
}
return false, nil
}
func (u *User) SetPassword(password string) error {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
u.Password = string(hashed)
return nil
}
type SASL struct {
Mechanism string
Plain struct {
Username string
Password string
}
// TLS client certificate authentication.
External struct {
// X.509 certificate in DER form.
CertBlob []byte
// PKCS#8 private key in DER form.
PrivKeyBlob []byte
}
}
type Network struct {
ID int64
Name string
Addr string
Nick string
Username string
Realname string
Pass string
ConnectCommands []string
CertFP string
SASL SASL
AutoAway bool
Enabled bool
}
func (net *Network) GetName() string {
if net.Name != "" {
return net.Name
}
return net.Addr
}
func (net *Network) URL() (*url.URL, error) {
s := net.Addr
if !strings.Contains(s, "://") {
// This is a raw domain name, make it an URL with the default scheme
s = "ircs://" + s
}
u, err := url.Parse(s)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
}
return u, nil
}
func GetNick(user *User, net *Network) string {
if net != nil && net.Nick != "" {
return net.Nick
}
if user.Nick != "" {
return user.Nick
}
return user.Username
}
func GetUsername(user *User, net *Network) string {
if net != nil && net.Username != "" {
return net.Username
}
return GetNick(user, net)
}
func GetRealname(user *User, net *Network) string {
if net != nil && net.Realname != "" {
return net.Realname
}
if user.Realname != "" {
return user.Realname
}
return GetNick(user, net)
}
type MessageFilter int
const (
// TODO: use customizable user defaults for FilterDefault
FilterDefault MessageFilter = iota
FilterNone
FilterHighlight
FilterMessage
)
type Channel struct {
ID int64
Name string
Key string
Detached bool
DetachedInternalMsgID string
RelayDetached MessageFilter
ReattachOn MessageFilter
DetachAfter time.Duration
DetachOn MessageFilter
}
type DeliveryReceipt struct {
ID int64
Target string // channel or nick
Client string
InternalMsgID string
}
type ReadReceipt struct {
ID int64
Target string // channel or nick
Timestamp time.Time
}
type WebPushConfig struct {
ID int64
VAPIDKeys struct {
Public, Private string
}
}
type WebPushSubscription struct {
ID int64
Endpoint string
Keys struct {
Auth string
P256DH string
VAPID string
}
}
func toNullString(s string) sql.NullString {
return sql.NullString{
String: s,
Valid: s != "",
}
}