From d92149cc94acd829cd1652bfe3e72d3468182649 Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Wed, 8 Jan 2025 16:28:10 +0000 Subject: [PATCH 01/11] fix: Solve CUDA av1 decoding --- CONTRIBUTING.md | 3 +- .../decoders/_core/CPUOnlyDevice.cpp | 7 ++++ src/torchcodec/decoders/_core/CudaDevice.cpp | 32 ++++++++++++++++++ .../decoders/_core/DeviceInterface.h | 5 +++ .../decoders/_core/VideoDecoder.cpp | 6 ++++ test/decoders/test_video_decoder.py | 19 ++++++++++- test/generate_reference_resources.sh | 16 +++++++++ test/resources/av1_video.mkv | Bin 0 -> 16375 bytes .../av1_video.mkv.stream0.frame000010.pt | Bin 0 -> 692574 bytes test/utils.py | 15 ++++++++ 10 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 test/resources/av1_video.mkv create mode 100644 test/resources/av1_video.mkv.stream0.frame000010.pt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 55e5545c..bc3ec3bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -42,11 +42,12 @@ git clone git@github.com:pytorch/torchcodec.git cd torchcodec pip install -e ".[dev]" --no-build-isolation -vv +# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv ``` ### Running unit tests -To run python tests run: +To run python tests run (please make sure `torchvision` is installed): ```bash pytest test -vvv diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index eb605465..94a56c1b 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -35,4 +35,11 @@ void releaseContextOnCuda( throwUnsupportedDeviceError(device); } +void forceCudaCodec( + const torch::Device& device, + const AVCodec** codec, + const AVCodecID& codecId) { + throwUnsupportedDeviceError(device); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 5d48d26a..192ed544 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -256,4 +256,36 @@ void convertAVFrameToDecodedOutputOnCuda( << " took: " << duration.count() << "us" << std::endl; } +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +void forceCudaCodec( + const torch::Device& device, + const AVCodec** codec, + const AVCodecID& codecId) { + if (device.type() != torch::kCUDA) { + return; + } + + const AVCodec* c; + void* i = NULL; + bool found = false; + + while (!found && (c = av_codec_iterate(&i))) { + const AVCodecHWConfig* config; + + if (c->id != codecId || !av_codec_is_decoder(c)) { + continue; + } + + for (int j = 0; config = avcodec_get_hw_config(c, j); j++) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + found = true; + } + } + } + + if (found) { + *codec = c; + } +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 42dd63fc..ba3992d9 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -43,4 +43,9 @@ void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +void forceCudaCodec( + const torch::Device& device, + const AVCodec** codec, + const AVCodecID& codecId); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5fa7a872..92b1af50 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -456,6 +456,12 @@ void VideoDecoder::addVideoStreamDecoder( "Stream with index " + std::to_string(streamNumber) + " is not a video stream."); } + + if (options.device.type() == torch::kCUDA) { + forceCudaCodec( + options.device, &codec, streamInfo.stream->codecpar->codec_id); + } + AVCodecContext* codecContext = avcodec_alloc_context3(codec); codecContext->thread_count = options.ffmpegThreadCount.value_or(0); TORCH_CHECK(codecContext != nullptr); diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 1ff8266c..14d78927 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -11,7 +11,14 @@ from torchcodec.decoders import _core, VideoDecoder -from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO +from ..utils import ( + assert_frames_equal, + AV1_VIDEO, + cpu_and_cuda, + H265_VIDEO, + in_fbcode, + NASA_VIDEO, +) class TestVideoDecoder: @@ -409,6 +416,16 @@ def test_get_frames_at_fails(self, device): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at_av1(self, device): + decoder = VideoDecoder(AV1_VIDEO.path, device=device) + ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10) + ref_frame_info11 = AV1_VIDEO.get_frame_info(10) + decoded_frame11 = decoder.get_frame_at(10) + assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds + assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds + assert_frames_equal(decoded_frame11.data, ref_frame11.to(device=device)) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh index e10f451c..fba098a7 100755 --- a/test/generate_reference_resources.sh +++ b/test/generate_reference_resources.sh @@ -61,3 +61,19 @@ do python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp" rm -f "$bmp" done + +# This video was generated by running the following: +# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv +# Note that this video only has 1 stream, at index 0. +VIDEO_PATH=$RESOURCES_DIR/av1_video.mkv +FRAMES=(10) +for frame in "${FRAMES[@]}"; do + frame_name=$(printf "%06d" "$frame") + ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.stream0.frame$frame_name.bmp" +done + +for bmp in "$RESOURCES_DIR"/*.bmp +do + python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp" + rm -f "$bmp" +done diff --git a/test/resources/av1_video.mkv b/test/resources/av1_video.mkv new file mode 100644 index 0000000000000000000000000000000000000000..429a1e9a7d75eb84f8a37ee735689b5cce20dcbc GIT binary patch literal 16375 zcmeIZbySr>_cnaf-672(rQ^^7a_BDU1}UWk>5_Bk1`!EC>W~6TiG(zg(x`xdG$!e4gL?z2D#OTHik(OIYj3o|!#+X0AQgwFf%oA4sG^Gy+ZG6#PC!U@H7U;3|a2 zy4(3Yed6tErw{}FR}=!B#sKb&2B5y=ddReMd^$TK;meh(+N}u&m625H6;1|Kh{AuL zs14t`f}i%ZTNA%OV9*NwGF+}~=U;Dt!ELAicYmP#8SIFZq+MI! z`Boc2q5#u|Cu-UGIf)AkNC^lF37g3OA0NGoQzIJu9C>w`_F51uJ5qT;5fo~H5AT`tw0094gZ$)S=B_+TD z2mzR|XvtC+k(jNb`|N*yquWD%+6bJMP-(k!1>XNDnK^Q9UNP@c{FdxRUyb4~x!G~G za4~z8glS`+iyAMvW?og#nu?b^pH?SG{d->qF=F8EVwrm-tS)pO?p=w+hXpkNmKiI5 zCRWhR67WCdC9Kc>p60|Cjndchv*P%is+jNMU4QkGEMfTL5Asg*@D^|P@sT5dj?(33 zGhDF%z;|0($)V9HkX0zlD2uT!lANiDx{1)#3Li*8%)hVm_V^Gn<=Nk9tMFYVf0DT< zkTkJxIPfQ?B5(?skPhYS)p_BEHp$eaD56F9se1N~;4go~-IU&b_P4A!cTZ?}?lp%~ zr@WtCR=;;B;VZFNvx!$biBn!UM3Rv*ZjXp~5b(q!Cp+|6qXNO_CphOj&g_ItWf4LI zuc>#x-{pi5uNPL*dVGAfrbQd~rrLLBPY9l}%veZg5a!He(8%X;`Ql0Qn+@4-yBN;{ zHa{0`tZ;G0di_9d+wZi#<}Xg{5%C{y6IGJJB(x8R!++;C`b#;1sN^X_B6UKx-`Dr z1Ol2E95QHDd_qwpu6{wE9(@0j=RUPNnif%4v^g5v1^2W~X2rFvSHzqXl@lP6L|krl z++~nX^Lo_Pd^nDHSIVB<8gJV+;&J%So-=I){g>IjBK^qg_lsg_1e(S%zZUOlHefg? zmj!IqslMgY)kj~B%ZYw#~PiY*@=yrO5+)avz<3n zlV_hJmlj%Hc`2-_&Gw7TmH4x2B-=tSJaqF+_8oOafyg}b#qDf`y3M*f80@*$=CU8`>Hw%;Kt!CAS5lEYE$ z&nWN-pQDFo&}dtXWiE#Kw!RdGmMXzF2DXyTai46jM z?l1MDg+JV^Cr>jWYNjH$x*NZLAX<{;P@r79vK!wk9Tb1$GQ{~^G0PL*%%k_hiujuz z!$He|kmWrsE<%68emjrZ5BV(Q%ZhA`lgko$p5HllvnYR4wnC8;>#V~2tBf`fTSoH)5DJAvV|51Pt0dgRRVlc_t-H)nc@W7p>OqAFlXxf^j zjBImI9PuiBbc*`f%+vBrr=y33 zWI|?hih#-5rRW9YQhR=Vy3saI!R7+Q3vX%d4Oe)kejY83VZf)^n5>|>0yb>>v1@lC zlO_54L&F90cNdsyms+y%yP~8rMTk#dt1BU{zEX6S=x!M^YKwEU)3Kjo^QSz_Js})Q z_ce7Ndj}aCSgiUp^`5)E(x^|zza=X8}X-|L_hBlzvZw{m}pi6jcwQmuoC9fvmd zg>;XMo;f85;|*Otv=}F6`&)^*d|~AlLM$w3)|(OhS#^ZG?>M~hQQk()Q(G(ZE|qu6 zNrm($7jM^&jx#1(8$Up{HUIon**i#5y!g;HKo}KhXgnJJftsX==X*iU!S>MQuIRm) zqSW4Ph46dd3YT?KtZW)6{I{hgpEB5Xy-f3z|2?qi>Aa`46FC9Dhy=5n}S}*&04n#G0n!6Qw6A0Fkzge zDn6qTIE2hB&hhJ!^)&9NI78&er2>)}^ zmtFkGR(3~&Fu5mUeSBx*-76!xqJcjvjzv%?Jw+L}#o}r{>01%;M`HAq%Rtg3J#uGl z1vZ*%&|B#mUR=?0zJVv2G_b4vK_8D_Mm{!4ebA0Hy7bPc!7=<{4nka!@SWpGqu6Ae z^vpZn*eWzT!W;K9f+7~~O2tv7RnMympJyBmyo|p2@rMn5B^B=4Wd4LvvW!_|%%(5l zo#AhoLwPyWH;e3vlSu2mboa8Fv>AOu=K_^08a@h6Q+ciCFvm=sPnxN{7f}P3D{Kb5 z9df*{{S;m(=ua@vcsOL<4XOR1Y;4zT&gb*qec~`fP}OcOK4PTDs|W9%sSl~x^HVc| z`@K)_(XQBS?(A~uVf3aOg;HF;qv^7poRx}_EJz8fe4=KLoDm+_f7u#JOFKT+&qo`5 z_<6=?#PsDeQVpht_B-~$l3y4ke~3>A>~BSG(-b%Nve5(%YF!aen_sv zme-siO43nKc21@REq-$L2fmxM(_m{NKa9CJ2jL=;1rz>|%sZW2-)a?EJf<*q5XdJt zf5MsLtF!LR*}Wb9Yw6vlVjh476DyB1Z)V5Q_gsqIE;W5@xyQ#ZOQ{#*@r}u^l83g?CU(>OL z;(YveIc7Z7BFuu+MV1&J{lYF{k62JFKce_@HY=p!)n91xLCcDK%9nG{yh(eHNsQ0D z>qs8yr2l1OXXi(5rJK-;7xZqI60ZXbX7WGxtk`}Gwoizs#b{xjhzd*JJ7}(cDDp0| zy_;u^NOQJgSwMIAeN2bjRckJZw|XsH8bU^`k3(Sko_CZ_L@2)KY_hYtPHW6l4(3E+ zh@pBzZ&2u1JxRxy#O$1#)6+z}3A=nu&-Ykc&gkm}?$TAy8OM!lN!$M5>FO*?$-~7$ zQ(8PZznBMBUxZaC-?H*JDG$1}kffc~`l?n-0Fv zbsY5aXLNKHY891>;sfXfG^08v=^IK}vO25_8ArH&@wIVm4omYt3{pp(H7n)SBgY+I zDtNf-&g3q67Nh^@un|#ZlXc`iQV{?s#W25o(l35}@OdL%6G^TX@AP8l$$--S$L-;h zZ-sT-qfG?^$(jluPpHIJUSn{&#YK(Jp7IgV zehc~azShjJWYZM3_W_{9GBDY==<~Zs%DeWMUcM$t#)6%yWBuktk7@nwUT0`QaUx~3 z`i8>g&Ef=cnC8Wv+4?!P=X1mHtqdqjXOrWR&TmH**ot@^&fe0Km`mH>EjcL%bbA(E zCMgo;2emfV_Td2u89DF0_crK@rT?M};(rVHG|=Hw0q4)n|33YeAf|?Ge|<+7W^>}v znCYf}H_7xVoKBOjEmu`goboOr;dCQM(B%C)j@>7&YXdgl#Nux1u0-c%O3y+Cmbgea zANNsrjJw#o5zzP}0}K^iM(hJrQu7n|bf@T-?yoa9ah%#)ESa!)++C=Ix7$2?MYb1t ziODQ?G+OkS^FFhQA2oM1rB&7HkF}{pPVbXGO#l1-ZG@W**n97qqLhx8R1^r>ANpTK zc2Im^FAKw39I9@?rra5BoqWvDt$xNgm#^&r!HdNlqW5t}S_VBumsU@hiE@@k?|NnX zfbwT+UPQNiQFLQ?&shH5! zErb42ujLU7&p{57`UmmA!9Ex`%mfC&URZb{jU`6y(GNRhNYsc!M3m79yFbk^%)GL$IxUkyGS)k z>C773l_ASogTUs$@!Bd_0SDxh?lOItKQR~Lv>cXQy4z|JcZFwkh&fE>{QwR;S z-iN0`W6-%VJz)S~9pBl7ODYU0LR$iIe-3{2<7kT_%6MB}E|u28+X!#JtSTO^G$AT2 z{uBgjfqAu?dM{+1Rh+`F2Px(zrVaZI5@S#p z=IA?TyqZz7XzLATq&3_68C1(nVpG@rV~?fMq2jU0U~~*aPignJ?ltyopL`8=nLlV3 z8}pVX2I5<4LwxbxWDNDG-)%NuRO`-%x#ozLEHuvj9&DIyD(^6qMSi0FjJUwv?4fM1 zg(`BR$pe6>buz3g8dsR58?gjz{3H@3{6k^&M=NkruaxGgAOl-!zLdEX?2Xyi#FLC@ z^o$?d3*Y_U1*sE`>0&uAjzh(uXlnpzVjClz#&&?54j{f=Uf*ZaS#-OHag-XtH}-?x zA@HqlXgsx>pFbrX(eP(@Eg@qNfF3}HHh?TZ0097$4vN8r2Yb}U0QM3~4uHK34e8kV zHr71qL75hSrcKBt1TX+V#keuJFvfYTlOf&Mq;??aE&)3O8be|DWMkQrN&@-!L5WJt zud*v=B)+-)@nS_nsjg}u9Lfg8V#g`iYGdpKZ2N9USmcn7vEn{kUhW_(kd?c+D<6&; zY9Acx3|WLgqXD1{P#jhgQ`0tv|7o*c+J%N)Upz+>0t1cFV4!h164r@8djvqOxv|;N zbB(Q&I=cV{d59r=3E0~q;5AFc+$Js0W~x2T4GN&d&?eZ6Q9f`xSapLf*^sr~DZ#6O z$D;i$PgDNPaI^Q?#Hyj!!QyfgaIY#inSstP6>%`6F8>;7fI}s?v0&l3E~w?gB$U6M z_%UQ^zUjFL-)X0BaOJw@)dbI$_D*_vzsMzPsX@&p^)|-f#G6nPClJEX8wZyIK)W4impK*>D^zBuVbI`q z;!B|M7IGeejtf9Vxp7#qcMPqQ;hN>={G94c2AbLQb zPI*=zL`l_c357t3AW1XnsP%w-x^?Q%;kSGkX2&1p~i=@3bW9eP`hji4)+iw5`hi}po~ybfPi$qjR~Lu(7}!| zfvfhr9l~@acVN-hhDaqB1bR6D&BJ0L6N-vWnkY;K6(4zgezr;KXcu78o)6Qg%WAT+ z6WAM_48b~~a(JGa&GEL=xAhBkGs`?nVi39chN^=O=_C;m?_rwyj-g=R|DE1r1`QJMb!vw0+BjTk{iJ_a7Q> z$(Zb8bA<+2zYm{qX4CY@$$-(~)$6F14fW~vDx+PhKW@b!AF zCB0s?mIg;3uqBW8ok1oQ6d`J@TxVx`@&w}VXS`)BRQf_}k4rebmz|#QGuQN9>1E0; znX*)@S@^!9I|B zES7+eb+>Y8|1ooYL}3WpyG?N%Gu@Ws3l0MA+=x#98s-hQD)HayyQ5$EYhC$CG=&9a zTHZB!&s{AAA8w}!GPh|^r}v&7%3X^9Viq}Y{5#MtVe8$d`%hDzSiBUPGuHVvFz;Eb z>dS^qXXhs@HfXBDcx+d{)@Mi1mY`!r8x08Qy<+bnq%x0d+4=NcW=&C|L6B2I*w+^i9b-1mAw`+hf^7xK%>#Q3A3sca&{Su z=6|uiD`z+Sr@X-d!tPU5&u~=n$I9YzR@}iXiZ^uO_hDtW!T9d_EDv7Jjx| zliO{-c+}|5In8UVm!vCf9dSbNOgbxA*BPrN;W|a(42yCAyW7!V?M@_*Pg>@C%c8*k ztGI5?dl9ezL(fhx64uCsL(C=Y`!yJLs?vGmM3=n#0+o8k&MP_d_P;FMSL~&N3<}pN z0crg(`;)Jwcud=@L5M-TKkrLLOqbx& zvBR~#@APUdr!Vwz$8)`H=IN&D?F< ztvPTX#@0=UMwcZ;M2NmSTgDJ-@Eha8?!724`q3rS@;mSou zbp2`{ZkBxAp?g2}kRrTFfipe5NCysr4H-HCnpqnoj0R2l7R62(rMd8*rd>acVgHMQ zu<_dzk&i$>1E91}3U(~+t~NS24S;oPz0VV#-zvJp4_abCZZY*M0z(ajsRS(8<@<}t zKo^`;{??K6fV&%vFtWfb3cbUQTf5vw3#UPgAjhX62L2oo!_mjXw*%65 zjgUX_3-xP=0>7R{hP}dW*yA82wkT>Bc!MlNU~~dddMFJp0oigJ{cU4Zw*JE)S}iwI zh2sE64YUOgbcRZcn4|#Ilbd*4bvkx25u6I8K3g=nerFDSqx~}soZDf|B73wC7+Q;> z8H?+Cwaj>A9n&aMjmD|O(gN2#By8ETl}@VSczHK-Ev9LIbL`xgI%T+f{mAu~$?(!% zG}g(-T3^ zlz|H`P`-^GG>f@4+kGY%Cq8h6GK>Wc%aNA|%o+fy&rMQ}J|4daUIj2|4JCO>Ov%Ig;MvRo}`f99;3}jqA ze^}NdhuX36;80d56FXMnavL3t6!vdz55s(Qu)=of6{EbExea|S0t*13 zyif=$i9T%`eJ3D~OO6g6o&E5od!N@EI9-2OR@ZX=eY*1u4z29=6B61-mLssZ0B{r* zT)dxnZS(-C9q#Q6|FC5TGYFN@IomfststKZ2msSDmL4h{vjRj?RA8DVMZ*UzF@f`N z+UmZ8%6e;ce|p-SC8Xj20xJN3e#c~?6{3C}um7J+dvYYFQuYg&mK4XYr#n36iu36q zb}3vnU@5yB0)IchEUdvLa7;$Xth!w3?ZFUMDHoG6eU{c~lz`YJVl3~5_&(g_kRk1^wu3R4T;)ufgp-%T<8oKZbTA0eIot(d>oQkdSu z%!K0j!N>NgEM~(;c=m{S4Seh>jjkS7`V>X)5BcxL;+dT*QgkeTel{FTe3|;PcK3Xw z2=k9daNMA0XRlBDqvo0(&(nsPl(K&3T8CD}<0wxdhv&VbWt7261)3TJ>UN9ma+@4K zL(RYT+Y>tjgUx5@u|mUSFSI5|K>LV!^I_G&EsO+hnMye$*gngM>W*%CQOd{IM>H> zKB}eM7ZsI3LKqZZdA+?RF^VWyuvfOh^Wz`FGM6zj4fL)5Dmklq>HC2!7w_E9d+pCK z;|KK`$5;A)y#v_l4GB6UJBaF{hv0!-wdu5?72h|)+dPIcDBiygWgq8Owy1PEl*~=( zU$RbV2r4!pdD8Yct)kJAO-RBO9c$lfc&_}X)jOuJkDjOgpX7L+qArIgKV zv0u#!pl{K*sURw^UopkncAGeT9h>TlCdUpMV8UtiR72r#S8z`~DhCd50g|>&r%b>{g7^ z>b=#G@2QX6$3Bdj+14hz28;g%mcx4~=5U2czIHqy{!;9i-5o`(ek2GFS$SqUr|qZa zgWt4q;Zvk`@(Je|&5cvgvyqeMzuNaH6vv@9+!PAfuZx)CAYE7tERcIBI6P60PxtG8 zN;mjhI!3x^10MWEh!}g_r~HrDusrqG9#V58!tRCXQ-j@uady*P`cuXtjcxY7#?etj zJk}Sh#aaW!(dvr$T7NH`Rw<#9+~jK5uP2$3;9XIa|5C?ASTwz9#!yCdzHCt({vDC$N6j<_Ah2MeDA*aY z5`i@UKslfsTqGjJvUH%K61gM@mxuXZ#+5hC4<6#Nl!ZWo81p(wtzK zJGKXTZDfL_*`Xg-QO#1pW*}b!+cg1DGwwUA=o5KNiIC1nFi@o+NS`#gbJx*B&R@)p z|E|93xO9~nc|jmZKU0j{7+i#lJ#c1gyF8Idp!)iSt~pewO}$axCWhFsIE~Q^=!{l+lyxE=|)e1Z<3Mneug=w4|;a=N1k%< z62Z%gm2f1>*a`U$Kid);`SsT~tOTt}pi2N6WCe`O`buZZO(e2n9 zYOIqmFnfFsi~K3Da?wxy=E_M;XQqgKK{XAQhohI|)1GeQiBWg+PP@Uqc${RA_4$L4 zCkeg`%dhOQc_taR4g5Mn2BXSehm@&KxfkJPEk|(Mz;g-w9;&glzimCEl}LJF_S~=U z?k2qjw~d%hE&u%HukXf8dV#D|EKCi$dBP&<@{^3g&#`%Y$6j&LUJyC9i#Qd#Cr2!b zsS&w31CZ)3mve%AT|D$ZJXV&OSgu`ikNTzWe2R1RVA)eWZ5IS`5=dB^vd2HR<4dv3 zv#AOiYCOPbikr#pq9PAr6vwYPl1DFpa`RfgmJs8((va|>C+2fXG`7UJSz(-J3;QJ3 z>>lh90R)X+J~xTw3%Gsv<;rn@$FI%#7t5KA^o-%z(pXRO&rP)a7S)Z0JeRE}d2rxb zffXmY0Vx=5wO7H4KH>7g7vEJ zo#F0P)7ZSgRU0`|ZCen%o(AT9jW_j!7%t&4u~z0BE7--Mb+MV>81k4}aKnoylKcju z`D1ii_~ZCIYAz>f_hu*5b64M(TT?4qS};&=S6d8~5*3a4I9AR4dis{eZDcjsSUlSE zZp&v1EwPqXL}BI)>2aEiO33%4ossK-Rx3UX{abdVPU@*j*Fl=U;-%VX#ap%!*EHJQ-#A(RNjPr%th|z1de8hiG!s%%2_3Z{ zDx>52T35w9SDV<1*o$SI^R?6sNicpiy#&fo^a?n7PS)|ZU2qx}7Ig6NoTu-b#zNS10w!wL7sT>vZpOG8Y`ACM>sza=y^{Q9_p>g?&%>(H7sx$vu*M zUV&eZeZ@8Rm!RxW5g<7W6nbDZEEu=Yuu%@>XZRkpWRDw6Y5jchbEID z6$o54095P5gm9axP)rF8?Y7Zs9aigO?nWYN$3SN;89;Mc`%uP;qVs zi1NOZbz%WHap7CX_L8M1kP@2_xTVuAeQAqduvlfYk#3Bus#;i5@D! zN-$&E#@I=N2EQF^jW7L(HjhAMqaHOz=Pk|^B5=1*%tuNHi|cb61BhDxu_c0i6b=<= zxTRQlq@e8<5>|=8lLVmJ+{~=lTawoCaIh#u0|T?<#GOs76`n4GXE_1}3H=|~`YTvy z(4T7*i3wyjcd3&lON?ngM-!8{($T0;O%@hq68qD$=<_+ji4#*S44fa6JoXkC_yVxI zTaMz@nNBl+^WZLCJy1WhOb|2Xs~Xo9p_=j_0rugF8JFh`v{7tZ99bdUnkp1)K)Vg%`Wk)1oZem)%CP1(-pdHQ z82~E74PnRH7qw1;baqny)BR5MOfM(w#xvRlFnnAqvEB*z%6|(CP-WuY6#h+ROgsNl zPh0>E);;^O6HH@rkc6erVAKggIS)0Rb&>0;xLlj@Lpm9q@`5p`+iRh#qQNBmVo6R6E!N27t3IJZJ56DXx&`_3`zyUTZE9FLJ zD#fJOMKF&iQt>?kKLCLCV6t!vMGmx(fxHBUfvY5+?jmc~?ou z&*+t6>6RsC5NIf;@lN8)JvYbPz$E{u;LRe*EZ6SL)EjU0Vb85QB?eJr^(-VqKWAYA zCiks+m0nX5wZ7pI4G*f=zj*Hq(I0!hP|TRp+8wjZ+K{;a+TYx0)3@7Lz+jne{D;Hr ziOUDqlakQXxFV`M4~33M&I{h9wMAMr&n|QiRN@uNk+&XH|AJDf1?{yIJ)j3B$}i0) z2kmzIlN+d#L_@F+(*h;~7O%s)3f(PD)dz@KzmdQ0(dZ0~iI?I%occuw<;A8LV)yUM zCQmpuAJxM;TJw6I_#$W7=&`PqSWnmuF4jgPslzlsAv=ZM8yxQ*!E$)6Ik51Fb9+H$ zM(5pvan(v3bnIP)&SYMm-YJA98;#ca`A3@uNOU+lL@W@C^f+V-{ z;QMhOXHMbi4Hd9HA*uyE)2H$c@s@r&w-s*27@#}zNONQ$E<+j*scsbsb@ACI{kqF^ zFqWk(gqC7qLV2@!Cge!H*1X*tUxN%f^0f!Loa$QC7sdlMJe&PFY~f=CP8n6>4@d(Z z)NXi%R7K-*1bPm7#e6~J*`zVEn-+YTTA6Q8$?!;zN*7fc9Tu73!SrLUtFYL5{T{a7 zJ+E`V{LyJQ;|;lhZkoPrsu;b9yjEXg>1R!7HySqw4dy`nVoX6dK;s4AgZ$%gzG448 z(_nTlR=j9y-;AGq-h=Y=aAx&|mm(NO}NJ+JFzN5ExdhhbD>{6Uq)67FYt18)kD771Y zrN3mD!0?@}`e{P;OI-C1E@J1(`_1Xpr-GHd2flGOBa}CJ!C$}DGA*_nR&WSBDEDt{ zi?QDzFEu$DEwhELj&AG*K#jQB6fg(&7vtdF0IodC?RE9zdwDqiAJXrPs*yt4Kdgmp zEG}AbS6!H0_kZJ*Xwt^m+?)CM=)>TP&vNX)twKDnnTwF`9j1Tc?xeKqpD&euo2&uV zG*+0>!1y8v5OCaofP8*&B`KvCQh;$d`es-8xJ|C*23M24<`_FSt_2R39Tai#5Jq6C z_z(2_{6J3N8i22q&C6PVf`kGP63W5i4E_MBMxo4#Ran?Y2MQ4XG@H%QRNfpdzpmgz zUX=0vEmln-2xtJP2sb-BN<;+j`v0)=R4`uMu&Fk=Y)*&D+~8U6c4VgFgBBDhiITQCXyIBl||0Pkp(Uq4$l2`(mwY(!S>U%e^tZ zZW+?P###waD3qHMrkVu?9RRzRzXkA<_$2J^^`zh+$I1a2^7B9l*o06BY%+t~KoFV&(EHq6TqrdgOwIt!KTU4$ z7&$UXPcG@1k9IGxX|-*R&sENA=4s5t@8{kNslqaU?&p6_#tKPP44w7{2IV6MOi4(+}Wgbv8-$z}M2rM6K)kcuEzJp-Wam@L9Vk*+h@kOmE;deQ7J_Q|{hIPZKW^BU_GB+k*gcVPj!HF_wABnA z%4+a2)1%10YY*d3EDk^N*Ztm}O8dDL+%ege2PGRz$Y^$3DNiN1@)wYDFBcb{mF-dJbR)Dm2P!nU27|*t#6QGvhx?4`i z<*&f4-T3oM2iB~7=~kxBe%59WbB0)mdtJla;3V?%$*)pk;4&mR{`dYpXl#-!#j$pA&2=?fWg76N~(ky3X@6t4(dd zd)yJtJ zE;!;OpqNjBr28IqZpD>v^uTCX^V=7?Bs(3YV)UBgb9xQLW9F~${kyOUp3hvF*%^f0 zUKsYF=4cDw+n(gxcJnjgp!EDH(MZnKhCoV z+JwfQrE!pp<99M=SxqhyBsGMSmF7kn@e74(}5TeUx2$S>zJP$K2k%Hlb_s z9uIQzL($-D{<#<}Un;@ms#vPPKlxvo!6nC{D{vGN4yn|tm3l9qh5xLq|M!q_#C*D^q8BqYyL zM;$Y+ZTwb3_KnT(m|UX2>t*T}bKBhpUXm>Ea?7KHevVR1IuZX>tu&R^_YZXc@B|Xc z(wzPBp7s){d2HhzX}cN|LC{s1kZ%&8211bmDc0~bC`IAwE0L2zh=H{)Zt?M9M+%K+ zwF&t)mM&f(G7v#PWY_|$7es0(v^IhyO%=<6)eADU+q>kiJU@tG>i+#|4V10tEo{vq zh_X=f2_F}hSvON63@l#ALkdB#K6CQx6YsuCBDhe)(wtg|5jCR^y@ioSmQ{F)3&`fP|GJh$R6qPYbeQTIGV?biyEZEC6`dIAO~P8z*#(D5U(; zX=Gxr&p_fGUAo!UHB$;*E3<3$R(M;g*-9g%LS6sRSGOjM1vA-04c0%iapb8OXpbxL zo^`YU%+mr&m{wry|0g+;8L{v|h!*x=ZuY&$UwdxkLGw=N>^!qhpxQ7&G!?8fSqR}K zOv}=AlK!WT_2*OA$1f?WXQ%)Y+ujBni6D+esZ~~xq}5_sRAuEK`#Yw=J{0(~HZ{E_ z;)t@YxefO?f_Mf1xup;bmKi9?fw_^?>o$3v>6s0%A>94gOv8dX-H$sJ@ZStDsG--2 zRcUO3j@Ss{*GJ0!&pU;Eash}-1*^uoIM_7_a2vcr1PK>v)$Kr)kPEdWK>ux&Pbl@$ ztscd_e_Q|xi6kI(;DkN|h@?@H6c-xG1QWPsjIYMNGN;b_+hkd|6VpmDT~jHFR18Lt z1OU(qOcp62&RkHJDhHnsu>LA2&;u5Xr;6e$-k(s6$vR{_s1HK0AIv!W>?h)Wnsd(U zk>nGy*Raull^j)#b_Y2a^#fP>z)w5JKx<%vd5iY>6tiM_DKnjV4!ZUs z;%}fZgq*Ea(ob0S9PLzxewky|$7j&?nqF2jt1C7ZXbNv#jd#f*x4sE}B_TSZus%w(0gN|UGW?V!bI++w6Ex%aF= z*&iU)hbVwY_t2FE!%A>~jwac%`FiI3QBH4!t;yFP^mvT&lhI$ILw;eowce>x)D!t8 zTzH|+S8zshO&Q=gCvk@HTXnryQGAk9xAC#}-q=B}_KVTPUiRb+(a@Ljt|eN-cILZJ zxLxb67;<_Zv61k~J6t~__!FFjWiI<#R(kZ)gCDsRO}*%6Fkzfhot4f!7;aA_~~KoSd}y*XvAi37i&}T~9`^ds5gf&>#_{k5GvBmNSgqtCbtAD`uKh8L|8WfDZKUOD8u!I@Ylk{;B5hO{<<3O zt!v2Zwx%{kyXqOz%O$4Fen$dxElI67;ab> z0I|F6?AUJSi3_F>EFcydDT_HA0~>jL3=Wx+aP~M zkeLEdVQxu?^1LMINhcQW^IK02rg@IbEVf5B7Wk7LK7Dh;p$!mBH!u{uJ7QpG`V9sK z?g=YFlNQJ-NcH2$?+JlJWRlkE0BqoGzrcdO*M4;7blStEnOUNDYA>SNV0r zMB*~M&S@O&)dOY_dXMojn{sj1`B2gwhH%S-PJMqYo7)Wum} z7dxrfigOoN!yp^y9?T;abFpsFZKwS>WK1gZDxY-X#n?i}J6mnjic+55vIdKAU0t*nK0X>7=_8 zDG6D3a3~Sm-_;4^z|8@>mP)ein zyN)es^tf-^_JoG$@*@U)1A;pCMFZQ=N!rWy-tx{J2nFzj#9Gzwz(c*LO&5>V$L;I- z0kaDu?(0ZH06%mwd{kIyfjRV#mrUQV-_twAU~Nd*iy5obZ^cnEsYh@V7JRp#(7Yy>tU}kj{m!5QeY_G zmL5mDnzMRnVN?PshyO;`%>oV{{d=G`zI02D|MR}*G}M8td|qCHkk@K92d=BYBk_2B zTvIgE5efggnT763l!l;OfQj0In++)CE0;Dk@O z7Yc~;){ULAf#<&&uwF<)aq-v%y>Axy@guLGP{lma{htZoS@Hj;IKWWPcTgA1KlYGX zhOYL#+NPdP)G(dkFsOPh`4*~P8^}%~j0c_-OtgP*Uw233XTl;&2wXlIqm%Z3QgYgs zc>|Jr3Q`og4+0XFX+bdnkpK6j=}!?iljzej&Os`!!5lXODLznYrI18D0SN}M0?EL0 zMG6ewo~uKP6R$$;%2wH5PXjhnV8|H+1^8^{0W>5QE0`3tgUI8NV?M$J7Wk_ujJsW^ zlUPg@>+G9rB{J7m2DUlg7&H2*vuiAkER8VQ<|m(O%d95&pDm`x#yUWnG~YYTIgXi> mA^C;29-{u#KC>gz>Z{XsYf3!=_z(~NiT^=3#`Hb{?|%V!r_hT4 literal 0 HcmV?d00001 diff --git a/test/resources/av1_video.mkv.stream0.frame000010.pt b/test/resources/av1_video.mkv.stream0.frame000010.pt new file mode 100644 index 0000000000000000000000000000000000000000..6263e1ca2d9bdf8c48ced396c454876669b4ac4b GIT binary patch literal 692574 zcmeF42b?5D^~Yb&?96PsNH7Z`@&^O*#{m8XNs@y{PAU>k$vKFEWD&_Ahy=+Q_=68!|A3QDJnYC5S`RpW_CY5dcH|)koH+aF z6OK9lfRi_DJr34gcb^Q}cLJt7@Mp7+vz%L6w>&kDqb(Edh=^?8Fm~oXv|kZ_|!Xr7?T{pMMeS!rB!S!9SI+Z?MglYp(P& z!NMBq9|6!naKzz1eB*V<)0K(=X((y&@k+E%LZG1p3>r$6H-&@WcmuW44x!EyutK4s z)KKun%2wW<>p3OD*YJ13=L%gu;YwRyL%|m-+feYu%2wJAuxmT~aKm5a*lqY*c?++H zl$CqF&V%lRkkUBP@KrGp4?F@f| zdnDDp;3{#Z-0gX4Sy$Ot~=7rrbYIo_+)c)o_ehQD5ZsUBnaE88QfH5RNKK&v#^@K?4yN147F z&)4wR@K<%(dhIv-mF<7j@`9_xnR2%q{>rxJXwp~j`5OKj{(AYPdW_-k+?Z|w|!t7%opUsSj~Ot`t# zKQY6#(D1jK_8iu|weWnM??of{yzV->eZxiwdkLj?jNxzfea@9%T7zAB0Wiu{!5QLF zrU{pB^`D#JZ}siD`W$NE`5OKj{_4kK!{6$A8{{v_@V6f5Z1`J!d#)yj+IhZ)zlOht zzpC@Pn!Y9Y-Z3xO9*xqyU1?eMz{DtRF#HW}rM13VJI~kf*YMZy*YMY`2OIN(4Sx%T zD1lp6mf^2o`-~=jr4NKvCUB39HbS@&nCq^ieuNFKk?K5^wZGQA&~`iQFlX-EDm$4F zZ!g1>_P-kb*1ClU=O1FzoiM9LV6MB4ZZgj`fkLs@){>F=V!{6%L z@Z@s9eXeG|_11*Y3^CHprbY-i0(0GU)Q_;iHBz0Yvi47|Z;`#Les+hBmxj@oZul$P zR>Mk<@HPAm{gJ@UrbY+{(@MXqviwp#eOSFsa6a|3J9P8}jJHI?-{6*6i>uLizK#!1Hn>Kr^HkRUTJ%4DdDg@3(DBkR`qB-5{aRz~;-m3=4S(I7(ls1R zEB&s@@=Nvfwd;kvIoHGP(9shx-VzOez3s4eF~3`9ZB&TEyDQDA&fV3}&C`{$e)ZCi zu)#G_ou{()*RC(}=6vUxYZiX}>y>sXHBAsN4WlpJ@YmZ8YZrs?4Sk2_;IBjIfUENi z{Wzs-d@!x_yDG~s)zjCmFY@M${6z)#C$+vBJptn_(eT&X4r>t;SKF?0T!wC*uAKF& zmwto|u950Im9@VXeNk|p=p)z{{kGPe^~*nA8b)8b;cswj4C^X?w-d+7U!eo8&NKAm zl&A7O)Q zq&iP!?GLwisysKRAG`evYZcD_P3^p0Z_YJ^T^(j__*;3KtR!U_f62ucdfQl(F?E8) z$jtELl&$vB7!6BuVOuoH+n!D{)W|h-QsK%R2pe1@)p;swe^|X!HThijKWhzN9FZ3sJptn_(eSsLmg&B?jK8t0_BH(V zH(>~qrobUYn=kHo$jas))+59Jjl*W;|a@Mb2 z`Vls`Mym5v*8XaGr&`=w>(i0KrCa^yX82o8D=p)1oy>W5Gg@QRPFOWRPU#vSOe_7a z%JNI~^wsoEwYay|dkWP*F~hac@VA;)T6*)0BfoI+al8K~(!EmMob?M&2q}#tb>*yI zz4Rk&aE(;wsjU6g^iH+Aw|1u^g-f^k&&}|+npW!OFT&UGH+Hi7aZ1 zr>~}Ws@=V{I~}S1i5aeihQHOcQis1g?6B?Jxeq|z7NUNxXe{xu12JL)bmgpHz4Rk& zaE(;wsjU6g^iH+Dw-y~*xOA)k+zfxKX{Cn05#x2G#rbhc*Z5#s>33C@U#h3Crgy6K zy@tQl^WE^bnpW!Yx5cKYTC~RM^IR#gG)tf>XZ`A>A7O)Qq&iP!?XRYHs`b5NdmS(e zsC+o8o@i?2VE9{2D>eMBH+=hXO4s;cTIqLHmS3u;ucmjZ^}U9_)wAC4cXF-NAurBD zIONSY!+Q3sMLzY$$6DK?D`)-cr5|B~Yot0)W$m9_?-b2lP9M0(J8z-F>bt5Znp!y+ z{!Xr~I&BZ(YxrAhq)O%Q$0=RogK4GTRat(io_=z@Q#5yt`5-5qVE8*}=o|h{uC3nw z`Oo7azraXX4}LYuM+#mVB-WL)e)ZCiu)#G_ou{()Pp)^0<}Qc7w|`;H!uh`rtNWTX z$P~TN@ON@;75Q5@=N#VZ>KnltAXTagew@-ZKA2YeU6tjR>ggxfJ4JI>iNA6qjwaZ5Xq)Ju6k5jtF2h&QwtFru3J-ux26j%DNcEnPbq?5DU z@K?6Qj+=EK>e{zjP*@6H8YI@0vwro`kFdctQk|!=_RIE8aiw?OMOVKQ!sIBX?p=nz zvMtu|x1KTW$0=RogK4GTRat(io?famHZ?+6 zSI+v?OFzN}*GP4q%GzIf?^F*dtA7hUIf|)!m*H>a?Y6|<1QlD3hqvpvmz!b5Gy-$o zb@bzuuJOUN((kG)zf@0OdGAyYDGh&Rp9BnlD`~gxkzb2ZR9Jkhbv|x3H9}Zd&id6$ zKf(srNOhjd+FwcURBvgj83mT*m+I*ae=BLd?vY=^-+Jb_AE$JU52lrVS7rI7diqLw zr+Q0c_*?l?f#Gi@t+$#MRcjA7IRAQPWi4#em9u{J(vPsgHBz0Yvi4WfJJn~J$<2_d zoLqXOH2kfk^&0-xf|a$%(~ncS#s||%zpJwRQayboy;FUrG5oFM3Bm9;xb@;W3wNxs zMrwF8J(3IdH0$$fEf`nNdFskpzk2CM*x(we&Qn?YgL|iXzIye#>#$2NmhBZQVU|v5 z41a^$Fv7QR_Swi^lx$n7hxGN#%39dw$0=RogK4GTRat(io<6vDs^_c7UzBVgs1sW& zsf*!na2v*Rj&}dgd$O&le$v;1arK<1uAKF&mwto|u950Im9;;(cdGxZd%gD_s`CD@ z5@zX?#_%_|4I_LFf9n~zWw!ZoO4s;cTIqLHmS3u;5AL1n|0xJIh;RMvjK-bt1CA8-D>7w){h^1d+uRa#y${PkCCIQO!Am8uJSbe+z}ugM{I)wAxrV$bUiqN$aGuAKF&mwto|u950I zm9^j7JE<7cp8!ut}5ziKhpZLJ@tbd3+Dm3~)c`K5Y# zZ||gHjKkm7(HtsxU+4zGZmF7vzwY+ze)F856y^3MRSdh?)Cgf+IqO$1{RkUeBh`5- zYrng9Qa5Y^ptj$3>p6eAFSr*ifvG_PhQIDsZTPD?qy0FgYkV-R^t&p{FV)k#dna|n z4u5yte%l8hxZm*C{b*tM>u%MutHp!cQaxN%ou{()>&jWbdg({l;2NpUQ(61ny^|i{ z;9k+qrbY-G{<>SW;jb+3>Mgw=r*w@Erj>qIW%;FgdUx-nN7(S!{aj)A>u%M!Tlk(b zr*#CkrFytpZ+xt^J-Tw%uU`5QHn>Kr^HkP;ckiS}7(3X2KGw~qMhF}Jx?6Q{+o%Vg z%GywC-qlgnCRlOAC>=_JG7(xZjpuhXh;|H4`*S-RDK`g-GI zt?kj3vwro`kFdctQk|!=_Upl1cP;K%V~xVuXZ!V>B``He!0@-k-@>`)`Lzml<7>^k zdduIBQ@X|n(@MXqviwp#y_;Nmgpt1}zaF(guv@C8;jh!G^&9zhT6n$jvDWtJ%2~g9 z=||Y$8mZ1xS^M>1uD2F`BYh<>HAukl*J&Gm?ZNLkO_lgs^RC|V_v4hV@xipx@2V`n zR8Q|EmmV>GqreS<-BL9Tf1UQM3QuM2tT#T^+8$jw>sK%R2pe1@)p;swzaF&p*5ZII zHv8M{w*texpoIzygi2s)kbvQ@;cu;(7-jx`oYFNum{$5-mF1V}>AmFABjzVtZa&6e zJct1iZ1~$4B=)vm{r22&`^ZtUuO7e zUAO*BKf(srNOhjd+OJ!C_5V(`H`PBQV!pN5fwwUL#N1NdND*o;_JGt@OJp z%P-Z_>)C22`j#60x;dq5*x_#@%)B!G??m5n-TE{A2pe1@)p;swzi#c-|2yf-!fFt= zw^U7CIqO$1!{14}mP*{{|NYjpCkv*PephArrFwcjTkS;OQo~VohsoOgk)wDs02 zM6@Z9)%fqU{=HW)t@OJp%P-Z_>)(Ql-&kiBbc0~GRLx45rBj-1xBvcwfBq9N;zgB~ zF-kx)L?ese=0@qa^doF=ja28Uto@DBj*H*h5||n!pc7jwsf*!nh_I`}tcu^`M(MYL zX{Fy)S$?UWzERq7@q61K*ezAF5@u8!Gku^1`_&siveoH^X z2G>Y+p32(aDDAlTy)A*MK>}K_wSvm*`_8+2zxB4^Z%rW>W9G&0aijEG!L-uvsw}@$ zPv0o*xcI%rF1$g=>r}G>Wa*Ve>@w>06LpP=5rwf7UHl$5O24HaVS{U=I!|ToZV;cRZ47_4-VwU^J#LhKE0|XLU6tjR>ggM$9T&g1hQA(S=@U^X6m<1-{puBq z6B=Wo#qV*W^jrE7Hn>Kr^HkRUMrp^z@2vxy#(4#;4_qs#iQ#XIWQ?Zh;`g{w`mJDE z>33C@U#h2Xly+SF-WvW^z%f;m82&~B!5B*^evcca-_nn;!8KByr?U1pN;@uoZw-G{ zg=+JG498)q0!q@OO8VJT%O7VN# zDE*dxgbl8d>O7UTzfsz8@q1fBOM?snm+Gs1!0QdUkFdctQk|!=_BTp9E`D!ItTFs8{VDVa z82-kH#mI^-evcca-wLLcephArrF!~CX~)IyZHYC8zokEg1_1}Y_|tEltgi8qfv-*q zDt?a}rQgzzu)#G_ou{()H%dD$esA5ZX#@nj>#K3t@V8F*SZgzi-{VHt>MQului|blC8>*5Ip?{EOe?M(MZoBW!SuROhLz{f*L&i{D!}Wek7a ze+{L>hQD>f$6A|F{2n(-zZFa?{jSRLOZD`P(vFMYTQ_A4f8Bo#b;Bk2VnEk%UcY+P zns@bO7UTzt-(|7y#9uaFpmzgseWxs#K8bgH#!1v`27ReYO16 z@VDM5Svz|Se}nm@`c+wesh-~OS8oJs3EZ+u8UEIeclDOD;jik14Ypr(p32%^YyP$X zOsK->NSWB^NCYbDx>x$E>U3Ud1vN-u_*-w3teriEzrp-c{i-a#R8MdC8w?=B-{8Mb zJzRBvMt3pa@Mb2^-0OvS!MVe%rDii%JNI~^oGB|By#udDu#n=r@_^-55!PT-)L|4xG)vG=!Sv#u?e}nm@`c+wesh-~OH<&?X-MezX2G>sYtKN^pk5jtF z*N$}cnX}=q>VyrpUv-|!+Fxt_dV5GGCa#{}Ymc*f)XyGodbMKw>ZM1#_p{KIvwroe zPfFI#D#PDkeyM&{mS3u;H~cmHRhPwvzqKP>edcWVt2$wW?N^Gp@RM!4l^OqC~ zs0rQ^#?7xjtSnsFl$0=Ro>ywhTv&!%{m|v=2mF1V}=?#Ake?tIa99iH8 zXZTw?($#0qhQF#4HrRgEc`9pvH2n2?!Etohyi9M8ZTK74XHjLjhQH|4*C!=wXO-b^ zFuzp4D$6g`(;NO8{#MDsxNbN6HT=~c$Ex#G)_&D_DrxZuPor|JiRQxR8MdCYy98P$rw$BhQEfthQD>rQ(5y==c%myhQEft(cm)lQVf3$ ze+_?S0axegW%;Fgdc$AC-_XezO@@ZQhQEftbO7UTUv-|!+8+&nyBM+@ z)M#S7_89&~^Op~|6vJP`UpJ?84a@RN_4Km*Qa!!luiPf0f{=tdXkoRMvjOU&G%Bh#77XhQEfthQG3CQzE@Azf@0e_-ptZE*vAs&G6Uo z*YH;fp2`}jI!|ToH~cmHjewZp7Gd~n_-ptpi#8?F%koS0^oGBNzv03$g4_&$4Sx-P zmEftYk*f1l)_%iZ!`}#q8Ez4VzlOhtzp`jkBE2lXR8MdCYxo;393#lh@YnFy@K*_* z${MLUPi5^l{5AZIfSBPHVfbtKYxpaRHYL)_@=NvfhQEft;leS3+zfvWe+_??;Hj*U zs`FIVe#2kG-w22qZV`sRhQEftvS?Exy)3^}PjC2Z_!}-9BgoD0*YMZyR|%fV8mT%@ zW$mvuec)!(YQ+S+psUUY1{~r#Ji!CWzs0 z@ZTqvtA@XZzlOg`@Kn}F)p;swf35lJ$FErWZ7dYK?g{_-+2d!wD)D~xQaT<#LFe!9 ze3b6bq6DS}2{0fHr3S%nshYCZNqN;cw~bqM_i6 zl`Vm(K?17tRMvjgc`9pvtsAr-v*POi9ol+-`q|@Wzbf&5^-?G*)Z{NyO6?p&z;161d!to??+hQG?O*zmXXz@Q2@!?M#$D7|AM z#$W8k$?{9}^oGBNzlOghoGBLwi)*Unqc?56wTm5y5gVX7Pi5^l{5AYF{8bLSC4*Jr zW>|K538i;T#Q2N7I9YzFp5E}+@YnFSgfry=VR22BeDtQRw|21uF=7K$=c%myhQEft zhQG>Tw`8y?+ziW3FQN2~i5P#e7bnXv)zcgP8vYvomT;zAAS|w_l8@fB_0}$SAVzF} z>O7UT-|*M)*YH<4?3N5xg_~j7=_QojF%jc0_TpsurFwe9U&CL+-xAK03xvfrRr1lB zw%*#s4#bEJP@Si;_8a~h{wl{`G}!snyg!u#Xq5)5!p*Sk^b$(%n27NgdvUV-Qa!!l zui>xZZ

?@%FH|rb<3~)7D$N*nt?a0jl#<)_%iZ!(YSS7y(*k(W-DWEIYk~(mN(% z{Ka0JEWcDwZ}@BYYxr9w1Y^8CEUu}NkKVNP)-HA+Mr?rUJe9TI@YnFy@Ha+)R#~(v z+ziW3FQN2~i5P#e7bnXv)zcgP8vYvoRtdovZx4%Ws^p_LZN0UN9f%PdpgK=w?Kk{2 z{5AZI5ujBTtqM28veQc_y<;NAU+l%n@=NvfhQEfthQC!pFvi=%;+iV?=uKO1?P3RF z#0IF&Q(5~Be+_>Pe`5q_l|`$<&9LnB5=!rwi18PDakBhUJ-y+t;jfDPMVp-uuIy)( zI8*NSu(+m5K6=yETf5kS7_kAW^HkP;!(YQ+!(ZjFTQXP`ZiZ#2mr#1gM2x@Ki<9M- z>gf%C4Sx-POE^<55Ej=|$wzP6dTSRu5F<7~b)L%FZ}@BYYxt`ic1s4U!p*Sk^b$(% zn27NgdvUV-Qa!!lui>xZZwY701;XN*D*5P5TW{@R2V%qqsLoSa`wf2$e+_??!*0o7 zRk#_JonAue9TPGBVlPgXU#h1!{5AYF{4L>3xj_Cjz0M&UaYro;I z;jg0nMYEmH&izw4fL3X+D%=dqPA{SKj)@q5u@@)HFV)i<{u=%o{#FUW7;g`YYpUd< zH*LMOiyep&8=yK*W$icoHT*UFjS-+#7Oe_5!?M#$D7|AM#$W8k$?{9}^oGBNzlOh6 zLNLbL!{VAM`RGksZ|!0SV#Efh&Qn?Y4Sx-P4S!<xZZ

?@%FH|rb<3~)7D$N*nt?a0jl#<*8W=a*ROAPs2BZDKYRS*qqi{J z68-9>O?=s(D@GK?QnV`E49iY0q4bW47=N)BC(AF@(;NN=TIv zf-x2v7S~kCM{nACYZp5ZBQ`*Fp32%E7k>+`e`Q^0O@EaoPEC!8AamLJDwn?OZx|yA zV<}n{ZiZ#2mr#1gM2x@Ki<9M->gf%CgNakty({-?aP2g>TK0*=0>Kyy4U224m|+b``_2^;I#w`*VpAg|QT^3OB>D(@QA5VO7UTzt;R620-;E z93}b_@ySZ}Tj?_5xVvsG!)5i=@>ll<7b6N|DOwe7hGnOhPO7UTzt;S%xBLUTZg)4XO}MPS zy8c@FykbOQEJdrr&9LnB5=!rwi18PDakBhUJ$=3LvDWsu0aWU=H4eDzt8o~2b8$Yz zSRfc&z;161d!to^m- zUA^V+W{}3%Emh53U%kUlm1BGr981xva5F4By@b*`CSv@>UYsnyR8L=Te5|!SB?grU zrFWoIU+n|VfgZ6yFvdc|;+iV?=uKO1?P3RF#0IF&Q(60K&AWQbzl0&Zv0JKLslM6= z41ZN&Xf<|v38i;T#Q2N7I9YzFp1$7rSZjMq7%CA;??9=(+6N4OtKpg|_v%esZ|!0S zV#Efh&Qn?YYt6fQ%fEyny|G)WU8%l}5I_XS_$oM-qE+E$Sax~|rFTrk_=~+bS$?UW zzTWs)YkM3FmGGo}pb@ZJs-xj=SeUEZeDtQRw|21uF=7K$=c%mywdP&D|)K38i;T#Q2N7I9YzFp1$7rSZjL>fAzy|sYZssVPURr^U<5O-rB_u#E1=0 zou{()*P3_rmcQX|iB@_B@-&ul*tiOFfXC>j)@q5u@@)HFV)l68y{fOy3}Kw#V>S2g>@_;X6Or`R2bo9K9bJSBarOO`%Nm zrmeSju>&z;161d!to^m-UA^V+emC0)VSiHTdByNoA(&3WR4<|Qj)@q5u@@)HFV)l6 z8y{tCzAw%YEgdmf0(8;!o)Ngzz=TlA)_w|21uF=7K$=c%mywdP&@ z6|btf{Y=>5>RN?a+hG)tg2ZN0UN9f%PdpgK=w?XNZU>L>r7Y`OUhci!%1QzL}^ zNu}o%!(YvSTUH$}q4bW47=N)BC(AF@)7K9lYi-E^TWt2?U3WABbKP~+ld}G`B9K7M z`tQ4_xO~zm%Pqr`_P^>)TW{@R2V%qqsLoSa`)f_T`pF+oICj5x-*U655yJkY(({VL zUpW0#!(S~~>?M@mF%jc0_TpsurF#1M;bW~WLHO?f-rJ49Tz4Jyq^y6fYyi}cH{bH_ zci#()a%sIIrIwHlsFdEc_0}$SAVzF}>O7UTzt+^Nll-wyhWseX&89{O`;$t~D~7)T z@Y3Ltmr#1gM2x@Ki<9M->gnr*kF_?#@K;aVdTPaRe{X7Fg2UZ$JjN>mQ=KXUSzf}L#ODMf#BF10r#mVwZ z_4M_?$6DKf{C)nedklXy#I2`F4u72%wa5GaML|B#<%|VuHzKS$Pi5`bo3`HCF?@ua zNa2EV!XiqM;g%Jg+ZazNKuOM5TGOAkW?wzz@63EQ{Pn;}!w5Hj;k45Ye`Ntz=jpwK z(mSRE!LmRw#bmFRWnT8HY;M)0H~g&!VuovjGxOQ-R|dBUfH7GBCCTIvPg-}&*?Y^) zpS}0)ah|W+3!cCi7M`--eCcubqC-92ryh4z%)728#<9U|p-*lE%6q#B0h$Mj5|2IVFbPeZ=gMfC11OWbMJN}x%TQSMYx!qiYz;PrlDuikGQnS+G~Rf3j5nWJ z#H?_8EA)R$$uOXxj~s7{bY`pi{jAb93RVu1qsS3vqpbWWO369T`2t(13WW))HdzYl zz~4PL{~qUW9)aTz#BsFMmcK<5BGf{MFJ?F~D~s`yfG^%Mah?svU*aa9H|$5i5_~b> zP2-o3H{KoN&8HSIhQC!(PwI9*@;dQSS(PP5_>SXmo*4dC#^M#^ExEuVP}$v;rt&ML z+V}!MNm9kSM~B2VT}{F#>BYS=pf6T(gk#jdE5#y8 z;%>D|rTB4y0?;uPjhld~&cnonKLhTn!KH~w>PJ}rcwWaMf8B1+q{G-}RPHjBbua@x|vCHn9LkL@tbX0;b28M*vBLorb>jJx}TAHb) z1%GQm(g`SEK`FK1Z&|N^y>MAfg}+wb9aoqkCNun<_^Du}jLzX-PL`*#=_~)hBS~3# zs|rr(<)%BveXl#dlCb;U3c~^XEyGhNzgCajaXCx^zIcldf13bGa&yjw;u4c54!G|s zW#tN3O)#beSVH5OUm1`6Qk90D{Q_?DFg$2+4J5goujY4WZFTFT_qN(yF7>q9V`?Tn zLd%-Tm@;%(Nh!Kas3%2xGJ$HLOHj>hd708|`Gji9FE7cWy(J}BUr9}@FG+jdDkb9q zwhYzN(pjb@tz7}?ZtW^lXIob|YVGNi)6&k?0QGgYqLvFeVg3`Ex&yTYn#!~cn*5dQ zg_Wxxk}^v4L4YR6Y7mkWDKP+vfo1urTx_~WGncV3^|F~I_0yRZYJkoZnm$0MyKgDR zuR1`>wE6%Y(-2wRqfZVI6ptVF*HCU%N5iA5uh7URhl@&QSDlnrPYtb>85~ope-*bf z{i_A2WpFiA@!n5l2FKJg_z6d?L!WFPT)j-KL#vA$S7{j>r?!DrN}4*jYJl1YS49oM z456t*P^R`F7^h)qFViTryEJta#%UO)3az+n6k1ENp_M>eN%*S`%^@lWZCKW42MeF> z&(9d*nlfag*a!_&o(?k^73lD!RbG%iLL8A%T8+jkIBK&|I|bkfiRngFP5FxjxpC>o z)=J{t`AisgR|><~B-daQm-7I-*&f^8_Sk(6eA`I3(2_PXrYiZ{%#gWUmGjqOdI`Q5 z7)MnXi{0dP!sAl<@V5-dWxsZ{bfVk{4hOynaXI)RtOMYSmwfn}UIG&;vlO_NkEulA zbT5?2jbk@_Q~fZZCLrd@aO~#q%5QhUS1=c`TapI+Eiu;}SOH%IaS^^4m=t^w*{h?P zhq#I)Eg=>0I!02TfG>ZMypGD?i@yf**I{b_eDPB4_?v@is1cavk>>w>=}5Kk^ZhY8n64mEsm^T;7BBZsz*9@dsUvUT{Vw&CMi zM~-jHp3*XMYFqZS=8<2uX3uOHKC3l*cFV}QZNulcj$Y6{`rEeLh3zAkw2xjoW#sbK z+?DO4SGAFArsl3~CD*r)-q=EJY0ur>X5Z01a#zdf-L2VsTSo408+o90bWYpI+zff7 zZS>LR++!^xkEe1^W=5V)ljmATUr6U(%#8dqoqM@?_+OdPS6hZ(Z_d5dHvCTO=zHzi z|F#W(&_41}dv4y;(FId-9aBd;r{%h)k1m~(TQ=41pF#(x<+4-l;b|iyQ$|Op4(HlY z19mH_*KR>A%eA7Gjzq2VX()JhVR4y_d^)#J^dMrQf zQ8x9VJlBJTsdMsE?kjNJoo`3oQJ99hsW9cneEaoX zre8>=UO=Y*#-8?TG6QusnT|T0OgYVNKb1~Foy=yOM5mv?xQ=GijwDkKCGDt#NgL`% zq!qO@BoAYuSx9&thCV=bLw=>1{{`wI$7LLDJu&EvWC3so%-9@6tDY=iX`C z_fAJ`)wjx)y)!rKoB6%o85{R=eY1Z$YQz5N8}v^7O5b$Umj-68-9O{g1FL?jf0fny zS3`ZGZ*|lsd$~Tbe6=#I-uo%ks=c2@&FuX=YI^V5sHweQ7W(S6zOPLi_y$-1MqGWH z3T-j1548*-;k>^o_$54eWr33aB2Ix|B{ zh59l>r~$6QAyo6w2&!$!7Amw271{<1E&XH4^cGUfSkuxxD!GI~mkkSSm^YgED2K{? z)Rp>RaqIhwTK==J<=w^2?<{J0vn&1DqUL{bwZ7ET{Ex2I7kXNr?{59uvgW@oX@0UN z{diaA(Iu%zx>FBzH~qOQ{a|@{9@LRMpfmB~#mW6TP>H=eti2Yb_FT}k$NZ*U z7hAh@BzIbv*s-H&+eL}3J5pOMOm4BD>3a(kn=D9vrz5d3{#=~ccv14(i{Tp`)`kn= ze>>po3t{~Z_(}(Sc>%230bl$G)|v-j_!!pcfX^?mKD`h=xd2x0fYlaQGZ(;&`7mWZ zw9m6zJD_EO)!YG@1<6bYDwXL-H!sX&xLSnTGmB8uniryGHZMZ0(%dmKvtae~f=^^R zK9lbFT=SwYW)`m9ylB1jqAzC_Z;=Vr-MxSFk{@Td4s2eE`V}la1D2g7bRH~4T?jp>OQ7dcSaP}0 z)zFRl11v#Z4?U=xVF~Is=)N6#?uMRwV96Xksmh{5%KInDS551_tF?9_> zR~8lseUybUjlf4jZ;!y+9W5WC7Bqj%m3|*Jko-4lDDjfe--i-^A4tA{>Q6o^^mKpn z>Hen2g&yf|dStNa0n|Y9KB2n?lBhelnr`V&-rV1Gt0eUt+hu#H?QH2fc|MgtqBTY}1?ETxgTt)MdN_RVcx{to}C1rhA#@8}b0`>E#Kv(0Y) z+OM>X{JR;&_@WTL74p}M;9&l8_{#H_!`H>%8GZY^ z;ENz$!QYv^M+qIXYX5Po_ML=s_?ya}CB*rg8op5I;?!_Sm#0UrNDp5vbVGXhhRpCS zLbqjxQFn1=?-%;B!{3(SCt8M|S|JYKG5$7ZUlw}3jK7)T54p1Qg}O4?t`dKnv;EE4 zL9U@}OLn*=JBn)05}`5vItWJiqMDY{L=UyP35W0EQRpCu;M9EDG>??{+x%V!!Z-8w z!sa(RQ?D*W_-0;N*z(VAhri8#>uPy!Y3A9U%oE+n-}GO4-282Npfhz}SBbyr+q#oC zFV5W9h5T)~wlj0hA~%227cF-4H+gmk!Z&$l2f{abS_i^6_4CD*@V9CIg^7I^IQ&KU zCU*NcIcrhk2MfdG@8{-E%-^<;5x&qoAK?qBd5QD_R8x8ZM{t^}S*R_&5H&T^ftrz7 zh(h=}{6z#Ke^G0tI}pL^rWPWCzn*UhiQ74RUHt7!Ughw&KY5YR`TfcB`PAINVZf3Za`1$LQbUkT)QK(7Jz2cXvjy8-N* zfZhV+HlViyxeF*RvArMie*!WG=z~BWg8W~AJqG#5fjkBDX`s(R{&^rT0{uJCmw~-v zv42~I*OK`+t-@Og`Yy2dfV>a+4}p9PbRN)!z!m}R0@e*=DX`^02OvKLY#8z*V2=XH z!KlqOEYxoUYB_;rHeeryT12_#Gp>&)*MEiHAb@(6a=k*hUZ7mhG59Nmr-dG)fO>?% zLj>kfc!0ot6z(N(7v;K5=tgS&k-)VS{y^Xw%5??dx`cB5mcj)D&ZBTHfwL&r=?qSz z@N)*IP&iTOXacAsC>&1UPzpaIa4>~~2pmAU_7mEhCifzVSvH_{6xxo!wl-`*xi+V; z8G-LGu5Sr#zyP&=-ufzq^%$&6VQmIqq_8H1H3)o$a;=`XR%0-gS?vT`DOU@HW&&x- zWeJT9oI9`YDUPF~GF+^ci%t`o0Ti_w7x|RXnn2bBdu^e0!Cn{0SGd?WxbolOV&4_o z4A^Esw-VYG*mgjFAhat`)b2ui1Kk_g4}}f{b|BE%LWco69O$uJ>;$1xfSm&LmqMol z`_-7}*??l$h*4JxT?Gtvt`?M*n&~>x$H$kJz)0$StisEq#sz8i{`kTMw}%un1en7y*4Z(u+)Yg z!i6ok!{#mcV*(#w^Cj>egLf&s&EO3iUc+{5!z&bCCh(%o^*0-y%fZux>v0B;5_p8d zLpIE1@PG~X5x6G@cQLqwz%2}JB5)&v8?X@zT}|OCY{7ZB3|nv>E+lXPh4XDVkHI+v z&LnV#4W|+Kr46T2IK_q&37lYC$K`n=J`x+TZT*bFY+_*xKG22(82pI9z6AEMVNV9T z6WFZ)vl#5eU`J|gM`0Tqwr22sY{UiFEN^X+hwn1jn9IVZyrFGqrn}>BM ze6hgW@#h(Q4%;zf&%6iyR8G4{czbAX-; z>^DGA7Xw9I2KE&|uNHDzFg9G&tze_DC%yydJz!%CegMdW7JG>I#N<&Rj{$uWY}7N5 zL%kq+;(r2t3E01Zpx)puIRBo|2T(xG12!M>i-0T!+6}aa_rU}uS}+*~8#M~H4ODc) zqZEdP1}F?LSSHlT0ELf|^9gUksP|C}-oX}3;dP;xMe}_DyI}&)2t7sMDF%-T{e?kE z*oaYoqHsTjyM=D2a694c7{@O=O%wKw%;jb@$=io=qID!$rsOb5-8|4U&hQA5M5j-(}S+g8}$$A35Uz&jd_wNG1 z4u79z@RVK7-#Ns>ehU%o=PAn7`--p-TLH6)5r-RmNXLaEZT& zU@w0Kd{@F>0pD@_74TibUji=)_^zD46dn}tUBO=hx0UhN=Iz+I+^ZNRMc349q{U^34C{Tu<*Cz!Q5dJN2(I^O?nqJSFDUkjKeya(a` zhQr@Yfua2TU&r$;`o97EW$64)(et_T7jv=8h1~wHy}nR2{QVIy)I|KnHaUK9hI#DL=dub00zyp1aP zzZLNJVFswVb;Vzscf+W1|Chin{DE#f6;qQ3=w?6s1caFb*-h;w! zoWIu2`Lg%V))s;9pTB|+k>dUH|KK>vc>aExj(`7b_{$OO_`im~wSNDsB!6+->>q&R zX6zL{@KAWos$~3E@ONzd7Y82b2nhdoY~)v6{}+cG_^3nh_%E~e8tb}XN5ZuOh3$p& zTk?N5rGWDDe{tOGO9Z|^xa;8fzn`MkCy56B@6r72op=0K+W);>jQrlsCoxbQ`9=TN zJMvqF|4Vv+`o4c2^M8j4WUGGv%=ydUIh?_O18Al9&v|%=@^P~{gpZry`{%p!a3>u% z{(D^>u4NOvfBvmAiJ|!Z+3o-OkN@%+45jhk{c!wu9DiY#JnV!mI1k%WXZ&{p{(iFn z|4YZledy_WHj|`Ai(r zC*=6Q&iiKyIQWYi_Kg2_Gd>fiLukeL?|A?BUt$igJWpf{42v>zxh{&zs~$C!(Z?GXBy4?t4Z?1x2$lQ46xK)8(fnt_->Usz!{6fg zuidb$xn7?D|f1UZy&MIJM_2kXKZZ59;jdlFj_`k;gb*=w3n##^yg){{DF!f1N4NVi&AP{a^cQ zK)(*`+d|IHSKjr%I{@2JY-GlhnK?m?1WVn7NbOf*?i|db00)}$! zf90(F9b5lL+vfZ*m=DF6A-6R-bO*Z*aF^DC6I|JCH? zKR@NS|5XL+f3LFPcN6Wuah46J(@}o@FXbyyQ0V+tvjK)O6Fe~{u=&{+keu`zbfrNS-$?a-sWE!{u=%o{_@^S-S^Mo^7r#%3beES zH=^-hU;lRtU|Rs)TFCAHI_rx;4+Y;JDE_f1&#Izh?Z`JN|30x+jQ^6;fp30=ItOy+f?YiS$i5hC_xVTH14G>?&Oh?*f8}@n!e4=) zD%}6-W5~Pre|GPGH2|Zy1s0xvB#uCG&Y;5`%!lysBcWxK4*{>>uh@nkcjA8^CGY?G z?SFMEp5sOM30|l_2u*(cx9I==3{Rq?{3N=A$FKk0lK{$D|GSF~yZDX&YUcmu@#r?! z`Gd9If5Y%M49{qxxVllArfc_sY4i1PCXSGxaXY5%JunAv}0QtN-c^RJZQFCy6R zm-1t9#?8O-@)zfgvO~(xAH?13o%4UmDg69j-2aN5K6d^uIgg+JYva(_g+MM@`TAcx zH-_B>~H)4nLJK`F7E%F$AdRs2iyYMd;Tv!bpuaN z?cxVdy7zzfp8xCK{|XwqpG+9{!6;i}Qb1#@{Kxrc`?VFF$|qYh2F! zXFRu!YZIaGiSvJpj{o8mX!2t&`ye2Ai003@sF?rkod3&Dus#XM zsY2L-Pak*kFAn~q@bpwJoCS@Gf5rGOy`Gc^D{cjcg#TJaa=YAm0|8>v5dX4c$ z{0fDC=J?^ixc@V5{)~I*j@$pac>eDkJjA+q{x9zTj4yG}d-x;ey3TX{U@(7wPR7pv zRhz$f1Ta6i8udwjb~XK+I0EkBXDn@%T)XKlf>QCNn; zGuXv3?45+4f7F2z$FSpI;D=dzfgI2S?I+)yokrJ3*iyK{1jm4O(ZsA zG59-n{vaNrDTL1NIRBTyRWxY+Ge7@#W%Hk9{ok!!`(N?-&*S;a_rFq?zogsnHw6A7 zf)T!kzlOhXPmZ7eYxv90ADn={h+v1mUlIro;O}>VB795y-Buvj@YnEHLH>fD{~JAj zA7}jB@<)W`<^e~3EAfAS#{hL%p6f!!-`kxhR44xLg#2BN!KxJS2w+r;P|f{cwlSbq z;{VPP3xSRQOQ47F6RbOIsD=M~J2U>T*Z)<5zt;dojUWFl_kZW`H<9EKU?}u|QI7xn zH^A4Bta|@f? zFM*{tcYclk`!xbzsonWUhQE^^|8@M|ei%XxzyN9pmZ2OMSl0ig@UqZ9aVmp4{;#wD zm3#bGCI8p(x7PEo41f9LD>MJg_EyVI{&JE{9nUg z@A-qOyni8 z8vph4e+_@d;-h+6e`NR@!}$yK^8VTI*YG#`^}ojdHU6&{{KX;Qh9Cbm{5AY_y?;In zD9UgDt6IE&ehc!b_kf~40EU_eY(C@{0a*;RQ^;BWyA*OL@Az+V`d>BXKO6oU{u=&{oxfoC z+mQS{N^HS_{_japz`@^NKpr=r#PMIl-{ShCO5Q&k{!V89E5qNo`0MO{rSSY`as<#L zL8tw%41W!O$M~Daod|2k-@gJwIr~4m_rI#d|Lq35hi?WvvH$Cy|LooW%I*IWcn8Jc zEtDAi9XI|<_zrG31YEoE-(QL?uuj392q_#dbSwoF9{!7}Wc_cj|LdH8RPXD54Sx-P zC+Yth{=SV{_t{X*{*z|@a~;fop2+`27dUVDYxrC1{h!D9`zFx0&H7)n{&$Ds{?8Sx z|J9GbS)4>@=07von9GXj{NMQJU%C1F3t+$G@1M*0`&%HWOTk860rv02=2tibT>1W2 zj{`$J4HV_@T!00aWGVze@SPhQBpCf5C75 zvw8nqck`bOe;fY&v*9mpzJZGS{j(1I+U~5-iQ2$Kt2XK51jLl>Tv&Osr5%^GeG&Be{>{1a*-OIe-s9P z0T=*H3T1#b^TUrObN*4N>wk^^JHh*B3NMhc^N;Y{m?v@FB!fqZ*f-bgzkw~-ZUT~q z9I6d$6mI_P+5Z`jK%!p&hWZk)bs%3nf5BcK>rS3lq)1nwY;7o)h!0#`r6aAaqV%@b14+`T2tl!QcF!z@7`_&p`fy zM^0Mw3BXe~Y#eoX7VPJM$oRi>A+SY2yMT4`DR*=k&;ekBU}u4j0LgLL)Iz!U-@wCv zQC*B{F%zf9P*_0VBg*BhI&;qdeT~4YHoStTrxJJpyDmI8RS5kC)T7R+lMLn(e*O{a zPQrCNp0^=${z9<-i}SBghYS6b0_tF)11O;O6WW`>UKDm?Tsu*&?FejZ^HZNun^V|~ z@biCB-t&L!g}?0UW9R?ozXK??snBNQ&i`dQ0{wx|t}8$PsFD5O_pu$@ppgH&uMK24=dDw!1^Zxle`N@p`&SWqh&+#fiJ6pluChFoZ zS(*Qv`xMxC_>nmOXeItHy1>Qrf5-ga?~B8K@%&$Oekb&Qe*^?|pm2fR{;zZX?}_~U zUw-(nv;QP6L2~?Gbbe762p9NLpeP*wMd9G@q{n}s6GOn08~-g^|BFk2-TP0DU;n#A zY=(pg##4~^`M+-dio<{T`M(rich3JU`oH}Bv%_BkPoWULLJyZtKB90hiohL0<@_aZ zl|AwKzm3A*+=Cp!xrc#13gmG>1m~XO2(Cl^B7*(+i=R3o_#qd6@1lT0{-SOcx`Dy< z6#l@t{LViz`~~tgF#M%UfCTP8S^4;{;qT`ee2&6r7<`(7vxv^{mmh}1ZU=(87s!1; z-RBRw=U+w4U&g&)!(U*B067%sVPNCZNipicjs=2GT_*uK1@gb(tG?OkK+fdPUgSJr z=ko_IdLfWYfL&(UR{{|e@7N#2#Jd~7@BE|DIh?ivfGbs-KRxzI^i<5%0OgGT@~1A!`HS-BuXS18x{$$d zDV!&S@1H&VO`Jlk==r-Q1=QyVeAWh;^+#dMe-7RMrAGjpaQ~}c0{JDdGr8#5LgxcL zAIQZ*mjk;J*zdV09==dd=P$V5KQAS)#Q*$X1^CP0K%BwA;70`Z!=;lIj{o8$1_I9b zZw&mc?fCBjz)+R>zxWz*Z2t{>4T-t{*l!_U9RI~3U_SU;8vpg@FPp>Hpl}u6Kj)u@ z{Ifty`(M2R4D~ADP7HVkyR-i@ZULSD2pDSo{?B$N*hT+$WGRfG^IL)c+s#&t|9(sW z<>&vtEQWyN@_&z}fI31*z4_0Ezs2!i!{0T5tO@qoLhAxqZ`}BASpKi!Z(YBCHvILj zKQjC^{Pm3gp1>#G`HcTsm5=`h^EdaDYyR^KK>zNn|Fvvq{qGxKzvW*4Yl-#07VQAG zaQysd>G5A{QJMdXQ%5PhL*Xslx{q+ZM7ePCGs>C&{4@pBV`Bc5pZ|LsP25VX8<^#q ze`WoSS(oOmiz)m@=se;7qVtP7lUipG_$7m1@b}O0`@in<2ipmtT7=RRPzhYR!Poo> zVW9f;zY+Vt6rEpG{rkTJC>c9{A+J>ug*5wQ?qi_+0 z3kh6+O_^Bd;v#k;+A)FCupQfc1m$GgI)T7(IXDJg2LeYhIGkEPE5K}K9pvKgz6ACu z`oB1Xe-^QJ&g0QBiS3!S9ZhUaVJq92|7>ki1OJz=JaX0_;rwUX1T4*$0FySLsPXH6 zYwiE?^}mz#fB6#NqW@d@{O5`NUvz#;{_jgbP;USC9iXU>pn!7x-^D;t&iO}0|95Qv z=Rp6r8`mG9^DF89I-|he@n2{D^UBvBP4@lsWc**7FGDFi|7a5en=q*O{3F9(!(TuC zItVV~FU2i5rc}EB26ukf6O*5d>yM)Ge+_>%=I>e**2q^n|Jm@@^ZvOS{`LStmCj!< z{FULaHNpJn75q))e+dPg0(~ZqJy`TS$e-`{zm|9YGrJ!0Hvqj6@;3vy4cHx!y9>xY zfHR+Q^3}+n`Q+!3hxz7LxyQj4>yPNOmRNse*(>~C$o-q|e`UYH#~z&huX6w8&tA^{ zR}-xNb@{)!oP`>;0X0CldMVdZp-u`YXZ>$!{P(}b^}nxCK>bUoPG^x`e{T=sF0UPXKie zwga@3TO&gD;J$1v>MRW04I=geMV?azKsz15*NeuM_+^dH=*z? zC~VA?-$dwpkVkEavhv#s?PBG3;kyU1Sr*#^1#BNK`a>?9{LF=`Cw~g;aIXAOTXb;R2s&v0tMge=!%s@n6(cT=WkXyVhbia?x8+33d+`eNgD(BzweS zk8;r`g`Tx=5j(}^D_StV#Gx-{;~3XwDegJiw-$Ukv#kE{2U5 zHG;ATI==*Rd6YGr=l<^?;aXo%XkfN$?UpzvP;UjO%{Jp5fOV(0#E z9-hU`uN?ok0DoZ?w&Oojn9E=evF^{qeb|l}+({F+72p<@_#-Z&BSrt0!d0Z`|K{OB z1{dVvH_Y*Wtuym*27}WW{F1`S3{Ilf2?aPVZylY7qtN*+z+rjE{{_eY{c!<)$Y9?* z?1L-V3$RC?uV63tfAg>vwYDh0<^^k$Jbag}82|kSvA$k__4DwRJgi4y-8^r~Yv*A+ zf2~gxU{wl!^PeZ+FS@|2=>O)YLw+@2D4hI^5@RQ1ZJ?-if#S2{SFr6`g$;nBoE_Y7 z^D7i?eudf+2x@B}+W^HAtak(^nr{Icu4udYeIUOd(EY{WFOL7>^nGUr|Dk;P0&ada zc7oT*kVpL-@~4VESd9PT5ODrHD4=lg7j+3A0xnzzg)5+N4HW(Wh3kcGgaYbj$ln5m zJNYqmI0u%ZaPU`*{}vvB!lOb@KmqlP(DP70y(IK6D4<@4{F_jC2lDT7=hyyNY(Ut- z5C64M==q|S^J9NY8e9&8eb9^Qhoz_?STYDb!_bYgp%W#X-_GLtBW%GaXZ&|;CJu!+ z8Sji=E%wJ_{x81D5!*04N#HS|hY6tOa#{BY-DSg_wsk9t!XLS;Yn}O5=mImqmpGR* z{xb3+8_pA2G5^ZuZ*Z^?yU#!3{x9~#yfcQ`74H8$-v8YkH)BBOS7;*}zD=R*{G*2E zFT&T$-&e%@XWVbpd39IDUqrC8{s`fVLim>P_a?p(M+p2y_@WTLljHB^6#g%taQKS| zcKG|X0!Q$A1&-h^VyDHp)8OzI5xiOySd05p@%HoL6tnhW#dC6 zZld~((DOi0FLEJ&xm^6E@A7j*2=1f$AD|x!;Z~~HlXY+e<5n^#B31#r`8&#BgpTpI zKL>rIur!DKP4wVCGB#g2iTuTf$j0ICM${Uce^r35G0xw035V}G6uyX$lEVM}Ea8iP zKSlWZ-&F|ApfJtmn{TvHzWG&}0B*j4`(Lf#@7S{=!WY#lUO`Xgub}g*^0qqxe-XYF z@)zN|E#USsT(h8n+7t39gf9x=i~0%VQ3&5dfFXQQ2w#W4#|!v|%ijkCeCGm(FX}NN zgzu9AzRwEyhRELmScXFWqDuUwV=tc#f7j2>hT)%D!-rZUhb4xONRA$r${yP^a(sH| zO%Wn>{>bbaX0eD96=jbM*@K*i%rQb~|b@XPiIr8 z$463+4L3bBl1AM(+I;Uw<{H+H`dz;LO4fFHehTW6!jy~i?dLPDbMx)zu(n^(cGPK% z>ldv3RN8hjn}RxywI55{j%HI(hf%Jdl6KTj$Q0BM>2%b-bQ)?;Iu*4Gor2n)w4t`9 zTwAfWEotlKq-8U^c~jE7iJh4>Fmso_={xq%+@WvAw!PC)TlP=iqHo%!{WCY|o3U~K zjE(xHf3tu3hP_k2+Ryc+f$3}aPy6h^^w0E9TYZ3Q)xjC4RR(5q^{Mdfz8p z`t0<+&ra`MQ)un!y{L7#mai-H<>|}6Jfrt(LL1KL-Ed~_MqJA`Ma}5jQfS+m{oBp# z+ev8F%)VJOd-o99cV_RttMvX@=%7`4QHP*b>pL8^YX31pC#>3k!m53zqLRa>rAB|1 z8aZ3&g4D)iuH?#IRZ=p0EvhMdv(T+gS=8-ZLwBdL_oT89qLSH%g&s?0 zAIl6s#g%jeM#t*^zgqjBX4jG|3~P<%Xn8u@ zFEp4L9?E2gQO(0SRO>Ji$~Od6Jo zk1tF=x}fEej?6>zTOM4Hd0<}i{T-S67Bt`0k-2k0%WaD?w{*1JxHxme!shF`GuL)y zuI@}) zCoM>xFt6#@dC6lwP95=a)8QYb4qK4?*}SGh<|PmQxaq)oi32`r+W({Eejg_HnP=_w zQDV=JVfPPHvpz`f@_uT^51V%QplRDhiLE$-TBu4b-{N! zt&Mu%o1L)XV%T6ItiK4p+5ul)2G}n6RPSjU~)=w`+ZJ6#vZItdpeK);$)8?+tGhJIWcW<5H+CI~T+Bv;= zmrUoZbmy*_u07I=_sn$eo95a--T9;D?w@41PKEAYK+l;%=fKi)VcBnB3F;E)L0t}A zs4JoCchG$`bYBBYu7f4l!_u3CZiOYNJ7CG3u=IXd`X^ZSAS{~;%O8Q|k3!!QupIR) z^rD`FWq*U^FT(P_L+?Lf8R`{Sj(P)@q27XJ?^wO>K|g99^v#FnJ4sN zPx^z=%TQ4_c=WCQd~SStp`~ zt>d{8N29XVVM4RB*6eKJC#V5yf3C#dLVFG*_8dy=I%v%rOzbGM{a^yM^?>#LfyCxQ zn+zm28A$$Lf8sm+iEj#R*q_+2jK3pCHZ>rB9r%7cl=^R$3*r0TP!q!Ut)Udc_g~qj ze++RUd{KWLN}`?`YC<7=QGXszp%A_${vv!Q^iCgf1V_wY zM6ghW{QZ2UTZ+GkVB{}qLjEFxk-w;YQi~n_B7#@S-%DYM!(W6i3gL@F_&WT(8J0Qx zMfiI8iwJi3`;0(v8Gm1btw-wc4GG(hEQ0!S8>04T|_z$FW!`iCV| zsH_DOYA69E4I}|o`V|9COkWa8a>p!h0@R*6ZTiUle{9MxuTV+0EI<}pRwgI0JM(e3 z)81sc`jfPOS-QKVrJ3%MdYZcfw4|k{r0&+PnyRz4$45&$TgTMhvecvH-7OOc(%sP1 z3F##jXc?pf)C=jNQvI+}l@gZ%<8B#*RDgo>cW%r=DhtVBuEhVhcRs&yBvl+Q$3GIh zOO)7N6c+6Qt7X^Whu4WWiG##t4-42A6M>MpWWBLrolOX`tw_Kr0)!CkCPFJCE^vSY zCqTaN4{+lU9D&3x2M%z9CHO}i+cWLqwX3J6W~yJiySiU>)%Z=GC*{{w->R8;^FICD zJw16Ahdm5VDK_*lNVJiML1gRo5zU6y>#?=AZq!E(wXNESVjtS7#dc<^mSEep zGqLULok_FZr+0Rr?te7WVZGZUV6n4rYq)7o^=(Dgzx5Pi!~I*49p4;Iv=3|!gFUg4 zX0nfdpFT=%)kmqGb#W4GQ?o5`5^VdxcEm}y#Y3~5!$KEAkl_40?qb@HV(1+c2L-A^=$()n^&3W@oqi8F@GhFb=vBl7{1WZ z`FrJypZ&vce+f)f5oz2$w*TJ_J47n@Utgzf^aUF+e4!D;7aA$>g-tPhp+O2?5J%t( zP%gmVKA`ZS_HWUbV(*K-ru9Bd<1cIli#x{+>5uN$MQ@57D75ugJ#yft=mlf`woogX z%lt*)3rKCD6X08%zXyV+cf@HT{+?dnEev14sKUhdcJR24+kWuaz!brqy&%7dV?8ZX zYuBE%TUh1BH)T{?^{e)YJ6GM1A6-@N;d6ko?FzBewCKZu%2wG8{P8AvI+FBn39%`} zmG}PsrFY>R*Lt$;?~BmtEmM1{e%V&HW@_(?pv?jiXx0csg0&6nP)YhV5Vcg(=S}w1 zM37?I;*&+ z?78{|#@h1DY4aXiiSE}|rT01T=pB>eKuc3Z-NcP+H~w}%N^Vg$lQ~VfZ%;u<$pcL` zLEc-v5yXa#jXU6tTD@UnK~O7ZE%*Y66_TdVw-cets1AkBo(}x3hj|>RK-Cz&(2$(P z?&T=??vQQZFPK2{MGgcQq%1rw2iSFH<-UVvvPZwzdH7^c=Y6g zd2}_Tr`|EnSCus0_{*Ob{t8#7JO~lyZ!!1+yJODAY>Yi|Bi_^n4eFr%%|}1pxbwY% zokw6v{8cJnz1k)4q?V{r0lP11iI!!oYkfpmAW^@_4(adJ;j6Bbi_^r4mCh+CU$Qf-xZ4x^VC82sSf%MuT@r3@S6Fnu%7wrEILbBn7<@7Te;54ay{~`BhSUU_UHRx zBG@}HdB%d=dsNO9)X&H&^{5KQLACx>`(S>xkIQZ5FY~vXYGldZa&uG`H8qck>({S6 zc%av+6rNI`=8=dQ#rz$WCH9K+X^Y#}t}=fu!Ig_ytfdv7y;sd?=$NzGQCHyVrgm+?F` z_4(D+-?e=rSiPPz*%20udf0JZ%#Z1~=4bwPA%B%mB>BS4vesEV|D-F(^Q>eanW^Lz zx9oh{SFfz8dFPtG8O8h^l^%^0TbaK$8O{7<{+7i)7j!5xlB&u4t(M6ClGJSFYV6Uq z)u@)U$WH{nwss$^?Yhupm9N~l#Mp<`xtJf*am~;C^`5_}*Yt3Ki8j|jm!ov$=o@4f zuFtJK;QH4Nm(3{V?B0FK4lD`07=C2+2!Yuxh z)NJKg?9utvqbki~EBY0&sLn0KV29PYm><(|&CmR0{!%m1raoO>(v|Ih+K7+;no-Q( zQF_!-dzimA8O{7<{@Uat^<321N!4fmS}S|wiTow0*~-;;tEvrOqZ!}QS!Te`|5Qty(YEkiNSP48fFypca)uYPM;DC z)Gs<@{xW}U@{xM3^h{FLR>(;3-M-9U`_T87q-HBGD`Jc5L{%{Jm-#Eqqp7=Tt3Ntw z-npi49vRv8M1TFvGhlZ97MG=twokxnXEdXjzoV$*IlF#T1v7t{zs%p%&m5@~c8!nJ z>&c!^%z6TR8+Yz7f2*Rfza%wVxmlFbje31;4J^x^p;;g7_$+pu7xQB}uKBa=_0zqc zwv|;j9fdYE?_ATTjwi9a^yT;7HOnm&qpIRyRr8ur%->Pe@tjv5zT2AlTO!1nzuiq? z*Z4@ip1k=aZ-0O<^S3Ja!o2>H)NJK>CHeoV(TKl7LQOU*=^`gHk+ zde7giI<63fl}v3$F@HzV#d8s5y7vOcM~t(WAJcKo&-`Wn+TR(j~Hh$Kc?fFpZUxDrDmc{eTs|}?d@Cn{9|4gSBSz& zrZ%IPzoRJPxhzeJDj4G<##zjd>A2=+{xW}U@{xM3vT}>66Yb|Se;bWTn5$0eFGHRT2Vw}bNn2u|H<}dS?nu#{`%gs^I6(*e|FaQ17jA(w1DmR`f)Y5sH zcPqw6jI)>@({atu{AK>yuUdP$(DsX-#aCNyJX5HZdtEGyj~Hh$Kc?fF zpZUxDrDmc{{jxI5%INc8qF#$kmcVdNeEpxi`Ytfh<{F9<%`b(r;+amT-0Nate8f15 z`7s^W{LEkGuT4Hu&sA2Q0r&=Ye;mR$Ca;r!g|^8)7v}Pdp2b&NRy|e8f15 z`7s^W{LEkGFEtZw>KD&5E2GbWbz;|53cJ8Wn`bZ*K z9JKOv?3^p+uU`x;vf3i?T-DejRWQa!jI)>@({atu{AK=9Gts8LSs)n`s5PDj1`uN?zYsohUmtmz zcg^Y1>;BffRj|q?waxr(mcZ`6=4mT*hphoz+My<>*Tek9q5;SYmT=71$?I%L+Uob& zkDPW#&%FHdy+0PczWDQ2!76j1T^D=O!~A9bmdi@!FU7t=_Phk}-Fxf%3CDbyzYcuQ z5w^E5GkyHTN1prc74efn5&q0f_x1^K;)KrGv*N6ny?yTL?S;AdTVprhygjzGyf{C5 zbA0UUi?cW9$KiHv{H5FTi%Sc)zWM4apAcWma+??oSFrYrZ+zt#Od0;*E0`2#GOS>1 zZgzS0rKS1Bh1nYm-TQ{?g?6gNHIX zadCWN;{2tH=O!jze))2^%KyCauk}$_VfZsM-4iZ0ex|mcIKTZX_&xmjWVoU2cKGua zE_&=+Z{FzC{lVw{`QZW33V%M;ra93gGCUive|vfU*5cg!VyCeBqK6Ntbz8|KeDExf zcdDu{`ybvCdb6g^uWNUBeiFHKNIWrN3G`ne}pF_ jyiMMP_Xylz-0|@L`aSu3xOrx}KYWQqKl}^(%f9~yEo_a+ literal 0 HcmV?d00001 diff --git a/test/utils.py b/test/utils.py index 14e7db0f..1e2a462f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -312,3 +312,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: }, }, ) + +AV1_VIDEO = TestVideo( + filename="av1_video.mkv", + default_stream_index=0, + # This metadata is extracted manually. + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json + stream_infos={ + 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3), + }, + frames={ + 0: { + 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000), + }, + }, +) From b4825a76546bf7a845d09663a01a281197eedebc Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Thu, 9 Jan 2025 16:48:45 +0000 Subject: [PATCH 02/11] test: Remove av1 cuda test --- test/decoders/test_video_decoder.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 14d78927..10ecc712 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -416,15 +416,16 @@ def test_get_frames_at_fails(self, device): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_at_av1(self, device): - decoder = VideoDecoder(AV1_VIDEO.path, device=device) + def test_get_frame_at_av1(self): + # We don't parametrize with CUDA because the current GPUs on CI do not + # support AV1: + decoder = VideoDecoder(AV1_VIDEO.path, device="cpu") ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10) ref_frame_info11 = AV1_VIDEO.get_frame_info(10) decoded_frame11 = decoder.get_frame_at(10) assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds - assert_frames_equal(decoded_frame11.data, ref_frame11.to(device=device)) + assert_frames_equal(decoded_frame11.data, ref_frame11.to(device="cpu")) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): From 7e6bc92365df895ac35dffaa5107a7fd92f83460 Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Thu, 9 Jan 2025 16:49:46 +0000 Subject: [PATCH 03/11] fix: Use AVCodecPtr in forceCudaCodec --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 +- src/torchcodec/decoders/_core/DeviceInterface.h | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 94a56c1b..354f8b7b 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -37,7 +37,7 @@ void releaseContextOnCuda( void forceCudaCodec( const torch::Device& device, - const AVCodec** codec, + AVCodecPtr* codec, const AVCodecID& codecId) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 192ed544..0a2c306c 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -259,7 +259,7 @@ void convertAVFrameToDecodedOutputOnCuda( // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 void forceCudaCodec( const torch::Device& device, - const AVCodec** codec, + AVCodecPtr* codec, const AVCodecID& codecId) { if (device.type() != torch::kCUDA) { return; diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index ba3992d9..cd37ec16 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -10,6 +10,7 @@ #include #include #include +#include "FFMPEGCommon.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" extern "C" { @@ -45,7 +46,7 @@ void releaseContextOnCuda( void forceCudaCodec( const torch::Device& device, - const AVCodec** codec, + AVCodecPtr* codec, const AVCodecID& codecId); } // namespace facebook::torchcodec From a29287c61528759e02e17c408e3f4ccc17b086f4 Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Fri, 10 Jan 2025 10:37:58 +0000 Subject: [PATCH 04/11] fix: Change findCudaCodec signature --- .../decoders/_core/CPUOnlyDevice.cpp | 3 +-- src/torchcodec/decoders/_core/CudaDevice.cpp | 21 ++++++++----------- .../decoders/_core/DeviceInterface.h | 3 +-- .../decoders/_core/VideoDecoder.cpp | 5 +++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 354f8b7b..5feea5d2 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -35,9 +35,8 @@ void releaseContextOnCuda( throwUnsupportedDeviceError(device); } -void forceCudaCodec( +std::optional forceCudaCodec( const torch::Device& device, - AVCodecPtr* codec, const AVCodecID& codecId) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 0a2c306c..1eeb6422 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -257,19 +257,18 @@ void convertAVFrameToDecodedOutputOnCuda( } // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 -void forceCudaCodec( +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional forceCudaCodec( const torch::Device& device, - AVCodecPtr* codec, const AVCodecID& codecId) { - if (device.type() != torch::kCUDA) { - return; - } + throwErrorIfNonCudaDevice(device); - const AVCodec* c; + AVCodecPtr c; void* i = NULL; - bool found = false; - while (!found && (c = av_codec_iterate(&i))) { + while ((c = av_codec_iterate(&i))) { const AVCodecHWConfig* config; if (c->id != codecId || !av_codec_is_decoder(c)) { @@ -278,14 +277,12 @@ void forceCudaCodec( for (int j = 0; config = avcodec_get_hw_config(c, j); j++) { if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { - found = true; + return c; } } } - if (found) { - *codec = c; - } + return std::nullopt; } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index cd37ec16..77729111 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -44,9 +44,8 @@ void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); -void forceCudaCodec( +std::optional forceCudaCodec( const torch::Device& device, - AVCodecPtr* codec, const AVCodecID& codecId); } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 92b1af50..a80d9179 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -458,8 +458,9 @@ void VideoDecoder::addVideoStreamDecoder( } if (options.device.type() == torch::kCUDA) { - forceCudaCodec( - options.device, &codec, streamInfo.stream->codecpar->codec_id); + codec = + forceCudaCodec(options.device, streamInfo.stream->codecpar->codec_id) + .value_or(codec); } AVCodecContext* codecContext = avcodec_alloc_context3(codec); From 5e82e4805343734c7ff77b1a0112b4df6f66cf8d Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Fri, 10 Jan 2025 10:40:19 +0000 Subject: [PATCH 05/11] test: Follow frame index convention in AV1 test --- test/decoders/test_video_decoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 10ecc712..cddf7fec 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -420,12 +420,12 @@ def test_get_frame_at_av1(self): # We don't parametrize with CUDA because the current GPUs on CI do not # support AV1: decoder = VideoDecoder(AV1_VIDEO.path, device="cpu") - ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10) - ref_frame_info11 = AV1_VIDEO.get_frame_info(10) - decoded_frame11 = decoder.get_frame_at(10) - assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds - assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds - assert_frames_equal(decoded_frame11.data, ref_frame11.to(device="cpu")) + ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10) + ref_frame_info10 = AV1_VIDEO.get_frame_info(10) + decoded_frame10 = decoder.get_frame_at(10) + assert decoded_frame10.duration_seconds == ref_frame_info10.duration_seconds + assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds + assert_frames_equal(decoded_frame10.data, ref_frame10.to(device="cpu")) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): From a9fa4bbed61be3151cd60121eb5ae21d7d3ab18a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 10:07:44 -0800 Subject: [PATCH 06/11] Make CUDA Linux test runner linux.g5.4xlarge.nvidia.gpu --- .github/workflows/linux_cuda_wheel.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index c2248b88..9bfbed5b 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -56,7 +56,7 @@ jobs: build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation" install-and-test: - runs-on: linux.4xlarge.nvidia.gpu + runs-on: linux.g5.4xlarge.nvidia.gpu strategy: fail-fast: false matrix: From d49fde543e9cbd7dff543a5eeca66a3a8b26e7fd Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 11:39:12 -0800 Subject: [PATCH 07/11] Re-enable AV1 test on CUDA --- test/decoders/test_video_decoder.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index cddf7fec..db5cc7b7 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -416,10 +416,9 @@ def test_get_frames_at_fails(self, device): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) - def test_get_frame_at_av1(self): - # We don't parametrize with CUDA because the current GPUs on CI do not - # support AV1: - decoder = VideoDecoder(AV1_VIDEO.path, device="cpu") + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at_av1(self, device): + decoder = VideoDecoder(AV1_VIDEO.path, device=device) ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10) ref_frame_info10 = AV1_VIDEO.get_frame_info(10) decoded_frame10 = decoder.get_frame_at(10) From 25437b0229c5b349a3b17925aa79d028a33d2899 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 12:20:17 -0800 Subject: [PATCH 08/11] Make sure device is general in test --- test/decoders/test_video_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index db5cc7b7..2fcd324d 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -424,7 +424,7 @@ def test_get_frame_at_av1(self, device): decoded_frame10 = decoder.get_frame_at(10) assert decoded_frame10.duration_seconds == ref_frame_info10.duration_seconds assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds - assert_frames_equal(decoded_frame10.data, ref_frame10.to(device="cpu")) + assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): From c64b8337f7b5006f79f3eee5506df72a82448662 Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Mon, 13 Jan 2025 08:55:16 +0000 Subject: [PATCH 09/11] fix: Move AVCodecPtr declaration to while loop in forceCudaCodec --- src/torchcodec/decoders/_core/CudaDevice.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 1eeb6422..8975e526 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -265,10 +265,9 @@ std::optional forceCudaCodec( const AVCodecID& codecId) { throwErrorIfNonCudaDevice(device); - AVCodecPtr c; void* i = NULL; - while ((c = av_codec_iterate(&i))) { + while ((AVCodecPtr c = av_codec_iterate(&i))) { const AVCodecHWConfig* config; if (c->id != codecId || !av_codec_is_decoder(c)) { From 9a9b50f3050cdf57466900bc61cdbf7721594b89 Mon Sep 17 00:00:00 2001 From: hugo-ijw Date: Mon, 13 Jan 2025 08:56:11 +0000 Subject: [PATCH 10/11] fix: Rename forceCudaCodec to findCudaCodec --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 +- src/torchcodec/decoders/_core/DeviceInterface.h | 2 +- src/torchcodec/decoders/_core/VideoDecoder.cpp | 5 ++--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 5feea5d2..7d058130 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -35,7 +35,7 @@ void releaseContextOnCuda( throwUnsupportedDeviceError(device); } -std::optional forceCudaCodec( +std::optional findCudaCodec( const torch::Device& device, const AVCodecID& codecId) { throwUnsupportedDeviceError(device); diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 8975e526..b002c17d 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -260,7 +260,7 @@ void convertAVFrameToDecodedOutputOnCuda( // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional forceCudaCodec( +std::optional findCudaCodec( const torch::Device& device, const AVCodecID& codecId) { throwErrorIfNonCudaDevice(device); diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 77729111..289308cb 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -44,7 +44,7 @@ void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); -std::optional forceCudaCodec( +std::optional findCudaCodec( const torch::Device& device, const AVCodecID& codecId); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 40cc7fed..737cf478 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -463,9 +463,8 @@ void VideoDecoder::addVideoStreamDecoder( } if (options.device.type() == torch::kCUDA) { - codec = - forceCudaCodec(options.device, streamInfo.stream->codecpar->codec_id) - .value_or(codec); + codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id) + .value_or(codec); } AVCodecContext* codecContext = avcodec_alloc_context3(codec); From e453891c9d61b2d09b18384ea6cbc5340ee10fd2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 13 Jan 2025 03:32:40 -0800 Subject: [PATCH 11/11] Undo my incorrect suggestion --- src/torchcodec/decoders/_core/CudaDevice.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index b002c17d..69fef471 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -267,7 +267,8 @@ std::optional findCudaCodec( void* i = NULL; - while ((AVCodecPtr c = av_codec_iterate(&i))) { + AVCodecPtr c; + while (c = av_codec_iterate(&i)) { const AVCodecHWConfig* config; if (c->id != codecId || !av_codec_is_decoder(c)) {