Commit 5dd27e50 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Implement rate estimation.

parent 10526d47
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/estimator"
"sfu/packetcache" "sfu/packetcache"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
track := &upTrack{ track := &upTrack{
track: remote, track: remote,
cache: packetcache.New(96), cache: packetcache.New(96),
rate: estimator.New(time.Second),
maxBitrate: ^uint64(0), maxBitrate: ^uint64(0),
} }
u.tracks = append(u.tracks, track) u.tracks = append(u.tracks, track)
...@@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) { ...@@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) {
localTime = now localTime = now
} }
i, err := track.track.Read(buf) bytes, err := track.track.Read(buf)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Printf("%v", err) log.Printf("%v", err)
} }
break break
} }
track.rate.Add(uint32(bytes))
err = packet.Unmarshal(buf[:i]) err = packet.Unmarshal(buf[:bytes])
if err != nil { if err != nil {
log.Printf("%v", err) log.Printf("%v", err)
continue continue
} }
first := track.cache.Store(packet.SequenceNumber, buf[:i]) first := track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 { if packet.SequenceNumber-first > 24 {
first, bitmap := track.cache.BitmapGet() first, bitmap := track.cache.BitmapGet()
if bitmap != ^uint16(0) { if bitmap != ^uint16(0) {
...@@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) { ...@@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) {
if err != nil && err != io.ErrClosedPipe { if err != nil && err != io.ErrClosedPipe {
log.Printf("%v", err) log.Printf("%v", err)
} }
l.rate.Add(uint32(bytes))
} }
} }
} }
...@@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn ...@@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
track: local, track: local,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(timeStampedBitrate), maxBitrate: new(timeStampedBitrate),
rate: estimator.New(time.Second),
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
remoteTrack.addLocal(track) remoteTrack.addLocal(track)
...@@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) { ...@@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
if err != nil { if err != nil {
log.Printf("%v", err) log.Printf("%v", err)
} }
track.rate.Add(uint32(len(raw)))
} }
} }
} }
......
package estimator
import (
"sync"
"sync/atomic"
"time"
)
type Estimator struct {
interval time.Duration
count uint32
mu sync.Mutex
rate uint32
time time.Time
}
func New(interval time.Duration) *Estimator {
return &Estimator{
interval: interval,
time: time.Now(),
}
}
func (e *Estimator) swap(now time.Time) {
interval := now.Sub(e.time)
count := atomic.SwapUint32(&e.count, 0)
if interval < time.Millisecond {
e.rate = 0
} else {
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond))
}
e.time = now
}
func (e *Estimator) Add(count uint32) {
atomic.AddUint32(&e.count, count)
}
func (e *Estimator) estimate(now time.Time) uint32 {
if now.Sub(e.time) > e.interval {
e.swap(now)
}
return e.rate
}
func (e *Estimator) Estimate() uint32 {
now := time.Now()
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(now)
}
package estimator
import (
"testing"
"time"
)
func TestEstimator(t *testing.T) {
now := time.Now()
e := New(time.Second)
e.estimate(now)
e.Add(42)
e.Add(128)
e.estimate(now.Add(time.Second))
rate := e.estimate(now.Add(time.Second + time.Millisecond))
if rate != 42+128 {
t.Errorf("Expected %v, got %v", 42+128, rate)
}
}
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/estimator"
"sfu/packetcache" "sfu/packetcache"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
...@@ -23,6 +24,7 @@ import ( ...@@ -23,6 +24,7 @@ import (
type upTrack struct { type upTrack struct {
track *webrtc.Track track *webrtc.Track
rate *estimator.Estimator
cache *packetcache.Cache cache *packetcache.Cache
maxBitrate uint64 maxBitrate uint64
lastPLI uint64 lastPLI uint64
...@@ -76,6 +78,7 @@ type downTrack struct { ...@@ -76,6 +78,7 @@ type downTrack struct {
remote *upTrack remote *upTrack
isMuted uint32 isMuted uint32
maxBitrate *timeStampedBitrate maxBitrate *timeStampedBitrate
rate *estimator.Estimator
loss uint32 loss uint32
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment