Commit a50e9c67 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Buffer last keyframe.

parent bbd5ce0c
......@@ -5,6 +5,7 @@ import (
)
const BufSize = 1500
const maxKeyframe = 1024
type entry struct {
seqno uint16
......@@ -24,6 +25,9 @@ type Cache struct {
// bitmap
first uint16
bitmap uint32
// buffered keyframe
kfTimestamp uint32
kfEntries []entry
// packet cache
tail uint16
entries []entry
......@@ -75,7 +79,7 @@ func (cache *Cache) set(seqno uint16) {
}
// Store a packet, setting bitmap at the same time
func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) {
func (cache *Cache) Store(seqno uint16, timestamp uint32, keyframe bool, buf []byte) (uint16, uint16) {
cache.mu.Lock()
defer cache.mu.Unlock()
......@@ -97,9 +101,39 @@ func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) {
}
}
}
cache.set(seqno)
doit := false
if keyframe {
if cache.kfTimestamp != timestamp {
cache.kfTimestamp = timestamp
cache.kfEntries = cache.kfEntries[:0]
}
doit = true
} else if len(cache.kfEntries) > 0 {
doit = cache.kfTimestamp == timestamp
}
if doit {
i := 0
for i < len(cache.kfEntries) {
if cache.kfEntries[i].seqno >= seqno {
break
}
i++
}
if i >= len(cache.kfEntries) || cache.kfEntries[i].seqno != seqno {
if len(cache.kfEntries) >= maxKeyframe {
cache.kfEntries = cache.kfEntries[:maxKeyframe-1]
}
cache.kfEntries = append(cache.kfEntries, entry{})
copy(cache.kfEntries[i+1:], cache.kfEntries[i:])
}
cache.kfEntries[i].seqno = seqno
cache.kfEntries[i].length = uint16(len(buf))
copy(cache.kfEntries[i].buf[:], buf)
}
i := cache.tail
cache.entries[i].seqno = seqno
copy(cache.entries[i].buf[:], buf)
......@@ -118,23 +152,36 @@ func (cache *Cache) Expect(n int) {
cache.expected += uint32(n)
}
func (cache *Cache) Get(seqno uint16, result []byte) uint16 {
cache.mu.Lock()
defer cache.mu.Unlock()
for i := range cache.entries {
if cache.entries[i].length == 0 ||
cache.entries[i].seqno != seqno {
func get(seqno uint16, entries []entry, result []byte) uint16 {
for i := range entries {
if entries[i].length == 0 || entries[i].seqno != seqno {
continue
}
return uint16(copy(
result[:cache.entries[i].length],
cache.entries[i].buf[:]),
result[:entries[i].length],
entries[i].buf[:]),
)
}
return 0
}
func (cache *Cache) Get(seqno uint16, result []byte) uint16 {
cache.mu.Lock()
defer cache.mu.Unlock()
n := get(seqno, cache.kfEntries, result)
if n > 0 {
return n
}
n = get(seqno, cache.entries, result)
if n > 0 {
return n
}
return 0
}
func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 {
cache.mu.Lock()
defer cache.mu.Unlock()
......@@ -151,6 +198,17 @@ func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 {
)
}
func (cache *Cache) Keyframe() (uint32, []uint16) {
cache.mu.Lock()
defer cache.mu.Unlock()
seqnos := make([]uint16, len(cache.kfEntries))
for i := range cache.kfEntries {
seqnos[i] = cache.kfEntries[i].seqno
}
return cache.kfTimestamp, seqnos
}
func (cache *Cache) resize(capacity int) {
if len(cache.entries) == capacity {
return
......
......@@ -20,8 +20,8 @@ func TestCache(t *testing.T) {
buf1 := randomBuf()
buf2 := randomBuf()
cache := New(16)
_, i1 := cache.Store(13, buf1)
_, i2 := cache.Store(17, buf2)
_, i1 := cache.Store(13, 0, false, buf1)
_, i2 := cache.Store(17, 0, false, buf2)
buf := make([]byte, BufSize)
......@@ -62,7 +62,7 @@ func TestCacheOverflow(t *testing.T) {
cache := New(16)
for i := 0; i < 32; i++ {
cache.Store(uint16(i), []byte{uint8(i)})
cache.Store(uint16(i), 0, false, []byte{uint8(i)})
}
for i := 0; i < 32; i++ {
......@@ -84,7 +84,7 @@ func TestCacheGrow(t *testing.T) {
cache := New(16)
for i := 0; i < 24; i++ {
cache.Store(uint16(i), []byte{uint8(i)})
cache.Store(uint16(i), 0, false, []byte{uint8(i)})
}
cache.Resize(32)
......@@ -107,7 +107,7 @@ func TestCacheShrink(t *testing.T) {
cache := New(16)
for i := 0; i < 24; i++ {
cache.Store(uint16(i), []byte{uint8(i)})
cache.Store(uint16(i), 0, false, []byte{uint8(i)})
}
cache.Resize(12)
......@@ -150,6 +150,65 @@ func TestCacheGrowCond(t *testing.T) {
}
}
func TestKeyframe(t *testing.T) {
cache := New(16)
packet := make([]byte, 1)
buf := make([]byte, BufSize)
cache.Store(7, 57, true, packet)
cache.Store(8, 57, true, packet)
ts, kf := cache.Keyframe()
if ts != 57 || len(kf) != 2 {
t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2)
}
for _, i := range kf {
l := cache.Get(i, buf)
if int(l) != len(packet) {
t.Errorf("Couldn't get %v", i)
}
}
for i := 0; i < 32; i++ {
cache.Store(uint16(9 + i), uint32(58 + i), false, packet)
}
ts, kf = cache.Keyframe()
if ts != 57 || len(kf) != 2 {
t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2)
}
for _, i := range kf {
l := cache.Get(i, buf)
if int(l) != len(packet) {
t.Errorf("Couldn't get %v", i)
}
}
}
func TestKeyframeUnsorted(t *testing.T) {
cache := New(16)
packet := make([]byte, 1)
cache.Store(7, 57, true, packet)
cache.Store(9, 57, true, packet)
cache.Store(8, 57, true, packet)
cache.Store(10, 57, true, packet)
cache.Store(6, 57, true, packet)
cache.Store(8, 57, true, packet)
_, kf := cache.Keyframe()
if len(kf) != 5 {
t.Errorf("Got length %v, expected 5", len(kf))
}
for i, v := range kf {
if v != uint16(i + 6) {
t.Errorf("Position %v, expected %v, got %v\n",
i, i + 6, v)
}
}
}
func TestBitmap(t *testing.T) {
value := uint64(0xcdd58f1e035379c0)
packet := make([]byte, 1)
......@@ -159,7 +218,7 @@ func TestBitmap(t *testing.T) {
var first uint16
for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 {
first, _ = cache.Store(uint16(42+i), packet)
first, _ = cache.Store(uint16(42+i), 0, false, packet)
}
}
......@@ -175,13 +234,13 @@ func TestBitmapWrap(t *testing.T) {
cache := New(16)
cache.Store(0x7000, packet)
cache.Store(0xA000, packet)
cache.Store(0x7000, 0, false, packet)
cache.Store(0xA000, 0, false, packet)
var first uint16
for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 {
first, _ = cache.Store(uint16(42+i), packet)
first, _ = cache.Store(uint16(42+i), 0, false, packet)
}
}
......@@ -199,7 +258,7 @@ func TestBitmapGet(t *testing.T) {
for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 {
cache.Store(uint16(42+i), packet)
cache.Store(uint16(42+i), 0, false, packet)
}
}
......@@ -241,7 +300,7 @@ func TestBitmapPacket(t *testing.T) {
for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 {
cache.Store(uint16(42+i), packet)
cache.Store(uint16(42+i), 0, false, packet)
}
}
......@@ -299,7 +358,7 @@ func BenchmarkCachePutGet(b *testing.B) {
for i := 0; i < b.N; i++ {
seqno := uint16(i)
cache.Store(seqno, buf)
cache.Store(seqno, 0, false, buf)
for _, ch := range chans {
ch <- seqno
}
......@@ -350,7 +409,7 @@ func BenchmarkCachePutGetAt(b *testing.B) {
for i := 0; i < b.N; i++ {
seqno := uint16(i)
_, index := cache.Store(seqno, buf)
_, index := cache.Store(seqno, 0, false, buf)
for _, ch := range chans {
ch <- is{index, seqno}
}
......
......@@ -5,12 +5,24 @@ import (
"log"
"github.com/pion/rtp"
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3"
"sfu/packetcache"
"sfu/rtptime"
)
func isVP8Keyframe(packet *rtp.Packet) bool {
var vp8 codecs.VP8Packet
_, err := vp8.Unmarshal(packet.Payload)
if err != nil {
return false
}
return vp8.S != 0 && vp8.PID == 0 &&
len(vp8.Payload) > 0 && (vp8.Payload[0]&0x1) == 0
}
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
writers := rtpWriterPool{conn: conn, track: track}
defer func() {
......@@ -19,6 +31,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
}()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
codec := track.track.Codec().Name
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
for {
......@@ -39,8 +52,14 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
track.jitter.Accumulate(packet.Timestamp)
first, index :=
track.cache.Store(packet.SequenceNumber, buf[:bytes])
kf := false
if isvideo && codec == webrtc.VP8 {
kf = isVP8Keyframe(&packet)
}
first, index := track.cache.Store(
packet.SequenceNumber, packet.Timestamp, kf, buf[:bytes],
)
if packet.SequenceNumber-first > 24 {
found, first, bitmap := track.cache.BitmapGet()
if found {
......
......@@ -138,7 +138,7 @@ func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo
continue
}
// audio, try again with a delay
d := delay/uint32(2*len(wp.writers))
d := delay / uint32(2*len(wp.writers))
timer := time.NewTimer(rtptime.ToDuration(
uint64(d), rtptime.JiffiesPerSec,
))
......@@ -208,6 +208,31 @@ func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
}
}
func sendKeyframe(track conn.DownTrack, cache *packetcache.Cache) {
_, kf := cache.Keyframe()
if len(kf) == 0 {
return
}
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
for _, seqno := range kf {
bytes := cache.Get(seqno, buf)
if(bytes == 0) {
return
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
return
}
err = track.WriteRTP(&packet)
if err != nil && err != conn.ErrKeyframeNeeded {
return
}
track.Accumulate(uint32(bytes))
}
}
// rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done)
......@@ -245,6 +270,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
if cname != "" {
action.track.SetCname(cname)
}
go sendKeyframe(action.track, track.cache)
} else {
found := false
for i, t := range local {
......@@ -286,8 +312,9 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
if err != nil {
if err == conn.ErrKeyframeNeeded {
kfNeeded = true
} else {
continue
}
continue
}
l.Accumulate(uint32(bytes))
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment