Commit da97560c authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Protect upConn.tracks by the upConn mutex rather than the client mutex.

Also don't rely on tracks being immutable in sendRR.
parent 8ba50bd2
......@@ -107,17 +107,25 @@ type upConnection struct {
id string
label string
pc *webrtc.PeerConnection
tracks []*upTrack
labels map[string]string
iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex
closed bool
tracks []*upTrack
local []downConnection
}
var ErrConnectionClosed = errors.New("connection is closed")
func (up *upConnection) getTracks() []*upTrack {
up.mu.Lock()
defer up.mu.Unlock()
tracks := make([]*upTrack, len(up.tracks))
copy(tracks, up.tracks)
return tracks
}
func (up *upConnection) addLocal(local downConnection) error {
up.mu.Lock()
defer up.mu.Unlock()
......@@ -206,6 +214,7 @@ func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
return ""
}
// called locked
func (up *upConnection) complete() bool {
for mid, _ := range up.labels {
found := false
......
......@@ -565,7 +565,8 @@ func getClientStats(c *webClient) clientStats {
for _, up := range c.up {
conns := connStats{id: up.id}
for _, t := range up.tracks {
tracks := up.getTracks()
for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false)
if expected == 0 {
expected = 1
......
......@@ -282,12 +282,12 @@ func getUpConn(c *webClient, id string) *upConnection {
return conn
}
func getUpConns(c *webClient) []string {
func getUpConns(c *webClient) []*upConnection {
c.mu.Lock()
defer c.mu.Unlock()
up := make([]string, 0, len(c.up))
for id := range c.up {
up = append(up, id)
up := make([]*upConnection, 0, len(c.up))
for _, u := range c.up {
up = append(up, u)
}
return up
}
......@@ -337,22 +337,16 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
})
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
c.mu.Lock()
u, ok := c.up[id]
if !ok {
log.Printf("Unknown connection")
c.mu.Unlock()
return
}
conn.mu.Lock()
defer conn.mu.Unlock()
mid := getUpMid(pc, remote)
if mid == "" {
log.Printf("Couldn't get track's mid")
c.mu.Unlock()
return
}
label, ok := u.labels[mid]
label, ok := conn.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
......@@ -373,25 +367,24 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
u.tracks = append(u.tracks, track)
var tracks []*upTrack
if u.complete() {
tracks = make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
}
conn.tracks = append(conn.tracks, track)
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1)
}
c.mu.Unlock()
go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver)
if tracks != nil {
if conn.complete() {
// cannot call getTracks, we're locked
tracks := make([]*upTrack, len(conn.tracks))
copy(tracks, conn.tracks)
clients := c.group.getClients(c)
for _, cc := range clients {
cc.pushConn(u, tracks, u.label)
cc.pushConn(conn, tracks, conn.label)
}
go rtcpUpSender(conn)
}
......@@ -573,7 +566,7 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
}
}
if(firstSR) {
if firstSR {
// this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream
local := conn.getLocal()
......@@ -591,6 +584,9 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
}
func sendRR(conn *upConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 {
return nil
}
......@@ -650,6 +646,8 @@ func rtcpUpSender(conn *upConnection) {
}
func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now()
......@@ -716,17 +714,19 @@ func rtcpDownSender(conn *rtpDownConnection) {
func delUpConn(c *webClient, id string) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
c.mu.Unlock()
return false
}
conn := c.up[id]
if conn == nil {
c.mu.Unlock()
return false
}
delete(c.up, id)
c.mu.Unlock()
conn.mu.Lock()
for _, track := range conn.tracks {
if track.track.Kind() == webrtc.RTPCodecTypeVideo {
count := atomic.AddUint32(&c.group.videoCount,
......@@ -737,9 +737,9 @@ func delUpConn(c *webClient, id string) bool {
}
}
}
conn.mu.Unlock()
conn.Close()
delete(c.up, id)
return true
}
......@@ -1007,59 +1007,59 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
}
}
func updateUpTrack(up *upConnection, maxVideoRate uint64) {
func updateUpTrack(track *upTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies()
for _, track := range up.tracks {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
if bitrate <= minrate {
rate = minrate
break
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
if rate > bitrate {
rate = bitrate
}
if packets > 256 {
packets = 256
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
track.cache.ResizeCond(packets)
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
}
var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
......@@ -1468,8 +1468,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
})
case pushConnsAction:
for _, u := range c.up {
tracks := make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
tracks := u.getTracks()
go a.c.pushConn(u, tracks, u.label)
}
case connectionFailedAction:
......@@ -1490,12 +1489,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
Permissions: c.permissions,
})
if !c.permissions.Present {
ids := getUpConns(c)
for _, id := range ids {
found := delUpConn(c, id)
up := getUpConns(c)
for _, u := range up {
found := delUpConn(c, u.id)
if found {
failConnection(
c, id,
c, u.id,
"permission denied",
)
}
......@@ -1727,13 +1726,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
}
func sendRateUpdate(c *webClient) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
maxVideoRate := ^uint64(0)
count := atomic.LoadUint32(&c.group.videoCount)
if count >= 3 {
......@@ -1743,27 +1735,22 @@ func sendRateUpdate(c *webClient) {
}
}
c.mu.Lock()
for _, u := range c.up {
updateUpTrack(u, maxVideoRate)
for _, t := range u.tracks {
up := getUpConns(c)
for _, u := range up {
tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t, maxVideoRate)
if !t.hasRtcpFb("goog-remb", "") {
continue
}
bitrate := t.maxBitrate
if bitrate == ^uint64(0) {
if rate == ^uint64(0) {
continue
}
rembs = append(rembs,
remb{u.pc, t.track.SSRC(), bitrate})
}
}
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
err := sendREMB(u.pc, t.track.SSRC(), rate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment