Commit da97560c authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Protect upConn.tracks by the upConn mutex rather than the client mutex.

Also don't rely on tracks being immutable in sendRR.
parent 8ba50bd2
...@@ -107,17 +107,25 @@ type upConnection struct { ...@@ -107,17 +107,25 @@ type upConnection struct {
id string id string
label string label string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
tracks []*upTrack
labels map[string]string labels map[string]string
iceCandidates []*webrtc.ICECandidateInit iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex mu sync.Mutex
closed bool closed bool
tracks []*upTrack
local []downConnection local []downConnection
} }
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
func (up *upConnection) getTracks() []*upTrack {
up.mu.Lock()
defer up.mu.Unlock()
tracks := make([]*upTrack, len(up.tracks))
copy(tracks, up.tracks)
return tracks
}
func (up *upConnection) addLocal(local downConnection) error { func (up *upConnection) addLocal(local downConnection) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
...@@ -206,6 +214,7 @@ func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string { ...@@ -206,6 +214,7 @@ func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
return "" return ""
} }
// called locked
func (up *upConnection) complete() bool { func (up *upConnection) complete() bool {
for mid, _ := range up.labels { for mid, _ := range up.labels {
found := false found := false
......
...@@ -565,7 +565,8 @@ func getClientStats(c *webClient) clientStats { ...@@ -565,7 +565,8 @@ func getClientStats(c *webClient) clientStats {
for _, up := range c.up { for _, up := range c.up {
conns := connStats{id: up.id} conns := connStats{id: up.id}
for _, t := range up.tracks { tracks := up.getTracks()
for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false) expected, lost, _, _ := t.cache.GetStats(false)
if expected == 0 { if expected == 0 {
expected = 1 expected = 1
......
...@@ -282,12 +282,12 @@ func getUpConn(c *webClient, id string) *upConnection { ...@@ -282,12 +282,12 @@ func getUpConn(c *webClient, id string) *upConnection {
return conn return conn
} }
func getUpConns(c *webClient) []string { func getUpConns(c *webClient) []*upConnection {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
up := make([]string, 0, len(c.up)) up := make([]*upConnection, 0, len(c.up))
for id := range c.up { for _, u := range c.up {
up = append(up, id) up = append(up, u)
} }
return up return up
} }
...@@ -337,22 +337,16 @@ func addUpConn(c *webClient, id string) (*upConnection, error) { ...@@ -337,22 +337,16 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
}) })
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
c.mu.Lock() conn.mu.Lock()
u, ok := c.up[id] defer conn.mu.Unlock()
if !ok {
log.Printf("Unknown connection")
c.mu.Unlock()
return
}
mid := getUpMid(pc, remote) mid := getUpMid(pc, remote)
if mid == "" { if mid == "" {
log.Printf("Couldn't get track's mid") log.Printf("Couldn't get track's mid")
c.mu.Unlock()
return return
} }
label, ok := u.labels[mid] label, ok := conn.labels[mid]
if !ok { if !ok {
log.Printf("Couldn't get track's label") log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
...@@ -373,25 +367,24 @@ func addUpConn(c *webClient, id string) (*upConnection, error) { ...@@ -373,25 +367,24 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
localCh: make(chan localTrackAction, 2), localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
} }
u.tracks = append(u.tracks, track)
var tracks []*upTrack conn.tracks = append(conn.tracks, track)
if u.complete() {
tracks = make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
}
if remote.Kind() == webrtc.RTPCodecTypeVideo { if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1) atomic.AddUint32(&c.group.videoCount, 1)
} }
c.mu.Unlock()
go readLoop(conn, track) go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver) go rtcpUpListener(conn, track, receiver)
if tracks != nil { if conn.complete() {
// cannot call getTracks, we're locked
tracks := make([]*upTrack, len(conn.tracks))
copy(tracks, conn.tracks)
clients := c.group.getClients(c) clients := c.group.getClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.pushConn(u, tracks, u.label) cc.pushConn(conn, tracks, conn.label)
} }
go rtcpUpSender(conn) go rtcpUpSender(conn)
} }
...@@ -573,7 +566,7 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) { ...@@ -573,7 +566,7 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
} }
} }
if(firstSR) { if firstSR {
// this is the first SR we got for at least one track, // this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream // quickly propagate the time offsets downstream
local := conn.getLocal() local := conn.getLocal()
...@@ -591,6 +584,9 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) { ...@@ -591,6 +584,9 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
} }
func sendRR(conn *upConnection) error { func sendRR(conn *upConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 { if len(conn.tracks) == 0 {
return nil return nil
} }
...@@ -650,6 +646,8 @@ func rtcpUpSender(conn *upConnection) { ...@@ -650,6 +646,8 @@ func rtcpUpSender(conn *upConnection) {
} }
func sendSR(conn *rtpDownConnection) error { func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks)) packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now() now := time.Now()
...@@ -716,17 +714,19 @@ func rtcpDownSender(conn *rtpDownConnection) { ...@@ -716,17 +714,19 @@ func rtcpDownSender(conn *rtpDownConnection) {
func delUpConn(c *webClient, id string) bool { func delUpConn(c *webClient, id string) bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
c.mu.Unlock()
return false return false
} }
conn := c.up[id] conn := c.up[id]
if conn == nil { if conn == nil {
c.mu.Unlock()
return false return false
} }
delete(c.up, id)
c.mu.Unlock()
conn.mu.Lock()
for _, track := range conn.tracks { for _, track := range conn.tracks {
if track.track.Kind() == webrtc.RTPCodecTypeVideo { if track.track.Kind() == webrtc.RTPCodecTypeVideo {
count := atomic.AddUint32(&c.group.videoCount, count := atomic.AddUint32(&c.group.videoCount,
...@@ -737,9 +737,9 @@ func delUpConn(c *webClient, id string) bool { ...@@ -737,9 +737,9 @@ func delUpConn(c *webClient, id string) bool {
} }
} }
} }
conn.mu.Unlock()
conn.Close() conn.Close()
delete(c.up, id)
return true return true
} }
...@@ -1007,59 +1007,59 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint ...@@ -1007,59 +1007,59 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
} }
} }
func updateUpTrack(up *upConnection, maxVideoRate uint64) { func updateUpTrack(track *upTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies() now := rtptime.Jiffies()
for _, track := range up.tracks { isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo clockrate := track.track.Codec().ClockRate
clockrate := track.track.Codec().ClockRate minrate := uint64(minAudioRate)
minrate := uint64(minAudioRate) rate := ^uint64(0)
rate := ^uint64(0) if isvideo {
if isvideo { minrate = minVideoRate
minrate = minVideoRate rate = maxVideoRate
rate = maxVideoRate if rate < minrate {
if rate < minrate { rate = minrate
rate = minrate }
} }
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
} }
local := track.getLocal() if bitrate <= minrate {
var maxrto uint64 rate = minrate
for _, l := range local { break
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
} }
track.maxBitrate = rate if rate > bitrate {
_, r := track.rate.Estimate() rate = bitrate
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
} }
if packets > 256 { ll, ok := l.(*rtpDownTrack)
packets = 256 if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
} }
track.cache.ResizeCond(packets)
} }
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
} }
var ErrUnsupportedFeedback = errors.New("unsupported feedback type") var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
...@@ -1468,8 +1468,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -1468,8 +1468,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
}) })
case pushConnsAction: case pushConnsAction:
for _, u := range c.up { for _, u := range c.up {
tracks := make([]*upTrack, len(u.tracks)) tracks := u.getTracks()
copy(tracks, u.tracks)
go a.c.pushConn(u, tracks, u.label) go a.c.pushConn(u, tracks, u.label)
} }
case connectionFailedAction: case connectionFailedAction:
...@@ -1490,12 +1489,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -1490,12 +1489,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
Permissions: c.permissions, Permissions: c.permissions,
}) })
if !c.permissions.Present { if !c.permissions.Present {
ids := getUpConns(c) up := getUpConns(c)
for _, id := range ids { for _, u := range up {
found := delUpConn(c, id) found := delUpConn(c, u.id)
if found { if found {
failConnection( failConnection(
c, id, c, u.id,
"permission denied", "permission denied",
) )
} }
...@@ -1727,13 +1726,6 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1727,13 +1726,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
func sendRateUpdate(c *webClient) { func sendRateUpdate(c *webClient) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
maxVideoRate := ^uint64(0) maxVideoRate := ^uint64(0)
count := atomic.LoadUint32(&c.group.videoCount) count := atomic.LoadUint32(&c.group.videoCount)
if count >= 3 { if count >= 3 {
...@@ -1743,27 +1735,22 @@ func sendRateUpdate(c *webClient) { ...@@ -1743,27 +1735,22 @@ func sendRateUpdate(c *webClient) {
} }
} }
c.mu.Lock() up := getUpConns(c)
for _, u := range c.up {
updateUpTrack(u, maxVideoRate) for _, u := range up {
for _, t := range u.tracks { tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t, maxVideoRate)
if !t.hasRtcpFb("goog-remb", "") { if !t.hasRtcpFb("goog-remb", "") {
continue continue
} }
bitrate := t.maxBitrate if rate == ^uint64(0) {
if bitrate == ^uint64(0) {
continue continue
} }
rembs = append(rembs, err := sendREMB(u.pc, t.track.SSRC(), rate)
remb{u.pc, t.track.SSRC(), bitrate}) if err != nil {
} log.Printf("sendREMB: %v", err)
} }
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
} }
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment