Commit a6c8f8e0 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Handle NACKs locally.

parent d23cac10
......@@ -16,6 +16,8 @@ import (
"sync/atomic"
"time"
"sfu/packetlist"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/rtp"
......@@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) {
c.mu.Unlock()
return
}
list := packetlist.New(32)
track := &upTrack{
track: remote,
list: list,
maxBitrate: ^uint64(0),
}
u.tracks = append(u.tracks, track)
......@@ -286,7 +290,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
}
go func() {
buf := make([]byte, 1500)
buf := make([]byte, packetlist.BufSize)
var packet rtp.Packet
var local []*downTrack
var localTime time.Time
......@@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
continue
}
list.Store(packet.SequenceNumber, buf[:i])
for _, l := range local {
if l.muted() {
continue
......@@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT
uint64(ms),
)
case *rtcp.ReceiverReport:
case *rtcp.TransportLayerNack:
sendRecovery(p, track)
default:
log.Printf("RTCP: %T", p)
}
......@@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
})
}
func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
var packet rtp.Packet
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
raw := track.remote.list.Get(seqno)
if raw != nil {
err := packet.Unmarshal(raw)
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("%v", err)
}
}
}
}
}
func countMediaStreams(data string) (int, error) {
desc := sdp.NewJSEPSessionDescription(false)
err := desc.Unmarshal(data)
......
......@@ -15,11 +15,14 @@ import (
"sync/atomic"
"time"
"sfu/packetlist"
"github.com/pion/webrtc/v2"
)
type upTrack struct {
track *webrtc.Track
list *packetlist.List
maxBitrate uint64
mu sync.Mutex
......@@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
webrtc.DefaultPayloadTypeVP8, 90000,
[]webrtc.RTCPFeedback{
{"goog-remb", ""},
{"nack", ""},
{"nack", "pli"},
},
"",
......
package packetlist
import (
"sync"
)
const BufSize = 1500
type entry struct {
seqno uint16
length int
buf [BufSize]byte
}
type List struct {
mu sync.Mutex
tail int
entries []entry
}
func New(capacity int) *List {
return &List{
entries: make([]entry, capacity),
}
}
func (list *List) Store(seqno uint16, buf []byte) {
list.mu.Lock()
defer list.mu.Unlock()
list.entries[list.tail].seqno = seqno
copy(list.entries[list.tail].buf[:], buf)
list.entries[list.tail].length = len(buf)
list.tail = (list.tail + 1) % len(list.entries)
}
func (list *List) Get(seqno uint16) []byte {
list.mu.Lock()
defer list.mu.Unlock()
for i := range list.entries {
if list.entries[i].length == 0 ||
list.entries[i].seqno != seqno {
continue
}
buf := make([]byte, list.entries[i].length)
copy(buf, list.entries[i].buf[:])
return buf
}
return nil
}
package packetlist
import (
"bytes"
"math/rand"
"testing"
)
func randomBuf() []byte {
length := rand.Int31n(BufSize-1) + 1
buf := make([]byte, length)
rand.Read(buf)
return buf
}
func TestList(t *testing.T) {
buf1 := randomBuf()
buf2 := randomBuf()
list := New(16)
list.Store(13, buf1)
list.Store(17, buf2)
if bytes.Compare(list.Get(13), buf1) != 0 {
t.Errorf("Couldn't get 13")
}
if bytes.Compare(list.Get(17), buf2) != 0 {
t.Errorf("Couldn't get 17")
}
if list.Get(42) != nil {
t.Errorf("Creation ex nihilo")
}
}
func TestOverflow(t *testing.T) {
list := New(16)
for i := 0; i < 32; i++ {
list.Store(uint16(i), []byte{uint8(i)})
}
for i := 0; i < 32; i++ {
buf := list.Get(uint16(i))
if i < 16 {
if buf != nil {
t.Errorf("Creation ex nihilo: %v", i)
}
} else {
L if len(buf) != 1 || buf[0] != uint8(i) {
t.Errorf("Expected [%v], got %v", i, buf)
}
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment