soju/database/database.go
2023-11-01 00:04:57 +01:00

297 lines
6.7 KiB
Go

package database
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v4"
)
type MessageTarget struct {
Name string
LatestMessage time.Time
}
type MessageOptions struct {
AfterID int64
AfterTime time.Time
BeforeTime time.Time
Limit int
Events bool
Sender string
Text string
TakeLast bool
}
type Database interface {
Close() error
Stats(ctx context.Context) (*DatabaseStats, error)
ListUsers(ctx context.Context) ([]User, error)
GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error
ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error)
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error
DeleteNetwork(ctx context.Context, id int64) error
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
DeleteChannel(ctx context.Context, id int64) error
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error
ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error
GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error)
StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error)
ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error)
ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error)
}
type MetricsCollectorDatabase interface {
Database
RegisterMetrics(r prometheus.Registerer) error
}
func Open(driver, source string) (Database, error) {
switch driver {
case "sqlite3":
return OpenSqliteDB(source)
case "postgres":
return OpenPostgresDB(source)
default:
return nil, fmt.Errorf("unsupported database driver: %q", driver)
}
}
type DatabaseStats struct {
Users int64
Networks int64
Channels int64
}
type User struct {
ID int64
Username string
Password string // hashed
Nick string
Realname string
Admin bool
Enabled bool
DownstreamInteractedAt time.Time
}
func NewUser(username string) *User {
return &User{
Username: username,
Enabled: true,
}
}
func (u *User) CheckPassword(password string) (upgraded bool, err error) {
// Password auth disabled
if u.Password == "" {
return false, fmt.Errorf("password auth disabled")
}
err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
if err != nil {
return false, fmt.Errorf("wrong password: %v", err)
}
passCost, err := bcrypt.Cost([]byte(u.Password))
if err != nil {
return false, fmt.Errorf("invalid password cost: %v", err)
}
if passCost < bcrypt.DefaultCost {
return true, u.SetPassword(password)
}
return false, nil
}
func (u *User) SetPassword(password string) error {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
u.Password = string(hashed)
return nil
}
type SASL struct {
Mechanism string
Plain struct {
Username string
Password string
}
// TLS client certificate authentication.
External struct {
// X.509 certificate in DER form.
CertBlob []byte
// PKCS#8 private key in DER form.
PrivKeyBlob []byte
}
}
type Network struct {
ID int64
Name string
Addr string
Nick string
Username string
Realname string
Pass string
ConnectCommands []string
CertFP string
SASL SASL
AutoAway bool
Enabled bool
}
func NewNetwork(addr string) *Network {
return &Network{
Addr: addr,
AutoAway: true,
Enabled: true,
}
}
func (net *Network) GetName() string {
if net.Name != "" {
return net.Name
}
return net.Addr
}
func (net *Network) URL() (*url.URL, error) {
s := net.Addr
if !strings.Contains(s, "://") {
// This is a raw domain name, make it an URL with the default scheme
s = "ircs://" + s
}
u, err := url.Parse(s)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
}
return u, nil
}
func GetNick(user *User, net *Network) string {
if net != nil && net.Nick != "" {
return net.Nick
}
if user.Nick != "" {
return user.Nick
}
return user.Username
}
func GetUsername(user *User, net *Network) string {
if net != nil && net.Username != "" {
return net.Username
}
return GetNick(user, net)
}
func GetRealname(user *User, net *Network) string {
if net != nil && net.Realname != "" {
return net.Realname
}
if user.Realname != "" {
return user.Realname
}
return GetNick(user, net)
}
type MessageFilter int
const (
// TODO: use customizable user defaults for FilterDefault
FilterDefault MessageFilter = iota
FilterNone
FilterHighlight
FilterMessage
)
type Channel struct {
ID int64
Name string
Key string
Detached bool
DetachedInternalMsgID string
RelayDetached MessageFilter
ReattachOn MessageFilter
DetachAfter time.Duration
DetachOn MessageFilter
}
type DeliveryReceipt struct {
ID int64
Target string // channel or nick
Client string
InternalMsgID string
}
type ReadReceipt struct {
ID int64
Target string // channel or nick
Timestamp time.Time
}
type WebPushConfig struct {
ID int64
VAPIDKeys struct {
Public, Private string
}
}
type WebPushSubscription struct {
ID int64
Endpoint string
CreatedAt, UpdatedAt time.Time // read-only
Keys struct {
Auth string
P256DH string
VAPID string
}
}
func toNullString(s string) sql.NullString {
return sql.NullString{
String: s,
Valid: s != "",
}
}
func toNullTime(t time.Time) sql.NullTime {
return sql.NullTime{
Time: t,
Valid: !t.IsZero(),
}
}