Commit 2347417f authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Merge branch 'modular' into master

parents 714a0939 709a6857
...@@ -3,3 +3,4 @@ data/*.pem ...@@ -3,3 +3,4 @@ data/*.pem
sfu sfu
passwd passwd
groups/*.json groups/*.json
static/*.d.ts
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
// This is not open source software. Copy it, and I'll break into your // This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist. // house and tell your three year-old that Santa doesn't exist.
package main // Package conn defines interfaces for connections and tracks.
package conn
import ( import (
"errors" "errors"
...@@ -15,29 +16,33 @@ import ( ...@@ -15,29 +16,33 @@ import (
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
var ErrKeyframeNeeded = errors.New("keyframe needed") var ErrKeyframeNeeded = errors.New("keyframe needed")
type upConnection interface { // Type Up represents a connection in the client to server direction.
addLocal(downConnection) error type Up interface {
delLocal(downConnection) bool AddLocal(Down) error
DelLocal(Down) bool
Id() string Id() string
Label() string Label() string
} }
type upTrack interface { // Type UpTrack represents a track in the client to server direction.
addLocal(downTrack) error type UpTrack interface {
delLocal(downTrack) bool AddLocal(DownTrack) error
DelLocal(DownTrack) bool
Label() string Label() string
Codec() *webrtc.RTPCodec Codec() *webrtc.RTPCodec
// get a recent packet. Returns 0 if the packet is not in cache. // get a recent packet. Returns 0 if the packet is not in cache.
getRTP(seqno uint16, result []byte) uint16 GetRTP(seqno uint16, result []byte) uint16
} }
type downConnection interface { // Type Down represents a connection in the server to client direction.
type Down interface {
GetMaxBitrate(now uint64) uint64 GetMaxBitrate(now uint64) uint64
} }
type downTrack interface { // Type DownTrack represents a track in the server to client direction.
type DownTrack interface {
WriteRTP(packat *rtp.Packet) error WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32) Accumulate(bytes uint32)
setTimeOffset(ntp uint64, rtp uint32) SetTimeOffset(ntp uint64, rtp uint32)
setCname(string) SetCname(string)
} }
package main package disk
import ( import (
crand "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "encoding/hex"
"sync" "sync"
"time" "time"
...@@ -14,10 +15,15 @@ import ( ...@@ -14,10 +15,15 @@ import (
"github.com/pion/rtp/codecs" "github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media/samplebuilder" "github.com/pion/webrtc/v3/pkg/media/samplebuilder"
"sfu/conn"
"sfu/group"
) )
type diskClient struct { var Directory string
group *group
type Client struct {
group *group.Group
id string id string
mu sync.Mutex mu sync.Mutex
...@@ -25,45 +31,37 @@ type diskClient struct { ...@@ -25,45 +31,37 @@ type diskClient struct {
closed bool closed bool
} }
var idCounter struct {
mu sync.Mutex
counter int
}
func newId() string { func newId() string {
idCounter.mu.Lock() b := make([]byte, 16)
defer idCounter.mu.Unlock() crand.Read(b)
return hex.EncodeToString(b)
s := strconv.FormatInt(int64(idCounter.counter), 16)
idCounter.counter++
return s
} }
func NewDiskClient(g *group) *diskClient { func New(g *group.Group) *Client {
return &diskClient{group: g, id: newId()} return &Client{group: g, id: newId()}
} }
func (client *diskClient) Group() *group { func (client *Client) Group() *group.Group {
return client.group return client.group
} }
func (client *diskClient) Id() string { func (client *Client) Id() string {
return client.id return client.id
} }
func (client *diskClient) Credentials() clientCredentials { func (client *Client) Credentials() group.ClientCredentials {
return clientCredentials{"RECORDING", ""} return group.ClientCredentials{"RECORDING", ""}
} }
func (client *diskClient) SetPermissions(perms clientPermissions) { func (client *Client) SetPermissions(perms group.ClientPermissions) {
return return
} }
func (client *diskClient) pushClient(id, username string, add bool) error { func (client *Client) PushClient(id, username string, add bool) error {
return nil return nil
} }
func (client *diskClient) Close() error { func (client *Client) Close() error {
client.mu.Lock() client.mu.Lock()
defer client.mu.Unlock() defer client.mu.Unlock()
...@@ -75,13 +73,13 @@ func (client *diskClient) Close() error { ...@@ -75,13 +73,13 @@ func (client *diskClient) Close() error {
return nil return nil
} }
func (client *diskClient) kick(message string) error { func (client *Client) kick(message string) error {
err := client.Close() err := client.Close()
delClient(client) group.DelClient(client)
return err return err
} }
func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { func (client *Client) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
client.mu.Lock() client.mu.Lock()
defer client.mu.Unlock() defer client.mu.Unlock()
...@@ -95,11 +93,11 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac ...@@ -95,11 +93,11 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
delete(client.down, id) delete(client.down, id)
} }
if conn == nil { if up == nil {
return nil return nil
} }
directory := filepath.Join(recordingsDir, client.group.name) directory := filepath.Join(Directory, client.group.Name())
err := os.MkdirAll(directory, 0700) err := os.MkdirAll(directory, 0700)
if err != nil { if err != nil {
return err return err
...@@ -109,12 +107,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac ...@@ -109,12 +107,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
client.down = make(map[string]*diskConn) client.down = make(map[string]*diskConn)
} }
down, err := newDiskConn(directory, label, conn, tracks) down, err := newDiskConn(directory, label, up, tracks)
if err != nil { if err != nil {
return err return err
} }
client.down[conn.Id()] = down client.down[up.Id()] = down
return nil return nil
} }
...@@ -125,7 +123,7 @@ type diskConn struct { ...@@ -125,7 +123,7 @@ type diskConn struct {
mu sync.Mutex mu sync.Mutex
file *os.File file *os.File
remote upConnection remote conn.Up
tracks []*diskTrack tracks []*diskTrack
width, height uint32 width, height uint32
} }
...@@ -150,7 +148,7 @@ func (conn *diskConn) reopen() error { ...@@ -150,7 +148,7 @@ func (conn *diskConn) reopen() error {
} }
func (conn *diskConn) Close() error { func (conn *diskConn) Close() error {
conn.remote.delLocal(conn) conn.remote.DelLocal(conn)
conn.mu.Lock() conn.mu.Lock()
tracks := make([]*diskTrack, 0, len(conn.tracks)) tracks := make([]*diskTrack, 0, len(conn.tracks))
...@@ -164,7 +162,7 @@ func (conn *diskConn) Close() error { ...@@ -164,7 +162,7 @@ func (conn *diskConn) Close() error {
conn.mu.Unlock() conn.mu.Unlock()
for _, t := range tracks { for _, t := range tracks {
t.remote.delLocal(t) t.remote.DelLocal(t)
} }
return nil return nil
} }
...@@ -196,7 +194,7 @@ func openDiskFile(directory, label string) (*os.File, error) { ...@@ -196,7 +194,7 @@ func openDiskFile(directory, label string) (*os.File, error) {
} }
type diskTrack struct { type diskTrack struct {
remote upTrack remote conn.UpTrack
conn *diskConn conn *diskConn
writer webm.BlockWriteCloser writer webm.BlockWriteCloser
...@@ -206,7 +204,7 @@ type diskTrack struct { ...@@ -206,7 +204,7 @@ type diskTrack struct {
origin uint64 origin uint64
} }
func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrack) (*diskConn, error) { func newDiskConn(directory, label string, up conn.Up, remoteTracks []conn.UpTrack) (*diskConn, error) {
conn := diskConn{ conn := diskConn{
directory: directory, directory: directory,
label: label, label: label,
...@@ -231,10 +229,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac ...@@ -231,10 +229,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
conn: &conn, conn: &conn,
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
remote.addLocal(track) remote.AddLocal(track)
} }
err := up.addLocal(&conn) err := up.AddLocal(&conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -242,10 +240,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac ...@@ -242,10 +240,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
return &conn, nil return &conn, nil
} }
func (t *diskTrack) setTimeOffset(ntp uint64, rtp uint32) { func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) {
} }
func (t *diskTrack) setCname(string) { func (t *diskTrack) SetCname(string) {
} }
func clonePacket(packet *rtp.Packet) *rtp.Packet { func clonePacket(packet *rtp.Packet) *rtp.Packet {
...@@ -310,7 +308,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -310,7 +308,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
if t.writer == nil { if t.writer == nil {
if !keyframe { if !keyframe {
return ErrKeyframeNeeded return conn.ErrKeyframeNeeded
} }
return nil return nil
} }
......
package main package group
type clientCredentials struct { import (
"sfu/conn"
)
type ClientCredentials struct {
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
} }
type clientPermissions struct { type ClientPermissions struct {
Op bool `json:"op,omitempty"` Op bool `json:"op,omitempty"`
Present bool `json:"present,omitempty"` Present bool `json:"present,omitempty"`
Record bool `json:"record,omitempty"` Record bool `json:"record,omitempty"`
} }
type client interface { type Client interface {
Group() *group Group() *Group
Id() string Id() string
Credentials() clientCredentials Credentials() ClientCredentials
SetPermissions(clientPermissions) SetPermissions(ClientPermissions)
pushConn(id string, conn upConnection, tracks []upTrack, label string) error PushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error
pushClient(id, username string, add bool) error PushClient(id, username string, add bool) error
} }
type kickable interface { type Kickable interface {
kick(message string) error Kick(message string) error
} }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// This is not open source software. Copy it, and I'll break into your // This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist. // house and tell your three year-old that Santa doesn't exist.
package main package group
import ( import (
"encoding/json" "encoding/json"
...@@ -18,18 +18,60 @@ import ( ...@@ -18,18 +18,60 @@ import (
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
) )
type chatHistoryEntry struct { var Directory string
id string
user string type UserError string
kind string
value string func (err UserError) Error() string {
return string(err)
}
type ProtocolError string
func (err ProtocolError) Error() string {
return string(err)
}
var IceFilename string
var iceConf webrtc.Configuration
var iceOnce sync.Once
func IceConfiguration() webrtc.Configuration {
iceOnce.Do(func() {
var iceServers []webrtc.ICEServer
file, err := os.Open(IceFilename)
if err != nil {
log.Printf("Open %v: %v", IceFilename, err)
return
}
defer file.Close()
d := json.NewDecoder(file)
err = d.Decode(&iceServers)
if err != nil {
log.Printf("Get ICE configuration: %v", err)
return
}
iceConf = webrtc.Configuration{
ICEServers: iceServers,
}
})
return iceConf
}
type ChatHistoryEntry struct {
Id string
User string
Kind string
Value string
} }
const ( const (
minBitrate = 200000 MinBitrate = 200000
) )
type group struct { type Group struct {
name string name string
mu sync.Mutex mu sync.Mutex
...@@ -37,35 +79,39 @@ type group struct { ...@@ -37,35 +79,39 @@ type group struct {
// indicates that the group no longer exists, but it still has clients // indicates that the group no longer exists, but it still has clients
dead bool dead bool
locked bool locked bool
clients map[string]client clients map[string]Client
history []chatHistoryEntry history []ChatHistoryEntry
} }
func (g *group) Locked() bool { func (g *Group) Name() string {
return g.name
}
func (g *Group) Locked() bool {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.locked return g.locked
} }
func (g *group) SetLocked(locked bool) { func (g *Group) SetLocked(locked bool) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
g.locked = locked g.locked = locked
} }
func (g *group) Public() bool { func (g *Group) Public() bool {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.description.Public return g.description.Public
} }
func (g *group) Redirect() string { func (g *Group) Redirect() string {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.description.Redirect return g.description.Redirect
} }
func (g *group) AllowRecording() bool { func (g *Group) AllowRecording() bool {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.description.AllowRecording return g.description.AllowRecording
...@@ -73,20 +119,20 @@ func (g *group) AllowRecording() bool { ...@@ -73,20 +119,20 @@ func (g *group) AllowRecording() bool {
var groups struct { var groups struct {
mu sync.Mutex mu sync.Mutex
groups map[string]*group groups map[string]*Group
api *webrtc.API api *webrtc.API
} }
func (g *group) API() *webrtc.API { func (g *Group) API() *webrtc.API {
return groups.api return groups.api
} }
func addGroup(name string, desc *groupDescription) (*group, error) { func Add(name string, desc *groupDescription) (*Group, error) {
groups.mu.Lock() groups.mu.Lock()
defer groups.mu.Unlock() defer groups.mu.Unlock()
if groups.groups == nil { if groups.groups == nil {
groups.groups = make(map[string]*group) groups.groups = make(map[string]*Group)
s := webrtc.SettingEngine{} s := webrtc.SettingEngine{}
m := webrtc.MediaEngine{} m := webrtc.MediaEngine{}
m.RegisterCodec(webrtc.NewRTPVP8CodecExt( m.RegisterCodec(webrtc.NewRTPVP8CodecExt(
...@@ -113,15 +159,15 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -113,15 +159,15 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
g := groups.groups[name] g := groups.groups[name]
if g == nil { if g == nil {
if desc == nil { if desc == nil {
desc, err = getDescription(name) desc, err = GetDescription(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
g = &group{ g = &Group{
name: name, name: name,
description: desc, description: desc,
clients: make(map[string]client), clients: make(map[string]Client),
} }
groups.groups[name] = g groups.groups[name] = g
return g, nil return g, nil
...@@ -147,7 +193,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -147,7 +193,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
return nil, err return nil, err
} }
if changed { if changed {
desc, err := getDescription(name) desc, err := GetDescription(name)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Printf("Reading group %v: %v", log.Printf("Reading group %v: %v",
...@@ -167,7 +213,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -167,7 +213,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
return g, nil return g, nil
} }
func rangeGroups(f func(g *group) bool) { func Range(f func(g *Group) bool) {
groups.mu.Lock() groups.mu.Lock()
defer groups.mu.Unlock() defer groups.mu.Unlock()
...@@ -179,17 +225,17 @@ func rangeGroups(f func(g *group) bool) { ...@@ -179,17 +225,17 @@ func rangeGroups(f func(g *group) bool) {
} }
} }
func getGroupNames() []string { func GetNames() []string {
names := make([]string, 0) names := make([]string, 0)
rangeGroups(func(g *group) bool { Range(func(g *Group) bool {
names = append(names, g.name) names = append(names, g.name)
return true return true
}) })
return names return names
} }
func getGroup(name string) *group { func Get(name string) *Group {
groups.mu.Lock() groups.mu.Lock()
defer groups.mu.Unlock() defer groups.mu.Unlock()
...@@ -210,8 +256,8 @@ func delGroupUnlocked(name string) bool { ...@@ -210,8 +256,8 @@ func delGroupUnlocked(name string) bool {
return true return true
} }
func addClient(name string, c client) (*group, error) { func AddClient(name string, c Client) (*Group, error) {
g, err := addGroup(name, nil) g, err := Add(name, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -219,7 +265,7 @@ func addClient(name string, c client) (*group, error) { ...@@ -219,7 +265,7 @@ func addClient(name string, c client) (*group, error) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
perms, err := getPermission(g.description, c.Credentials()) perms, err := g.description.GetPermission(c.Credentials())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -227,37 +273,34 @@ func addClient(name string, c client) (*group, error) { ...@@ -227,37 +273,34 @@ func addClient(name string, c client) (*group, error) {
c.SetPermissions(perms) c.SetPermissions(perms)
if !perms.Op && g.locked { if !perms.Op && g.locked {
return nil, userError("group is locked") return nil, UserError("group is locked")
} }
if !perms.Op && g.description.MaxClients > 0 { if !perms.Op && g.description.MaxClients > 0 {
if len(g.clients) >= g.description.MaxClients { if len(g.clients) >= g.description.MaxClients {
return nil, userError("too many users") return nil, UserError("too many users")
} }
} }
if g.clients[c.Id()] != nil { if g.clients[c.Id()] != nil {
return nil, protocolError("duplicate client id") return nil, ProtocolError("duplicate client id")
} }
g.clients[c.Id()] = c g.clients[c.Id()] = c
go func(clients []client) { go func(clients []Client) {
u := c.Credentials().Username u := c.Credentials().Username
c.pushClient(c.Id(), u, true) c.PushClient(c.Id(), u, true)
for _, cc := range clients { for _, cc := range clients {
uu := cc.Credentials().Username uu := cc.Credentials().Username
err := c.pushClient(cc.Id(), uu, true) c.PushClient(cc.Id(), uu, true)
if err == ErrClientDead { cc.PushClient(c.Id(), u, true)
return
}
cc.pushClient(c.Id(), u, true)
} }
}(g.getClientsUnlocked(c)) }(g.getClientsUnlocked(c))
return g, nil return g, nil
} }
func delClient(c client) { func DelClient(c Client) {
g := c.Group() g := c.Group()
if g == nil { if g == nil {
return return
...@@ -271,21 +314,21 @@ func delClient(c client) { ...@@ -271,21 +314,21 @@ func delClient(c client) {
} }
delete(g.clients, c.Id()) delete(g.clients, c.Id())
go func(clients []client) { go func(clients []Client) {
for _, cc := range clients { for _, cc := range clients {
cc.pushClient(c.Id(), c.Credentials().Username, false) cc.PushClient(c.Id(), c.Credentials().Username, false)
} }
}(g.getClientsUnlocked(nil)) }(g.getClientsUnlocked(nil))
} }
func (g *group) getClients(except client) []client { func (g *Group) GetClients(except Client) []Client {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.getClientsUnlocked(except) return g.getClientsUnlocked(except)
} }
func (g *group) getClientsUnlocked(except client) []client { func (g *Group) getClientsUnlocked(except Client) []Client {
clients := make([]client, 0, len(g.clients)) clients := make([]Client, 0, len(g.clients))
for _, c := range g.clients { for _, c := range g.clients {
if c != except { if c != except {
clients = append(clients, c) clients = append(clients, c)
...@@ -294,13 +337,13 @@ func (g *group) getClientsUnlocked(except client) []client { ...@@ -294,13 +337,13 @@ func (g *group) getClientsUnlocked(except client) []client {
return clients return clients
} }
func (g *group) getClient(id string) client { func (g *Group) GetClient(id string) Client {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.getClientUnlocked(id) return g.getClientUnlocked(id)
} }
func (g *group) getClientUnlocked(id string) client { func (g *Group) getClientUnlocked(id string) Client {
for idd, c := range g.clients { for idd, c := range g.clients {
if idd == id { if idd == id {
return c return c
...@@ -309,7 +352,7 @@ func (g *group) getClientUnlocked(id string) client { ...@@ -309,7 +352,7 @@ func (g *group) getClientUnlocked(id string) client {
return nil return nil
} }
func (g *group) Range(f func(c client) bool) { func (g *Group) Range(f func(c Client) bool) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
for _, c := range g.clients { for _, c := range g.clients {
...@@ -320,11 +363,11 @@ func (g *group) Range(f func(c client) bool) { ...@@ -320,11 +363,11 @@ func (g *group) Range(f func(c client) bool) {
} }
} }
func (g *group) shutdown(message string) { func (g *Group) Shutdown(message string) {
g.Range(func(c client) bool { g.Range(func(c Client) bool {
cc, ok := c.(kickable) cc, ok := c.(Kickable)
if ok { if ok {
cc.kick(message) cc.Kick(message)
} }
return true return true
}) })
...@@ -332,13 +375,13 @@ func (g *group) shutdown(message string) { ...@@ -332,13 +375,13 @@ func (g *group) shutdown(message string) {
const maxChatHistory = 20 const maxChatHistory = 20
func (g *group) clearChatHistory() { func (g *Group) ClearChatHistory() {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
g.history = nil g.history = nil
} }
func (g *group) addToChatHistory(id, user, kind, value string) { func (g *Group) AddToChatHistory(id, user, kind, value string) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
...@@ -347,20 +390,20 @@ func (g *group) addToChatHistory(id, user, kind, value string) { ...@@ -347,20 +390,20 @@ func (g *group) addToChatHistory(id, user, kind, value string) {
g.history = g.history[:len(g.history)-1] g.history = g.history[:len(g.history)-1]
} }
g.history = append(g.history, g.history = append(g.history,
chatHistoryEntry{id: id, user: user, kind: kind, value: value}, ChatHistoryEntry{Id: id, User: user, Kind: kind, Value: value},
) )
} }
func (g *group) getChatHistory() []chatHistoryEntry { func (g *Group) GetChatHistory() []ChatHistoryEntry {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
h := make([]chatHistoryEntry, len(g.history)) h := make([]ChatHistoryEntry, len(g.history))
copy(h, g.history) copy(h, g.history)
return h return h
} }
func matchUser(user clientCredentials, users []clientCredentials) (bool, bool) { func matchUser(user ClientCredentials, users []ClientCredentials) (bool, bool) {
for _, u := range users { for _, u := range users {
if u.Username == "" { if u.Username == "" {
if u.Password == "" || u.Password == user.Password { if u.Password == "" || u.Password == user.Password {
...@@ -383,13 +426,13 @@ type groupDescription struct { ...@@ -383,13 +426,13 @@ type groupDescription struct {
MaxClients int `json:"max-clients,omitempty"` MaxClients int `json:"max-clients,omitempty"`
AllowAnonymous bool `json:"allow-anonymous,omitempty"` AllowAnonymous bool `json:"allow-anonymous,omitempty"`
AllowRecording bool `json:"allow-recording,omitempty"` AllowRecording bool `json:"allow-recording,omitempty"`
Op []clientCredentials `json:"op,omitempty"` Op []ClientCredentials `json:"op,omitempty"`
Presenter []clientCredentials `json:"presenter,omitempty"` Presenter []ClientCredentials `json:"presenter,omitempty"`
Other []clientCredentials `json:"other,omitempty"` Other []ClientCredentials `json:"other,omitempty"`
} }
func descriptionChanged(name string, old *groupDescription) (bool, error) { func descriptionChanged(name string, old *groupDescription) (bool, error) {
fi, err := os.Stat(filepath.Join(groupsDir, name+".json")) fi, err := os.Stat(filepath.Join(Directory, name+".json"))
if err != nil { if err != nil {
return false, err return false, err
} }
...@@ -399,8 +442,8 @@ func descriptionChanged(name string, old *groupDescription) (bool, error) { ...@@ -399,8 +442,8 @@ func descriptionChanged(name string, old *groupDescription) (bool, error) {
return false, err return false, err
} }
func getDescription(name string) (*groupDescription, error) { func GetDescription(name string) (*groupDescription, error) {
r, err := os.Open(filepath.Join(groupsDir, name+".json")) r, err := os.Open(filepath.Join(Directory, name+".json"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -424,10 +467,10 @@ func getDescription(name string) (*groupDescription, error) { ...@@ -424,10 +467,10 @@ func getDescription(name string) (*groupDescription, error) {
return &desc, nil return &desc, nil
} }
func getPermission(desc *groupDescription, creds clientCredentials) (clientPermissions, error) { func (desc *groupDescription) GetPermission (creds ClientCredentials) (ClientPermissions, error) {
var p clientPermissions var p ClientPermissions
if !desc.AllowAnonymous && creds.Username == "" { if !desc.AllowAnonymous && creds.Username == "" {
return p, userError("anonymous users not allowed in this group, please choose a username") return p, UserError("anonymous users not allowed in this group, please choose a username")
} }
if found, good := matchUser(creds, desc.Op); found { if found, good := matchUser(creds, desc.Op); found {
if good { if good {
...@@ -438,34 +481,34 @@ func getPermission(desc *groupDescription, creds clientCredentials) (clientPermi ...@@ -438,34 +481,34 @@ func getPermission(desc *groupDescription, creds clientCredentials) (clientPermi
} }
return p, nil return p, nil
} }
return p, userError("not authorised") return p, UserError("not authorised")
} }
if found, good := matchUser(creds, desc.Presenter); found { if found, good := matchUser(creds, desc.Presenter); found {
if good { if good {
p.Present = true p.Present = true
return p, nil return p, nil
} }
return p, userError("not authorised") return p, UserError("not authorised")
} }
if found, good := matchUser(creds, desc.Other); found { if found, good := matchUser(creds, desc.Other); found {
if good { if good {
return p, nil return p, nil
} }
return p, userError("not authorised") return p, UserError("not authorised")
} }
return p, userError("not authorised") return p, UserError("not authorised")
} }
type publicGroup struct { type Public struct {
Name string `json:"name"` Name string `json:"name"`
ClientCount int `json:"clientCount"` ClientCount int `json:"clientCount"`
} }
func getPublicGroups() []publicGroup { func GetPublic() []Public {
gs := make([]publicGroup, 0) gs := make([]Public, 0)
rangeGroups(func(g *group) bool { Range(func(g *Group) bool {
if g.Public() { if g.Public() {
gs = append(gs, publicGroup{ gs = append(gs, Public{
Name: g.name, Name: g.name,
ClientCount: len(g.clients), ClientCount: len(g.clients),
}) })
...@@ -478,8 +521,8 @@ func getPublicGroups() []publicGroup { ...@@ -478,8 +521,8 @@ func getPublicGroups() []publicGroup {
return gs return gs
} }
func readPublicGroups() { func ReadPublicGroups() {
dir, err := os.Open(groupsDir) dir, err := os.Open(Directory)
if err != nil { if err != nil {
return return
} }
...@@ -496,7 +539,7 @@ func readPublicGroups() { ...@@ -496,7 +539,7 @@ func readPublicGroups() {
continue continue
} }
name := fi.Name()[:len(fi.Name())-5] name := fi.Name()[:len(fi.Name())-5]
desc, err := getDescription(name) desc, err := GetDescription(name)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Printf("Reading group %v: %v", name, err) log.Printf("Reading group %v: %v", name, err)
...@@ -504,7 +547,7 @@ func readPublicGroups() { ...@@ -504,7 +547,7 @@ func readPublicGroups() {
continue continue
} }
if desc.Public { if desc.Public {
addGroup(name, desc) Add(name, desc)
} }
} }
} }
// Copyright (c) 2020 by Juliusz Chroboczek. package rtpconn
// This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist.
package main
import ( import (
"errors" "errors"
...@@ -14,14 +9,16 @@ import ( ...@@ -14,14 +9,16 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"sfu/conn"
"sfu/estimator" "sfu/estimator"
"sfu/group"
"sfu/jitter" "sfu/jitter"
"sfu/packetcache" "sfu/packetcache"
"sfu/rtptime" "sfu/rtptime"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
) )
type bitrate struct { type bitrate struct {
...@@ -71,7 +68,7 @@ type iceConnection interface { ...@@ -71,7 +68,7 @@ type iceConnection interface {
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.Track track *webrtc.Track
remote upTrack remote conn.UpTrack
maxBitrate *bitrate maxBitrate *bitrate
rate *estimator.Estimator rate *estimator.Estimator
stats *receiverStats stats *receiverStats
...@@ -91,26 +88,26 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) { ...@@ -91,26 +88,26 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes) down.rate.Accumulate(bytes)
} }
func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) { func (down *rtpDownTrack) SetTimeOffset(ntp uint64, rtp uint32) {
atomic.StoreUint64(&down.remoteNTPTime, ntp) atomic.StoreUint64(&down.remoteNTPTime, ntp)
atomic.StoreUint32(&down.remoteRTPTime, rtp) atomic.StoreUint32(&down.remoteRTPTime, rtp)
} }
func (down *rtpDownTrack) setCname(cname string) { func (down *rtpDownTrack) SetCname(cname string) {
down.cname.Store(cname) down.cname.Store(cname)
} }
type rtpDownConnection struct { type rtpDownConnection struct {
id string id string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
remote upConnection remote conn.Up
tracks []*rtpDownTrack tracks []*rtpDownTrack
maxREMBBitrate *bitrate maxREMBBitrate *bitrate
iceCandidates []*webrtc.ICECandidateInit iceCandidates []*webrtc.ICECandidateInit
} }
func newDownConn(c client, id string, remote upConnection) (*rtpDownConnection, error) { func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) {
pc, err := c.Group().API().NewPeerConnection(iceConfiguration()) pc, err := c.Group().API().NewPeerConnection(group.IceConfiguration())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -193,7 +190,7 @@ type rtpUpTrack struct { ...@@ -193,7 +190,7 @@ type rtpUpTrack struct {
mu sync.Mutex mu sync.Mutex
cname string cname string
local []downTrack local []conn.DownTrack
srTime uint64 srTime uint64
srNTPTime uint64 srNTPTime uint64
srRTPTime uint32 srRTPTime uint32
...@@ -201,17 +198,17 @@ type rtpUpTrack struct { ...@@ -201,17 +198,17 @@ type rtpUpTrack struct {
type localTrackAction struct { type localTrackAction struct {
add bool add bool
track downTrack track conn.DownTrack
} }
func (up *rtpUpTrack) notifyLocal(add bool, track downTrack) { func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) {
select { select {
case up.localCh <- localTrackAction{add, track}: case up.localCh <- localTrackAction{add, track}:
case <-up.readerDone: case <-up.readerDone:
} }
} }
func (up *rtpUpTrack) addLocal(local downTrack) error { func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
up.mu.Lock() up.mu.Lock()
for _, t := range up.local { for _, t := range up.local {
if t == local { if t == local {
...@@ -226,7 +223,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error { ...@@ -226,7 +223,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error {
return nil return nil
} }
func (up *rtpUpTrack) delLocal(local downTrack) bool { func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool {
up.mu.Lock() up.mu.Lock()
for i, l := range up.local { for i, l := range up.local {
if l == local { if l == local {
...@@ -240,15 +237,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool { ...@@ -240,15 +237,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool {
return false return false
} }
func (up *rtpUpTrack) getLocal() []downTrack { func (up *rtpUpTrack) getLocal() []conn.DownTrack {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
local := make([]downTrack, len(up.local)) local := make([]conn.DownTrack, len(up.local))
copy(local, up.local) copy(local, up.local)
return local return local
} }
func (up *rtpUpTrack) getRTP(seqno uint16, result []byte) uint16 { func (up *rtpUpTrack) GetRTP(seqno uint16, result []byte) uint16 {
return up.cache.Get(seqno, result) return up.cache.Get(seqno, result)
} }
...@@ -278,7 +275,7 @@ type rtpUpConnection struct { ...@@ -278,7 +275,7 @@ type rtpUpConnection struct {
mu sync.Mutex mu sync.Mutex
tracks []*rtpUpTrack tracks []*rtpUpTrack
local []downConnection local []conn.Down
} }
func (up *rtpUpConnection) getTracks() []*rtpUpTrack { func (up *rtpUpConnection) getTracks() []*rtpUpTrack {
...@@ -297,7 +294,7 @@ func (up *rtpUpConnection) Label() string { ...@@ -297,7 +294,7 @@ func (up *rtpUpConnection) Label() string {
return up.label return up.label
} }
func (up *rtpUpConnection) addLocal(local downConnection) error { func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
for _, t := range up.local { for _, t := range up.local {
...@@ -309,7 +306,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error { ...@@ -309,7 +306,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error {
return nil return nil
} }
func (up *rtpUpConnection) delLocal(local downConnection) bool { func (up *rtpUpConnection) DelLocal(local conn.Down) bool {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
for i, l := range up.local { for i, l := range up.local {
...@@ -321,10 +318,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool { ...@@ -321,10 +318,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool {
return false return false
} }
func (up *rtpUpConnection) getLocal() []downConnection { func (up *rtpUpConnection) getLocal() []conn.Down {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
local := make([]downConnection, len(up.local)) local := make([]conn.Down, len(up.local))
copy(local, up.local) copy(local, up.local)
return local return local
} }
...@@ -370,8 +367,8 @@ func (up *rtpUpConnection) complete() bool { ...@@ -370,8 +367,8 @@ func (up *rtpUpConnection) complete() bool {
return true return true
} }
func newUpConn(c client, id string) (*rtpUpConnection, error) { func newUpConn(c group.Client, id string) (*rtpUpConnection, error) {
pc, err := c.Group().API().NewPeerConnection(iceConfiguration()) pc, err := c.Group().API().NewPeerConnection(group.IceConfiguration())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -396,10 +393,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -396,10 +393,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return nil, err return nil, err
} }
conn := &rtpUpConnection{id: id, pc: pc} up := &rtpUpConnection{id: id, pc: pc}
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
conn.mu.Lock() up.mu.Lock()
mid := getTrackMid(pc, remote) mid := getTrackMid(pc, remote)
if mid == "" { if mid == "" {
...@@ -407,7 +404,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -407,7 +404,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return return
} }
label, ok := conn.labels[mid] label, ok := up.labels[mid]
if !ok { if !ok {
log.Printf("Couldn't get track's label") log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
...@@ -428,34 +425,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -428,34 +425,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
} }
conn.tracks = append(conn.tracks, track) up.tracks = append(up.tracks, track)
go readLoop(conn, track) go readLoop(up, track)
go rtcpUpListener(conn, track, receiver) go rtcpUpListener(up, track, receiver)
complete := conn.complete() complete := up.complete()
var tracks []upTrack var tracks []conn.UpTrack
if(complete) { if complete {
tracks = make([]upTrack, len(conn.tracks)) tracks = make([]conn.UpTrack, len(up.tracks))
for i, t := range conn.tracks { for i, t := range up.tracks {
tracks[i] = t tracks[i] = t
} }
} }
// pushConn might need to take the lock // pushConn might need to take the lock
conn.mu.Unlock() up.mu.Unlock()
if complete { if complete {
clients := c.Group().getClients(c) clients := c.Group().GetClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.pushConn(conn.id, conn, tracks, conn.label) cc.PushConn(up.id, up, tracks, up.label)
} }
go rtcpUpSender(conn) go rtcpUpSender(up)
} }
}) })
return conn, nil return up, nil
} }
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
...@@ -606,7 +603,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) { ...@@ -606,7 +603,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks { for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() { for _, seqno := range nack.PacketList() {
l := track.remote.getRTP(seqno, buf) l := track.remote.GetRTP(seqno, buf)
if l == 0 { if l == 0 {
continue continue
} }
...@@ -650,7 +647,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei ...@@ -650,7 +647,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.srRTPTime = p.RTPTime track.srRTPTime = p.RTPTime
track.mu.Unlock() track.mu.Unlock()
for _, l := range local { for _, l := range local {
l.setTimeOffset(p.NTPTime, p.RTPTime) l.SetTimeOffset(p.NTPTime, p.RTPTime)
} }
case *rtcp.SourceDescription: case *rtcp.SourceDescription:
for _, c := range p.Chunks { for _, c := range p.Chunks {
...@@ -665,7 +662,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei ...@@ -665,7 +662,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.cname = i.Text track.cname = i.Text
track.mu.Unlock() track.mu.Unlock()
for _, l := range local { for _, l := range local {
l.setCname(i.Text) l.SetCname(i.Text)
} }
} }
} }
...@@ -749,8 +746,8 @@ func sendUpRTCP(conn *rtpUpConnection) error { ...@@ -749,8 +746,8 @@ func sendUpRTCP(conn *rtpUpConnection) error {
rate = r rate = r
} }
} }
if rate < minBitrate { if rate < group.MinBitrate {
rate = minBitrate rate = group.MinBitrate
} }
var ssrcs []uint32 var ssrcs []uint32
......
package main package rtpconn
import ( import (
"sort" "sort"
...@@ -6,76 +6,20 @@ import ( ...@@ -6,76 +6,20 @@ import (
"time" "time"
"sfu/rtptime" "sfu/rtptime"
"sfu/stats"
) )
type groupStats struct { func (c *webClient) GetStats() *stats.Client {
name string
clients []clientStats
}
type clientStats struct {
id string
up, down []connStats
}
type connStats struct {
id string
maxBitrate uint64
tracks []trackStats
}
type trackStats struct {
bitrate uint64
maxBitrate uint64
loss uint8
rtt time.Duration
jitter time.Duration
}
func getGroupStats() []groupStats {
names := getGroupNames()
gs := make([]groupStats, 0, len(names))
for _, name := range names {
g := getGroup(name)
if g == nil {
continue
}
clients := g.getClients(nil)
stats := groupStats{
name: name,
clients: make([]clientStats, 0, len(clients)),
}
for _, c := range clients {
c, ok := c.(*webClient)
if ok {
cs := getClientStats(c)
stats.clients = append(stats.clients, cs)
}
}
sort.Slice(stats.clients, func(i, j int) bool {
return stats.clients[i].id < stats.clients[j].id
})
gs = append(gs, stats)
}
sort.Slice(gs, func(i, j int) bool {
return gs[i].name < gs[j].name
})
return gs
}
func getClientStats(c *webClient) clientStats {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
cs := clientStats{ cs := stats.Client{
id: c.id, Id: c.id,
} }
for _, up := range c.up { for _, up := range c.up {
conns := connStats{ conns := stats.Conn{
id: up.id, Id: up.id,
} }
tracks := up.getTracks() tracks := up.getTracks()
for _, t := range tracks { for _, t := range tracks {
...@@ -87,23 +31,23 @@ func getClientStats(c *webClient) clientStats { ...@@ -87,23 +31,23 @@ func getClientStats(c *webClient) clientStats {
jitter := time.Duration(t.jitter.Jitter()) * jitter := time.Duration(t.jitter.Jitter()) *
(time.Second / time.Duration(t.jitter.HZ())) (time.Second / time.Duration(t.jitter.HZ()))
rate, _ := t.rate.Estimate() rate, _ := t.rate.Estimate()
conns.tracks = append(conns.tracks, trackStats{ conns.Tracks = append(conns.Tracks, stats.Track{
bitrate: uint64(rate) * 8, Bitrate: uint64(rate) * 8,
loss: loss, Loss: loss,
jitter: jitter, Jitter: jitter,
}) })
} }
cs.up = append(cs.up, conns) cs.Up = append(cs.Up, conns)
} }
sort.Slice(cs.up, func(i, j int) bool { sort.Slice(cs.Up, func(i, j int) bool {
return cs.up[i].id < cs.up[j].id return cs.Up[i].Id < cs.Up[j].Id
}) })
jiffies := rtptime.Jiffies() jiffies := rtptime.Jiffies()
for _, down := range c.down { for _, down := range c.down {
conns := connStats{ conns := stats.Conn{
id: down.id, Id: down.id,
maxBitrate: down.GetMaxBitrate(jiffies), MaxBitrate: down.GetMaxBitrate(jiffies),
} }
for _, t := range down.tracks { for _, t := range down.tracks {
rate, _ := t.rate.Estimate() rate, _ := t.rate.Estimate()
...@@ -112,19 +56,19 @@ func getClientStats(c *webClient) clientStats { ...@@ -112,19 +56,19 @@ func getClientStats(c *webClient) clientStats {
loss, jitter := t.stats.Get(jiffies) loss, jitter := t.stats.Get(jiffies)
j := time.Duration(jitter) * time.Second / j := time.Duration(jitter) * time.Second /
time.Duration(t.track.Codec().ClockRate) time.Duration(t.track.Codec().ClockRate)
conns.tracks = append(conns.tracks, trackStats{ conns.Tracks = append(conns.Tracks, stats.Track{
bitrate: uint64(rate) * 8, Bitrate: uint64(rate) * 8,
maxBitrate: t.maxBitrate.Get(jiffies), MaxBitrate: t.maxBitrate.Get(jiffies),
loss: uint8(uint32(loss) * 100 / 256), Loss: uint8(uint32(loss) * 100 / 256),
rtt: rtt, Rtt: rtt,
jitter: j, Jitter: j,
}) })
} }
cs.down = append(cs.down, conns) cs.Down = append(cs.Down, conns)
} }
sort.Slice(cs.down, func(i, j int) bool { sort.Slice(cs.Down, func(i, j int) bool {
return cs.down[i].id < cs.down[j].id return cs.Down[i].Id < cs.Down[j].Id
}) })
return cs return &cs
} }
package main package rtpconn
import ( import (
"errors" "errors"
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
"sfu/conn"
"sfu/packetcache" "sfu/packetcache"
"sfu/rtptime" "sfu/rtptime"
) )
...@@ -43,7 +44,7 @@ func sqrt(n int) int { ...@@ -43,7 +44,7 @@ func sqrt(n int) int {
} }
// add adds or removes a track from a writer pool // add adds or removes a track from a writer pool
func (wp *rtpWriterPool) add(track downTrack, add bool) error { func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
n := 4 n := 4
if wp.count > 16 { if wp.count > 16 {
n = sqrt(wp.count) n = sqrt(wp.count)
...@@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track") ...@@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track")
type writerAction struct { type writerAction struct {
add bool add bool
track downTrack track conn.DownTrack
maxTracks int maxTracks int
ch chan error ch chan error
} }
...@@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter { ...@@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter {
} }
// add adds or removes a track from a writer. // add adds or removes a track from a writer.
func (writer *rtpWriter) add(track downTrack, add bool, max int) error { func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
ch := make(chan error, 1) ch := make(chan error, 1)
select { select {
case writer.action <- writerAction{add, track, max, ch}: case writer.action <- writerAction{add, track, max, ch}:
...@@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error { ...@@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error {
} }
// rtpWriterLoop is the main loop of an rtpWriter. // rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) { func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done) defer close(writer.done)
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet var packet rtp.Packet
local := make([]downTrack, 0) local := make([]conn.DownTrack, 0)
// reset whenever a new track is inserted // reset whenever a new track is inserted
firSent := false firSent := false
...@@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
cname := track.cname cname := track.cname
track.mu.Unlock() track.mu.Unlock()
if ntp != 0 { if ntp != 0 {
action.track.setTimeOffset(ntp, rtp) action.track.SetTimeOffset(ntp, rtp)
} }
if cname != "" { if cname != "" {
action.track.setCname(cname) action.track.SetCname(cname)
} }
} else { } else {
found := false found := false
...@@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
for _, l := range local { for _, l := range local {
err := l.WriteRTP(&packet) err := l.WriteRTP(&packet)
if err != nil { if err != nil {
if err == ErrKeyframeNeeded { if err == conn.ErrKeyframeNeeded {
kfNeeded = true kfNeeded = true
} }
continue continue
...@@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
} }
if kfNeeded { if kfNeeded {
err := conn.sendFIR(track, !firSent) err := up.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback { if err == ErrUnsupportedFeedback {
conn.sendPLI(track) up.sendPLI(track)
} }
firSent = true firSent = true
} }
......
// Copyright (c) 2020 by Juliusz Chroboczek. package rtpconn
// This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist.
package main
import ( import (
"encoding/json" "encoding/json"
...@@ -13,49 +8,14 @@ import ( ...@@ -13,49 +8,14 @@ import (
"sync" "sync"
"time" "time"
"sfu/estimator"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
)
var iceConf webrtc.Configuration
var iceOnce sync.Once
func iceConfiguration() webrtc.Configuration {
iceOnce.Do(func() {
var iceServers []webrtc.ICEServer
file, err := os.Open(iceFilename)
if err != nil {
log.Printf("Open %v: %v", iceFilename, err)
return
}
defer file.Close()
d := json.NewDecoder(file)
err = d.Decode(&iceServers)
if err != nil {
log.Printf("Get ICE configuration: %v", err)
return
}
iceConf = webrtc.Configuration{
ICEServers: iceServers,
}
})
return iceConf
}
type protocolError string
func (err protocolError) Error() string { "sfu/conn"
return string(err) "sfu/disk"
} "sfu/estimator"
"sfu/group"
type userError string )
func (err userError) Error() string {
return string(err)
}
func errorToWSCloseMessage(err error) (string, []byte) { func errorToWSCloseMessage(err error) (string, []byte) {
var code int var code int
...@@ -63,10 +23,10 @@ func errorToWSCloseMessage(err error) (string, []byte) { ...@@ -63,10 +23,10 @@ func errorToWSCloseMessage(err error) (string, []byte) {
switch e := err.(type) { switch e := err.(type) {
case *websocket.CloseError: case *websocket.CloseError:
code = websocket.CloseNormalClosure code = websocket.CloseNormalClosure
case protocolError: case group.ProtocolError:
code = websocket.CloseProtocolError code = websocket.CloseProtocolError
text = string(e) text = string(e)
case userError: case group.UserError:
code = websocket.CloseNormalClosure code = websocket.CloseNormalClosure
text = string(e) text = string(e)
default: default:
...@@ -82,10 +42,10 @@ func isWSNormalError(err error) bool { ...@@ -82,10 +42,10 @@ func isWSNormalError(err error) bool {
} }
type webClient struct { type webClient struct {
group *group group *group.Group
id string id string
credentials clientCredentials credentials group.ClientCredentials
permissions clientPermissions permissions group.ClientPermissions
requested map[string]uint32 requested map[string]uint32
done chan struct{} done chan struct{}
writeCh chan interface{} writeCh chan interface{}
...@@ -97,7 +57,7 @@ type webClient struct { ...@@ -97,7 +57,7 @@ type webClient struct {
up map[string]*rtpUpConnection up map[string]*rtpUpConnection
} }
func (c *webClient) Group() *group { func (c *webClient) Group() *group.Group {
return c.group return c.group
} }
...@@ -105,15 +65,15 @@ func (c *webClient) Id() string { ...@@ -105,15 +65,15 @@ func (c *webClient) Id() string {
return c.id return c.id
} }
func (c *webClient) Credentials() clientCredentials { func (c *webClient) Credentials() group.ClientCredentials {
return c.credentials return c.credentials
} }
func (c *webClient) SetPermissions(perms clientPermissions) { func (c *webClient) SetPermissions(perms group.ClientPermissions) {
c.permissions = perms c.permissions = perms
} }
func (c *webClient) pushClient(id, username string, add bool) error { func (c *webClient) PushClient(id, username string, add bool) error {
kind := "add" kind := "add"
if !add { if !add {
kind = "delete" kind = "delete"
...@@ -179,7 +139,7 @@ type clientMessage struct { ...@@ -179,7 +139,7 @@ type clientMessage struct {
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
Permissions clientPermissions `json:"permissions,omitempty"` Permissions group.ClientPermissions `json:"permissions,omitempty"`
Group string `json:"group,omitempty"` Group string `json:"group,omitempty"`
Value string `json:"value,omitempty"` Value string `json:"value,omitempty"`
Offer *webrtc.SessionDescription `json:"offer,omitempty"` Offer *webrtc.SessionDescription `json:"offer,omitempty"`
...@@ -263,11 +223,11 @@ func delUpConn(c *webClient, id string) bool { ...@@ -263,11 +223,11 @@ func delUpConn(c *webClient, id string) bool {
delete(c.up, id) delete(c.up, id)
c.mu.Unlock() c.mu.Unlock()
go func(clients []client) { go func(clients []group.Client) {
for _, c := range clients { for _, c := range clients {
c.pushConn(conn.id, nil, nil, "") c.PushConn(conn.id, nil, nil, "")
} }
}(c.Group().getClients(c)) }(c.Group().GetClients(c))
conn.pc.Close() conn.pc.Close()
return true return true
...@@ -299,7 +259,7 @@ func getConn(c *webClient, id string) iceConnection { ...@@ -299,7 +259,7 @@ func getConn(c *webClient, id string) iceConnection {
return nil return nil
} }
func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnection, error) { func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) {
conn, err := newDownConn(c, id, remote) conn, err := newDownConn(c, id, remote)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -332,7 +292,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti ...@@ -332,7 +292,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti
} }
}) })
err = remote.addLocal(conn) err = remote.AddLocal(conn)
if err != nil { if err != nil {
conn.pc.Close() conn.pc.Close()
return nil, err return nil, err
...@@ -354,18 +314,18 @@ func delDownConn(c *webClient, id string) bool { ...@@ -354,18 +314,18 @@ func delDownConn(c *webClient, id string) bool {
return false return false
} }
conn.remote.delLocal(conn) conn.remote.DelLocal(conn)
for _, track := range conn.tracks { for _, track := range conn.tracks {
// we only insert the track after we get an answer, so // we only insert the track after we get an answer, so
// ignore errors here. // ignore errors here.
track.remote.delLocal(track) track.remote.DelLocal(track)
} }
conn.pc.Close() conn.pc.Close()
delete(c.down, id) delete(c.down, id)
return true return true
} }
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, remoteConn upConnection) (*webrtc.RTPSender, error) { func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) {
var pt uint8 var pt uint8
var ssrc uint32 var ssrc uint32
var id, label string var id, label string
...@@ -510,7 +470,7 @@ func gotOffer(c *webClient, id string, offer webrtc.SessionDescription, renegoti ...@@ -510,7 +470,7 @@ func gotOffer(c *webClient, id string, offer webrtc.SessionDescription, renegoti
func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error { func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error {
down := getDownConn(c, id) down := getDownConn(c, id)
if down == nil { if down == nil {
return protocolError("unknown id in answer") return group.ProtocolError("unknown id in answer")
} }
err := down.pc.SetRemoteDescription(answer) err := down.pc.SetRemoteDescription(answer)
if err != nil { if err != nil {
...@@ -523,7 +483,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error ...@@ -523,7 +483,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error
} }
for _, t := range down.tracks { for _, t := range down.tracks {
t.remote.addLocal(t) t.remote.AddLocal(t)
} }
return nil return nil
} }
...@@ -553,8 +513,8 @@ func (c *webClient) setRequested(requested map[string]uint32) error { ...@@ -553,8 +513,8 @@ func (c *webClient) setRequested(requested map[string]uint32) error {
return nil return nil
} }
func pushConns(c client) { func pushConns(c group.Client) {
clients := c.Group().getClients(c) clients := c.Group().GetClients(c)
for _, cc := range clients { for _, cc := range clients {
ccc, ok := cc.(*webClient) ccc, ok := cc.(*webClient)
if ok { if ok {
...@@ -567,7 +527,7 @@ func (c *webClient) isRequested(label string) bool { ...@@ -567,7 +527,7 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0 return c.requested[label] != 0
} }
func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rtpDownConnection, error) { func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false requested := false
for _, t := range tracks { for _, t := range tracks {
if c.isRequested(t.Label()) { if c.isRequested(t.Label()) {
...@@ -600,13 +560,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt ...@@ -600,13 +560,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt
return down, nil return down, nil
} }
func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { func (c *webClient) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{id, conn, tracks}) err := c.action(pushConnAction{id, up, tracks})
if err != nil { if err != nil {
return err return err
} }
if conn != nil && label != "" { if up != nil && label != "" {
err := c.action(addLabelAction{conn.Id(), conn.Label()}) err := c.action(addLabelAction{up.Id(), up.Label()})
if err != nil { if err != nil {
return err return err
} }
...@@ -614,7 +574,7 @@ func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, lab ...@@ -614,7 +574,7 @@ func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, lab
return nil return nil
} }
func startClient(conn *websocket.Conn) (err error) { func StartClient(conn *websocket.Conn) (err error) {
var m clientMessage var m clientMessage
err = conn.SetReadDeadline(time.Now().Add(15 * time.Second)) err = conn.SetReadDeadline(time.Now().Add(15 * time.Second))
...@@ -646,7 +606,7 @@ func startClient(conn *websocket.Conn) (err error) { ...@@ -646,7 +606,7 @@ func startClient(conn *websocket.Conn) (err error) {
c := &webClient{ c := &webClient{
id: m.Id, id: m.Id,
credentials: clientCredentials{ credentials: group.ClientCredentials{
m.Username, m.Username,
m.Password, m.Password,
}, },
...@@ -683,32 +643,32 @@ func startClient(conn *websocket.Conn) (err error) { ...@@ -683,32 +643,32 @@ func startClient(conn *websocket.Conn) (err error) {
} }
if m.Type != "join" { if m.Type != "join" {
return protocolError("you must join a group first") return group.ProtocolError("you must join a group first")
} }
g, err := addClient(m.Group, c) g, err := group.AddClient(m.Group, c)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
err = userError("group does not exist") err = group.UserError("group does not exist")
} }
return return
} }
if redirect := g.Redirect(); redirect != "" { if redirect := g.Redirect(); redirect != "" {
// We normally redirect at the HTTP level, but the group // We normally redirect at the HTTP level, but the group
// description could have been edited in the meantime. // description could have been edited in the meantime.
err = userError("group is now at " + redirect) err = group.UserError("group is now at " + redirect)
return return
} }
c.group = g c.group = g
defer delClient(c) defer group.DelClient(c)
return clientLoop(c, conn) return clientLoop(c, conn)
} }
type pushConnAction struct { type pushConnAction struct {
id string id string
conn upConnection conn conn.Up
tracks []upTrack tracks []conn.UpTrack
} }
type addLabelAction struct { type addLabelAction struct {
...@@ -717,7 +677,7 @@ type addLabelAction struct { ...@@ -717,7 +677,7 @@ type addLabelAction struct {
} }
type pushConnsAction struct { type pushConnsAction struct {
c client c group.Client
} }
type connectionFailedAction struct { type connectionFailedAction struct {
...@@ -730,9 +690,9 @@ type kickAction struct { ...@@ -730,9 +690,9 @@ type kickAction struct {
message string message string
} }
func clientLoop(c *webClient, conn *websocket.Conn) error { func clientLoop(c *webClient, ws *websocket.Conn) error {
read := make(chan interface{}, 1) read := make(chan interface{}, 1)
go clientReader(conn, read, c.done) go clientReader(ws, read, c.done)
defer func() { defer func() {
c.setRequested(map[string]uint32{}) c.setRequested(map[string]uint32{})
...@@ -748,14 +708,14 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -748,14 +708,14 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
Permissions: c.permissions, Permissions: c.permissions,
}) })
h := c.group.getChatHistory() h := c.group.GetChatHistory()
for _, m := range h { for _, m := range h {
err := c.write(clientMessage{ err := c.write(clientMessage{
Type: "chat", Type: "chat",
Id: m.id, Id: m.Id,
Username: m.user, Username: m.User,
Value: m.value, Value: m.Value,
Kind: m.kind, Kind: m.Kind,
}) })
if err != nil { if err != nil {
return err return err
...@@ -829,11 +789,11 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -829,11 +789,11 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
case pushConnsAction: case pushConnsAction:
for _, u := range c.up { for _, u := range c.up {
tracks := u.getTracks() tracks := u.getTracks()
ts := make([]upTrack, len(tracks)) ts := make([]conn.UpTrack, len(tracks))
for i, t := range tracks { for i, t := range tracks {
ts[i] = t ts[i] = t
} }
go a.c.pushConn(u.id, u, ts, u.label) go a.c.PushConn(u.id, u, ts, u.label)
} }
case connectionFailedAction: case connectionFailedAction:
if down := getDownConn(c, a.id); down != nil { if down := getDownConn(c, a.id); down != nil {
...@@ -842,12 +802,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -842,12 +802,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
return err return err
} }
tracks := make( tracks := make(
[]upTrack, len(down.tracks), []conn.UpTrack, len(down.tracks),
) )
for i, t := range down.tracks { for i, t := range down.tracks {
tracks[i] = t.remote tracks[i] = t.remote
} }
go c.pushConn( go c.PushConn(
down.remote.Id(), down.remote, down.remote.Id(), down.remote,
tracks, down.remote.Label(), tracks, down.remote.Label(),
) )
...@@ -879,7 +839,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -879,7 +839,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
} }
} }
case kickAction: case kickAction:
return userError(a.message) return group.UserError(a.message)
default: default:
log.Printf("unexpected action %T", a) log.Printf("unexpected action %T", a)
return errors.New("unexpected action") return errors.New("unexpected action")
...@@ -911,7 +871,7 @@ func failConnection(c *webClient, id string, message string) error { ...@@ -911,7 +871,7 @@ func failConnection(c *webClient, id string, message string) error {
} }
} }
if message != "" { if message != "" {
err := c.error(userError(message)) err := c.error(group.UserError(message))
if err != nil { if err != nil {
return err return err
} }
...@@ -919,15 +879,15 @@ func failConnection(c *webClient, id string, message string) error { ...@@ -919,15 +879,15 @@ func failConnection(c *webClient, id string, message string) error {
return nil return nil
} }
func setPermissions(g *group, id string, perm string) error { func setPermissions(g *group.Group, id string, perm string) error {
client := g.getClient(id) client := g.GetClient(id)
if client == nil { if client == nil {
return userError("no such user") return group.UserError("no such user")
} }
c, ok := client.(*webClient) c, ok := client.(*webClient)
if !ok { if !ok {
return userError("this is not a real user") return group.UserError("this is not a real user")
} }
switch perm { switch perm {
...@@ -944,7 +904,7 @@ func setPermissions(g *group, id string, perm string) error { ...@@ -944,7 +904,7 @@ func setPermissions(g *group, id string, perm string) error {
case "unpresent": case "unpresent":
c.permissions.Present = false c.permissions.Present = false
default: default:
return userError("unknown permission") return group.UserError("unknown permission")
} }
return c.action(permissionsChangedAction{}) return c.action(permissionsChangedAction{})
} }
...@@ -953,18 +913,18 @@ func (c *webClient) kick(message string) error { ...@@ -953,18 +913,18 @@ func (c *webClient) kick(message string) error {
return c.action(kickAction{message}) return c.action(kickAction{message})
} }
func kickClient(g *group, id string, message string) error { func kickClient(g *group.Group, id string, message string) error {
client := g.getClient(id) client := g.GetClient(id)
if client == nil { if client == nil {
return userError("no such user") return group.UserError("no such user")
} }
c, ok := client.(kickable) c, ok := client.(group.Kickable)
if !ok { if !ok {
return userError("this client is not kickable") return group.UserError("this client is not kickable")
} }
return c.kick(message) return c.Kick(message)
} }
func handleClientMessage(c *webClient, m clientMessage) error { func handleClientMessage(c *webClient, m clientMessage) error {
...@@ -980,10 +940,10 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -980,10 +940,10 @@ func handleClientMessage(c *webClient, m clientMessage) error {
Type: "abort", Type: "abort",
Id: m.Id, Id: m.Id,
}) })
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
if m.Offer == nil { if m.Offer == nil {
return protocolError("null offer") return group.ProtocolError("null offer")
} }
err := gotOffer( err := gotOffer(
c, m.Id, *m.Offer, m.Kind == "renegotiate", m.Labels, c, m.Id, *m.Offer, m.Kind == "renegotiate", m.Labels,
...@@ -994,7 +954,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -994,7 +954,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
case "answer": case "answer":
if m.Answer == nil { if m.Answer == nil {
return protocolError("null answer") return group.ProtocolError("null answer")
} }
err := gotAnswer(c, m.Id, *m.Answer) err := gotAnswer(c, m.Id, *m.Answer)
if err != nil { if err != nil {
...@@ -1017,15 +977,15 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1017,15 +977,15 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
case "ice": case "ice":
if m.Candidate == nil { if m.Candidate == nil {
return protocolError("null candidate") return group.ProtocolError("null candidate")
} }
err := gotICE(c, m.Candidate, m.Id) err := gotICE(c, m.Candidate, m.Id)
if err != nil { if err != nil {
log.Printf("ICE: %v", err) log.Printf("ICE: %v", err)
} }
case "chat": case "chat":
c.group.addToChatHistory(m.Id, m.Username, m.Kind, m.Value) c.group.AddToChatHistory(m.Id, m.Username, m.Kind, m.Value)
clients := c.group.getClients(nil) clients := c.group.GetClients(nil)
for _, cc := range clients { for _, cc := range clients {
cc, ok := cc.(*webClient) cc, ok := cc.(*webClient)
if ok { if ok {
...@@ -1035,9 +995,9 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1035,9 +995,9 @@ func handleClientMessage(c *webClient, m clientMessage) error {
case "groupaction": case "groupaction":
switch m.Kind { switch m.Kind {
case "clearchat": case "clearchat":
c.group.clearChatHistory() c.group.ClearChatHistory()
m := clientMessage{Type: "clearchat"} m := clientMessage{Type: "clearchat"}
clients := c.group.getClients(nil) clients := c.group.GetClients(nil)
for _, cc := range clients { for _, cc := range clients {
cc, ok := cc.(*webClient) cc, ok := cc.(*webClient)
if ok { if ok {
...@@ -1046,21 +1006,21 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1046,21 +1006,21 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
case "lock", "unlock": case "lock", "unlock":
if !c.permissions.Op { if !c.permissions.Op {
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
c.group.SetLocked(m.Kind == "lock") c.group.SetLocked(m.Kind == "lock")
case "record": case "record":
if !c.permissions.Record { if !c.permissions.Record {
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
for _, cc := range c.group.getClients(c) { for _, cc := range c.group.GetClients(c) {
_, ok := cc.(*diskClient) _, ok := cc.(*disk.Client)
if ok { if ok {
return c.error(userError("already recording")) return c.error(group.UserError("already recording"))
} }
} }
disk := NewDiskClient(c.group) disk := disk.New(c.group)
_, err := addClient(c.group.name, disk) _, err := group.AddClient(c.group.Name(), disk)
if err != nil { if err != nil {
disk.Close() disk.Close()
return c.error(err) return c.error(err)
...@@ -1068,23 +1028,23 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1068,23 +1028,23 @@ func handleClientMessage(c *webClient, m clientMessage) error {
go pushConns(disk) go pushConns(disk)
case "unrecord": case "unrecord":
if !c.permissions.Record { if !c.permissions.Record {
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
for _, cc := range c.group.getClients(c) { for _, cc := range c.group.GetClients(c) {
disk, ok := cc.(*diskClient) disk, ok := cc.(*disk.Client)
if ok { if ok {
disk.Close() disk.Close()
delClient(disk) group.DelClient(disk)
} }
} }
default: default:
return protocolError("unknown group action") return group.ProtocolError("unknown group action")
} }
case "useraction": case "useraction":
switch m.Kind { switch m.Kind {
case "op", "unop", "present", "unpresent": case "op", "unop", "present", "unpresent":
if !c.permissions.Op { if !c.permissions.Op {
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
err := setPermissions(c.group, m.Id, m.Kind) err := setPermissions(c.group, m.Id, m.Kind)
if err != nil { if err != nil {
...@@ -1092,7 +1052,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1092,7 +1052,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
case "kick": case "kick":
if !c.permissions.Op { if !c.permissions.Op {
return c.error(userError("not authorised")) return c.error(group.UserError("not authorised"))
} }
message := m.Value message := m.Value
if message == "" { if message == "" {
...@@ -1103,7 +1063,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1103,7 +1063,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return c.error(err) return c.error(err)
} }
default: default:
return protocolError("unknown user action") return group.ProtocolError("unknown user action")
} }
case "pong": case "pong":
// nothing // nothing
...@@ -1113,7 +1073,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1113,7 +1073,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
}) })
default: default:
log.Printf("unexpected message: %v", m.Type) log.Printf("unexpected message: %v", m.Type)
return protocolError("unexpected message") return group.ProtocolError("unexpected message")
} }
return nil return nil
} }
...@@ -1207,7 +1167,7 @@ func (c *webClient) close(data []byte) error { ...@@ -1207,7 +1167,7 @@ func (c *webClient) close(data []byte) error {
func (c *webClient) error(err error) error { func (c *webClient) error(err error) error {
switch e := err.(type) { switch e := err.(type) {
case userError: case group.UserError:
return c.write(clientMessage{ return c.write(clientMessage{
Type: "usermessage", Type: "usermessage",
Kind: "error", Kind: "error",
......
...@@ -14,14 +14,14 @@ import ( ...@@ -14,14 +14,14 @@ import (
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"syscall" "syscall"
"sfu/disk"
"sfu/group"
) )
var httpAddr string var httpAddr string
var staticRoot string var staticRoot string
var dataDir string var dataDir string
var groupsDir string
var recordingsDir string
var iceFilename string
func main() { func main() {
var cpuprofile, memprofile, mutexprofile string var cpuprofile, memprofile, mutexprofile string
...@@ -31,9 +31,9 @@ func main() { ...@@ -31,9 +31,9 @@ func main() {
"web server root `directory`") "web server root `directory`")
flag.StringVar(&dataDir, "data", "./data/", flag.StringVar(&dataDir, "data", "./data/",
"data `directory`") "data `directory`")
flag.StringVar(&groupsDir, "groups", "./groups/", flag.StringVar(&group.Directory, "groups", "./groups/",
"group description `directory`") "group description `directory`")
flag.StringVar(&recordingsDir, "recordings", "./recordings/", flag.StringVar(&disk.Directory, "recordings", "./recordings/",
"recordings `directory`") "recordings `directory`")
flag.StringVar(&cpuprofile, "cpuprofile", "", flag.StringVar(&cpuprofile, "cpuprofile", "",
"store CPU profile in `file`") "store CPU profile in `file`")
...@@ -81,9 +81,9 @@ func main() { ...@@ -81,9 +81,9 @@ func main() {
}() }()
} }
iceFilename = filepath.Join(dataDir, "ice-servers.json") group.IceFilename = filepath.Join(dataDir, "ice-servers.json")
go readPublicGroups() go group.ReadPublicGroups()
webserver() webserver()
terminate := make(chan os.Signal, 1) terminate := make(chan os.Signal, 1)
......
package stats
import (
"sort"
"time"
"sfu/group"
)
type GroupStats struct {
Name string
Clients []*Client
}
type Client struct {
Id string
Up, Down []Conn
}
type Statable interface {
GetStats() *Client
}
type Conn struct {
Id string
MaxBitrate uint64
Tracks []Track
}
type Track struct {
Bitrate uint64
MaxBitrate uint64
Loss uint8
Rtt time.Duration
Jitter time.Duration
}
func GetGroups() []GroupStats {
names := group.GetNames()
gs := make([]GroupStats, 0, len(names))
for _, name := range names {
g := group.Get(name)
if g == nil {
continue
}
clients := g.GetClients(nil)
stats := GroupStats{
Name: name,
Clients: make([]*Client, 0, len(clients)),
}
for _, c := range clients {
s, ok := c.(Statable)
if ok {
cs := s.GetStats()
stats.Clients = append(stats.Clients, cs)
} else {
stats.Clients = append(stats.Clients,
&Client{Id: c.Id()},
)
}
}
sort.Slice(stats.Clients, func(i, j int) bool {
return stats.Clients[i].Id < stats.Clients[j].Id
})
gs = append(gs, stats)
}
sort.Slice(gs, func(i, j int) bool {
return gs[i].Name < gs[j].Name
})
return gs
}
...@@ -18,6 +18,11 @@ import ( ...@@ -18,6 +18,11 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"sfu/disk"
"sfu/group"
"sfu/rtpconn"
"sfu/stats"
) )
var server *http.Server var server *http.Server
...@@ -47,8 +52,8 @@ func webserver() { ...@@ -47,8 +52,8 @@ func webserver() {
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
} }
server.RegisterOnShutdown(func() { server.RegisterOnShutdown(func() {
rangeGroups(func (g *group) bool { group.Range(func(g *group.Group) bool {
go g.shutdown("server is shutting down") go g.Shutdown("server is shutting down")
return true return true
}) })
}) })
...@@ -139,7 +144,7 @@ func groupHandler(w http.ResponseWriter, r *http.Request) { ...@@ -139,7 +144,7 @@ func groupHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
g, err := addGroup(name, nil) g, err := group.Add(name, nil)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
notFound(w) notFound(w)
...@@ -168,7 +173,7 @@ func publicHandler(w http.ResponseWriter, r *http.Request) { ...@@ -168,7 +173,7 @@ func publicHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
g := getPublicGroups() g := group.GetPublic()
e := json.NewEncoder(w) e := json.NewEncoder(w)
e.Encode(g) e.Encode(g)
return return
...@@ -222,7 +227,7 @@ func statsHandler(w http.ResponseWriter, r *http.Request) { ...@@ -222,7 +227,7 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
stats := getGroupStats() ss := stats.GetGroups()
fmt.Fprintf(w, "<!DOCTYPE html>\n<html><head>\n") fmt.Fprintf(w, "<!DOCTYPE html>\n<html><head>\n")
fmt.Fprintf(w, "<title>Stats</title>\n") fmt.Fprintf(w, "<title>Stats</title>\n")
...@@ -239,51 +244,51 @@ func statsHandler(w http.ResponseWriter, r *http.Request) { ...@@ -239,51 +244,51 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
return err return err
} }
printTrack := func(w io.Writer, t trackStats) { printTrack := func(w io.Writer, t stats.Track) {
fmt.Fprintf(w, "<tr><td></td><td></td><td></td>") fmt.Fprintf(w, "<tr><td></td><td></td><td></td>")
fmt.Fprintf(w, "<td>") fmt.Fprintf(w, "<td>")
printBitrate(w, t.bitrate, t.maxBitrate) printBitrate(w, t.Bitrate, t.MaxBitrate)
fmt.Fprintf(w, "</td>") fmt.Fprintf(w, "</td>")
fmt.Fprintf(w, "<td>%d%%</td>", fmt.Fprintf(w, "<td>%d%%</td>",
t.loss, t.Loss,
) )
fmt.Fprintf(w, "<td>") fmt.Fprintf(w, "<td>")
if t.rtt > 0 { if t.Rtt > 0 {
fmt.Fprintf(w, "%v", t.rtt) fmt.Fprintf(w, "%v", t.Rtt)
} }
if t.jitter > 0 { if t.Jitter > 0 {
fmt.Fprintf(w, "&#177;%v", t.jitter) fmt.Fprintf(w, "&#177;%v", t.Jitter)
} }
fmt.Fprintf(w, "</td>") fmt.Fprintf(w, "</td>")
fmt.Fprintf(w, "</tr>") fmt.Fprintf(w, "</tr>")
} }
for _, gs := range stats { for _, gs := range ss {
fmt.Fprintf(w, "<p>%v</p>\n", html.EscapeString(gs.name)) fmt.Fprintf(w, "<p>%v</p>\n", html.EscapeString(gs.Name))
fmt.Fprintf(w, "<table>") fmt.Fprintf(w, "<table>")
for _, cs := range gs.clients { for _, cs := range gs.Clients {
fmt.Fprintf(w, "<tr><td>%v</td></tr>\n", cs.id) fmt.Fprintf(w, "<tr><td>%v</td></tr>\n", cs.Id)
for _, up := range cs.up { for _, up := range cs.Up {
fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td>", fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td>",
up.id) up.Id)
if up.maxBitrate > 0 { if up.MaxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>", fmt.Fprintf(w, "<td>%v</td>",
up.maxBitrate) up.MaxBitrate)
} }
fmt.Fprintf(w, "</tr>\n") fmt.Fprintf(w, "</tr>\n")
for _, t := range up.tracks { for _, t := range up.Tracks {
printTrack(w, t) printTrack(w, t)
} }
} }
for _, down := range cs.down { for _, down := range cs.Down {
fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td>", fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td>",
down.id) down.Id)
if down.maxBitrate > 0 { if down.MaxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>", fmt.Fprintf(w, "<td>%v</td>",
down.maxBitrate) down.MaxBitrate)
} }
fmt.Fprintf(w, "</tr>\n") fmt.Fprintf(w, "</tr>\n")
for _, t := range down.tracks { for _, t := range down.Tracks {
printTrack(w, t) printTrack(w, t)
} }
} }
...@@ -302,7 +307,7 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { ...@@ -302,7 +307,7 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
go func() { go func() {
err := startClient(conn) err := rtpconn.StartClient(conn)
if err != nil { if err != nil {
log.Printf("client: %v", err) log.Printf("client: %v", err)
} }
...@@ -322,7 +327,7 @@ func recordingsHandler(w http.ResponseWriter, r *http.Request) { ...@@ -322,7 +327,7 @@ func recordingsHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
f, err := os.Open(filepath.Join(recordingsDir, pth)) f, err := os.Open(filepath.Join(disk.Directory, pth))
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
notFound(w) notFound(w)
...@@ -389,7 +394,7 @@ func handleGroupAction(w http.ResponseWriter, r *http.Request, group string) { ...@@ -389,7 +394,7 @@ func handleGroupAction(w http.ResponseWriter, r *http.Request, group string) {
return return
} }
err := os.Remove( err := os.Remove(
filepath.Join(recordingsDir, group+"/"+filename), filepath.Join(disk.Directory, group+"/"+filename),
) )
if err != nil { if err != nil {
if os.IsPermission(err) { if os.IsPermission(err) {
...@@ -409,8 +414,8 @@ func handleGroupAction(w http.ResponseWriter, r *http.Request, group string) { ...@@ -409,8 +414,8 @@ func handleGroupAction(w http.ResponseWriter, r *http.Request, group string) {
} }
} }
func checkGroupPermissions(w http.ResponseWriter, r *http.Request, group string) bool { func checkGroupPermissions(w http.ResponseWriter, r *http.Request, groupname string) bool {
desc, err := getDescription(group) desc, err := group.GetDescription(groupname)
if err != nil { if err != nil {
return false return false
} }
...@@ -420,7 +425,7 @@ func checkGroupPermissions(w http.ResponseWriter, r *http.Request, group string) ...@@ -420,7 +425,7 @@ func checkGroupPermissions(w http.ResponseWriter, r *http.Request, group string)
return false return false
} }
p, err := getPermission(desc, clientCredentials{user, pass}) p, err := desc.GetPermission(group.ClientCredentials{user, pass})
if err != nil || !p.Record { if err != nil || !p.Record {
return false return false
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment