package h264

import (
	"testing"

	"github.com/stretchr/testify/require"
)

var casesAnnexB = []struct {
	name   string
	encin  []byte
	encout []byte
	dec    AnnexB
}{
	{
		"2 zeros",
		[]byte{
			0x00, 0x00, 0x01, 0xaa, 0xbb, 0x00, 0x00, 0x01,
			0xcc, 0xdd, 0x00, 0x00, 0x01, 0xee, 0xff,
		},
		[]byte{
			0x00, 0x00, 0x00, 0x01, 0xaa, 0xbb,
			0x00, 0x00, 0x00, 0x01, 0xcc, 0xdd,
			0x00, 0x00, 0x00, 0x01, 0xee, 0xff,
		},
		[][]byte{
			{0xaa, 0xbb},
			{0xcc, 0xdd},
			{0xee, 0xff},
		},
	},
	{
		"3 zeros",
		[]byte{
			0x00, 0x00, 0x00, 0x01, 0xaa, 0xbb,
			0x00, 0x00, 0x00, 0x01, 0xcc, 0xdd,
			0x00, 0x00, 0x00, 0x01, 0xee, 0xff,
		},
		[]byte{
			0x00, 0x00, 0x00, 0x01, 0xaa, 0xbb,
			0x00, 0x00, 0x00, 0x01, 0xcc, 0xdd,
			0x00, 0x00, 0x00, 0x01, 0xee, 0xff,
		},
		[][]byte{
			{0xaa, 0xbb},
			{0xcc, 0xdd},
			{0xee, 0xff},
		},
	},
	{
		// used by Apple inside HLS test streams
		"2 or 3 zeros",
		[]byte{
			0, 0, 0, 1, 9, 240,
			0, 0, 0, 1, 39, 66, 224, 21, 169, 24, 60, 23, 252, 184, 3, 80, 96, 16, 107, 108, 43, 94, 247, 192, 64,
			0, 0, 0, 1, 40, 222, 9, 200,
			0, 0, 1, 6, 0, 7, 131, 236, 119,
			0, 0, 0, 0, 1, 3, 0, 64, 128,
			0, 0, 1, 6, 5, 17, 3, 135, 244, 78, 205, 10, 75, 220, 161, 148, 58, 195, 212, 155, 23, 31, 0, 128,
		},
		[]byte{
			0, 0, 0, 1, 9, 240,
			0, 0, 0, 1, 39, 66, 224, 21, 169, 24, 60, 23, 252, 184, 3, 80, 96, 16, 107, 108, 43, 94, 247, 192, 64,
			0, 0, 0, 1, 40, 222, 9, 200,
			0, 0, 0, 1, 6, 0, 7, 131, 236, 119, 0,
			0, 0, 0, 1, 3, 0, 64, 128,
			0, 0, 0, 1, 6, 5, 17, 3, 135, 244, 78, 205, 10, 75, 220, 161, 148, 58, 195, 212, 155, 23, 31, 0, 128,
		},
		[][]byte{
			{9, 240},
			{39, 66, 224, 21, 169, 24, 60, 23, 252, 184, 3, 80, 96, 16, 107, 108, 43, 94, 247, 192, 64},
			{40, 222, 9, 200},
			{6, 0, 7, 131, 236, 119, 0},
			{3, 0, 64, 128},
			{6, 5, 17, 3, 135, 244, 78, 205, 10, 75, 220, 161, 148, 58, 195, 212, 155, 23, 31, 0, 128},
		},
	},
	{
		"AUs end with zeros",
		[]byte{
			0x00, 0x00, 0x00, 0x01, 0xaa, 0xbb, 0x00,
			0x00, 0x00, 0x00, 0x01, 0xcc, 0xdd, 0x00, 0x00, 0x00,
			0x00, 0x00, 0x00, 0x01, 0xee, 0xff, 0x00, 0x00,
			0x00, 0x00, 0x01, 0x1a, 0x1b, 0x1c,
		},
		[]byte{
			0x00, 0x00, 0x00, 0x01, 0xaa, 0xbb, 0x00,
			0x00, 0x00, 0x00, 0x01, 0xcc, 0xdd, 0x00, 0x00, 0x00,
			0x00, 0x00, 0x00, 0x01, 0xee, 0xff, 0x00,
			0x00, 0x00, 0x00, 0x01, 0x1a, 0x1b, 0x1c,
		},
		[][]byte{
			{0xaa, 0xbb, 0},
			{0xcc, 0xdd, 0, 0, 0},
			{0xee, 0xff, 0},
			{0x1a, 0x1b, 0x1c},
		},
	},
}

func TestAnnexBUnmarshal(t *testing.T) {
	for _, ca := range casesAnnexB {
		t.Run(ca.name, func(t *testing.T) {
			var dec AnnexB
			err := dec.Unmarshal(ca.encin)
			require.NoError(t, err)
			require.Equal(t, ca.dec, dec)
		})
	}
}

func TestAnnexBUnmarshalEmpty(t *testing.T) {
	buf := []byte{0, 0, 0, 1, 0, 0, 0, 1}
	var dec AnnexB
	err := dec.Unmarshal(buf)
	require.Equal(t, ErrAnnexBNoNALUs, err)

	buf = []byte{0, 0, 0, 1, 0, 0, 0, 1, 1}
	err = dec.Unmarshal(buf)
	require.NoError(t, err)
	require.Equal(t, AnnexB{{1}}, dec)
}

func TestAnnexBMarshal(t *testing.T) {
	for _, ca := range casesAnnexB {
		t.Run(ca.name, func(t *testing.T) {
			enc, err := ca.dec.Marshal()
			require.NoError(t, err)
			require.Equal(t, ca.encout, enc)
		})
	}
}

func BenchmarkAnnexBUnmarshal(b *testing.B) {
	for i := 0; i < b.N; i++ {
		var dec AnnexB
		dec.Unmarshal([]byte{ //nolint:errcheck
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
			0x00, 0x00, 0x00, 0x01,
			0x01, 0x02, 0x03, 0x04,
		})
	}
}

func FuzzAnnexBUnmarshal(f *testing.F) {
	for _, ca := range casesAnnexB {
		f.Add(ca.encin)
	}

	f.Fuzz(func(t *testing.T, b []byte) {
		var au AnnexB
		err := au.Unmarshal(b)
		if err != nil {
			return
		}

		require.NotZero(t, len(au))

		for _, nalu := range au {
			require.NotZero(t, len(nalu))
		}

		_, err = au.Marshal()
		require.NoError(t, err)
	})
}
