Commit 038ab46d authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Use a separate track for each down connection.

parent 9c9748b8
......@@ -208,8 +208,8 @@ func startClient(conn *websocket.Conn) (err error) {
}
func getUpConn(c *client, id string) *upConnection {
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
return nil
......@@ -222,8 +222,8 @@ func getUpConn(c *client, id string) *upConnection {
}
func getUpConns(c *client) []string {
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
up := make([]string, 0, len(c.up))
for id := range c.up {
up = append(up, id)
......@@ -262,34 +262,24 @@ func addUpConn(c *client, id string) (*upConnection, error) {
})
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
local, err := pc.NewTrack(
remote.PayloadType(),
remote.SSRC(),
remote.ID(),
remote.Label())
if err != nil {
log.Printf("%v", err)
return
}
c.group.mu.Lock()
c.mu.Lock()
u, ok := c.up[id]
if !ok {
log.Printf("Unknown connection")
c.group.mu.Unlock()
c.mu.Unlock()
return
}
u.pairs = append(u.pairs, trackPair{
remote: remote,
local: local,
track := &upTrack{
track: remote,
maxBitrate: ^uint64(0),
})
done := len(u.pairs) >= u.trackCount
c.group.mu.Unlock()
}
u.tracks = append(u.tracks, track)
done := len(u.tracks) >= u.trackCount
c.mu.Unlock()
clients := c.group.getClients(c)
for _, cc := range clients {
cc.action(addTrackAction{id, local, u, done})
cc.action(addTrackAction{track, u, done})
if done && u.label != "" {
cc.action(addLabelAction{id, u.label})
}
......@@ -313,9 +303,12 @@ func addUpConn(c *client, id string) (*upConnection, error) {
continue
}
err = local.WriteRTP(&packet)
if err != nil && err != io.ErrClosedPipe {
log.Printf("%v", err)
local := track.getLocal()
for _, l := range local {
err := l.track.WriteRTP(&packet)
if err != nil {
log.Printf("%v", err)
}
}
}
}()
......@@ -323,8 +316,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
conn := &upConnection{id: id, pc: pc}
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
c.up = make(map[string]*upConnection)
......@@ -338,8 +331,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
}
func delUpConn(c *client, id string) {
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
log.Printf("Deleting unknown connection")
......@@ -357,16 +350,20 @@ func delUpConn(c *client, id string) {
id string
}
cids := make([]clientId, 0)
for _, cc := range c.group.clients {
clients := c.group.getClients(c)
for _, cc := range clients {
cc.mu.Lock()
for _, otherconn := range cc.down {
if otherconn.remote == conn {
cids = append(cids, clientId{cc, otherconn.id})
}
}
cc.mu.Unlock()
}
for _, cid := range cids {
cid.client.action(delPCAction{cid.id})
cid.client.action(delConnAction{cid.id})
}
conn.pc.Close()
......@@ -378,8 +375,8 @@ func getDownConn(c *client, id string) *downConnection {
return nil
}
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
conn := c.down[id]
if conn == nil {
return nil
......@@ -406,8 +403,8 @@ func addDownConn(c *client, id string, remote *upConnection) (*downConnection, e
}
conn := &downConnection{id: id, pc: pc, remote: remote}
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.down[id] != nil {
conn.pc.Close()
return nil, errors.New("Adding duplicate connection")
......@@ -417,8 +414,8 @@ func addDownConn(c *client, id string, remote *upConnection) (*downConnection, e
}
func delDownConn(c *client, id string) {
c.group.mu.Lock()
defer c.group.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.down == nil {
log.Printf("Deleting unknown connection")
......@@ -429,31 +426,49 @@ func delDownConn(c *client, id string) {
log.Printf("Deleting unknown connection")
return
}
for _, track := range conn.tracks {
found := track.remote.delLocal(track)
if !found {
log.Printf("Couldn't find remote track")
}
track.remote = nil
}
conn.pc.Close()
delete(c.down, id)
}
func addDownTrack(c *client, id string, track *webrtc.Track, remote *upConnection) (*downConnection, *webrtc.RTPSender, error) {
func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConnection) (*downConnection, *webrtc.RTPSender, error) {
conn := getDownConn(c, id)
if conn == nil {
var err error
conn, err = addDownConn(c, id, remote)
conn, err = addDownConn(c, id, remoteConn)
if err != nil {
return nil, nil, err
}
}
s, err := conn.pc.AddTrack(track)
local, err := conn.pc.NewTrack(
remoteTrack.track.PayloadType(),
remoteTrack.track.SSRC(),
remoteTrack.track.ID(),
remoteTrack.track.Label(),
)
if err != nil {
return nil, nil, err
}
conn.tracks = append(conn.tracks,
downTrack{track.SSRC(), new(timeStampedBitrate)},
)
s, err := conn.pc.AddTrack(local)
if err != nil {
return nil, nil, err
}
go rtcpListener(c.group, conn, s,
conn.tracks[len(conn.tracks)-1].maxBitrate)
track := &downTrack{local, remoteTrack, new(timeStampedBitrate)}
conn.tracks = append(conn.tracks, track)
remoteTrack.addLocal(track)
go rtcpListener(c.group, conn, s, track.maxBitrate)
return conn, s, nil
}
......@@ -545,44 +560,26 @@ func splitBitrate(bitrate uint32, audio, video bool) (uint32, uint32) {
return audioRate, bitrate - audioRate
}
func updateUpBitrate(g *group, up *upConnection) {
for i := range up.pairs {
up.pairs[i].maxBitrate = ^uint64(0)
func updateUpBitrate(up *upConnection) {
for _, t := range up.tracks {
t.maxBitrate = ^uint64(0)
}
now := msSinceEpoch()
g.Range(func(c *client) bool {
for _, down := range c.down {
if down.remote == up {
for _, dt := range down.tracks {
ms := atomic.LoadUint64(
&dt.maxBitrate.timestamp,
)
bitrate := atomic.LoadUint64(
&dt.maxBitrate.bitrate,
)
if bitrate == 0 {
continue
}
if now-ms > 5000 {
continue
}
for i, p := range up.pairs {
if p.local.SSRC() == dt.ssrc {
if p.maxBitrate > bitrate {
up.pairs[i].maxBitrate = bitrate
break
}
}
}
}
for _, track := range up.tracks {
local := track.getLocal()
for _, l := range local {
ms := atomic.LoadUint64(&l.maxBitrate.timestamp)
bitrate := atomic.LoadUint64(&l.maxBitrate.bitrate)
if now-ms > 5000 || bitrate == 0 {
continue
}
if track.maxBitrate > bitrate {
track.maxBitrate = bitrate
}
}
return true
})
}
}
func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
......@@ -779,18 +776,18 @@ func clientLoop(c *client, conn *websocket.Conn) error {
case addTrackAction:
down, _, err :=
addDownTrack(
c, a.id, a.track,
c, a.remote.id, a.track,
a.remote)
if err != nil {
return err
}
if a.done {
err = negotiate(c, a.id, down.pc)
err = negotiate(c, a.remote.id, down.pc)
if err != nil {
return err
}
}
case delPCAction:
case delConnAction:
c.write(clientMessage{
Type: "close",
Id: a.id,
......@@ -805,11 +802,10 @@ func clientLoop(c *client, conn *websocket.Conn) error {
case pushTracksAction:
for _, u := range c.up {
var done bool
for i, p := range u.pairs {
for i, t := range u.tracks {
done = i >= u.trackCount-1
a.c.action(addTrackAction{
u.id, p.local, u,
done,
t, u, done,
})
}
if done && u.label != "" {
......@@ -931,22 +927,35 @@ func handleClientMessage(c *client, m clientMessage) error {
}
func sendRateUpdate(c *client) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
c.mu.Lock()
for _, u := range c.up {
updateUpBitrate(c.group, u)
for _, p := range u.pairs {
bitrate := p.maxBitrate
updateUpBitrate(u)
for _, t := range u.tracks {
bitrate := t.maxBitrate
if bitrate != ^uint64(0) {
if bitrate < 6000 {
bitrate = 6000
}
err := sendREMB(u.pc, p.remote.SSRC(),
uint64(bitrate))
if err != nil {
log.Printf("sendREMB: %v", err)
}
rembs = append(rembs,
remb{u.pc, t.track.SSRC(), bitrate})
}
}
}
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {
......
......@@ -17,9 +17,38 @@ import (
"github.com/pion/webrtc/v2"
)
type trackPair struct {
remote, local *webrtc.Track
maxBitrate uint64
type upTrack struct {
track *webrtc.Track
maxBitrate uint64
mu sync.Mutex
local []*downTrack
}
func (up *upTrack) addLocal(local *downTrack) {
up.mu.Lock()
defer up.mu.Unlock()
up.local = append(up.local, local)
}
func (up *upTrack) delLocal(local *downTrack) bool {
up.mu.Lock()
defer up.mu.Unlock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
return true
}
}
return false
}
func (up *upTrack) getLocal() []*downTrack {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]*downTrack, len(up.local))
copy(local, up.local)
return local
}
type upConnection struct {
......@@ -27,7 +56,7 @@ type upConnection struct {
label string
pc *webrtc.PeerConnection
trackCount int
pairs []trackPair
tracks []*upTrack
}
type timeStampedBitrate struct {
......@@ -35,7 +64,8 @@ type timeStampedBitrate struct {
timestamp uint64
}
type downTrack struct {
ssrc uint32
track *webrtc.Track
remote *upTrack
maxBitrate *timeStampedBitrate
}
......@@ -43,7 +73,7 @@ type downConnection struct {
id string
pc *webrtc.PeerConnection
remote *upConnection
tracks []downTrack
tracks []*downTrack
}
type client struct {
......@@ -55,8 +85,10 @@ type client struct {
writeCh chan interface{}
writerDone chan struct{}
actionCh chan interface{}
down map[string]*downConnection
up map[string]*upConnection
mu sync.Mutex
down map[string]*downConnection
up map[string]*upConnection
}
type chatHistoryEntry struct {
......@@ -76,13 +108,12 @@ type group struct {
history []chatHistoryEntry
}
type delPCAction struct {
type delConnAction struct {
id string
}
type addTrackAction struct {
id string
track *webrtc.Track
track *upTrack
remote *upConnection
done bool
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment