Commit c441b49d authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Send rate updates over RTCP.

parent 98034c0f
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"math"
"os" "os"
"strings" "strings"
"sync" "sync"
...@@ -98,8 +97,6 @@ type clientMessage struct { ...@@ -98,8 +97,6 @@ type clientMessage struct {
Answer *webrtc.SessionDescription `json:"answer,omitempty"` Answer *webrtc.SessionDescription `json:"answer,omitempty"`
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"` Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
Del bool `json:"del,omitempty"` Del bool `json:"del,omitempty"`
AudioRate int `json:"audiorate,omitempty"`
VideoRate int `json:"videorate,omitempty"`
} }
type closeMessage struct { type closeMessage struct {
...@@ -284,6 +281,7 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -284,6 +281,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
u.pairs = append(u.pairs, trackPair{ u.pairs = append(u.pairs, trackPair{
remote: remote, remote: remote,
local: local, local: local,
maxBitrate: ^uint64(0),
}) })
done := len(u.pairs) >= u.trackCount done := len(u.pairs) >= u.trackCount
c.group.mu.Unlock() c.group.mu.Unlock()
...@@ -442,12 +440,24 @@ func addDownTrack(c *client, id string, track *webrtc.Track, remote *upConnectio ...@@ -442,12 +440,24 @@ func addDownTrack(c *client, id string, track *webrtc.Track, remote *upConnectio
return nil, nil, err return nil, nil, err
} }
go rtcpListener(c.group, conn, s) conn.tracks = append(conn.tracks,
downTrack{track.SSRC(), new(timeStampedBitrate)},
)
go rtcpListener(c.group, conn, s,
conn.tracks[len(conn.tracks)-1].maxBitrate)
return conn, s, nil return conn, s, nil
} }
func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender) { var epoch = time.Now()
func msSinceEpoch() uint64 {
return uint64(time.Since(epoch) / time.Millisecond)
}
func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender,
bitrate *timeStampedBitrate) {
for { for {
ps, err := s.ReadRTCP() ps, err := s.ReadRTCP()
if err != nil { if err != nil {
...@@ -460,16 +470,23 @@ func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender) { ...@@ -460,16 +470,23 @@ func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender) {
for _, p := range ps { for _, p := range ps {
switch p := p.(type) { switch p := p.(type) {
case *rtcp.PictureLossIndication: case *rtcp.PictureLossIndication:
err := sendPLI(c.remote, p.MediaSSRC) err := sendPLI(c.remote.pc, p.MediaSSRC)
if err != nil { if err != nil {
log.Printf("sendPLI: %v", err) log.Printf("sendPLI: %v", err)
} }
case *rtcp.ReceiverEstimatedMaximumBitrate: case *rtcp.ReceiverEstimatedMaximumBitrate:
bitrate := uint32(math.MaxInt32) ms := msSinceEpoch()
if p.Bitrate < math.MaxInt32 { // this is racy -- a reader might read the
bitrate = uint32(p.Bitrate) // data between the two writes. This shouldn't
} // matter, we'll recover at the next sample.
atomic.StoreUint32(&c.maxBitrate, bitrate) atomic.StoreUint64(
&bitrate.bitrate,
p.Bitrate,
)
atomic.StoreUint64(
&bitrate.timestamp,
uint64(ms),
)
case *rtcp.ReceiverReport: case *rtcp.ReceiverReport:
default: default:
log.Printf("RTCP: %T", p) log.Printf("RTCP: %T", p)
...@@ -520,42 +537,61 @@ func splitBitrate(bitrate uint32, audio, video bool) (uint32, uint32) { ...@@ -520,42 +537,61 @@ func splitBitrate(bitrate uint32, audio, video bool) (uint32, uint32) {
return audioRate, bitrate - audioRate return audioRate, bitrate - audioRate
} }
func updateBitrate(g *group, up *upConnection) (uint32, uint32) { func updateUpBitrate(g *group, up *upConnection) {
audio := uint32(math.MaxInt32) for i := range up.pairs {
video := uint32(math.MaxInt32) up.pairs[i].maxBitrate = ^uint64(0)
}
now := msSinceEpoch()
g.Range(func(c *client) bool { g.Range(func(c *client) bool {
for _, down := range c.down { for _, down := range c.down {
if down.remote == up { if down.remote == up {
bitrate := atomic.LoadUint32(&down.maxBitrate) for _, dt := range down.tracks {
ms := atomic.LoadUint64(
&dt.maxBitrate.timestamp,
)
bitrate := atomic.LoadUint64(
&dt.maxBitrate.bitrate,
)
if bitrate == 0 { if bitrate == 0 {
bitrate = 256000 continue
} else if bitrate < 6000 { }
bitrate = 6000
if now - ms > 5000 {
continue
}
for i, p := range up.pairs {
if p.local.SSRC() == dt.ssrc {
if p.maxBitrate > bitrate {
up.pairs[i].maxBitrate = bitrate
break
}
} }
hasAudio, hasVideo := trackKinds(down)
a, v := splitBitrate(bitrate, hasAudio, hasVideo)
if a < audio {
audio = a
} }
if v < video {
video = v
} }
} }
} }
return true return true
}) })
up.maxAudioBitrate = audio
up.maxVideoBitrate = video
return audio, video
} }
func sendPLI(up *upConnection, ssrc uint32) error { func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
// we use equal SSRC values on both sides return pc.WriteRTCP([]rtcp.Packet{
return up.pc.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{MediaSSRC: ssrc}, &rtcp.PictureLossIndication{MediaSSRC: ssrc},
}) })
} }
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func countMediaStreams(data string) (int, error) { func countMediaStreams(data string) (int, error) {
desc := sdp.NewJSEPSessionDescription(false) desc := sdp.NewJSEPSessionDescription(false)
err := desc.Unmarshal(data) err := desc.Unmarshal(data)
...@@ -709,7 +745,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { ...@@ -709,7 +745,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
readTime := time.Now() readTime := time.Now()
ticker := time.NewTicker(2 * time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
slowTicker := time.NewTicker(10 * time.Second) slowTicker := time.NewTicker(10 * time.Second)
defer slowTicker.Stop() defer slowTicker.Stop()
...@@ -888,16 +924,19 @@ func handleClientMessage(c *client, m clientMessage) error { ...@@ -888,16 +924,19 @@ func handleClientMessage(c *client, m clientMessage) error {
func sendRateUpdate(c *client) { func sendRateUpdate(c *client) {
for _, u := range c.up { for _, u := range c.up {
oldaudio := u.maxAudioBitrate updateUpBitrate(c.group, u)
oldvideo := u.maxVideoBitrate for _, p := range u.pairs {
audio, video := updateBitrate(c.group, u) bitrate := p.maxBitrate
if audio != oldaudio || video != oldvideo { if bitrate != ^uint64(0) {
c.write(clientMessage{ if bitrate < 6000 {
Type: "maxbitrate", bitrate = 6000
Id: u.id, }
AudioRate: int(audio), err := sendREMB(u.pc, p.remote.SSRC(),
VideoRate: int(video), uint64(bitrate))
}) if err != nil {
log.Printf("sendREMB: %v", err)
}
}
} }
} }
} }
......
...@@ -19,23 +19,31 @@ import ( ...@@ -19,23 +19,31 @@ import (
type trackPair struct { type trackPair struct {
remote, local *webrtc.Track remote, local *webrtc.Track
maxBitrate uint64
} }
type upConnection struct { type upConnection struct {
id string id string
label string label string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
maxAudioBitrate uint32
maxVideoBitrate uint32
trackCount int trackCount int
pairs []trackPair pairs []trackPair
} }
type timeStampedBitrate struct {
bitrate uint64
timestamp uint64
}
type downTrack struct {
ssrc uint32
maxBitrate *timeStampedBitrate
}
type downConnection struct { type downConnection struct {
id string id string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
remote *upConnection remote *upConnection
maxBitrate uint32 tracks []downTrack
} }
type client struct { type client struct {
......
...@@ -323,9 +323,6 @@ function serverConnect() { ...@@ -323,9 +323,6 @@ function serverConnect() {
case 'ice': case 'ice':
gotICE(m.id, m.candidate); gotICE(m.id, m.candidate);
break; break;
case 'maxbitrate':
setMaxBitrate(m.id, m.audiorate, m.videorate);
break;
case 'label': case 'label':
gotLabel(m.id, m.value); gotLabel(m.id, m.value);
break; break;
...@@ -450,49 +447,6 @@ async function gotICE(id, candidate) { ...@@ -450,49 +447,6 @@ async function gotICE(id, candidate) {
conn.iceCandidates.push(candidate) conn.iceCandidates.push(candidate)
} }
let maxaudiorate, maxvideorate;
async function setMaxBitrate(id, audio, video) {
let conn = up[id];
if(!conn)
throw new Error("Setting bitrate of unknown id");
let senders = conn.pc.getSenders();
for(let i = 0; i < senders.length; i++) {
let s = senders[i];
if(!s.track)
return;
let p = s.getParameters();
let bitrate;
if(s.track.kind == 'audio')
bitrate = audio;
else if(s.track.kind == 'video')
bitrate = video;
for(let j = 0; j < p.encodings.length; j++) {
let e = p.encodings[j];
if(bitrate)
e.maxBitrate = bitrate;
else
delete(e.maxBitrate);
await s.setParameters(p);
}
}
if((audio && audio < 128000) || (video && video < 256000)) {
let l = '';
if(audio)
l = `${Math.round(audio/1000)}kbps`
if(video) {
if(l)
l = l + ' + ';
l = l + `${Math.round(video/1000)}kbps`
}
setLabel(id, l)
} else {
setLabel(id);
}
}
async function addIceCandidates(conn) { async function addIceCandidates(conn) {
let promises = [] let promises = []
conn.iceCandidates.forEach(c => { conn.iceCandidates.forEach(c => {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment