diff --git a/client.go b/client.go index dfc8d1b..d631c24 100644 --- a/client.go +++ b/client.go @@ -37,7 +37,7 @@ type ClientOnDecodeErrorFunc func(err error) type ClientOnTracksFunc func([]*Track) error // ClientOnDataAV1Func is the prototype of the function passed to OnDataAV1(). -type ClientOnDataAV1Func func(pts time.Duration, obus [][]byte) +type ClientOnDataAV1Func func(pts time.Duration, tu [][]byte) // ClientOnDataVP9Func is the prototype of the function passed to OnDataVP9(). type ClientOnDataVP9Func func(pts time.Duration, frame []byte) @@ -73,6 +73,8 @@ type Client struct { // // callbacks (all optional) // + // called when tracks are available. + OnTracks ClientOnTracksFunc // called before downloading a primary playlist. OnDownloadPrimaryPlaylist ClientOnDownloadPrimaryPlaylistFunc // called before downloading a stream playlist. @@ -88,7 +90,6 @@ type Client struct { ctx context.Context ctxCancel func() - onTracks ClientOnTracksFunc onData map[*Track]interface{} playlistURL *url.URL @@ -101,6 +102,11 @@ func (c *Client) Start() error { if c.HTTPClient == nil { c.HTTPClient = http.DefaultClient } + if c.OnTracks == nil { + c.OnTracks = func(_ []*Track) error { + return nil + } + } if c.OnDownloadPrimaryPlaylist == nil { c.OnDownloadPrimaryPlaylist = func(_ string) {} } @@ -140,11 +146,6 @@ func (c *Client) Wait() chan error { return c.outErr } -// OnTracks sets a callback that is called when tracks are read. -func (c *Client) OnTracks(cb ClientOnTracksFunc) { - c.onTracks = cb -} - // OnDataAV1 sets a callback that is called when data from an AV1 track is received. func (c *Client) OnDataAV1(forma *Track, cb ClientOnDataAV1Func) { c.onData[forma] = cb @@ -185,7 +186,7 @@ func (c *Client) runInner() error { c.OnDownloadSegment, c.OnDecodeError, rp, - c.onTracks, + c.OnTracks, c.onData, ) rp.add(dl) diff --git a/client_processor_fmp4.go b/client_processor_fmp4.go index 02d52c5..a39af45 100644 --- a/client_processor_fmp4.go +++ b/client_processor_fmp4.go @@ -195,18 +195,18 @@ func (p *clientProcessorFMP4) initializeTrackProcs(ctx context.Context, track *f switch track.Codec.(type) { case *codecs.AV1: - var onDataCasted ClientOnDataAV1Func = func(pts time.Duration, obus [][]byte) {} + var onDataCasted ClientOnDataAV1Func = func(pts time.Duration, tu [][]byte) {} if onData != nil { onDataCasted = onData.(ClientOnDataAV1Func) } postProcess = func(pts time.Duration, dts time.Duration, sample *fmp4.PartSample) error { - obus, err := sample.GetAV1() + tu, err := sample.GetAV1() if err != nil { return err } - onDataCasted(pts, obus) + onDataCasted(pts, tu) return nil } diff --git a/client_test.go b/client_test.go index fd63a95..69f509b 100644 --- a/client_test.go +++ b/client_test.go @@ -270,17 +270,6 @@ func TestClientMPEGTS(t *testing.T) { prefix = "https" } - c := &Client{ - URI: prefix + "://localhost:5780/stream.m3u8", - HTTPClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - }, - } - onH264 := func(pts time.Duration, dts time.Duration, au [][]byte) { require.Equal(t, 2*time.Second, pts) require.Equal(t, time.Duration(0), dts) @@ -292,12 +281,23 @@ func TestClientMPEGTS(t *testing.T) { close(packetRecv) } - c.OnTracks(func(tracks []*Track) error { - require.Equal(t, 1, len(tracks)) - require.Equal(t, &codecs.H264{}, tracks[0].Codec) - c.OnDataH26x(tracks[0], onH264) - return nil - }) + var c *Client + c = &Client{ + URI: prefix + "://localhost:5780/stream.m3u8", + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + OnTracks: func(tracks []*Track) error { + require.Equal(t, 1, len(tracks)) + require.Equal(t, &codecs.H264{}, tracks[0].Codec) + c.OnDataH26x(tracks[0], onH264) + return nil + }, + } err = c.Start() require.NoError(t, err) @@ -355,18 +355,18 @@ segment.mp4 close(packetRecv) } - c := &Client{ + var c *Client + c = &Client{ URI: "http://localhost:5780/stream.m3u8", + OnTracks: func(tracks []*Track) error { + require.Equal(t, 1, len(tracks)) + _, ok := tracks[0].Codec.(*codecs.H264) + require.Equal(t, true, ok) + c.OnDataH26x(tracks[0], onH264) + return nil + }, } - c.OnTracks(func(tracks []*Track) error { - require.Equal(t, 1, len(tracks)) - _, ok := tracks[0].Codec.(*codecs.H264) - require.Equal(t, true, ok) - c.OnDataH26x(tracks[0], onH264) - return nil - }) - err = c.Start() require.NoError(t, err) @@ -429,10 +429,6 @@ segment1.ts } require.NoError(t, err) - c.OnTracks(func(tracks []*Track) error { - return nil - }) - err = c.Start() require.NoError(t, err) diff --git a/examples/client/main.go b/examples/client/main.go index 6bee66b..543c778 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -11,44 +11,45 @@ import ( // This example shows how to read a HLS stream. func main() { - // setup client. - c := gohlslib.Client{ + // setup client + var c *gohlslib.Client + c = &gohlslib.Client{ URI: "https://myserver/mystream/index.m3u8", - } - // setup a hook that is called when tracks are parsed - c.OnTracks(func(tracks []*gohlslib.Track) error { - for _, track := range tracks { - ttrack := track - - log.Printf("detected track with codec %T\n", track.Codec) - - // setup a hook that is called when data is received - switch track.Codec.(type) { - case *codecs.AV1: - c.OnDataAV1(track, func(pts time.Duration, obus [][]byte) { - log.Printf("received data from track %T, pts = %v", ttrack, pts) - }) - - case *codecs.H264, *codecs.H265: - c.OnDataH26x(track, func(pts time.Duration, dts time.Duration, au [][]byte) { - log.Printf("received data from track %T, pts = %v", ttrack, pts) - }) - - case *codecs.MPEG4Audio: - c.OnDataMPEG4Audio(track, func(pts time.Duration, aus [][]byte) { - log.Printf("received data from track %T, pts = %v", ttrack, pts) - }) - - case *codecs.Opus: - c.OnDataOpus(track, func(pts time.Duration, packets [][]byte) { - log.Printf("received data from track %T, pts = %v", ttrack, pts) - }) - } + // set a callback that is called when tracks are parsed + OnTracks: func(tracks []*gohlslib.Track) error { + for _, track := range tracks { + ttrack := track + + log.Printf("detected track with codec %T\n", track.Codec) + + // set a callback that is called when data is received + switch track.Codec.(type) { + case *codecs.AV1: + c.OnDataAV1(track, func(pts time.Duration, tu [][]byte) { + log.Printf("received data from track %T, pts = %v", ttrack, pts) + }) + + case *codecs.H264, *codecs.H265: + c.OnDataH26x(track, func(pts time.Duration, dts time.Duration, au [][]byte) { + log.Printf("received data from track %T, pts = %v", ttrack, pts) + }) - } - return nil - }) + case *codecs.MPEG4Audio: + c.OnDataMPEG4Audio(track, func(pts time.Duration, aus [][]byte) { + log.Printf("received data from track %T, pts = %v", ttrack, pts) + }) + + case *codecs.Opus: + c.OnDataOpus(track, func(pts time.Duration, packets [][]byte) { + log.Printf("received data from track %T, pts = %v", ttrack, pts) + }) + } + + } + return nil + }, + } // start reading err := c.Start() diff --git a/muxer.go b/muxer.go index 8c6b73b..b96edb6 100644 --- a/muxer.go +++ b/muxer.go @@ -194,14 +194,14 @@ func (m *Muxer) Close() { m.segmenter.close() } -// WriteAV1 writes an AV1 OBU sequence. -func (m *Muxer) WriteAV1(ntp time.Time, pts time.Duration, obus [][]byte) error { +// WriteAV1 writes an AV1 temporal unit. +func (m *Muxer) WriteAV1(ntp time.Time, pts time.Duration, tu [][]byte) error { codec := m.VideoTrack.Codec.(*codecs.AV1) update := false sequenceHeader := codec.SequenceHeader randomAccess := false - for _, obu := range obus { + for _, obu := range tu { var h av1.OBUHeader err := h.Unmarshal(obu) if err != nil { @@ -236,7 +236,7 @@ func (m *Muxer) WriteAV1(ntp time.Time, pts time.Duration, obus [][]byte) error forceSwitch = true } - return m.segmenter.writeAV1(ntp, pts, obus, randomAccess, forceSwitch) + return m.segmenter.writeAV1(ntp, pts, tu, randomAccess, forceSwitch) } // WriteVP9 writes a VP9 frame. diff --git a/muxer_segmenter_fmp4.go b/muxer_segmenter_fmp4.go index 5888cc3..1cb8ad3 100644 --- a/muxer_segmenter_fmp4.go +++ b/muxer_segmenter_fmp4.go @@ -196,7 +196,7 @@ func (m *muxerSegmenterFMP4) adjustPartDuration(sampleDuration time.Duration) { func (m *muxerSegmenterFMP4) writeAV1( ntp time.Time, dts time.Duration, - obus [][]byte, + tu [][]byte, randomAccess bool, forceSwitch bool, ) error { @@ -214,7 +214,7 @@ func (m *muxerSegmenterFMP4) writeAV1( ps, err := fmp4.NewPartSampleAV1( randomAccess, - obus) + tu) if err != nil { return err }