parent
eacd4e6f0c
commit
84fe3ae255
@ -33,6 +33,11 @@ func main() {
|
|||||||
cfg.Addr = addr
|
cfg.Addr = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db, err := jounce.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var ln net.Listener
|
var ln net.Listener
|
||||||
if cfg.TLS != nil {
|
if cfg.TLS != nil {
|
||||||
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
|
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
|
||||||
@ -53,19 +58,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := jounce.NewServer()
|
srv := jounce.NewServer(db)
|
||||||
// TODO: load from config/DB
|
// TODO: load from config/DB
|
||||||
srv.Hostname = cfg.Hostname
|
srv.Hostname = cfg.Hostname
|
||||||
srv.Debug = debug
|
srv.Debug = debug
|
||||||
srv.Upstreams = []jounce.Upstream{{
|
|
||||||
Addr: "chat.freenode.net:6697",
|
|
||||||
Nick: "jounce",
|
|
||||||
Username: "jounce",
|
|
||||||
Realname: "jounce",
|
|
||||||
Channels: []string{"#jounce"},
|
|
||||||
}}
|
|
||||||
|
|
||||||
log.Printf("server listening on %q", cfg.Addr)
|
log.Printf("server listening on %q", cfg.Addr)
|
||||||
go srv.Run()
|
go func() {
|
||||||
|
if err := srv.Run(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
log.Fatal(srv.Serve(ln))
|
log.Fatal(srv.Serve(ln))
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,11 @@ type TLS struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Addr string
|
Addr string
|
||||||
Hostname string
|
Hostname string
|
||||||
TLS *TLS
|
TLS *TLS
|
||||||
|
SQLDriver string
|
||||||
|
SQLSource string
|
||||||
}
|
}
|
||||||
|
|
||||||
func Defaults() *Server {
|
func Defaults() *Server {
|
||||||
@ -25,8 +27,10 @@ func Defaults() *Server {
|
|||||||
hostname = "localhost"
|
hostname = "localhost"
|
||||||
}
|
}
|
||||||
return &Server{
|
return &Server{
|
||||||
Addr: ":6667",
|
Addr: ":6667",
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
|
SQLDriver: "sqlite3",
|
||||||
|
SQLSource: "jounce.db",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +68,10 @@ func Parse(r io.Reader) (*Server, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
srv.TLS = tls
|
srv.TLS = tls
|
||||||
|
case "sql":
|
||||||
|
if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown directive %q", d.Name)
|
return nil, fmt.Errorf("unknown directive %q", d.Name)
|
||||||
}
|
}
|
||||||
|
134
db.go
Normal file
134
db.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package jounce
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNoSuchUser = errors.New("jounce: no such user")
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Network struct {
|
||||||
|
ID int64
|
||||||
|
Addr string
|
||||||
|
Nick string
|
||||||
|
Username string
|
||||||
|
Realname string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Channel struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenSQLDB(driver, source string) (*DB, error) {
|
||||||
|
db, err := sql.Open(driver, source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DB{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
db.lock.Lock()
|
||||||
|
defer db.lock.Unlock()
|
||||||
|
return db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ListUsers() ([]User, error) {
|
||||||
|
db.lock.Lock()
|
||||||
|
defer db.lock.Unlock()
|
||||||
|
|
||||||
|
rows, err := db.db.Query("SELECT username, password FROM User")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
for rows.Next() {
|
||||||
|
var user User
|
||||||
|
var password *string
|
||||||
|
if err := rows.Scan(&user.Username, &password); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if password != nil {
|
||||||
|
user.Password = *password
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ListNetworks(username string) ([]Network, error) {
|
||||||
|
db.lock.Lock()
|
||||||
|
defer db.lock.Unlock()
|
||||||
|
|
||||||
|
rows, err := db.db.Query("SELECT id, addr, nick, username, realname FROM Network WHERE user = ?", username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var networks []Network
|
||||||
|
for rows.Next() {
|
||||||
|
var net Network
|
||||||
|
var username, realname *string
|
||||||
|
if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if username != nil {
|
||||||
|
net.Username = *username
|
||||||
|
}
|
||||||
|
if realname != nil {
|
||||||
|
net.Realname = *realname
|
||||||
|
}
|
||||||
|
networks = append(networks, net)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return networks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
|
||||||
|
db.lock.Lock()
|
||||||
|
defer db.lock.Unlock()
|
||||||
|
|
||||||
|
rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var channels []Channel
|
||||||
|
for rows.Next() {
|
||||||
|
var ch Channel
|
||||||
|
if err := rows.Scan(&ch.ID, &ch.Name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
channels = append(channels, ch)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return channels, nil
|
||||||
|
}
|
@ -58,7 +58,7 @@ type downstreamConn struct {
|
|||||||
nick string
|
nick string
|
||||||
username string
|
username string
|
||||||
realname string
|
realname string
|
||||||
upstream *Upstream
|
network *network // can be nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
|
func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
|
||||||
@ -100,7 +100,7 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
|
|||||||
|
|
||||||
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
|
func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
|
||||||
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
dc.user.forEachUpstream(func(uc *upstreamConn) {
|
||||||
if dc.upstream != nil && uc.upstream != dc.upstream {
|
if dc.network != nil && uc.network != dc.network {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f(uc)
|
f(uc)
|
||||||
@ -301,9 +301,9 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
|||||||
|
|
||||||
func (dc *downstreamConn) register() error {
|
func (dc *downstreamConn) register() error {
|
||||||
username := strings.TrimPrefix(dc.username, "~")
|
username := strings.TrimPrefix(dc.username, "~")
|
||||||
var upstreamName string
|
var networkName string
|
||||||
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
|
if i := strings.LastIndexAny(username, "/@"); i >= 0 {
|
||||||
upstreamName = username[i+1:]
|
networkName = username[i+1:]
|
||||||
}
|
}
|
||||||
if i := strings.IndexAny(username, "/@"); i >= 0 {
|
if i := strings.IndexAny(username, "/@"); i >= 0 {
|
||||||
username = username[:i]
|
username = username[:i]
|
||||||
@ -320,14 +320,14 @@ func (dc *downstreamConn) register() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if upstreamName != "" {
|
if networkName != "" {
|
||||||
dc.upstream = dc.user.getUpstream(upstreamName)
|
dc.network = dc.user.getNetwork(networkName)
|
||||||
if dc.upstream == nil {
|
if dc.network == nil {
|
||||||
dc.logger.Printf("failed registration: unknown upstream %q", upstreamName)
|
dc.logger.Printf("failed registration: unknown network %q", networkName)
|
||||||
dc.SendMessage(&irc.Message{
|
dc.SendMessage(&irc.Message{
|
||||||
Prefix: dc.srv.prefix(),
|
Prefix: dc.srv.prefix(),
|
||||||
Command: irc.ERR_PASSWDMISMATCH,
|
Command: irc.ERR_PASSWDMISMATCH,
|
||||||
Params: []string{"*", fmt.Sprintf("Unknown upstream server %q", upstreamName)},
|
Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)},
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
5
go.mod
5
go.mod
@ -2,4 +2,7 @@ module git.sr.ht/~emersion/jounce
|
|||||||
|
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require gopkg.in/irc.v3 v3.1.0
|
require (
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||||
|
gopkg.in/irc.v3 v3.1.0
|
||||||
|
)
|
||||||
|
3
go.sum
3
go.sum
@ -1,2 +1,5 @@
|
|||||||
|
github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c=
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
gopkg.in/irc.v3 v3.1.0 h1:AeDaEhQ/78gHfpbj/3mSi8FfiNIsFiVrWEgLzOwHWnU=
|
gopkg.in/irc.v3 v3.1.0 h1:AeDaEhQ/78gHfpbj/3mSi8FfiNIsFiVrWEgLzOwHWnU=
|
||||||
gopkg.in/irc.v3 v3.1.0/go.mod h1:qE0DWv0j8Z8wCbFhA9783JBO0bufi3rttcV1Sjin8io=
|
gopkg.in/irc.v3 v3.1.0/go.mod h1:qE0DWv0j8Z8wCbFhA9783JBO0bufi3rttcV1Sjin8io=
|
||||||
|
21
schema.sql
Normal file
21
schema.sql
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
CREATE TABLE User (
|
||||||
|
username VARCHAR(255) PRIMARY KEY,
|
||||||
|
password VARCHAR(255)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE Network (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user VARCHAR(255) NOT NULL,
|
||||||
|
addr VARCHAR(255) NOT NULL,
|
||||||
|
nick VARCHAR(255) NOT NULL,
|
||||||
|
username VARCHAR(255),
|
||||||
|
realname VARCHAR(255),
|
||||||
|
FOREIGN KEY(user) REFERENCES User(username)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE Channel (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
network INTEGER NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
FOREIGN KEY(network) REFERENCES Network(id)
|
||||||
|
);
|
152
server.go
152
server.go
@ -47,26 +47,73 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
|
|||||||
l.logger.Printf("%v"+format, v...)
|
l.logger.Printf("%v"+format, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type network struct {
|
||||||
|
Network
|
||||||
|
user *user
|
||||||
|
conn *upstreamConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNetwork(user *user, record *Network) *network {
|
||||||
|
return &network{
|
||||||
|
Network: *record,
|
||||||
|
user: user,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *network) run() {
|
||||||
|
var lastTry time.Time
|
||||||
|
for {
|
||||||
|
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
|
||||||
|
delay := retryConnectMinDelay - dur
|
||||||
|
net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
lastTry = time.Now()
|
||||||
|
|
||||||
|
uc, err := connectToUpstream(net)
|
||||||
|
if err != nil {
|
||||||
|
net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
uc.register()
|
||||||
|
|
||||||
|
net.user.lock.Lock()
|
||||||
|
net.conn = uc
|
||||||
|
net.user.lock.Unlock()
|
||||||
|
|
||||||
|
if err := uc.readMessages(); err != nil {
|
||||||
|
uc.logger.Printf("failed to handle messages: %v", err)
|
||||||
|
}
|
||||||
|
uc.Close()
|
||||||
|
|
||||||
|
net.user.lock.Lock()
|
||||||
|
net.conn = nil
|
||||||
|
net.user.lock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type user struct {
|
type user struct {
|
||||||
username string
|
User
|
||||||
srv *Server
|
srv *Server
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
upstreamConns []*upstreamConn
|
networks []*network
|
||||||
downstreamConns []*downstreamConn
|
downstreamConns []*downstreamConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUser(srv *Server, username string) *user {
|
func newUser(srv *Server, record *User) *user {
|
||||||
return &user{
|
return &user{
|
||||||
username: username,
|
User: *record,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
for _, uc := range u.upstreamConns {
|
for _, network := range u.networks {
|
||||||
if !uc.registered || uc.closed {
|
uc := network.conn
|
||||||
|
if uc == nil || !uc.registered || uc.closed {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
f(uc)
|
f(uc)
|
||||||
@ -82,21 +129,30 @@ func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
|||||||
u.lock.Unlock()
|
u.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) getUpstream(name string) *Upstream {
|
func (u *user) getNetwork(name string) *network {
|
||||||
for i, upstream := range u.srv.Upstreams {
|
for _, network := range u.networks {
|
||||||
if upstream.Addr == name {
|
if network.Addr == name {
|
||||||
return &u.srv.Upstreams[i]
|
return network
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Upstream struct {
|
func (u *user) run() {
|
||||||
Addr string
|
networks, err := u.srv.db.ListNetworks(u.Username)
|
||||||
Nick string
|
if err != nil {
|
||||||
Username string
|
u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
|
||||||
Realname string
|
return
|
||||||
Channels []string
|
}
|
||||||
|
|
||||||
|
u.lock.Lock()
|
||||||
|
for _, record := range networks {
|
||||||
|
network := newNetwork(u, &record)
|
||||||
|
u.networks = append(u.networks, network)
|
||||||
|
|
||||||
|
go network.run()
|
||||||
|
}
|
||||||
|
u.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -104,18 +160,20 @@ type Server struct {
|
|||||||
Logger Logger
|
Logger Logger
|
||||||
RingCap int
|
RingCap int
|
||||||
Debug bool
|
Debug bool
|
||||||
Upstreams []Upstream // TODO: per-user
|
|
||||||
|
db *DB
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
users map[string]*user
|
users map[string]*user
|
||||||
downstreamConns []*downstreamConn
|
downstreamConns []*downstreamConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer(db *DB) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
Logger: log.New(log.Writer(), "", log.LstdFlags),
|
Logger: log.New(log.Writer(), "", log.LstdFlags),
|
||||||
RingCap: 4096,
|
RingCap: 4096,
|
||||||
users: make(map[string]*user),
|
users: make(map[string]*user),
|
||||||
|
db: db,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,55 +181,23 @@ func (s *Server) prefix() *irc.Prefix {
|
|||||||
return &irc.Prefix{Name: s.Hostname}
|
return &irc.Prefix{Name: s.Hostname}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runUpstream(u *user, upstream *Upstream) {
|
func (s *Server) Run() error {
|
||||||
var lastTry time.Time
|
users, err := s.db.ListUsers()
|
||||||
for {
|
if err != nil {
|
||||||
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
|
return err
|
||||||
delay := retryConnectMinDelay - dur
|
|
||||||
s.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), upstream.Addr)
|
|
||||||
time.Sleep(delay)
|
|
||||||
}
|
|
||||||
lastTry = time.Now()
|
|
||||||
|
|
||||||
uc, err := connectToUpstream(u, upstream)
|
|
||||||
if err != nil {
|
|
||||||
s.Logger.Printf("failed to connect to upstream server %q: %v", upstream.Addr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
uc.register()
|
|
||||||
|
|
||||||
u.lock.Lock()
|
|
||||||
u.upstreamConns = append(u.upstreamConns, uc)
|
|
||||||
u.lock.Unlock()
|
|
||||||
|
|
||||||
if err := uc.readMessages(); err != nil {
|
|
||||||
uc.logger.Printf("failed to handle messages: %v", err)
|
|
||||||
}
|
|
||||||
uc.Close()
|
|
||||||
|
|
||||||
u.lock.Lock()
|
|
||||||
for i := range u.upstreamConns {
|
|
||||||
if u.upstreamConns[i] == uc {
|
|
||||||
u.upstreamConns = append(u.upstreamConns[:i], u.upstreamConns[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
u.lock.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Run() {
|
|
||||||
// TODO: multi-user
|
|
||||||
u := newUser(s, "jounce")
|
|
||||||
|
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
s.users[u.username] = u
|
for _, record := range users {
|
||||||
|
s.Logger.Printf("starting bouncer for user %q", record.Username)
|
||||||
|
u := newUser(s, &record)
|
||||||
|
s.users[u.Username] = u
|
||||||
|
|
||||||
|
go u.run()
|
||||||
|
}
|
||||||
s.lock.Unlock()
|
s.lock.Unlock()
|
||||||
|
|
||||||
for i := range s.Upstreams {
|
select {}
|
||||||
go s.runUpstream(u, &s.Upstreams[i])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getUser(name string) *user {
|
func (s *Server) getUser(name string) *user {
|
||||||
|
52
upstream.go
52
upstream.go
@ -25,7 +25,7 @@ type upstreamChannel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type upstreamConn struct {
|
type upstreamConn struct {
|
||||||
upstream *Upstream
|
network *network
|
||||||
logger Logger
|
logger Logger
|
||||||
net net.Conn
|
net net.Conn
|
||||||
irc *irc.Conn
|
irc *irc.Conn
|
||||||
@ -41,33 +41,40 @@ type upstreamConn struct {
|
|||||||
|
|
||||||
registered bool
|
registered bool
|
||||||
nick string
|
nick string
|
||||||
|
username string
|
||||||
|
realname string
|
||||||
closed bool
|
closed bool
|
||||||
modes modeSet
|
modes modeSet
|
||||||
channels map[string]*upstreamChannel
|
channels map[string]*upstreamChannel
|
||||||
history map[string]uint64
|
history map[string]uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) {
|
func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||||
logger := &prefixLogger{u.srv.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
|
logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
|
||||||
logger.Printf("connecting to server")
|
|
||||||
|
|
||||||
netConn, err := tls.Dial("tcp", upstream.Addr, nil)
|
addr := network.Addr
|
||||||
|
if !strings.ContainsRune(addr, ':') {
|
||||||
|
addr = addr + ":6697"
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Printf("connecting to TLS server at address %q", addr)
|
||||||
|
netConn, err := tls.Dial("tcp", addr, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
|
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setKeepAlive(netConn)
|
setKeepAlive(netConn)
|
||||||
|
|
||||||
msgs := make(chan *irc.Message, 64)
|
msgs := make(chan *irc.Message, 64)
|
||||||
uc := &upstreamConn{
|
uc := &upstreamConn{
|
||||||
upstream: upstream,
|
network: network,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
net: netConn,
|
net: netConn,
|
||||||
irc: irc.NewConn(netConn),
|
irc: irc.NewConn(netConn),
|
||||||
srv: u.srv,
|
srv: network.user.srv,
|
||||||
user: u,
|
user: network.user,
|
||||||
messages: msgs,
|
messages: msgs,
|
||||||
ring: NewRing(u.srv.RingCap),
|
ring: NewRing(network.user.srv.RingCap),
|
||||||
channels: make(map[string]*upstreamChannel),
|
channels: make(map[string]*upstreamChannel),
|
||||||
history: make(map[string]uint64),
|
history: make(map[string]uint64),
|
||||||
}
|
}
|
||||||
@ -102,7 +109,7 @@ func (uc *upstreamConn) Close() error {
|
|||||||
|
|
||||||
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||||
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
uc.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
if dc.upstream != nil && dc.upstream != uc.upstream {
|
if dc.network != nil && dc.network != uc.network {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f(dc)
|
f(dc)
|
||||||
@ -163,10 +170,16 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
uc.registered = true
|
uc.registered = true
|
||||||
uc.logger.Printf("connection registered")
|
uc.logger.Printf("connection registered")
|
||||||
|
|
||||||
for _, ch := range uc.upstream.Channels {
|
channels, err := uc.srv.db.ListChannels(uc.network.ID)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.Printf("failed to list channels from database: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ch := range channels {
|
||||||
uc.SendMessage(&irc.Message{
|
uc.SendMessage(&irc.Message{
|
||||||
Command: "JOIN",
|
Command: "JOIN",
|
||||||
Params: []string{ch},
|
Params: []string{ch.Name},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
case irc.RPL_MYINFO:
|
case irc.RPL_MYINFO:
|
||||||
@ -371,14 +384,23 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (uc *upstreamConn) register() {
|
func (uc *upstreamConn) register() {
|
||||||
uc.nick = uc.upstream.Nick
|
uc.nick = uc.network.Nick
|
||||||
|
uc.username = uc.network.Username
|
||||||
|
if uc.username == "" {
|
||||||
|
uc.username = uc.nick
|
||||||
|
}
|
||||||
|
uc.realname = uc.network.Realname
|
||||||
|
if uc.realname == "" {
|
||||||
|
uc.realname = uc.nick
|
||||||
|
}
|
||||||
|
|
||||||
uc.SendMessage(&irc.Message{
|
uc.SendMessage(&irc.Message{
|
||||||
Command: "NICK",
|
Command: "NICK",
|
||||||
Params: []string{uc.nick},
|
Params: []string{uc.nick},
|
||||||
})
|
})
|
||||||
uc.SendMessage(&irc.Message{
|
uc.SendMessage(&irc.Message{
|
||||||
Command: "USER",
|
Command: "USER",
|
||||||
Params: []string{uc.upstream.Username, "0", "*", uc.upstream.Realname},
|
Params: []string{uc.username, "0", "*", uc.realname},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user