From 192cf312201657d0dd820c1a3e5493bf8d01e327 Mon Sep 17 00:00:00 2001 From: Chris Abraham Date: Tue, 29 Oct 2024 11:53:52 -0400 Subject: [PATCH 1/2] Add blog "Triton Kernel Compilation Stages" Signed-off-by: Chris Abraham --- ...-10-29-triton-kernel-compilation-stages.md | 205 ++++++++++++++++++ .../triton-kernel-compilation-stages.jpg | Bin 0 -> 111770 bytes 2 files changed, 205 insertions(+) create mode 100644 _posts/2024-10-29-triton-kernel-compilation-stages.md create mode 100644 assets/images/triton-kernel-compilation-stages.jpg diff --git a/_posts/2024-10-29-triton-kernel-compilation-stages.md b/_posts/2024-10-29-triton-kernel-compilation-stages.md new file mode 100644 index 000000000000..10b0e3d88785 --- /dev/null +++ b/_posts/2024-10-29-triton-kernel-compilation-stages.md @@ -0,0 +1,205 @@ +--- +layout: blog_detail +title: "Triton Kernel Compilation Stages" +author: Sara Kokkila-Schumacher*, Brian Vaughan*, Raghu Ganti*, and Less Wright+ (*IBM Research, +Meta) +--- + +The Triton open-source programming language and compiler offers a high-level, python-based approach to create efficient GPU code. In this blog, we highlight the underlying details of how a triton program is compiled and the intermediate representations. For an introduction to Triton, we refer readers to this [blog](https://openai.com/index/triton/). + + +## Triton Language and Compilation + +The Triton programming language supports different types of modern GPUs and follows a blocked programming approach. As an example, we will follow the [Triton vector add tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/01-vector-add.py) with minor modifications. The vector addition kernel and helper function is defined as: + + +``` +import torch +import triton +import triton.language as tl + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + triton_kernel=add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + torch.cuda.synchronize() + + # Save compilation stages - some of the stages identified here are specific to NVIDIA devices: + with open('triton_IR.txt', 'w') as f: + print(triton_kernel.asm['ttir'], file=f) + with open('triton_TTGIR.txt', 'w') as f: + print(triton_kernel.asm['ttgir'], file=f) + with open('triton_LLVMIR.txt', 'w') as f: + print(triton_kernel.asm['llir'], file=f) + with open('triton_PTX.ptx', 'w') as f: + print(triton_kernel.asm['ptx'], file=f) + with open('triton_cubin.txt', 'w') as f: + print(triton_kernel.asm['cubin'], file=f) + + return output + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') +output_torch = x + y +output_triton = add(x, y) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') +``` + + +The Triton vector add kernel includes the `@triton.jit` decorator. The Triton compiler will compile functions marked by `@triton.jit`, which lowers the function through multiple compilation stages. The helper function `add` allocates the output tensor, computes the appropriate GPU grid size, and additionally saves the intermediate compilation stages. + +Focusing on the compilation process, the Triton kernel is lowered to device specific assembly through a series of stages outlined in the following figure. + + + +![compilation process](/assets/images/triton-kernel-compilation-stages.jpg){:style="width:100%; max-width: 500px; margin-left: auto; margin-right: auto; display: block"} + + + +The kernel is compiled by first walking the abstract syntax tree (AST) of the decorated python function to create the Triton Intermediate Representation (Triton-IR). The Triton-IR is an unoptimized, machine independent intermediate representation. It introduces tile-level programming requirements and is based on the open-source LLVM compiler project. Next the Triton compiler optimizes and converts the Triton-IR into the stages Triton-GPU IR (Triton-TTGIR) and then LLVM-IR. Both the Triton-IR and Triton-GPUIR representations are written as MLIR dialects, where MLIR is a subproject of LLVM that aims to improve compilation for heterogeneous hardware. + +For the Triton vector add tutorial kernel, the example Triton IR snippet is: + + +``` +module { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0)) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5) + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6) + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6) + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc7) + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc7) + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc9) + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> loc(#loc10) + %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11) + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc12) + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc12) + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> loc(#loc13) + tt.return loc(#loc14) + } loc(#loc) +} loc(#loc) +``` + + +Notice that the main functions in the Triton kernel are now represented as: + + + + + + + + + + + + + + + + + + + + + + + +
Triton kernel + Triton IR +
x = tl.load(x_ptr + offsets, mask=mask) + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8) +
y = tl.load(y_ptr + offsets, mask=mask) + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10) +
output = x + y + %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11) +
tl.store(output_ptr + offsets, output, mask=mask) + tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13) +
+ + +At the Triton IR stage, the `%arg0: !tt.ptr<f32>` and the following tensor references show that the intermediate representation is already specialized by the data type. + +We ran this example on a Tesla V100-SXM2-32GB GPU with CUDA Version 12.2, Python version 3.11.9, and PyTorch 2.4.1 with the default version of Triton that is installed with PyTorch. On this device, the simple vector addition has the following Triton GPU IR snippet with lines omitted for clarity: + + +``` +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:70", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} + ⋮ + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc8) + ⋮ + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc10) + %13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11) + ⋮ + tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc13) + ⋮ + } loc(#loc) +} loc(#loc) +``` + + +At this stage, some of the hardware specific information is included. For example, the compute capability is included along with details on how the tensors are distributed to cores and warps or for AMD GPUs on wavefronts. In this example, the tensors are represented as a `#blocked` layout. In this encoding, each warp owns a contiguous portion of the tensor. Currently, other possible memory optimizations include layouts such as `slice` (restructures and distributes a tensor along a dimension), `dot_op`(optimized layout for block matrix product), `shared`(indicates GPU shared memory), `nvidia_mma` (produced by NVIDIA tensor cores), `amd_mfma` (produced by AMD MFMA matrix core), and `amd_wmma` (produced by AMD WMMA matrix core). As announced at the recent Triton conference, this layout representation will transition to a new linear layout to unify layouts within and across backends. The stage from Triton-GPUIR to LLVM-IR converts the Triton-GPUIR to LLVM's representation. At this time, Triton has third-party backend support for NVIDIA and AMD devices, but other device support is under active development by the open-source community. + +A small subset of the LLVM-IR vector add arguments shown below for illustration: + + +``` + %19 = extractvalue { i32, i32, i32, i32 } %18, 0, !dbg !16 + %39 = extractvalue { i32, i32, i32, i32 } %38, 0, !dbg !18 + %23 = bitcast i32 %19 to float, !dbg !16 + %43 = bitcast i32 %39 to float, !dbg !18 + %56 = fadd float %23, %43, !dbg !19 +``` + + +After some pointer arithmetic and an inline assembly call to retrieve the data from global memory, the vector elements are extracted and cast to the correct type. Finally they are added together and later written to global memory through an inline assembly expression. + +The final stages of the Triton compilation process lower the LLVM-IR to a device specific binary. For the example vector add, on an NVIDIA GPU, the next intermediate is PTX (Parallel Thread Execution). The low-level PTX syntax specifies the execution at the thread level of NVIDIA devices, starting with the CUDA 1.0 release. For an in-depth guide on PTX, see [NVIDIA's documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#). In the vector add, the kernel parameters are passed from the host to the kernel, addresses are assigned and `mov` instructions facilitate the thread-level data access, ultimately representing the element addition calls with `add.f32` such as the example below: + + +``` + add.f32 %f17, %f1, %f9// add type float32, output register, input register for x, input register for y +``` + + +The Triton compiler orchestrates the final stage with different hardware backends managing how the assembly code is compiled into binary. The Triton kernel is now ready for use. + + +## Summary + +Triton provides a high-level abstraction to program and compile kernels for different types of hardware. In this post, we highlight the different stages of the Triton code representations and Triton compiler. For details on including custom Triton kernels or accelerating different workloads with Triton kernels, check out the [PyTorch Triton tutorial](https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html), the blog posts on [Triton GPTQ kernels](https://pytorch.org/blog/accelerating-triton), [Llama3 FP8 Inference with Triton](https://pytorch.org/blog/accelerating-llama3/), and [CUDA-Free Inference for LLMs](https://pytorch.org/blog/cuda-free-inference-for-llms/), or the [PyTorch 2.2 Section on Triton code generation](https://pytorch.org/assets/pytorch2-2.pdf). \ No newline at end of file diff --git a/assets/images/triton-kernel-compilation-stages.jpg b/assets/images/triton-kernel-compilation-stages.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c65829b5181f4659a3144628cc0b6037d9241838 GIT binary patch literal 111770 zcmc$`2UwHYwl^L?MNz;4A|PNPQ~?D-6|f92fC*9~U8M;adO``s28>h0iLUSJJUI{N9y6o7ZkfMZ7?H>IU-GJg6!)}0eaod2R_kLX{Y!$H6sz^N0= zCz((1@m=M-3cP+@2zZT8NJv-=cSmQR8&}4OypNe_`0yL)Da;eAucK`BrPr_ zCM^t<5@I;T!oqgpEIa$z^P-%boWf#~QX-;~ayNw)rLTd+MWm%9L?o|B2un(Uzd&cH~06b z&XXJ$r9tu@FQdQacgV`!(>HpR`nhVH^SZ%P&({UjGC!r&ToSxL2DE;&yrQ5^q1o69 zi=-a`a54P){DJ)2FCzof5kuhIQ82^tzvVJ8p8E~-$pz`d0l=9fof*$Dodc)=%I^G~ zfd4lynjpEN@1Au!hB&I}tf-<^Zip>e*#q&@Y3V_@c%V@bIx-!#|9*JQe0`x3Y+a!E z5C@6D3Q`qCGN>4Pt{N0rc5sX;K2WPE-p`-eAUIJ14)ht;nGAK>3`1QTq~-AuBFsW3 z6z-ElR-)ev)H_nOA~*X;m9mKe^)n8(2^X4X0&8nqb+bh829_Lj-hI=r5mlqM#b#+Z zc?hsuwfaJ((BvjIG-h%Tz}YfoxZoim7vn!UqFAR2TjzcF_=~UngVBsc*O+MY_*a!| z;!!fUHH;p$02o4@jC&moBHP8?tlrc0*yq;+6@6ODa4)NklBf#Ap3djA9eIZR7!}S< zE*bTz%%tR~DJg~DW04Lbz8wN=+134PcI26Ui^E&FlEAa}wJ~O*ki`#Lp3?RF?IiY9 zwO&%U0=5VBVEw&wR`eaf)A*0LSxRJPt((V-B&HKj8@RGH!&qQ_XisL^bkA!;W~6kK7oVym97`}DX-xW08w*5 zr*_;w6BCRY8ZvJwzxq>YvS_+pOJ*lU!0m5>K>l;EnOC+NUeiyp*hnbfH zHAdC)Lq1o~(py?byDZx@*U*Wy0=-|6|BQ5m%=|m@0WW9N*F|j;)$0>ned_t1?KSN+ zxs#&au@z~>Q&2{k!phpqy5y!=!G=0**X+|EcxMU+4`ui_sfb^Vnxy^#IJVFG;-HZ~ zbT=~L5OA3e;`&Xq|H|yLKl0-!uYWHO)zw*1o%Gh%imB#)}5Jm zMSR5}gj_SL@of9eb;w2Srcb+y zq~nV7L5Wws5CaDe0fUk<85KOi*GKNEW~VPq$r(TDV*77~2mYGVM~45dLfFVwM1=Zp z-ql*B?g$B6L<@JqHrJh*ep5`~4>inNq8z~>DH)zfD7ugU~8f#r(tyOJ*+aIOtapnZe1rnBXC-VsL z;jf5w1fCdL;SF8RrL==Y)}-6rH!mp_lwvRL41g0NeuY&1Z;*d!AbGM@?4Iwo+aZ90 z+o-LqOmh|Jzu#fywyASGC{;OAU9}KpXP$7gV;XSwOV^D~yvQWMJ8vs`=3rNDeKk%fPgx_6-@7KzyNVMuG7WkvIeuE$1JP@|l`>l$(_tI5bWm?>B2MVS1 zO8@x7sT!!4B?Z{F?n<54nU$W%wo_96c1&Wd_25MI51sh}y=Oh?ILR6JIe8vE{fH>F zKgvV=av`|EIQ6%>Io0K8fYz3q+L8I)JNZEQH$GP?ag5C772~H5V9=ULg+2LlyuRv# z2?5;a)xJn6>4uqE6wcV=pa8+!QYuofA{8aKy>b3Fr&(urD-7bUTa2W^VE5I1Asn_oC&tQhb`dj`AJuB~ z@kBd8Riiswu`U1%nenyU_BpUKY%m;3kY*VH!;pf_@ezhII{1~0Lx5S@2I__q2sbV& zlpgOF@1{-Uwg^3mQt+& z$JYui;|1!u^Lsu$K%S&)a;9LPgFaDeA|C!g@+!eZUQwp$w zm^4r~5a~@SdzK4gtC!xHT^A#{tl9e1?JX=x@eM3r{BqQYbi>p!1?BSk-OmauRoq_h z+cwpSt1-03cboKl$o4fdjd~z(Ni89*fme;YFxJ5!51bLL8}DZzu40vcxl+LAp<^P+ zx;3%vRDfnnO;xvtROTS|L1g4Wq+SY$B@Uz44-_6ESGJ(pq;UNk<|#3DEp)FB$V~S< zt$JBg2W61H|3$w<>!rNj(yk7lb8*RHPIK=VLjN-v@sFgOl%#+LIv>GLR&UrP1W|@U z&3RoiM&eV9=HGU_>fW(Go|TV=1SsfaoS(BYun>RwH8$tbyq%OxeMzf+;*Tw2UD3#~ z3NKsEn$&9@?MR^{&#ZAHv_j`hfSHx8-%>mfsn%b?{=gKY1IfEj^k5g-bLo`~xaMor zG;XJ6@yHx>Mu<=2mR_iFs5Tkcjv9>|J^AP#srzs8n8Gf<09=-Pui;dVqq!;88uMUM z!6V3B<6@KC_&_U=lhE^Uf-j>~Dk$!>@kFmhGwl#?p6`~wQ?69%UTcf=L+d;LrUw5c zxyHtDR?WBem9|{D>3icGw(c6 z4_-mcL@&Ebw-SKbyNy(@u_>lJzVMofEk%fX;nRd@Ka_d1<;U)i{8?CGjjZH_m|_J& zfVp6F0nNVY^{joiO!dU}ftG@z@F^^}Y_BAw?c3v`z|YeDyxQ2OIMI?pA2PODK7SNt zJ~&4kXS8HVfbCR*DC^S*mfh09-5#n}(LxEfr|C3g@?pJ=r_HRe=sngIbF)yFknl6b z%n6*4RinLd{adT?R*(4K`79SVOG?X!)1)rw&Z+`~qOa>(W~sp%NUyV&H}Ym5WNcg>Y7_YlFQzbz>=v zPclIVXbQp0@#!@Th0{mNyUT#sW_>zjP$&$qPmR~99up6!{T|R-1bS0v+mZdYbQl^T zJD{a(=98_wbz$sV{8Ref&tESgZzx9@gipYyU}{M>g}5ypUbHNiuFql%u6W&cU;5iqvBmlv z-dTO{b~IBDo0N5vuFN;1JW^Iw6Hd#j1^IFgFpIN{5U$tqroLEUhFrgJ|1@M3rGqa- zC#5KuJ9Aws{aA0QqdUvgu0hL{9U=%2N|ok`RB9nU05M#&u*Wr3^HK1MY!M9WFc%0D zLye5U<>liRcgLyt^0tV)T?Gco*0pf9`>~C96X$BVVT5z9qxdWEn$xFgif^=N*Ost; z1b7;!K(yx<6?W1RZu~iNp+=123mzx0%0=2#=Qc_v7OhqqZKX(Y1xH@h3{Y-EkxFMK+R$f}zx`D` zJPLPhjCM>=q}S80y+P)}AD8GPO>k~gov`t=C45cr^PB?X3>%Yrnd{9o<3zm3b~@?l z4rMp~M?_r5PRdau4A;LbAli^|enxjr7xO5aUYUCFgbSy(N>gC3dVtQ}#%N${*Vbyu zsgC$GPK*pjU)~L2bbY&HInc+$(mBo|zCqCwpU#YyXn0a^`6$DKNBo(!eWqe@*+Ivs!}G6Aco;yo3c5M`;Goca2fAVzLN?eH z$xuO6vOFpKxc{tVJ#W069fhSoJj-tLlR?M3ZS@W5)0p+#>;Y9USUVpZl1{u@HDFdQ zCU4f3N-mWYb$^z$drj^9*y{V8kHCKUx=8`@!3}m!utH%GIjenSgo=fdiWW8#F^x;P zL*LO~>-nOu*GiTEcQNSP$_uvygw=4dm0$Bl{WQ@6nJ4+Xx{GRJv?U zOyD?p(O_BC*?1;@zo^9YSqu8$~zROYoz=b7pIIeFLS)b zhuPj7SyE;$S`MS9QxxBVB8zA zWJMP|hJ3!2sKM^IlE^v1sjX*R{v)X~e|&MF>il!I?`siy`v@5>iESz0oFPB1&1`aL z{@hKZ{?>dNEtUTM%Qq)}`xni0_1sT1qhe+4x!^r12MwLrUV*;u5{Zv*h@U>*+XZqR z0xs5UU)bqDjvfMzMZNXEE0y!Id^fU71Vr}%t$!($(p!W8S;w!kzWVS!|_s*2axn5ut-D|Vmz z7k}k;-8y@har1c=9UGn|g{oPt$8P6g_z7lV9}T@J-u1fd!W|_A)$_1QKg-IPkQMR$ zqlsPI;QVDpQjAvxQZ!)Z@xt!-#cIagX`u3ncVL>T!1J= zW;K!vTU$QQW1j6LJr#vU>pA)E2__))!;F6{Yz!43xK>~~4F(I}H(~Cg^GoM`f##Zm(WenDHKHDHq)LFb$S|^kQ1)KIv zQYzV!FWd1qSSpDHic1wRIuu^Qrn*^OH<eXKbS$q8#IV9@4T(JPqo6Zg)<~QM0z< z^SPp}s`&xJ&X^1T-kr8-|NE>47+en_5$#FBT%>I0=9aN?lpj9?>~bF*X#5?4fNR?BHLPMk-shTMH zSA*-0y#qi{%>f68OH=q%>ps`>xI=&@EUyCRgF6K5-Mpq0CEQ9^POip?wHjt@t2I)` z$3RYQ&_dLaPxCUx*l~E6=O#Vc?XS~T6%~&uO6xE)=-L9YbmJhROTZzG))vc^SAGd& z$cwoz8r?qfN2O_Fh@avSUY6U0hPua~5vH&<*Imt8O}0urPN|SG6Y5dwl+}291c)%% zVYVrzs%6fMJ3fucASQnPOdX<$R4kNM`Xd)tt|drqE#yr=cOUdU{Oe}>SB|O}CCCv5 z^4tRTe8dlGcd~0uSu`GGd3mNx4knt-S@5~7$k?)9w2vM2Fckxd;sm=;j1E7o+kx5E&d?H=EHm?)W9RgUPFc2dt%6c-rQoxn<`hVJQ%dZdoG3m>PMhAK^bwY z2l@U~-Q(Mu8&Gz>5iMy2O+hgLp#3kEKh8WKyJf)oqis&T4HL*Y?H8+-9~Hjzz~41=|2SkYbLn5$dcp=shbCyD|TK=Wa;M3 zGik=CL<5p;>`v<*ks+w*ruNxa;nvln?su#~z2kP*e@tyi6X~Em_FqbwC0g}cjQ^SbhQt?;=8 z*FP0uv0^sGzS4=Y7Jd3fb6%)I1CnuppC8w_$%h=!8B4x;y`v@zj`e1#S!Wm9AjLfoz7a467|OVTJUB4Ib@M}Gg}iKAM+&t z-6yJC!Rvi-&#kqe%GSp&c4PIH$ewj3B6VB7Rt9z+ zz;=?gU=Bu?^*FE*?4sqJP@WQW{tjr$(TXp|&r`Ki?oze&ddB?^lc71&8b+LWVI+Ro zLgmDB7WRJe9b9)RcX77H{;7}N>USTpfBrJ4@dBT9 z(+#gJpMBh}QuVS0I!JC{uGB(38IX!V770v;UrP-QFO8f8onJ&*MJueq+)1>?xdb8a z^1|s^1w_}Ba0lm!gk_8Lw-B#1&q;s%jhpQ;l}bJ6B$MrCIWAisaaR3<9^J;6^s<(h z5JR8EK0=HoEdpe-H7uK2gCylcbn6fc)lI>5Mvbnq((?Ba?L+g3*~rC(vIc|&dNjXT zI!MXw4HU@nx$OG}1D+(KT=zW1#4yd2-Q8TUa~iy87&@aH(W|QBR#3RcmpeVr&p%7xia1O?T;^N=1Vmm;sYJ)ws>Ja312uP8&?yX zOq{l@vozZE-7D|M;xFZ1HssQM>6!0FTng)5eRlv13aWWn{iMs9^HT09L6i;-SjGRG zCepN)!#`o_=te^L<(4q8_gU2ur1*4+*GCAIVP%@pdwSO!-OJd(uMB6ZG{opSsM6AX zY>zE^Oy)KRu9Za(U7;JzNT4eP01dZI{6Ac<_!MycX@v&1&%P4)eYH8{%$PRnjoGk+ zyEECU`TcDDlV{dnP;b(_$qHmuuqvzHG)q)p|1xJ_OjQ3(RkB)0reDIY`XI`N8*&$% zWtetc*az6zsM9aJyw4sT`0Nzo{lKE$?eo~T8x`PC>NX7*N@7w@57F?c96<#MU6ysBpp{fK`Yx ztQS%2UNpOw!0=Qhw0*#gE8N3d8`(l7@XBqE1Qm}&e9QBH=WB4`!{eFpCnu<~WBQFt z9`Wc<$<48CsNV%?l9bc-i6_ww1zq#-^M#Aa^YMG>oqpT>E0+AuEpFg*mz(TNm2brk zpBMjLrrPu~3(e6CY(miXbcNgd&y`fKxWg3c#iL_4^JqHQ*vfR*5A!GDF+2tCnV(Cq zJ1_aH_*p4ij@+91?rE7w*^tLQp>YSm^Q4qZWk#Kbn+}wne2fk#avas8vKjZ$c9Df5 z8X&$UAhLWhz~i5<0yDKlG(2(NQ;wOsROjLD09^Z`a;0`=Up>}Q>3pnMLZkH1kqP;J zlCjQKt4?9F?*=n!;JUz{|n_V zOUDi_8ks5#fO3aST^;wX1M4Luwdasgud6L=tQ<^a+GH&r0ja=Bb}@DzspOvRmWCpe zYFuBW&%QvGh0tz%QwW-4V6DN&kIc1_`O*0aRHETtj!zXXkF|Jk%3Re0=1Q0u|MC?z zAsD{`@-%i+u;ly3c=7*2!~b>E(NIUKf6UB-dX8BQ7k@=dQ4lB64XPXUQu>ioWZ$jf zpz#jSRxG|OhXS4O>~x%bmMwSumFstLLt+Jv zRHhdm7#8r9B3<;4`uStJR&#!mZ>#U=V?zlUUq*yp#Jm77DrdrT6DjTl^Pdd%y6Y{&oeu#|BASY@pVwG+vpLXl ziCoEAQUk}@1#ZRIUN|1@#HW@t$(E!vYoJwBJlynps_TRwqw?#X9T&d%tXs9}EksU! zc~W`r{NtZr+_ao?7nm7wqTBYnz*^B zNNTrbk6}iFhh8)xc^tOtWqz&y@!2GKgS${&d=%oymupv? z*lD=zZP}Ps)rK(8&Fu(EA|?3wmN2+6wY|U@7oA|!?sS9b#3q;{>qQyVVwb$N`$H&J6nhk!S5WR4(y6y-w4Y<$SG zo=j?#5!-NVy_GmRp32pL*z?NSvLM*(+}LR%3cR0Cj67HE`9!3`p^hQ4V}K&ma|tSD zFkZfHgI5^Ad*-pg&Wp!DbPD!0F^E*|jx^!Bkio$Y8bl}5fh=Dd5R0sz-8(f**|5|_ z&+M#Q*z)8iP@{N>!0K<1MRz^r3q2n-z|lrVD{|5CZo{xrXUy1)WLd$@>d6n6y58h* zi`Q*@Qh8kE$ZGM

2zBG`m|IUUkZYTO8D!3L0n)tuxko3Y*Z-vqYd6UD?8g#M#4Z zL^T}s99JBupq}_i|L$x^gwWv8fUn||)M=(KXZIi!PiKBYIgxq1)Y3Elx@oD?yVOv_ zFEP8f2g#fPzr3M!`fvw^s)Ji5bFY%PfHAT5j&q$A7sF7mw&SWhGiwZ3lAVTK_6L-?oHstPo zU;MKwRkIeMg|3py58)jHt#9#WUW2OMWGQ81Mt62OX5OIYn~koCwp`@8jFi@>pZ{@# z!!)TS>kk0kH<;*bgaaJ|6ZExwB!<;U9sctnG6$&k!mSmoulv3wbh)^=GBHPP&W`tPi4BgiZ)( zwm*W2-cvCgqwArsuN>cHU48ZnDR^+KX4`#NaLM)% zz}4pxeMkG})(-fY^@^x$Bgq-yB_DkU;Wf)cz%3zhSfxgKQNRKj!nL+PdJmyfFfW;x zdfQmo*HPb5w1_yc*$I;!;Xh^J9>8dyMOV&P*)}g_m2A+HoWKetSg^|V@ta*#WoS5M7#ld!iP;W-IVq+E@M3*ZS7bS;Y$jL~bo@#2I&ikgW*&H{n)r z%T@2BtfOF@5sZD zBUyCrUUzt4i5HP~%_VrBkj=YQR`gVc z_#%gumjYTX8Dq%qXN&_0EM3OH&bWQ&1tBU;7+$S6)sX&W7cy0uA+{rZ2-y0h?s~KK zInhZE!!56!_R7m|Ox|YXv?bdeCACKt1y|YxtA+dvF=>jCvh@+AZHp zdoc}h;}8Br4%H@E4RJF{jIBw1kfMuhNu#^*Sl4{)4@&-MNvT!5j;76kMq>9RLM3dD zT)H?swC!OVIRAXlDsMGx9(|H$Bi}00^bjz$pKmn(ZqFcXHGBR+=scPhSGj!%cpi%s zRO6W5%i6ubbh-eB?(Fh;AUNp~)b2gMvQXv(hY}uc2|Gh&%n+6-3 zbSP>#CMG6%X+n+y&P|3cUxYLsO~Qxo?es?sqU$g&d}>;PAeK$h#@fTMY8y0LQES5&{d^fjCyjdtVj}jpt z27=7_*Ic7avzNk5som+lGy( zRj=0{b`2X7nstjA*h_{9Zmsw!noAn`tZZp>2eUjxmDwcPkSNCXsh7ZJnalA|Avl&> zAtRcqRWZ}=;q-z70WYw4^4UWk$|?#Gqq`mgenA_iZd+WA)=hdp(bZ*l<~aQq1!E=cLq2Ti~Qeii|)S@lL(#mSIf$G%*Shy>m6?`Vt6IG z4gtW=E8j+M_MgyBg|H*eA-|4~u*uO7jSYQ+TP0Q#T>a{pBs)WQzYL5CvQM2nn1m|k z&}ds{dy)+YnHCQi>2J0`q(f|CQUs~V^ZFC~I&kr8)!k;Iom)eZj>&xY93`siDjZx% z^=GOv4&2sCz$GI8PVcA-T0IJnP zektn!r?)srIhPQPufHn535GNKwx#S|zSF-k|F&g?++I_+DN~kSshAOCuq^D=jkH4D zs;I9FzQzAjK#^VfdlSnOMJC$W*_T@7!=G+wh1C-sUhuN{xA2T9LH z$7_j-Hw0Gq&BZCHC{TvW(MozqZDE zILOg}oQRx3BHL|=G3nhiU4*rLA)2D4XY#^eshi`93%#L+^hA8#MKSgSCZUU3FkEKq zV5~YBZmKF7!Hj4guwo$YHiSMy5R(yz+zhgM~5{fFS1j#+y_U ztv!8&tD~&{xp99$NOOMay*K(!C5eKA%gWRxzs-`(+@ipDb3W8IM3%8ZM8f{F5t7cN zKbW&s+}4+-gP)fZZ5)D~QgZj_-ze{`sTnWkriR(@U?14WI&@)x>KPP7L87TvwmfWl zq?r`@)(;UY$`3sj5-xc@V8{}P|6<)f~LWWvGR zPJ~+d7q#p?*78=>Q@9^Gvvw);pOKm=M|)OG8T@H#yf0(#f1U8`HGEq6R3#*ZX^X=z zc7y+Q{I`Dq%Q2E~F#13ZEtRjn-ZS8)c;Y@Jif~=xd*+CL_FxsMB{aq`_=nPq&uS}7 zp-UqB*Ee+7kBgas`*0*p7ij`s&>Cto9uF3deKO54fH7|8w&!q;3jxaTxyVoK_ zO1Y57{;K+k|NM1J)kpRt>VBwte=2%wfKo>`D28#bB`cy zQ7kVYcH4SaUxu#8WI(rlQ9CPP9h^Ssij^cg|B_igwWh5lOp{CF3&KQJX`hGcL!@=1 z!6r$K=Pa1pd64%-8jQP=Df)pcM4sv4Fjse?DO7IhcyQXQs`_)r;rijwdwRw?&4@*f zy{Jzj3%HpZbHMH}fsimi!*o<>c!@lzxR~rg4BlHrx+47=lUZHHvP9Q23u%G84l66J z%V@qRpn0AoUVky1cg_vGJ%60(z3uj7*T%Y=G`2)rQt}XhN@Mo_UMew3UEGlzxu@P# z7*%@%O4gkfbPEktbdWu+@@I#AZZ(a=pCQi-xHXY?5I6|5B(v8ksfkt8Hrcz2uCj@~ zyx0$|o8;XuaG;hM~Jr(lbZln&*m*#;L*EQ!U%=5{8O*^zpd*!Wm;gI~Ps>`j2A zoo=sl>g7d)!!Kcw6(z1rHrFT`!p1g+aG@le24ryV^{i94Rdc7E z=JERjeG2uPR4uFA#s+Vm4wT&v(dNu9PbYf3 zLJo6picjg4u3ei?Qh>-g9>_H5iZ(eIlG2Xtbm<2m8S-`lA&x2;jvK$=!C6oUI9t8` zw&9hOiz!x;62t(u?Jggumck8Ft-*lvua+%j!%^OAXag5-`7?_Oo*+J-vo!2)o^8(yNlGN< z6hov5CEDZ+)Obqr9*{e@oxoP{I_-6h_AN9^tjz)`g}Kf(A}@fIoU1)i6Z-A(O7c#k zZFgDpPRF2jtdi{f(Son*oriz)T@-26@1{3vXg4=vpWzyYt5%#_RDR@iM^7OovsWbv zm(r{d*<$@{Ju~BCH34Wh5y^-?6K%1*0W^H=lggF^vT!!0tNc?szGkOHinhl#Yv$DI zUiLKdI$%7-c-{7B@j)QTn>@+vXe!06aWD|NtfejC(<=}^jYl-(Nvjy3?*$$22UaB0 z=ZyVO5uLC``Y4{p*xTm5ccQEVB4*L@2xkymqUwg2s#>#@HHga03)rx75vM0StA~Uiq(i8Dfgq@i7iOK+E4FrUaYd75$6JvbgHUKC8@2<~R}@v}V4)(cXNob#04bO1cX3k~a3Jv5?E4^7JOSq(2^7GzjsJiOtBvEbs${z$(={NDZh?ZrTr*K24V0Y0r zD(AWrE7W)B0#&ava$&HhbnW7rR`R0>A5SM2#3iddxK-O@QH#i#d75+&rM@eD?BcAJ zg_*CB)RF;>bTk*?%&YJHTKE3@KWKt~V==uxolW3mm3U9;e2s#LHTEJ~o08C47rx68 z(DpvvO5*JLlFnozuO#bAfy<|7EyfCff4)ll&p`htpJ7b|_tQ$?eOs$OgLUgo{$=J< zf{}OrM?8kLvRvl&D?W(MywtyP)qYB+OF`w^A5_O3_kSMHrV(d%;v=eokV6xhPxIfc16!8H6jX($G5~Ywv+K`V+4eZ7-qqwxb+aw;}$yFyGcrc zn#g|reipBec1*NvbD&`5zVa3lhc}W}bplBfbtCmRKiXvmw^+!_FC_KE5Q~?rUJtKa zczEpk9a~?lR42jum&^&HMhI)0FS%Mw&Y!Q-Pc(d^-!n z+WO|neA9nq8YvB48bA___sbwu>o`Y0=eJlE8OQv2WE&E`90jL(uB=7hIVP#-j9qZ* z+Z-wKk|-~4@ZH-i)8x^;(R~SgT0hBzXpxnW=exMz#n-Ez@*)1kmC~Dy1Ur-0SznDV zrwZEcb?rNDO^0sSWM}cB-SjP$7siz)sz^y zM(%fl%Uh?u;m4H=E~h-|9xxQsNTcVw4M*-rk1idVe2l?tfeF~N9gM@?P=WGj-!eF| z&C%U`j51V+M5M(!4s`~pId-`@GzCzGNMkPC%7A-CFXJ4;YjpG@`JF+ApC>noE}W-? z8?dBj&?eCR{;Z6Q{Q^>S+$P`L3Uc_GBOu6OzQFJa9j*iy8opYe6;90Ej7KK*suz1l zMSPs_&AZV!n*-Jql^>Ey>qfS}2Q zA}*_LoN=}%FRp}dH$5?rZR&GV^P_JptGT*4!<25z@B8?GTO(tzCgh0tk}9YOjCUJo zjF~oa41GBNRu{1VE^(b%gwbw1KJHJronm4aD~dEu3)S{VQHIJlz5M)z>w7&T?XwCO zL?~~}X|jVbVOYO#-InR-qZcgJ(tV3Orsu^txQg(sQx%&o@bZ?$)k@7@-aD`ge6-<} zO9Gv2xFFJH*5SAaE^nwH#g?gu0^U3y5${%X)-+a#k2G+1W^M9qf=l{ecJ`4oiJrTh zg|UJex|2eVj9h>%EOCKoH?ACqr%S_}-t>a(O|v}(Z04N;=w;DS9v#oNBOQJWbig$8 z>Cgp>$xW{FV-cw%@kH)jnV3m}$wm zXSj)}w1>kM=Z>y16ea+!n%J4(rsv6>F-HfNy0pzUk*)e}i)E+s&bD*gWj3w1{5&o+ zVN~KQFR{p5rm*fN^y%b@#?=U`Sf$~C%nAS22A`#Y!QEj=x0StykOvMhu$i7QotpJ9 zZqqtEJp&E*-jt+|P2OlfewFo|5xx*A`Y?)Mjv2!B54tV(3#FHqy6V!od3F_>^OI8b zvxBjh9!5@XdgfK{4lPNFr|<*I0J(_NXeaYF0q~e(ah*DD0EzY*tHHipI@)H;D!6?d z6qHE=3QcwqCGj;4wFxh-`esXXU&5a{sh?ofGr2ZgFj|^Bhuuni*aW)QADupVWq9>D_|Dn#3rBUEEegl7=AN^;X-VG%4XsR%d(IQZx<6X-WYrnE!y4SzM)I{r z>e?YMIEUM${q51SEcvxY?irMMhc@3f-?Zm4MBvJpRFO9N(?LUb9c z0Cy0Wn7K2Zvi{k6^@IEW<;aM_GRHh9tHnT-8lU;Z?DRu zii$xHk;#!M=hng+3AIHD{hUYsJ3Xx2gnCu!(H5S2RJS%wdV6p7a2TDes2QR|hY0l+ z6dKh+Gn<~GY~1UU5sFi%q;c1qQKNlyf{0!)!A>xBP51rN%a2YDXZ2j-1I8NQOr^~h z)H7EK3S6>gOzW+rT5fvp<#HO9zQu?Tt7pwiHj=mxk{jU{FSz4@pbjL}JH9}-I{xOs zu11rg13VWT#^4MXzigMKdl!2qeL3Civ>tenTzZt&{S6}RgdD$v`d5uX5; zsehmTKlJ9a(ml+3T+|HZ(wwVLEi2%!33RB5-QV2+zmLm13x6`JRP4Fa8Q@V^|Kjfu zna^w`nYt1Nprd6{IB)Gk!0i@79bZ}}782?y;pGgu&5@YCKcCjY*;W(45Q{Cm9hp%` z#|$3=IDT}em2W3{4-cRA6fb(}p2V4XGvNX5t#@{(mtI&IMud!Pa@EsAb967+20TY= zf`2rrdPmmqH+Di0v>Gf<*|kzxkaVV)U)#j_M5LS%rvk2Vurq%cVGQ-7zMuO$CI9so zhL=4`)~B=bg*EhlJ#=+jAs1>MU&+fodfew@%Y3%|IYS%?6%maP3TWw7owD+3Vg z0mDw^x4~xGXngWd6&a4adrTXFFDU|g=!oiX2ge0~{+d1=FZWnfwRa-d{%)Owi8p2H zBO3m%k7<9w^NvxGkX!M_50rKHRh1I8T{87vL-^$&0qB(fo7M&u}Hc- z3hIfhb4qmPA>6P{y@;tA7WFW?>g+L1TLal81!+slc9!o+CEWffWbo?yYH=r)_nsKp zAGT4nMeJ?{d!hKS_00vC==E&0KyWhf?5HK4i6Nm1NPM{UNdK*`qTdrutPk zZWEuUd8T*-`lfM|W@eiZsyr=QF3_6H7~W(@L})@J-M|`Pi61GlDIixBksrB*{8`;T ztyy`Hq6_ZlCJq5jf$N?MAfR_jh~pW{Aq%!@JX;6RMk2pm1_z5Vd2m4|19BM z5cSH@=lsscOqXT$?t>G|~dlSycaFm@`=N=w$0|{;aXiE2VtVu>wDd+xJsT z6z_WI_-hy=iBryy`p16*`yV3D@aOvs$x<~#h_*xzcpx819^bep)!gvBRSCrdIeosA zfRtO3=Z`o9Ecq=A^wQiPHEuprR3$WnIGICoVigFYLJOA~r4sS=Zcf2VQF$*OJ>HI_ zFZKBUsoycDXFMV2GGv>xHAD-?QLzhDEy*xd6f;IHFVWsKE>=`8BW8j$FiaX9IbG_> zn+7t!!t|Fn!2Xo_Zxh?mRCS+Jc#jO!zAyAb^}tQ%mJB@az^z1JgHNt%0Lzd&@<9(= z*3RfuecWyg-dPzZ2_p-kWn*`Co|D<`h9AT>igwQ1uvfhk%c2qbE)uoxHT;ZA|GtU)`TFeuDQ0 zZG&?u-)yv5K2|LY7*wDZwkJ<+m}PvRNq@cdHguaO#TLb=Rwz(w4*lM>dxNE^uXjuC zdILKyxy=SWzQG<7Xq&8S4ITXuby+pTcOPXH)}^YqA9qaW!Mqh$rZf*cPVAO}8QGcW zk~}@7&@Y+Q`N78+^NT+9X!;>ym_I&G@-ypQ=HhQs{e3Smw#%o7>1v~}etKu9F2)iv zblyT@{fhX)+qpdiCVaXlcIp_0+uc1%7~WItW=^3sF7|*n`xfGi@^FqldWF`kZO11m z`TlZxJ9#a<$6oxquEeJh?~xtipxu%YZB8Mo-d6A!_X3D7pnvcXfwfTY_(n(Pxv&^w zY|o}(j41)luk!|0UnRbuVWU*4Sg%JX`1fIbGB&L}DPahxQ&KPIsIaH9@KflNEz-Rb z{Nj%$kb!Z~zHV=i6bR#(Yz-7}9k$xFF+aL)Qu$j?nb_Y(TjV|3dw+h^Vojml3IU9K z|E^vCr}URECqHk@nYLX+O@1C&Np~87Ej(F3E>R`|eMVO!-p3Wm@+p`!X7EVih4vcU zBK{O{`(v(Bq)D71P6~Ol(S1c-?%>naNB*Dn8elBm0!~Zf#W*Z{sD5=g%siSjwGlm? z{+_+vG*6_L#pp?hTde9tiO2|PyvwGTP-TTy@A&oaX8-`FWb`3mgHpYtl`A&_MyGep#uRwa@Dp-YUO8 zA<+L=qcFD>6ad5G41-mwr6Xv3+u4%YlKj?OM+XqjjoN_=bZi zYaUXpaekI#WkDMSY9B|RO`3SlkS%4x(}1I>0>-QPe0QU#d`&cBMJV(B%`;g4o^9m{ z37xRg@c}Pb1|$N}=PYeZ4*@Bxc036-vd$O+a?r6~;o%xj5tAq(W0S7nY=S^OIz}3h9vUPu*@SlxD zy*Kny{^v+^_0MCSm%tWDwu=0kqCL=z1#WQ{D`ZfjFVs}V#>`?|Hid9imrEqGg?L;% zyEs>JE;gp5)Pt!%E$zJ_Irf2yLVc2EdA0J}3Me(7D4C{Yy?U9SNs2sEjot%smelVz zjTcu)Zcf_#rakARSpkazYrsl!AFQM+d2)$;j^pSoY{Vjh_Ou4z8hMB5INoy|Zk&+B ztkSYDJlk6|A5&uCK`2S$IjvozX=4#&=&_iX+g&mkl)Llx3L|f!#VK9$v6;Fa*CE() zeLqHAvLZ$7r)FaXzj=TebFaEnycYW(#zvw($H7);oo(Q=Re6Ykny9Wyh z+d#cLCS`BR5*Lis6Kh2-ctkGX*9wTw;)RlnzujXi2>y|nn}%E_ngU197^BIwcMdzA zUYo>=3kyzrLjeV10L~V0oZcpU;rqC`A~j zEz-bx-(MGrEOvLuD7H}6m|6c^bxm2PyEuaPo!HlOtdB^4DbSO;?KX}gzBy$is;IB~ z4RHx}HNmj0uuiN+aEen1Ew}>-r8_AC ziWhNqaW>V9I9hgHimog-fBW(HfKco5&d3Fbe$H5Nys_5sf@;dy3lAhjo-EIbqqL;X zk4DnQW4uJu{a^fPO2q5}?7agISGEBufPRjq=tn(WJz#lO@I|)h9M~swYb<}`q$h;X zz2%|3vWEZRvEctz=MVs8D;v&c>^yF^VDU8&7P1y7LeK51EUMYAj5+-?rwaj5w_haC zvX?ec`t)Fsee&R;#1*58?m;ekU!eSIl;c9!?sK(-?K_1pjd!8Xtp=a${eSGeXIxX; z);AhNL}V*l0cip@dQ+q$U;_jUNS6|#0%GVjl+aWZr7A7bMQWr;4G)1179*>u(wF(uxxbidaF zIdh0Be`kjMHH=9sl}|~TN}8)uLutYDh|ZSK;qsPo^>y40=|#F=<|5Cm!qWdn?89 zr*L%8-l-e;VxfLJM}pgGN#ZZ;U&80UZ1-YGZX;-G#s1E(Act;F*`bwGm)0$b-Ksv> zuz6;7MfRd*S@AIGgA%M5d)c+@uOPKQEo-+hc4Ch3X1%~`%d=>oD89T-lPq213$)61 zP<^m3x3hI$a9KJKBD-39K-C2odGPd$$Og&T9z1l$UGgRdKI$*l-EH|%C&}$k%gLTeApS8r+i^eo-naTc_h&CYOZtzF z0mL1TswvgQgcZAO--96-8LDUf(czWdzuNVG^ge&SqrSg9@c)R>!Ts01U*+D!Kk88N zOwj(Q{(~3f(EV?pbO3wb2Yy|~-uPhp$Q--@_W$)e|F$CjXEtsBZf^eD=F!h}v9kL= z_vb4p^~bn>qyFL=`pbN7dH&05{pT7gZiM6CKo{Kh%)P#Xet&zg9@*(dI73j>2<70W z!&r0OGiOSR+1m08ozyuX3TjUSwqFcrU@4cyRF1W1SUO*v*Q@&dC!c+UNA6U_@D-6%qKU zcfXt=e@bS-lk@tl3u@ZK_!L+JMXSXT`5OkG?^S!{X9V9H@2R&rZeVwDeH+b!Nq4-& zm7jj_B0TM-mt>FSc|0=%aSgOt6fq$m3eCle*6(vZPT9 z{T8+LJFKxAoUSVhO$&H!qZY`T<*lN6jp6qLw=}26N;|hQiWp-AEw&F{Ura=PHNPU+_<@%wk2Q6c<4ta);LY1xSTp zk!Bk;S=>rGht&qk098=!-aPr+y918Me7flD`4w(*yt2r#?Z?;|b{18qHWN85E2n4E z1)6htG}C2zq-~_rmS9;zm6SC>39pBij%J8r=*KXRMzftgg{_vcO0k#R(SReyV<#_xRq<44>yi$l{+(>SyHEX?dg8la<`WKZ>)Q$E`cdbd`G48J* z$y=GCgE|6-UgN#?ZkiYQS%hl4*Ta&R4KZ!Sa8+-zpQ@(=hR8;n%(zXRnBnHadDDYU z`s)U$vScso?saUhT6&G;tAaw%9>j_F*EGyUhF7GpJStNHy7=pSNXs#Cc&GG(RF) z4~tv$)D6NNURtE^vWtu{%HcWhykzT><)7J@`U*mUt{`M9Zt}O?bGaCfs>V1L+u2N5 z<=!yPpC8_35B0?>+&bvVK0)a#)p#LnjGnisFo6l2+(b&plaVMy312}1^GaK{;eeq9 z(e4>?Xzb^2A4S-*5;Jn0bwv73g$72eg%>R>KH&g?D6FF9ym%U%5LG3bFr^!sI=SI< z4be)=?i+x#w{{z3`>UVwq2fld>}Q`W)XWh^s*n(o^{h-+k!)*sL5ct}9dm*F`}Sv+ zk{Mn^+eXZoD>qW#q3gEUL6>n&T{stE)eRmtPnL2zH@K#kZ>eM(0usKb_EDH*!k6q(wS#hPAQ+GZ5nxm%+R97}xQ z=kC<6zJh2qXDg^L{d~H@B2bc)4Q}=|>FY;3Valw!;`wURGg{v4I)#il9f^wNbrR>` zVo~uch_XSlU9(GY94KZLXJ!rca;lrRCJnfv>zBq54i_&vTpxW)lt}>~Lcp6`oMZ5L zc<)5*)4|&n-`jdJm$!Jx6CzjWXJaN8Tx27(n=8bMGTD5ANK91?6HggC65c+<#JN$I zT;~)X_p((gTcERpSM9eE#`-113Nd2FFpMz1b<%#pUlOr|4W5qE!s1>zg{9~&?rPQ} zx9kM-O&9aJ&mPkH*|q;YmE=fS3-$u;xu@oM1jrU~&}xIZTNrSiL@|i*-QQJ5=`Xp# znfgWu74_mP2($|wv>N|gJLyz#T0_myeXm2oApnT*eXqBdr_kie1Bj7 zJm3ZIf_YkC<;djD>96aLGV3y^il!A8rwn%(^XVnGcSEny^4briM6S7OhJ+m~92#r( z(@r~%UqKyDuAW8MZscWub590^PtALC@5IQ~3v)`Vxsxf50^xFR{;ms74{cZM?m08b zOBBGkszo+af?5wd_i={IGt2sTjz9|wBJK+l6D4lM7_(@S9)=idU@5w+m@)|F8y%j! zP>nr_dYn&1TXqAw3Fxhlkd5@oy*Cgg0A$buEd2;U;Njz8V$?zq=?E&6WFkR9Eb zM2Vs-E0Oli=9m+95syU27IqafTA?L@W8Irg z`qO;AD(E8+15td3h^TUKo=tI5yhBOEjC)ae`A3<1yF%~nm=m;PU~3jx_61|DeL{~m z9U!+-83HQ@LmXm^J#}HwW{aS_4IuVP*0=Kx#%ay3a>nrRIjIo80CZ?z)`mgSv0%QA zdy!TpEwqBwc(Z1k90FWt$b(>5Zi~-?`lmnc^oiF1qGXwipR_X?s^Vl)@f;W|w~cLo z{=)(PcbFB``lVg=*tN%9F>y12Q92{m*}tErdQm$aD;LpCjN=W1YNgWc9w^=WrARL)5d0Ou=avj-CTcoS2j{El8sfVf8B7K^#AMM|Gkb|dvO z{-&kAUWQbCS|~BZOPwTI3$>hJ$(!o~A6v5X*LA0xLVuDy4VG2&RincV*;rz9WsFnk zdh>_Zd7Z>U5DlL$&X1H#9s&n6UUK0?Pb-Lz!NYeZIu>2bU92WVG0p2b=zyfMP))h^PwtD5D6K#GOh7wtRt$1Pn2}&t^DJ zrQ-EzZxC5VROkq>DWj5CEAV-GqFl9r`9NJ|#+yaVd<}6Ug3oF@*Uv73Gb+SgQ&k-- z+8!n5)7iIGc1Dn}o_sO!3|LaBoQi|HaNx;VG4eja?=yFpw%~ZA$rT(uZ+jA|ImU21 z?qoptE8fO0ipn-p^DOZ}p1V`+Z^QgQ%LmQOiaGnYF9)#}=oS-UY31fS&=a~B zZy_KLmorXB#e7`s{}AJ>o$blH!~S&8X?FW~+>l7A@k^`DQgpPrDxtYg(<1y8&zV-2 z(s|aDME+aO=KTisUAY)E&B+V<%^V7?-;?{`ssW&Va&-A=@@?#)##DACkovqEk$)qg zNZhwSmi}0ZU|{0*2AAM7@KoI#5|ji4Xa(pS9u1PS*xD9&>U8P#Q}-B@mY*kI9ofhG z?`#R$y<)si!)(>Vpqy!)u#KqjNMp+?^=TRT=N}g7q*?s^g6+5=BJ=an*v7+y|&NP}wSAtW$>eV?y&?@dknOj+wSG-utb7 z`__&Uf9%@qoNrzPQ;70m)rtbE6orQqLXBiW`uETNT9V?zJ6S1*oV({3C;Yzxdx2F+FE`cES5-S*#6IY~QSqa~Lk!st0wNY0zb}wTqSyGF?ioj?1Buk{Eey zQv40gz^ol=xA_W7r{8qgPtuKu%-1(bZ+dr?uJRl8TW*g>jda9lm)6%bF%t@AkMOC? z2}}3@(k%Uhv764W(5&g5VkL(0*bY6T_YL&QPPD4sy1M=aa~!SRn&ega=5#sCG;~qe zRZXv}njv0Sl){~JQZf}^ey-UO028W+=hX|#jGxr6lujzVcQO|sa7G6}%k8I2t&xx& zO5dL4*IgB`t?=a282*swTMy_zZA4r53#NEiZga{Vt=#wk_zw!m#O|Gc;0a(s%InR( z-MK(LoW56y!w(FzF^r;i!zfu4l=W%omo&He({}>^_dJF@-QCKzUany_zR?dJ5!(GV zHXeX=mTzzHlniQS?q>3pmp1|#l_ciHx$9Ne(W7<0XZ3AOl)*e$)eL*?9=0(n?w3=C zNjaH+Qn2qYQA*0>L>?iGB9(5Ls|9=Qy+?MFKWZ}Rh-atX0Sje;Eb$+D(PV0(gx{vb z+EGTAtQ?eq*tVfzE2RucbZzBV5WULWXKM0%Hu(WjFK5!ZUf>!L)NFPJe5|QjB&)su znR`5+$N)9rNzOQf-k6GXdjZPRk;8_!+kZ3a+Cwo&dVda1* zj+3(#!Mf~aZtfOlfHnziw^>T4sdFL~?%eLN{ke?xDG7JK?g!9A54}EjY+vF}{gMf% zU*P200juSK_#UwhTopAiOLIs_B-b;Zv^zmcdj^OGA=>_lVN3ql({3_vd2iDsf{sGb(fhvCHy{W zb4Zg=d^ObJXegCzbU{^OIMF}N*2NI3s1tCGQC>4Pk?uW`vX1ZUg!8eG!}Rr`IO1l@MA&h2AP=z>AG^{kP!3RC=uIcRV8_WaYC+gw^g~) zd4^nf47sk^Tlm!VvTpcr_I&c-g}pD9TZNERo;H7rTy&ZB+sbaJbjW7tHMrTf{TKTk zl|U(tbZdLCg3#w)tBpD9^nhp9dM6&dQb(z3Wp9E}@-Dr3M$!gZHM%CuHwEfeVtHeO z1e5Ra)Pp(aM56pWj&%(dmzzoJw?%p)lBdC%e(9>xtsxMc`ta~vlc;0pD6t|=JyYeR zypjk`7iJ|krX=HSZpAM6S$$AM}RC-MA z=9swoUk>ECp*6wv7?R2)UT4FnSMHxKi^+e*7SXzSPSfuE6k)2h~SRfoTpWv7&j}UWXW5J^8EN&>54>|P}yy&J?GtA z*iG$0#}^H0+VKoM?-E!RW6-)3tLWC3&|I^D2Xa6FQK4?+s^zCg5;j&hR6|nL?BwxM zI!3GI3?4G#Lr$NbT*0q}H;TPU;|TbS!C@)tz**{klc{Xc4(<|VixFV zTP!24=$u2!Z0gsU%pt{uowGgPi-V53Nu5SIJ^N(pC%$Geq@MWmA*P% zWJ0xT`3b2#Dx0hGt-zLMS!NkSXyP#C)y|L02y;a0uwD3!2|vf zd7gKRS|ZJO2(Px*l&UIK^}KwIirX1YSP_PAgPkw5#YApc%TwLs4vnWx0Z}I0lrE(s z=7T&~Q^(%V9UVJcjwOq(6Ae6ryIE@X&|EAAZ5@+I$DrMdvOmo=|S}G4`eQ)2S^b)1>p+8V8#P_C6 z;D9~^BEvBd2Qx$p*cYQ){|Hm?7qetrKxFJlFg$HQzC?_`MlqXn%SJ| zFug1hC2>bOEGdyOy$)>>GEQ(0;f^KW#cYPm{pk|~ zxEHXbg?t4iYR*evC}4e`(l~*hMcKd(FOw6(e@RzHjiE>bA9>>Vb_w_0xDhF4l1@hY0y{DXRfcH zRxiLLxoJj}Goj~dIkabv{t9|{h(vLj-Q);)>BLla(Ol3#Zrmzr;w8iHyxeo^=jY9_ znDuxI8#jmTWBrfa1`V#EFdkFbS-#DLnV(OJ=2 z3cRAV;+93!sJlxj`MT~!*`svqY{3HEY~rxR zezmlPwTU)pZ$m4xPnQ7hkWw@l;ZZU+lxDUFo62TA|H&BC^j}(Y%Bcx@^&g>^t20`&hwo`hhpgDxIVW@?PsQtV7e^AP~jp09$xW{qo7& zzPBBp2TS2jEhy~r@3U`qR!v1p`NFWqi{%ciDbt4W8kQ-gC>`Je=<8Ikvh~Hmf4H>b zo<3smm>d^PZosP`+yF>@Q^F(6s581Rc_lOmnT5E2awJ;Xs_gRPJ)T1^q7ImAjL!_fVR1J5{5rUwJCI z!A)Ejtq)sWQjjy-=2~Y-lNK*TZzEaYGX3xok4-11(~rq%Q2G73AB1aFGsKZ1o4_%& z#okpU{XV5;_EhFo6~~-H2um9&Wa<6J<{KApo?Y>e8EA5UjZ8=fnx2C73zmmiSZq%m zt63MQr-MD1OiKz;t`nr`u7JHAbR5;Is%f(T3SruiaK=D3~kFL_my)y0@T@3}n`i z0?T1sn#-hbC$PcV#>-8&eEr#*Zb7Qe}@CS`91UWJwG%RNk3{TM451Dj>oH+`)t}W%t~BY<=&K;l1NHatn^a9Qo0;DIGbX=l~U&m7t`zj&1l1#s?uwA zldRL}sV`z=l#eAC zRVmFuVhp6H(jpXEAle=cIV*BPFi!l|P!e$BJwH6x^*$7O<<4A}Yv7q|rEU~g0`8Nui%_im^QC-X_2*J z4K6e+v$bSrQf99vgF{}8Rx^R5IAPGTkw{zc6%=%geRQ`F!hGe<;7Xi4^QFlVE}W#S zy{Abj)%l#>N1-iwkd*dxRiEUeYkYP34RCOI^ujfW>G9c!9XTL^Z)+IJA%hMSiZkao zrBn_`ElDl2u*C!x7j#hMMC`EAaLQ@OaK)2R-6TJ}bv|^N&6~T$q@}#lD;DqqN0zZa z>oO3k!Xc@6?>bK4n5?s<+@toy9(h!H`6D6Y_D;M!jm7XK++FBG%IvyV$?Skb zXJW5Ar^m#KyK}J(Tyht&joys|32#~}=JER*+z2z} zco`5eYoneyo>fF}4xL$y?f+u)`Rz`UXNrp$BrjutCXMY1J&+WMr0x%EaMM-QK%Y>1 zs+RixOUqo|#m411K1PkC34xW_%=(sK7P2tB(Dy_bbu+=AXBF#Orp-bhZJg23d!MEm zbHBPz8#yJCyH^iFTllJuAHo58A; zD=8kSc~R}@H|}aYRl?XgFNodS8c-Q>@{AAXk+!!D&l=*W`wF5Qw0Zq|*_B-4$(X5W zQ3M^PNI?yQ6btW`^>y@^zzE~}uLH?t?kQgi%vw*`%tul2ufF6~Cd z&a<7cTdajEEAPy1xe)xZHBO4KXeA$5%Lx>&Y=BagVbZ?eJ`}B?_L(9x?=kExY+_D_ zD6rmw7|??bByE*aXdg(HO-UkqEGn=Kb^nrJaUACcdKoLs3x1H`^hhb2jgRX17>`Z5 z{-Vu*g1Qw%#lK``5`Jj4xQ@MzF=Lr270~mF1A5*%K3nVwWZTfipF1BwX-JWeq+tEf zy|0|ysBa8!nmRq;H!pR`Ig0T)%)^Nop7XA18(KQI-Bd%uC%O+xzOslb+*!P zn#nsrWT5nC-L+sWvTVk~=!ywAGi{wEIesUIeKKJ>*sjE&6j%oGG+GV&ddccbY6>tw z6ufW~C=gF`3nu;tJ+giZbl7fE=Hge-XxbJvARD)=JOjknlTMkJkbCG?hm8 zwURdCm|CN<_etHnSk`{6oq_FGdwfbDawkU4reDST_c}|og36wbtn>QoOy`_<>tR$W zKTjo(?3ZL7K1{lSPl+Ie!l{61CT#?U&_uvna`mn+mY}f<(ATUvJ!^ipBfid4q>1&I zwU{>c9u-`Ch8m9zTy%KY;|XX^^1a+gWU5^0W<aPO>^4tXFJ*453 z7X0*x1+y|D&|Stl8l@J#v9d_6lONP1%9)m*sAeC%Q@o(FAZR~S4_ud23nChY1KIPU zI>nQv$}0QaPnF$<(~QpH_Sa;+gX!^?h@rnRYvjiNV)6jxRvc9E zn_yOhaK^y-aH;kJajS_kOkGORMq`^#my}z?2i(%T@!@2P{7~~Iz`6O%PO#=SXPoQKkm5Efb`L8s14uo(ppO1i&4p> z`%f8`Nj$6vhULIIi51XqcMk?=pCGS}ai`V%g6wFz>E-R2y&$$M@wbL%BN;8T5!;V) zPZa9_cMQk2`I4m8}m5FQoSicR!$fFV#9p9(% z%%fz5i3L#*HNY){a}4YPsPx0oQz`pyGX1+ z_kZ^L%&$`CY0X~*Qt#9E4k{&M>yD)GP)jjWxT0=P%yWk6rM3-UyofNmpBGSjr4pL6&i3krAn_OxQl-Bd`Xz1oXc{+%uctkGWMg=H^8MlqthBG@O+ zEFpIGpvwoG>~jR7Iqd+%UvnPBU-$8w(pV00un=5DJ=4?azJea3;4BEl`+)`Q-UF8o zQz=76_KL*O?6i5QNzGXijcFy$Ybvehn zi_85QE{k;W#$4=0M0ad+QKu80p;^d;Q#ukHW;arX2eA*eNWOE3sglE~r|ZOx5<=|qbn;-f?yjdUeQ!zDRHE8JR@)rB zJvMQcV68N}^~bhj@n~hL3i_Z($ZHSJav6-u!hdAUK;jvxwEG zCvmf(StjRO_URL?PKbx!@@}tH19_WmrLT9+`F%@=O`rpkbyv($>K1YXnVf?yzvZYV z&;ReE|1S+w5XY;kYQ*ZqHRqE==d{VGSaB(&L6bEO8Y@=hT8uqp=Br`FqH4mE)qs_2 z^mUvcvM3Io5z&#j%H~xNdo5MEG60fQHmki+a}}MkDQVyf&IrZJrgVhfcjZu)I9BAl z>G4^pDGgn^=W4qPrLqF7}J&xX`k4CEX|`o z9r+HqTwlW}I~7_OH7v6?ez=Ym9zGm@aGqS+lA}HIC@g^kr%JP~{`O#D|C0o{?=QBz zL$%qWs{S5IxrIF!oJ7iC1fNCPFn2yUd(9;EFUW6zKb#^-6Wok79v}i}NobqW_6eBN zD$_D`dS1+t>u;1ZZ`<4uE6H{5Re3f@sqnlbZpN|CVoj!Lf&)_qc+Yi9E5>GjdHq&~ zHFj*v_k5uKUE(tU%f=AUfK)|Y@&4YEYlVdW4K&Qv$Oy>Orb|KeJT>a>9pP;=l;ZLR zF_xN~)mqlEFu;f!GAEHuzny(G(sz#IZPvY>Et+n@ch=LfIm&AxjhHQ2pEA2%72)~G zty8NFX)GmZOO3LmJi)!-1R+ECj{^47E=h64z#Us%mlUTA-xn^ky(&y)+zrF#fg8~?D5X%hvECRj_lWmyvvgA_~+GU5a_aD z{0-Q|NtJHaeRfN{iQqr4zT-byyt71}Z+{4D_qh^$C$96q9Gm}aapL(uA7QcTKWBdK z|BI3Tbto%)UPu`(=2)Z_dDU2!TqEig@>vEOyB>jROEx=M)eE-yM)F~4;R~lZS!fig zuParm+myEO;fxb8gZ33I7!putFp}DMF%O8Zpu$d zZ4fom`k()FHvEl=lDS@D!{l23*dK@`xJe= zfH3HwKG&wrzg^;=gM)}8k90p*n6EV+$SV!e_mXc*>P$*>&4PzvP(qOyk6Rb+++QRp zcG}f+J_|es+;0UaH&ypGcc#Rx!$L~@%Zj{z^2i8*LN8RA2)p>i3bJ(jhm2=_;@l*> zYXW}vuD`)Ks%w{@J&1M$*u%l zz5spmkbWFjz?M~aIX{Bs#CTHNw$_+Nlc)`%!TU4?;AdZIsFs%%Y8|y`CwKP z{gj~C@TmT2N%E|43H6kaZ`q~mI9lNc>joeyr+4LUYMBTp@)V);=k#}&J;WhL9tT@A zNo#7uUbSihDv9CFc4CNokrNpqaRrDm992(bDpnB4CuS4Hwj>-QE5+Z55W!q253V;U zRo4;upswhKV7Y(skoF(NsZ9)ClLk-8H0YXuN7K}rpAN4+ToK%A+5{Zaru{`jiuKO!0*6zHP2PvEj-LqxQb7+J0`n~Pq zF|R?&X7BoC_xL+Cv99x4s;V0O_GP2J!y8M>r*4S9ge@|3ClzJ<}uZabGGe_+eU&Rcdh=SSz;#D>7HpM}?y9ek*X&uNvo4 zd|ipFXFX{I@b(3O56LcnSMRuM5@#*e_HS~VDC0T%ABtA;)9r9!o3q*9gjgX_9%Ff(^H7P_rh4s$c+lLlJ8^*^RPd-d3ZzWBh zk>^>hMx(ALsjNXH`a$3}(QxQC?chIdf4{_GXQT1RIB@cWMu7MS0b&oS*yx*%$DQgV zd$)ez6j)ShH`vSBxww38J<4!)Ys-5Yzh@OSp>IUbr`ea*_eV-EF(hT}8;%Plv>Gyc zDT&Y=l8vMQ0dR9DXl_VP`w=|zS7QhNvrtp4_xwIqwb{FeSZ%zaoGVL2+u+Ss9JXJ* za9EK41*j24Yhi~{73E1OaY>v#h(9_AL_Dff-sa||J&Y>pOr|{C;cUv-Kj-NP*>t4U zpvd=#{pQ`Gl1H-U%*2!-^0&J0(tdy`90iv3H+i>bf8aqhV&Y6M8Wx%zq5^cqxlGF8 zlu7jEW`$tM7oSO*&z0K&2cO#F<-+L2q;8)X1NA1Z2Ciz+Q8_K@OZ4%UI}OZy2S`#mDTusCGNGdM%GkZ>^=->*m8UD!0T!PlqiQ z`bm|}LyZ(>1Ubpw<9;g~(?R&nmPt$91fl<5R2^@oXT%~|pQ>!W7-i=rrxqy7R|)I) z3L6%_I-bn5J*fob7><9HE;{(A3f|Hz_s75`k?S`weJB9)&hu1ru<+*eAy@qbXO>(X zQwI;#^h89g%O!P;VoJJVN;cqk2H>XljXy?P^1rX+Uz(r2Ot}-G!@vZ7 zr{|rq5dG*&7l%UTu@0!=NYzAM|8BcS_@koVMyP*ex6Iy+51_FZVf+_Wk*ymgdQ^L<{dE9Bv2 zT*$K#wG-KNNQHIfp)I17y{uy_MLxtylR;L5yR7jll#8R*t)+boQ8dsWG3UK7rl$Z2 z9qNF!2b`MpyS?uHuZ$RFG^=;zBf2wYN9%>6LcpHw`7MPu0blHdPI{+JHQ9yp_0tQ~ z)PDuZQC6ip7M$QetYX$nph^D-mLO^S9?Ks%63>^k(YqtDsP>GbTC?CxZr0moo^cWF z80MgJLc6!?>*o_c)oy1msfY{Iv^M5$X!)0#(Q{+Tb+5TyJK*~_&YUWfw`FqN!xzK3 ztj0%&i_`7Ejy=F^g}fdT6_n?_5*=V_QDCwmvZV=Oz2^S7QBd*OmkVdEh9*0gbi{;{ z=Bv@|W_Q?vzg%)NJ|D;kuD@)RGCSZ1_zD{isK*?%^B*C|Yxf1_8-Lhe#hEiDbW(7J z?36>dEDFYh=Xs4@X{3Di{h>RWXuQC9qu+S)r?BjmUpk^U_O{-FTG2yfeLHYQI^8jw z>OOFZUtKd!w98LCb`p@5`*--;Cl{aGZ?|{s_J0LnLxhHlwER#Y9frapV{$#HUPie) zi)%toSY4Nk{1W!N40=bedlB~)qyaJ!(+jWSZq+)EWMJKNnIOuE&JEv=V4RZ*!L4|z zvYRsf{>#`bKU^#_JTBdA=Qjp#hvqMBC)6Q^zR1c7p4e(1FM7>m7dbF8;GDF7% z-A*^T$ZtB3w56-<$*5si)2f1<&Ds?nTS-t&M@z1o6SIpSEcFpO9B0^lN%-0iHd5J3 z*B&h1b5|yr#J@^iCF>^}u6a}pKa!lPbJBfwG!)%Tna_SH%#3>Avj>PWo++T}KB`|+ z9b%P)eOzgLsb^Z(XN04EDkXcFY?2bm6#QxM{2$S-rR~_5*{X++A5QWpC@U=iC?7%h zoyi*l8*T4V+jIojG`J>vNsh1I3|X(1Q}P%z`4pTOYkKdMp6;_nC69?>k&3Nfx=R+= z;p)Uca5i0ODTafczMFDH&2QoL?u$pFmwS(0b}r+okKW9@SZ;c;{NX1golf+=Yo+(; z`vkcr1E-)tMe+&u}q{6Yw2v)nywI=0P9WH zipMhI0$eGH*E6^=S93t*yF~%=kU^C=)j$x*4s3w3uvx3BG=Cmax*ot+4NH z{buyO<&u;GpOd-R`zY1=dH>wO(u)Z0UhAmlE4MX#Qp?02_aEo$-5Gc}d4P)1zYn6S z6~ozYa1F$9?SSw93_*6Z_ekqUeLi|s;Bsw~f&&8q@3;}p(JK)1vdgKLwL0Alnw3cF z$KXoo%VUjHKu*re3H%~B_3-|d{Li^6oD_IJ`DMoP1KqaZ_$+$fV&Iw1cu76rR7A6~ zy|~4`jdUTu3-?Oy%I;7OHX)Xp^7ZbmQoOy%^CMsy9&)@BlWgk zZFobT{_PxfEF?Vm0ock0YF1*bcOzEH?~-jma$jn*hFVn}W5A5LTekYR@QV{+X!`ID zaUk-$bhCzMg|pNqfuiBVQlwcgJd<9gTcRkn(`%@8hv6K{1$>uT;$>|vjSQnLb@gXf zw$2zU#o!x5tC1Tc%Na)+IW!W@OSx6;t@uzi0*flX*H8Dv<+X?!ah7C?tka#b3@PZ# zZuoo{@V1}0g6lQUG-aKjaE#9Nl6;2uEXdBxtq%?7IF7f*i_L$MUR-=FlR50Pw#Ug! z(_1RhE4$W)Rb4N~UZ=s=@Imi3aDIGUDvl{X>CuUt5N*3J=5sQZO!saVu5Q@a9MfJh z)R5S}z&H2K3c>WEnT?&Af``XiQnbv~KL30?gXbAd%xKBA;!ii$WC}9G_Et*cr_rT= z+h1S(?7BryhPCa-jWq72276%@c>5Se0{v;=dHeWm`*20^HEy0cGl}}u-P_tIw@WA; zL0&%Bv|{e0t$AyH>J>UsuyOvbQOv^18J>Y+qWZ`dZoOY5!$mkY&4TrEnl`^iNy)6A zy0W@jilCTk-ZPwZaPutX}w-T{sriBG7t83EyGo z+ejRnP|Bt~l+_irEZdErxL-BTE!J^?0j9N3w=iTa=6-6T+V=G*fn1}?sE`BYn8ug9 z+z=nh7H=G(sK8`ucoTv{8v9TCDPdd)x+9yihb9mCbVO5E4}K)9@D6{w?0O4=gjWC!tgm2W3+?@>-YIeGnBNJyE z8fIz|9_nZ43D4RQcw)jF`>JjvO)OK#KZ3M(KEtM&*jI8q`zegM?zKAtyC~|}Q4y(f zpV7jTK}g_0o`i7rT%pBr_6zLjyUp9o4`dtsJWi)h zkN~{)%|?cUe;t2bSAH~gw02tYag|x{tCQi%3MZN6fQrVOv)7k)W=|JDdy|7)!2 zp@yHk>4lO-v+O%+z2EdxV`!Tu=z#5*r^!nI7ob+G-m!Gz221B@(lpVbY&4Xl%TUz_ zD6aP6*Dupp4&?S5bQ8x)*FEM>=!_~I?P^)RIw0P@fHI03Qy42_V4%L(3uzen65{;+hALaOPCPVDB2D%Kq8D5`$O{Z zGK!$e1@bqX1eZ^%`y>wC zwO5;aCzQ^-JsVj6_<=*cs74-&iO0q8jLT-S&D)-fcjM}=Ix*u@*rMDYFx49!UuW@n zK~eR7`B-sQNP>pZme%mt5PWzKZqclZd~n61us&v_jOihtji4_KoW3Dg>=;e)?xWL3 zWw@$MFMWD!W~oW0sXZhG!$1j9G0iJmxOf%}WBgx1vu^sWlHr!#iW78ANERrfC&97d@k zr2QbYQOKb6ttfsydu|%zwTvp$jmV*tc5Sv@F=A+Vh}~j-c&X8a`=D-g{%*2!_Ab-z z1d^$c)!7@S8g{|cpks_>BdaYZGs!}=%K-{ExE082JuRL8zMF1zz_GK07QTB|j7DW7 z{gIpbb0q^e$013B@G&}KDZ``eM=m>X-GXDCJT=Qf=hOZ2oDoZthI$1>zlqr1SM_?&{*g(0`>Jj^oaT+BfGE+P4?s86NFR=Z|fnZ_)guk=W2<(Pv<_P)o%WcL_ zHy-uM)m`BKmagHVrb>Hn`WK$H<9pR+J+o82-Ol6l%{z6ve-sEhI3@0s_jxCaI{ImgT82#Dbt?rRZ7U$|y&b zamlZh<>9oC`NE`ftlD6rpe)^y)(fFNzkr5O(){AlND=lo-i{OzKX+Zl)#i7o67LJ<5U_bU*q1_fd;x?1wv+I99LR9T;Q>GB%g)XKYTu zR>kh7|ESUYe?d{_Vc=*K^OOb{OGJIFLRb(m^W>zwS1Dox-{ zFX(Dl`0PzVh7OdAeH7I$ocbbNVvRe!@>EFq$_9ipyPBbkN!0gSLY>Ftl$n#ApC@PU z7@r>F)uz+XwJO3~M4led;1e<)SBf@eY*aK!(1`62cOK&IZ`Lhb(#>)7zREOipdwK> zt>K=^z%W>nlg(tUk`IX)$IWVSa0ni+Jdwm1KD)u$nY9%kgTH?mYE~TQ90RMqDlbqU z-L5Jqc!oi6+|Xq;{81i0#Vhs7#H&xl^m*>80o;tT&_Fe|P9?N(Mt%+SF0LIOHD$f6 zG;fc6fbwE3hNgy>Qa*Zxe}RZME63Lub@$uGGigmJC2hYc zW2}WG!N;|4=5y8t9!JgBK(jct&1teP=u50Gl zfiCGYKerSMtWr=nR}L4w*coNcJC#~3+VRZ^>9mCcEtwn*1JbF! zys!+Q2eO!+p4Y&P37)YMc-D#k4W;ILhr^;sC1C|xnU$6G3Pi5&(KN?m=Cah10i zD^=h~K9Vsx5}nH-dI^F#I$em7MTk^gS91@LKieoJ*ucn)ms`^hA1X!cG({Kpzieh# z<<>r({IvQ8**$yhc8>{x$TnE=-)2ozpJ^}0Hlg~6G$l~}@Fds`btyGFL;Ra&+IhhZ zktyZo6L0FiJLees;V;DVn~?2RtSLjU6L+aPHF2c{9Ube|uVOBM$T~N&=Dq6!+OcJ^ zsBNaa(jvK1@q@EybZVJ}PSsFkB(T-pwGHBUtxGH%cF6*s;e0*Dyvqtu57ECJk^OMxhly1veeC1*_Y~i#>%Dj;+KSMTa-c-`ag=YW(h~XA@vyUE)0dU&CVZ`@W>HaMHDKDAnbL z$lR)d!%Cbr_2Er?PMujTSgG1mgiBOm$I=}Tk$FZyXyo>U{8%$o+hxWV=Z(JIHR{#Wo zux9O8o5d?_EZ3=m%GGH zD7%W)NJgd9U~kFU^~J$K<)t5MX=B=SJkBsvA|t`_9ksZ-*>P8k5SY=ATj#{faU;4{ zZ^eFuENTaEqgsbBEe+!bE@8&lMeD!7Sm&)9Us z;hWPf3E8o^7Cu4>-5w$|xi7*U?lcJD^fRx=<+{T7XDHE|x8pV<%8n`1KObCsB`dgC zX4%EaAp99*7dYf05o2w0)8u+T+ttPK)jP3ciZ$lZb;8lFj!w_c$>ca%bxoBnCD+a7 z>8y1~K>bIdGtYsHm;0YV^1I!`Z7B9ihfPGZ!xOYy_)3&P5%s`K)f5-z#8$k9{1AGp zF}n5Qm^n>NWr+I%GIBycBzQRz)4g=OEJqK%6Ftryj6|peW<-{qEgck>TgXYE$s3R3 z4zILd_n&OiRCVe`j#ei+smBkj9M)sVAN8)Y^vKTN7Mi!agy?5uQArPk(a1F{i5_!+ zmU^6#Nr(%93N-gJUen*=1QK9Ns_^Dm?!>0&-Ey*#r`rapCugzw-Q9^XA<6JLcsVeXhgIw0c#6ixh##jECQpj1zYUe z0zbU)kkD*6fgS1tRKSmCyH%O07P&h77HR<(uwAsZX+yJVrJ_+7QZGUDooGT&?;uRVqT z@%P&K-?UFNUfUVNNh{=`L6j=FygJ{kd~_#EH+pBeVra5CfF3`hJ$3WjLMBwxBz5RU z`aUV|_n@64`=b0{Ro2;kEUKCm2A7J9MTkp=X5GYakCkrHzdF(tEu$hSvs3tL&lDEv zC!W(#b)8>k1bfgwiFplmQtj>6LVi*$_WR~WJUIVbBHsZMTm>k3(Pc~U!*lBQG`{ck z^7Kf1Mc-n0qO25?TBh8$wPqm6Q(KnK4&@dum%X|DaaMn6PCfqp79Y~>q4mQ@zhD*B zxeWEUZYSN&jk-0n_G8ut{l@y;gch~K8GG$5o^J>U)A33GJ}LUrpcEWxohlR>AEYSV$I(sXLTa%GJgVT zbWoUDd@FFWHz$u2u@cqgj!}FU$i8HCWc|m0cLlXYE{W}U@hRbD^X;_d+FG87*^N_| zJd+nXHkZ{Jk4D2}9F5!6lme|Db>Gc>3!in?L8+3*)<$Ylw*z#rb7LXM`TW~EiZxj> zq=EI<&sQ}^#8Xei{49a;bz4xA43sY2XqEibhvHhg01#Rm?510%N0%+Y07jPUHme(n zftnM2HAx=E5H;?O8f)o+q|ika&m!O5w#DqOlevQ|6WapDKSUNdDaTJgRy^<9PdtI-xJJMsJM zVz58EJVx;c4&6|10)U*ZDT}X{SNX+?INP;sUlG_Kt~LnR_p=o{N1lf~K1O!m;hMhG z9v!p7*L`Xez1A=uw(jvMiCxv%IDwn@v6GA8c%a5z6BK8jPyGX5Zu<0R8(S~@wVNTU za|_M}^^eFJ&ND7py=jKr2}830c?H=Osk)w`>T+G-B=Y$idCw5*@bb?Sc-XxOO)$&yQyyQh0^Hoo>QT#zvQ! zSh8uM|MMI9o*nW=ruFeD6LYDt$T;)t=|{4h%K`z(4xmndzo^2YU1I>{dammlLgJ>m zMwqYFqEqt`wBzEyHdlaM7eq;Ryf$tF&TAyKL_QXS&2Vp@3^1txoRD&kn8pot+UxdG ze&cj;BQ9BNy%6PIUj?KtK3lgkn3h|`xJxefb=-{JK++f~WLSbWZnI0ZP2y39PbWAv zV<4d|Dodo|9E^;#qBm>u-t6|DL1}HpkRbZjogMkO0HpD(?UGvn>)SMsV>~fz2)$4v zbk2pdr6*=?HlJ)s-&Mn`2dl3(J*PEwvyOPME1cBxB(LVzF(QB@YNTGo&Th&H}jhFofyIvT+I7e zi`B`&)7?KuSNa&8>>TD$kRL!hS&v6h=h011NC{$bSBY!=jpO{p8>AwLBt8$Gvt$p= zfl3S=w!Qh{eIAAJVzPhd>`aMe{n6?QZcpM9bkLhsHItcGmYgQMM9)MxvlMM>hjJp&(ZDCEd94%ig-M3(iOZ--fu;IDVA{kO87x7e;a zb9sii&(%bc=Vu#Abz3}^Xl^(p(Ha5POR11mxovgpkuI{t?Y{V_QKJ(k2r2EVchScx z-3w7q9b_%dP6S1rfG!j6y61;Ek+UP;Pgv@5L3zgFvE@3w%3NJ+yi9qHue-vEk=@d4 zc&RHquj?E%9ZS#!n>L$g%Y+urJ05dOJCp8t7g9lu$7=a_K$PEj&Rbe~a};W34;I4G z7b#E{st*L9a-o_v35ic4r-qzXg({u>W$P4?=<%+=7cbQcd80eqPcC{t zS?KMHPqG9TIcMvRY1X`^dy^pXgme4PpSmd-22S#-B1Hb9=ck4N5C`n9Kk zV;=Q0!*_x!@Oww!=mSuQ_-ODJckou+rn`1 z)#@`ftPV3i1r2rqyK29^dQ%>?50mkZyoutbmbn@Fw9Knga;SeU&K>**ZqSyuI*%!sgqE%H!NKT zE>r=Dw=A8_&fdO#tV!aHpo5Zc^fjA!%ggxLGy!{Yk?_TdH$X;V)h-*+e%Y~8+}9;c z2_~)`r1rMM$>^C)p>-Bap{d2oGlkEoX8~5-U6Gn}H}2Bi>ekwUWCcPE7$T7+>uzH1 z+U8BsLK{R2BaD6)*NzLt@RFZJecIs~r%N7&6abL2J)vHuHtQKkKbBrU`cwR}lEr4G z57-p&YT*311#yR$9hMm*8=ZPt!iIR14#E%@05j|KDn%Mwa)l+<^O^Y=$xHYXxaYW| z->$NZ_2%_Y_#YL-n;Dq_)~)gJu&$ONQK`@#D~)sy&D3Dr!+1O?>~%-xsJIk-?tR7e zGh0gC(jA+Dc@RQsd~A@OexF`Je}$#I@%)LTZS@ARE46}1Tm+uhL83?xrP}ySrxnO zC$V6-+Y3#UC7Mp`+z@-&k)wT7jYWR4?(BMXYAjJr1j{q$OwCEpc?HDM@4O!p9K59n z0c(g9QqY!8dDrA#rKLMkVC;G!QMNmlA_3;kp=I>MPmuwP2!YE^+Q!BOR&@?7xlqj` zRi#PYA!1_jGen72s$t|zG*PMd=YXRc9zFLk@&nM9d}L>lhv27eDa6ylYH}F1rCjc! z&?~oYPdP4Of5>6+uOT#pZW0pZW?&oYN!z9uU~h9&htYW40%H!7eqSgk@BH_3 ze`L*TXU?xkjCP=t<+S9F^n$ssSTsL}+J@IHQq+aIf};7HMut*LGQ7ms@`@fGD`!c= z!?RRRw%!YhM1u(xqXn%)(ql(7$2i0to@#n~3zD2>u`GuGS9AU7`{p~}>-n$Pp&4#U z78*5TA?`&=nDoy2Q?*h0MG`aP9Yb=TIHjm^O6S0tTz=Y9%5T2|%Iwk2+^KOCwO7yi;IDRFGvVWlJygb`N=7)QXq;e&{ftV4am166y__`&K$lhgF!01|4g5EG#>J zBZ?-#@lBeXEa;g|Eh3OeU*5`&l~xPiQ~lxhnkZm1p5QDwy#7(uGHJQ^6(Wd zak^9gsKbx(o_Tsm@?Pu7jBC?}tT}Fxa3Y$={rE~ZG?4jcJ9<;rlC(6bjp)m!BY%H2 z;|s9H$&LFsZameazAIUW!o_uJT>!PUAqUt-DvWQwGr}?Cli`!|j#)gSxU==Sw+bMk zKL?=jB0cZv%DxxJ%M=BBS4c$!fX;EAZ#jd@_lK7Yi(RS`a*3e&`LRh7c&)^jMirxjyPn;UgL!yZ#GJJFXR9*myQP9TjCOGh?sERKUusonW5>5t+$yZ3hQ zVhY7R6+1gy^>&na-`)e}te5ItIvp*XN4cNn85_YLJ2YEzcL`7%FNrUcsg9l^U7U)I19#Dw$PEuvn=EFx6f(SA(2 zxXhgV@DAByZup6|c){6bW7He8U)(@1d18t$?)Eo6wcqjxLL!4he8OIWe){H^_%8py zeJ98Bg}G&0P21{ggq%ORo)MrWTMOC_`60V7Ib(%Cwx-QcPzW!iXmuhXv!i_4_&_DL z{L>yUpF=eOg9U^si1h#ML#q;C5b;c+tk6=@_7cWzA*IQh(z7-rQ_PdR(UdXnxlNLjm0(?q&XDwE8L zzsR)sI;k_1c@_40Tx3g!w*bKW)*|fu+@D{5R<(cf*ba8`Oar-A!z*58ajVvmo2AcX zsbP+`1BlhQ>_wFTGkR@aBEozxFRs36wH#+usaV)Q`Rxjd*v7r3-FAtx`*ezI)r}S zA1*vrok}8QgX5S!<^7X+a=s0rZ^K0kQxG0Jr{FWX7QTNjUB?=sCjOh z`-?5-{XRDx;L=`R5i`JHFEmwadj*3FN=cIo9F@vBQYu!-Bql6<8pOOXvOGH0pUW3C z(`x1#6Gd8hnL6@%gamIXbppr-RWz3mUkW9e*~sAZbij2bBKdRtU213d!X<#u53D@_ z662$!ty?%bxAm=uiva5U>EZ8L31jHIY*MCkL`+%sdtg_0tM_G0Fdnt6m#8D;r9w6>4l$0bvTQS`%h;O zkPD{Io|V%)X_553DfM0>y(nhU(dz-FSmMx=Q(q*BBRAI%)q0a88{?1Kz;JO#kZj+rSTxX6hPbfa(oLrlcK)S&WJ^NG zz&b-;vA5}?M~dKeL@C(aHvKTx6_syM@4WB*Q$IpdnNV)E4BwbZiy~&r{y=o+i@lQBk}t1 zPA73c@7J2wWpq*=W{_I5UlGbivQRx91ubLJ17EAAdH3PZWcb&P^H|hJP1}>S<3hA+ zTC@E#WFrckKCy|QCRV(gVh-0ga4TGfu$yM5uOqt;m)0@j*xTk8^*!;Hto>d;7<=_Aj9S@c)|vn;k7xwGCmhQP!j4 zqbqxM$fbFU_8*OTvYD6!>&)~JQkGC<*>Smu!oJ6INxezw2Fh9)*gPtdA?gkNMJz!% zKk}XhS7sOY#A5DRUH@}dx;Jaus7pdI{CczLUNUp7he%x%btQ1j2I!rW)`jaX&+a2x zL-m1J?{&U+0bt60fB*0O_`?uX4DZapk~FAJP&cH1drr~#4e3SsFRy}A`;8ZM8`ob9 z{O$n&9cD8b=kUA)U?rL=0$7P)Fd-jvxNU52MVbiyF_>`W5f*(8}RW@9q}OSoUH=~bM{C1g|<->n7nTz`h`*6Y+i0E0Dx1 zHgf3h7a+Nr^)E}(eR@R@Nab3gLr~q~h9O>-y*4gzXTfP+JI*h-kR$t7vWLkh&Z-+F z&ILPmn^Z)0T0Z7V+`x-+r8vDVL|A`;OUq^raYfj*V0&sBt-n2DZtBbI(#~vy1~87^ z@eLecjJzFjd9=_qWlB5Mh|?>8eLf;VTAiB%J3u^;yo}rmou{RnRll;9z?&S=t6Dpm zxb)@Uf&cyYfBWITc|0J+IIv^iM&-nTEwdVS3gSE*i*tZ+fDmzoEJTaE&h(jZQ@I@rkO;mAq+j$M+~S~-yE zrdrn>N|G7Y)oHcfE(VCM-u{E0LIM~CAke_E^9r=U%Jdj$gp2DAfN;RJ_b>4K*!eHp zKYYD)%#2BV*=B@SdtStQFu6ve8Pj;9P(ChDduB7jg3UG%|1N;~B!38d1VA%C5q{~3 zrRf&{^E=;5Iv8kaC%YmZo50@Q>q^Nlp3l>L-SYD3af<()W@%Lkg{yI4S_A+d<79rM zf7?*8&)eY(lbf5?P|eVE=i(C8SrxB1e>Evh&B8#;xS|#2zDQGjx@#Xo#|g*Z<-T6Q zV*~kG)j+DR?GhOMD`L|$ac z$@6E4vB7V_kGsNFl}T>7rRA9k#h>V-6@~WDspZ+A`xs|`gHHWhiVuy*cQUjqp5GHe z|JZiAd#ZRPMKR-q*!!{mI*qBGZ0SogYsT#lEPlbUK=i>CYzC=0W6sSMu^|Ht6~F8@ z2uj#S443OZ#q^{et_6XDs2G`AyUo)EIK4{9RdBYSHB;5EnbT4&?pSa~`Z%9awS1G< z2W6igwZgRpF9~tpYvHR30el@BF#q7|W}`wySGo*fGehx{HUcNz!ZMQ%DJ)G)LVaF% zJau|ABBA!NLo`P?uv2Tre{o|tF!Ak8cOLHo9yHc?+n!Uq9&ll{UuL-UX*fwhC!3?f zy0fafQff)HYmhJ{K7Sge(=v{eu`~-djx>+zheQ?gmswahUR*u?0xovJxWKGnTtV|= zu=Bmhp3*|Vyhvn0H!e4QjcM+aO@d7 zfbY@FZs!m$aKSx}KU2bHw{F?{MouHMyc59CFNRtd!I5ZIBp)b(`-h;CCkYZ;gA$LE|f-qK}VN6_}t8S=N@$?c8~ItQsLSZHur zost2-9fr)0*&zxi53;YnPUwHI4thc&>!ddpoDTRDWvdiY ztXl%EbFr4Eg@u7zas^=tFKdFldZRfjN)#&f1H{It19G1LQx&31nhRm@5?Jrba|j%f zZL&h@t}3&O-5W1SMWKzO&2r-m!15o945Uyl>v>tjxZRY?{crLXJc;!Zxp$(e_oQk@ zPoBU++>mIOJd5=@fh`6{-7$0h_U9%_vgyr2vqj}B4tCV)?1v$a2qJ5TUwWVuyRNO* z5Y`#++uNOgQYP0uzwK<^cAse9O-WvDiXV{H&}A-4_V0%D6r7rKEO@7^2CrLj&HKb; zS^msah3_a}aN`rIA{0RrA_F*Gi2Cz1zs^TFaBf5-)4q)Ad{o)Lb<+BV`6|c2AIJD# zWJp%co^ptCZ|vii2RQ5@Em-R$fWrQ)m(@5&xCVCII9gUQi}IjVsj zF|!uMrG8>+JHMh@`b-fqq{PQ%v(XEnRtMsSwPy?tbg2pa(0o+3<;5zfV+P;iU~aDG zToMMY9mJ{WVlKi6?y0wz$Jq@K+ia|QRv&tw7`$(L&o@WRND8yk+lgx!)3jCS63(uQ z)?V{C^NBaQM#@joOQ>O7JMI7tx)ODfybG{8l_osfPEa;)j=$fr>{#Lrb^HTC=oy$DdHLb_1 z=DuHSpt;;dwFR2&Qd+nx$yX^$h|(w(gAVuc|^SI;Ty^7{sDs zPIckV_Uw95v;1^l$!sSlOQKskpK*d|B@;5GY!OQAV z=KdMf{)BT?E zsk}|vBfd5KRAEe{?+WTll^+AXp&n_ zNjZT~xmr%P*mYC8;o<2ML@N0)z5QO4tz$~0pP2=q7lvtDUU>+VHp!pdiX${by2^{p zE!N%MoIdBPuKE^!C0G{ru0||B60CNt1u!dY^SgW%<_EP;3ksq958{&*HRcS4S%>!6 z)~`kLx48MW6_!-t9;cI%n8*2(+{9spA=ck{C{g=7lz?#<74~1{$>x$nE0q^K++N@K z$o+N@`SrDPPww-ck$ISGWE7SfURxYnIJPrO+MeinrY2I3o8$)Lk;@4TTfsv3HP^7& zj$*AO-OH>;fe?<%fXNw{J-r#fy)_;D&*cNsa;JR;u}%9hZHab-#ZU7hIY!b~dFsmP zjyNO8i=jFG-2qNnh)5^6!2`IHtoypmdTME?L~)ZypF-g4N9l)507M7%&!D2uAOQGh z!y!kghqxT>sFRfML=>#tiCOJ~Sy95Qx<*~=0M0m6E#val72Q%LZRpkXlwS>*bSBeX zpt5y1#8Dcip8o2QvSNRGgS@b`t;cot+T`**lG*+*B(n(ZC1=!f9A8Ak{IwWfPpm?- zsy;;f(On9A3o_Q^8vz|d4nRa4_!;E<8I<)Iv{)zJQkXN_aMhC+KIo}M19ur&Dsc))8pBf8~X976hm0cV&4@cAq zO;Thb!P}2hwx#^+c75a;KY|=Rdn#nkjF2|F-^{N)uC>iy1wS7u8h?mY3~+zNq5Vi) zO#W@1ILH-5Yg(AR(5f!}=u_(L$>3^OV+coRe++9Bu5vNz^a%AG625%`jGT~q35=pW zv=j3^w3BBME5J!AFXJR$nd~^L;Nac`dV!hKoMFRA$|ihkSD4}g&^a7 zQ=%}8v{0I#J~caWbvY~-FW`J#=MjrLok`6)x;*AF_8UfYC*5Biq>%3$8u;bdVL16} z-=e-g*nji=chO+~cVfC;y=~XW3V2+M-E&}8Pd+{^;IhA(B+Yj~LgEb=^eM%7G za7sr<@wztz*GXZMm??0zrF_6KioA3xqt^A@E5x5eT^4oFt`xZW6qT@M z;x!x)sqj^*gGatAlJA7J$nT~}HwX2IVrPk+TqWBMw}xUeh^FcLJI3F%>Oc2`W_Um2 zC}~vzrG**!45D>jiU@-H07*A#&U^eAbXQb0t>lmF!y9VSeM?Vv+^?D8GYIZ3TkWNB zy_YDF+x;#8F~?D`Gi$AQ{rFMr&C(FT55uCes*`hAzS7a&{Pk{cr`qW3gSaEzs#%=b zawZEsPkL?g3(uUZhE3^baFfGDaz)JACI?$P?azO$;$O;8+L!L1-5%rp9bw@Sp}!$B zJRX^mJUWDT<F)d=o*|pe`np{Nta$IHr?l#`-63dmJOdL+%1=wVSG;CeNd>IAW)jUzA#Y|r2vyT zC~v=+BiYf-m2Q~wvYi|jg^s!%3#(t7gyLF1UionNLcT}bboU@P-BHR)8KozRSA@?| zCQfH$=bQZhx!2@>q7#5q{~{-$l@l9y3VI=-HR0sevq$t~*ZX8wlftPYzP8pIxr90? zinI2u4gvj))w1|ecD#@lI`mAnm5OpUmw3X)!Vw1HR<)pSask1Qg7zOygoor9dWPWIEhj5n_;fn-sJ?GmXuWgC4>*UZO%X*d*|9R5xd z=><$8)@JC^0NbYc>x)UcTjb}$5X$R*3$3l<5~`k*G# zvgJ#OkDlGAS*6{7T?S?eylRr}Udk#I$Pzh|^J}j&myRRmIXjz0d#3 z3F{g9%rGDpeCWgY?W12}6AwP;{Rth9*tfs}h0;_FnQRsr*>OsS+Hy&p%&ME^V_wf1 z8zr4hk~mRnVHD);#+5_!Vs@@Abd6`(3+;(9BObXsQ8^?u^y?Iu0_IOLgHOz!QqqPo zd$8tYJeoQ2aJ1cgs3!yt0lPd3noCX~|FS~!VXvm&4+j38UD4@vz*vL_@tN(jkUj|` zg{i7*jSQ~pq5q`4H0)hwEwr!a&&3Za3Tb;F$ZG|gW(RZ_}=#@!3>!Q4W9>r4m!X%41vs>$Eg^$qZ8Rhhf|2R_4pOw6F-F4OPk)@JbXW= zMzR-UIaswiAHe^_M|W~_`ey7d5QSbf#Ir-!%Hdk@8RV7P`N&-SPx|nX@(em~`@o}m z_1uQXspV*3p>c1pYi)G(eExDpFI-@;dR$sQpGV@qa56m{^yTt^pIP(c=BH z+lXGaFaqD;X>-N~hdUqjBFByO(|1v4cRiRJ)!LRAN~^9KG^ABTi(bJ9)~GlIsFNTGus?j2pq z$bycImy*ibZ^yeUVxn9+PDIa%C_^AB!yE$dD1NI2t+8#N{JM_fCTj{bl<;0oBFL%v zZWp4i@@e_pCZ>qiVzXe6wY8L9{WW}mM-um>`@-NmA>Qqn@=V3q^NlGDP{X@j5+4sm%^{G%|tsf#_W3k!p- zBJSkXN$DS-*vHc9j*w%#(@3Z$WpTJzzfZMW8PS2$Y(N{~YOH$|1iOf@F7r4_q{SZ$t|aQqCa5P4xWO3t`?d+2p? z=>2rPQJpfDnNsGQ_EKK=Quz??sT|Fi(~C8W|_->VGa0tbWjPsKV~qX!wW8ch_zz6adrs7LEOZLh2PFcxXqG{30N- zdh`0VQ&|K9i8gQI70*=bl!bujtd~U)?G4L`5+rUsW>UpE^$?n&r$A~)AY{1~ZojF^ zHhI@B=>MTYz}hH2ULcXXYQtdj*{QFMmn!HHN1Iti^KN&*3C)2EeB}g3*M}i(`~OlQF00yXpFy*xM1VN(x`0nAGRB_SE8X}$+gb`Hc6-;R zsHyN?>{O%nt2?Rvzf9Me{bcDl2mcI;O_Dx?x$OYCkv<1COrJwF>qbO)c>-((&AKiM zT>O>~_3ocx&xQz{t*|>7c;QwVPeenku5U_lN7_cyPW^1dxF7iPmuK`CBacIz{e~eE z-uPS}T(S>lSpoLQWy!y7Rs1Spu=Aw+J&c!qTQS+G-}9(@cG9t_8&fTwkui3q_3YO! zdPPTJ5%djx-EL^-dyYZ#`EAttRpc}K(n`}bT-od}=bB_xPC4H&5uhL{(n!&-cTvc! zWU+^+qUqBb;=uE;49#}7`5tC~-^W!F)z_NKx*mG!VzZiTs z%Kjt4pM+yqB-GUiSs9PD0W)qy^QBefD`{?l7oE{pw9{~NK3LB=Aq7g-7_~M@O~sTI zBvO7TOM2%5K0_b>p!dEPmP1yszNa$?w&Y|<@sL)()43t8ZlHaHL#>W=Hj`SF#PXJ2 zQ1UZ~bj^A>1l#?pn8mjDUgbJritQLAG)Dp6iB`6Hj>^gD{)%go`+Ny zm0|C`{e0*0ip8-DK#4A2VhiX%Uw$Cno`MFx_Us{`lR=8#{&#T&0Raefbf;)K(@7p; z9swB(k_8?-+S>AlmfoJ4Gq2o7QP#m0CdX9HHyq$dmwOUDS+1eWr~s`;6~gOA2 zo*%BXzgKmM{(dNvlZIeMVGr0m#Ls-b!oi#N>0SrzjkaQ77wG+=@l$)_@c$F%%HK83gz*Ex!0k32v-FOMajWnjbzO$+FB2k?&dUO#F zc&HV@-2}dtAi}wIzt150bgBC3qmRw!%jB~>IC}2ojq@@Ltam?m!i<_F)1%u>OPrnS z>IB3l0jNuVdNr81iC+m#RNH1PU8ZYk-Ye9*;;d1}^7bKW0JI`eh2u%Lr(1?wtcR6w zNYJdM__$XV*gWZoKy79eY)3H))+k36mJFIKl1CJ-6BAxs7b&O37HER;(}r1(-Ee$- zP$hIc#yuW^t?J8uC`DmQ94w3!)qXDkE|EPt>C$Ixp*kn^9%Rzp9Nr>H(8N6nbH{5` z`Es_IP{!Z6sL6%AUZR2%^%F!d-nT)c`_M#>0PRPo--9l!CScvd^n!3FWuv&vUA>|Z zS)&5@lrvybvMyPzsOlsd$-XOrI zjiiLPZjAb4wN7MYzmF}G$JVquV#Mu(K-sdhk6Q{3XJXA_e!eIfdPZ;#Z0+e92^hPT zkt*cRAp6Cpg=hGONUYq*6l5`cb&V&y2o-bZKyqi#>G$K;Fb=rK1tuls<&`3y4Dun6 zwOJE&`qspm485BQ5HZ^m2}*jY9iuQ=A$hUs{;8&j6*s2H`BaLU@P$CVS+eP0d}P}XroJvWe08SZ8*;RsHWc{EtZ zLz#8BKd!eoVYH@esI+7`Ny+P`89~CsYXLT=kN;uK%`+;CATE;J9P*q7I; zScd?YouR!@y+{DOuCOiFc&ki+U@}889zTip5E^Q~!Jk<)mURnTC*r2@0hB7s+F0Ou z<5Q6&459x((aX&w%Gt-o!g+9LQ@5-Kke&ti>+*0g#JvN$MwBI%t4=IHw%LxUl!{y2 zVYq3&*K`6~A>c)_ckAm+(ZlmVw9~d3Y8kuN_(&KYX{ntY!8Cn&hGb&586+!g6ZN`?WWMCpcmH)l zY!SVT#Y+aQ)7>|7-#WODi`pPVllhCrW+Ms4@cuw>>x-Era~o-ndW8=uy|Ly?#snvE z3&5W=P=21dGp9Q>yIEEY614z77g-7%G}cBNkNHhmzJSlv9*0)3qQjJ2qQ_;u*vYy< zUc^U9V7EI$@%=Y_iEOxep!LDgvUX`>1w%SM8D};NwBInKTcBB?s zXSz}@>7*>cDw#ro+ux@h0<_=9Ok zqfkjOB%bT$+E`KIl2@tjt7Q!P0Zh&9Dw7UU6h70Y|9(0|H^9Fw_Tdy!jWZ|t!Mb~N zO{+|d7uFVRoS8aaPJlto34|icE~~}Sb|)W=Eq>mWVrv%bt?Ry0+%KlciY&Lv%*roc zxJZhl9hAJ?5IM+-SedKF-Nyzci>CdiNrvKlI_t}|%s-7Yd6 z2%|%RJ6Y;0eA#mLfa~kvf*KOLusj^NR8Hh?r=b%X#F&?-doyox9 zX}tBZw>o7I!aPoCat1@74MD>meANpBe#5J_85xuHGM!@1@7i5aOAzxdLYY$lXHM+o zfW-g{0fLfNJn# zP9YmZddvDw%!3S+asW*v+hVU^({8io?c8M63Z94bnh~4*;m}z4Ox@MqYQb4OoQ{8Q z$0+BVcF=H3+|cH2KmpiFcgJMEhfOgLhd>Os_?P-)`})wxVaIt(ztUs1c~J9N^_}dDyKCDb- zy?}3%uId2KiDaZPpdn`UN1{t!kAn=%FgIg~QtlR#df}rd4T1Y_b=5ekSX_R>Ib`z9 zQrztZ6B}2hpTfnQ4Gau`iJ{q`6kv)G=L)a^u66BB&GNmVw@U6eN`O)~7DDdZB%`eR z1*$OMr90g*S+7#3NZtiRd1&=q(`?$Dv&hnLz=PV9k~_@J@!hF84QVC@)f%f1Y22|{ zXHUdc*n*{=p9wP(3v^v>V^@H80g&%6yAbv~QhiULXfy zeNFS@M<|4}Vq8s6{o3$=7E%8;x&ju2)0ZW`*Gv$UM_nk^(9|z4>hV7|U1@%vVJx+o zRpgSe5sTswGuIho6VI}E3~Ns1T=hpgcQs7Yq8P7!7^G(@GESf1*%z!GGW_l}9CUYv z2zD?ti#9SV++CfW<4`tRTkDg~I6pD7q&o%a5H>LgCH0!Op^Bn6)VPXWM1MNu6+Kb5 zbtM0Vjn&Ue(55tJ?T3rGgAI4r;&06?apc3jpLao`EaYF6iz}C{yY8Kz4`oY&tL#QY zOdtGE&Bs^H7cf}?caOweK3SVgw5MDM22a=U3#=ac(io7aeBA zK_MunzMfqNxT)qnW*WfbO6GJ%##t3(ozrCrnF%{<5ib)!r#O{bH~Qk6rggIH61z^3o2^``4y+M6`j2XVpuyKKdnF& z(>DxUZ1e0xmga`_=tIrs9`PhrW}%x(2P8hkcorYN(N)6#^N zK3Rqpnw;JqEi2Nr;Odp{d%EmM)`Gu8G*)nLh5FBY1S9MSpFvH7enx7=Y$X>_eAb^q z$E$p#H-`=_XqcfF&ddzD{uYq*7fYsncxeiFw#sd(MQmk+tUi~13mgK~kP$sY4LJLr zhORXZ+)!ZUZyXA%7w5W8UR$u>bX>IDvB|<1FHIz^jda-~?)zGAgln`3z5C}~` zK&p~Z1ECk`y%+zu&)IvQv(LHbe&7F{bMLt0zhMl19alAmw+|A~)B@!us^oErXJ%A$7iSLuqMk&#jV z;}PhZw5mTMBj^2m!*Q_|JtC?JiD@hoClXxXPR78Un)LQ(W{bQg)Yl!!G!75HFWBXL zw>+{A*Z?lwoE*TQ+71$M%wF>Z*Kpsihh+`mrpxh!KG(`R)8(+;M~U2S`4z1cKZ&2Q zWChO{>wwlwN4=ags#{tvI?x%rvgjQp9@^z&5p~m#EQRz!$9KV^VBpYB&PT<|yyEc7 zO3%XDTn`dJ**yv>!ljLwTHN2O zzS0$aIM+xVCiI@*vD|=7U z2Qo|)T)@{2uD1^oY;MN9*y^!`gkdVFB6k*1M#H5fh}L{Vw=;TC z`)a4F4J1^=QFFiq;Nlp;OVbsP4WRk<&%5(Ksy*F!{|-(np5VhVe}iYtE!m@MS|>CPhV=q~|SAk_Z*gK-7L6d*pqEnt(U4R533wed)t5rhjn94h^sO(s&hgnEpkp#C0j>m^-=L6OtHS#v>tX_7AKAtuqFQuIC7x?L)8eC89fEE8t(=dK`xfp%E56GRPk^ZF|!l0Ts6v_X&OPEf1;u##xue zHZB1FMj`269_eqt+{w&8U4L|;cwy|BiXD9lv4?EVs8N7YY8El5igYNxW7H? z|FnFpD>Q%6o)3dirxJwgZmrd$0Q(T1{N=arVsIAK>Af*P&$ zSvYR$U-CtNi!#ASFAYI_dkE?|SB%62yFC4nkVjtckQ~16M8X|Jm!zSnrGSK=9n4jH zOHI9X-hJd7_kr^A07hEN_Ra99{=M@3V#K_rY^05iZr1Bi3yeJWYiPnP#D6Vl-v@k=*h5oU zDr5|$iAyF)MUswwx>TP2g%gGGQ>BSY*VG)8*xpl_9rfJ=g;o+hkQ=)yBa!pm@<|iY z!sFLu4J&~t<1aD8MZ#7Du_AtoPRtLQxtdGsCK}VUq|!r8ZLxtN3~N)-!&kw!i!1&K z{z+-93tnGRo)L@kNe0FW@1>clMVAMh zy=Zp67kZ6?hw|4OE70{F_bgbj!gd=b&h>omju!E`{%E9cr?f)Fw*`6tIycZLr0;x6)PK;x~f%1;b|h7H&yUg-U87%k*yv zQ~x`C`oG*~)mr&q35{{z++_bJ;eT0N!s>@tr>%aVVN18BToBIMM7u~JH_grzY`nCWdoqaH2imhal1(N#;*H?EtvPh8%feEm+2_|=@46`Z zzRHIoe8*s@pJhem!#9Zc!~sNXYm$ifAP(P1S)FNWa{BAn)TXcC+}u67BPQ{t{P5R>wNh$Gz+#UlZt}L>4-_dshy29&K&m8J5L= z!bwBzxalM9INl=ND?>7uD!u2;QZ}zKW=p+_fL=|-TM|Q&9%wFystgZsputOJ0lNILtLCP~K@ie)WX7ges*FEJK?)=Tz*@{En z^%$GuMT`EqwVGeQ|DZ`k)%KcY%xw5~8h)f}D?KVYlvYkVa;7@%4gC_5yMEBfS_J@r zeS2*iLq`$2&31>DdkBBEgQ@=`j=WWL2%=iUwIyW0D!s}4imVmU1?#R!FgEu?E$wMM zK4$p4_2hpu^Yw2hE&s5d%pZm>zQr8-oB1YwGqg(wje*H;KTP;;-9SI;fovY4<*LHL75J45a|Axp}TDhc7aPeL1ioDovo^Ord!$oMS^Y^g5y?k$_ znFv>J&42c!J9g}oeJy|e-0ZmVl(eiCnU$1cf96wQcm;G-vBWjHR&<0zO!%Dd4M&49(o%SVw!M6QQ^%Atd-8-`QXa3}fzeo2%$33nHhWiaNc;s`^d$M|O+v^~ud z*?|K~`jVJ?>-ZsZ;B?Mf^L!b#Mw?n(!S$bS{eS*2cizs}3}>{_#pz!3x}b0Tq=pie zFa#^ZU*Hn^R2Xb6*~58@#DOj~Dj6#1O6+DF3N?ztX%lo@CQ0Mv8*!qbQu4s?M~uoC zllxR<^?(?Her*Uk)r$!2vHcH`1_b$_gpk(?j@27bGcVXp7MAsFbLVvo0=+oj275;o zwn}@q0L5O6H9vc*q~iZkDnijLQ>1Hwjyp`+_i&yJR=`6|cAq^C_>&?2+mI8wW_!79 z>8|(jBIp+*5SL-i{XR~k4&>WM4zEo0vB+01v?KLqtoNI12gG(zoGn?KpG#40vE>S4 zEq50_3RfvnN(Fc@*jMR$Js8%_O3x6^yNhOgJ)mxm1$|zjmSOP~RRgHssmM7&wuMs+ zHyV~Y6D@h7`Qi^T$I28FyGW?d4>fgNmL~jnTznF}++iy#DEuij{hJf4=H%?MOd`Bv z7$aRF-M5=RV?3hqAaUrz?XzE{u<0(w8ZU6w{2jMtVW}N?(?{oq=31rjY3mQb5mru) z&qqRn{c%lz8Jn~HayQxHJ?sIIlaY~lp0j3Y$?a-xO|_(WEYB=4dnn2{-Z{!cAaa~& zOddUvPJ1>Rq@{y#dY1%W%2MnfkQyxAnoPr>fA6`Q5^r!X~tU#d=L;#a}UHrbjwTnyrfk+S+9 zAQY{X?FICAh$S>Ale%;G9d?94ot;C__N7p&ax}L@-|JS9vBB*KKb{8_`SIW%G`Yr) zk4?50UuV&O>6!((5j;J)P`)$5K~acLT4T#jYTn}5q)OIej%q6gKPPtX1%9FOFyPT> zvxVohVe)7tzeu+8ya~^#WTpdp0Sh*+=b)_3-bCVhsZoo@I>aS5&c&4Vsq(8qtx%jn zXZpmL6Y4P5zu+7bkNqk6bMMFW#y~Cn;u6_K4uoP-l3P!F=&_ zg)?*o{ZN>4rKEjZhOAbK;qXrAF-NjsY=ldjJ@Z?jqsFQw3ll+uwe9&^wk*@g(xfYq z<3rrm!axQW`4(l~lhm8y+tqk88zVctr_#c`RBl}<^(O5`KZa=pg)f)xm97Iop>-?1 zpfF6SX$C9}uQ2k)PgL8jqDEm&W%IQW9tU~6hk-w6G9{``CACtF(6)4aaA~_rAl>qN z_8U(msq@yR0ShEeaf-DPNH0lE=n-6xE1(IP^YzBMm4S3~^qVvt`O<`)DoxIAimL_X zb7^s_yTCz6YpeX0wbf%}>J#n12(bSB*>~)g1E>C6Xp=`U^NsxybDhkuF5o$4`S96o zF|QHj+dO}`Nd&sKBCf9wvT55bE$P|$`)x1ii!QtC=Ob!(F8wZ>$$gWGG8nX>IW2wb z_fz~E3#04P%$T50R8h$gyJ@Hpf+Ckj?ptt^H5+crX_Hbl>c5QqDV=6hQ9K>&^c&E!Y#JR31~- z>Ho`0Ek8q1F!qmWpv-mWRp73WZ`6s<5C?gN^I8F)Gb(hWYCKlWl#fI1pc>pOWv;qdwYgr+Y^0eUJMBO&ysls#Y#@OT6&ALVBeDRO*z*xqb68}C!g+6 zacE?Bcf#DfUq{+PGXLjE*Q`6}!25!HY4s8sA(}2)eZ>Y7U!U=|N+L+}mU=@syB<1) zJRgZYQ<7UfbEj}ND+PsvBxv)VQ)1X_aju;^_HIjY4m=R9si1)IYqUOhWI<5Sv76Ss zdXBJ#EW3v;k?spA)Sp{XJ@V-Pun=jrJ00FleVpkHnMj|k6rc=!?32l337J*!`%?9y z*s{A&*g0gu)t?f}h989xR@>@Nwz0&6W)SRwCTjWcCn=5f-%L$Z6XqxJj_aE&elL1T zR=p*{uR0!a@ejx?YaRHwsLLZOuNL*^9nZj$>5cV%&|Fkn$5*vwo-^06jB;^^ z>BaCoY|sMW$O%wC){Y%;Uw_vQ0?>lpY%fD@k4ph2Z4ZO=hMOW9oJ}>A^B!w%b!9tq z*r}5*kTSQg?mfkMg2wPC_Bbym@D2o-Bv~r;=dW!fy)dxMv9jluB-ms)4FGchq zCHCBS&Qblhf%@QYkIc}ox=%n27hY&vvFlMdOu|z&8<5ollI1E`q;)UaPf3-we>bFK z<$Z98$Q@gjk)RDD^Jn&S|=>G+Fw1^41+&A5~hhAf5}9cd~l$RSH#UPFQ~A z6OQ;5^%bRoeq3auRV}(#1cpGi3u?J3!v(DVz`zq@EC^_g!t2gD z!DUABhV=K)D@76^14rG3@8PPCzRR?L_+|i-yZIjKD8-M(zN|*YIYNq?^pvKN>f!38 z*hefjryGslVKnhA&VHALMl7s6=AdaoAybDIeSzCIsOTC7x)qj~fG;u0UEOd{qw4w^ zrpc_dzt?A7q0u^i&Exf5&n*?19BGd_3@4_kH&}aK;1G{`He9eL@}Yoj2=~>x(u(49 zd3s{hZEk-!-`it&h%sbNN}nSl&$WGBcY@idn~4L4e;B0;buja+QP6xlkl|NKWe9<6 zqhZl<+i7(p&Gn`N0a?E9b`JDPnK__K%!(Ve)xzr1>q0q5lg$``x8FGUXLn?}u=?r( zGr&c>rJq>r8>1$u>31Kz(I}JEkR|7mrfb-4jvJO()hoQaP-zCYzZ*?gIRrhBEw2EE zeQ#}hr9bj~)OYckHCy`&YC6$Nj1HE9VVdg4vf!*62iDcsioX{@r}a(C006A8N)sg1 zIC46{Sa&7kVH~gzWRTGA%_EY4Hr)lwM?>a~FWEJfZ^mw&X1{ymcL=8KcY5=-?~jpE zz0-vt`8fUfXi9!1l5J!SdZG&47BBVy(IJVFpsiTE59H$VQ1@_;@45^eaW;N?$~mop`6~kKq2C$4(3YUtn>7uZ*uG&L2#15(U`rlWZ<)1qi)PH0dPV~P0;ez8!>Z9jQqOPq*D|jexnRD&B=5Ml1=f6nl z)w;PCxJHLxEAr*SkOfBA*~C^eBU!Ma(Y$%V`2d`?M%-s=dqjW# z+k38c$kPv(q8+~IEoE;SDU85Y%|c4|dS(OqdqW$ZNLzw+eZ`|D6LwR_V-q^{VYEA+ zsW@8yPXhrUM~ODfM^rHzDY8W5c3q{^gk)ZFhwweJq+Q04pQP|EQ%!f+8_W+HdXIqa zY<7gQESVWBIR)3tGGINLcNwmIf!a0!JoB-DT_qY?(F;Qs(0HV9+44BJ#i#YqAzHQ5 zPl2@+x(}j{r8U~BO%eKDv$K9P27VHoXFaLzOr>6ww3HnGqUKDf0({FK6ZdO>5zA%II&jlzr}en*)pm)HwGi@ zl4xvZL@4bGbOr639C$ZR+4vU>`~v#=v?9#mX35zxY=hs%)Kp#dFlux^z1*2FFC|=g z&9rPLT935X`$2?DdQt3nmez<7#m}Xqhg_e;9jn23+8Po+G=&sNRDgycjxLKlYJ%Mx zyi2HJ8VM9}t1zn*s9Gy5vdK6M_dJ&p-SZi=JUprJe%-EF(6I;j>#E@I!JgPJI27}W zfZ&0qb}`LAeYIvnkdVU~iBEFI%vHhZB1RBfVWnmrVQ0-6sj-IUr`LYNYcg_YSQ&}2 z2j**q$?RL|Ee*j-1txS*TJ!D^nmemkz2zF7n&%aKbs75+z3Y%oS2dCTSriF@3Qr-)&HX zP$*>~q554RSz`tyLduZ^!?T2CO7a0KMcT=a7;pWl&;O(4)+KJvQpGFLvvN!`hFNzK z&G{`Pscr&%OC>bkO#Xu$bSriOR8$y92 zQiPWBlUj!q)OUih3LTrk-clxWIv>`m;O!F9jBxAl1;^|zlM9n6g%oDhY~SbFpJo8Q z7R|z~;tL7qD!@@scrw$g4jt=1t=M$Pz9JI5n4T^ZDzUpnzugV;GWz>HfTNdb84{aj za1o&gBY+p{4QfyB;si&bqSu1*PAV2TV5iPq4y4;;meqw7Mvk$oxJ-N_j#+<=p9SB| z-AcoX=c0YXt8De)nL_fbzM^?&f&S98l}ZCr=uT&#$(wrZgclUe&87XpgC=Fs{txK7 z?y6|;2ff@{$+W(xXb^HxUXrYf zKfNw5L3vsKt?o&;IWt?MkuZRdfL3#*eYg@Ia`Vf^(zP=^m3%SV+gvbRQ|@`boej=o zg+756@jUc-bNQR-Lr3rZGWlZP=vjpc$!J4j%tq1_yhY&Pe!T8WcpZO>)2+vwKWP48 zS!GFy8?~|WbenB)mk0?Z0Mm1Kc0z%XUFZ}kLHH6wkJ z`-`JgclrpY$!d|Y*{de7_+c=diyu}hyTcCBR1(P+~X0a=P~ z3Fa5WKaoD)kzRl|VPd)atT5vt)dzkO*NT0!Ey@>yf592*W}QeKqC_p$sVhN!g*Qe% z7SEiN8SK2}niK)&>ddiC$QYTH7mEkZQk8!_%EC&|N6=!aAd2p}4w+`>5!3-!V(du- z+^sn0F%`x1zC9ab+heP)o=F;U$WaT4a{M-gLv{|+4o2CV-=FHybxzXo^C|b46BZn| z3{;JlVQzk_(u7$r5228JnRZcXE4#g<6HB%coc=Dmx1aCc-;y z*96E?ksVGVrey57yqk$pe1x9#o+E_sDSz3PwS(p3vppGD)^A6p#NJzni1mc%lk+3Eal;h!9ItleKvi&V}5b>QAh^^9>8Uq}Q`Y@^M%< z9l1G58WbbOow}kW>7JfTN*Jc5KWjht<21{RrxMRyGqH+?8g<41ONR7k+q{-h>MPO7 zoJutrajqIBvqEFr)s?El#;^@X|9jkxyGqIFpEC#~29g zZ#wK&whnunZ9Oxx)J@Aq5wUkfZ};dLP-S$jg+X0aBp-kI+?vg$qU~B_BG;$QJRYC(e8Rwm5YxbLe+^L*?&}lO)GS57-xJ5BRELTetm8W?o-eE z5LCWSz^_)S5ZNAFi^U7qB&aN_AG9v)_i&unvW>6o5!yCOVwn(2%JY|HO(~!NG%mV} z;QZ{+8l9=hi^sWBTq$>PsDx^cQf6jBh|aKgVeob#YOHYp*4#;5nINcSVay7dzUdl8 zRBajf5Q8~J{f?s4^a#u`#E+#D#B(>GI}V1?mu4ZsrIJr08;_ZJbNOG&_My@eiuc|i zKp$9o>593{s$^C;R+jCQ#A-hm>Ntku{C(NuaxG&EiIwbx55Vt$4Q!XFNt8R{Mge^$ zO{t?3#1cEU{v5{kq+Cyv8?wB^Nd(jgxqS- zI@Tb~E>E;oeAUS@;&y?oLAP6w`E$rt`x8Is`9IlN z^_#=^$wj&2i--F6cigWhFGh5mdYKto7-cCnLOn5_0LeZzrUOC0bP~G~wqB?=Jw3MC zM>pb49$-$vNcXOVk@iJloflf_!ls%lT8EcCP!-tTlrH8=&?r`m!1BsQc}(vR0yhS# z_gu0}fX~UOG^(a2vOY*$NUgD&BJG;e+7}-To<>`&Q)yNkr4AI;k@yzsqDvGldOJVK zTa}}Pbx2m@K28SZ>Ul*Y_0<-yrD*=q<2ow z{p8RpegyV3xfjdI4r}i)+h)#I%yG2Un?jf#U!|rDkRLFN>a0P#-ILh(CSC-uP+76B zAb2kw`zYlsdL>h&cVW)=9ms(~8Enz`uwDyrxpx@d+qn^h`2u3Ljn2ucSI^{^1^1zS z($U5;PRE=6d;$MYup%)TI2Lh3lcpnG8=IYfCNxst{t-jmXt=!7^}~~C1#%Mn3(np| z#HK(`;z&ps`a5)N%mFX(*RK8-rLq6GFhH1cF?c}r8#ZLC-WHN0+A#^^&D!aiLqEEw%k%Z(S47`hqst(>5#GXdgK-^ZEM@$~x0~8DYX(iUY005+hWx z56}9Xw5n0%e}*#aCkmK`hI)}(U$gAWZRu*rj)enT^5w3>2HRD>%Z87O2m`$-J-cP8 zxbrHk9M0N>c+L7{PyWUgAGv_SfIlw#*V5lm9Ub@hFH%zTGNNBP&WI#k$(Nv*mk8G< z1UJ+#ZYz?4N6R-`jeY12Z^16VUR2XDFR4LC8+li@>X;A-#V8*cW*!DR?_15f2&QZN zE#<^W^-UA0Q;oEyfJJw+sn+#|IPbUv?kT;m+embikD18#J zeX}@8YDaRrVsI;?il6kX@murggiPxy8WLhc6~6X57S=7>$*4U^!cif%W=(sSR`7Mq zsCMLZL2rCuW43LScXpV{NQ6+vRL7zEGv{AzuIFt`Mp$E9zv!f<_)=yo0P8-N#tF6P1AwJqO-~Z-(`R`|sjqoA#FxI4O!Z~+x35~KZ4Jg)l}FQc?=$)^ z<$L+dqns;OF~m=0w^z>!H*G7u=Xzeyw{O9)Cpzk<@-Xo^3w9_-#Hi+BM z?arkvzACYZ4__uOmQXMjIbOJPYCmY=)SHx$m7C8BY4XZ^OYMsdpn6k0&CnESkFRyC zb;CpUeZgCnZ=7ULvX;>2%xNm-+IBt1R-WcG6|K1$GQLDGbbxcKyzF5=0l2#s>qao= z?-Xm?*H6CB{_)JIevltLZOzsdQ|YOVTn4OA9C65F$|uaJ-e?Jk4M9H}<^|g>Dq6A@ z%H#67bj5A+$o-we?J;M1IS7ghaq-?q$ZzDgI7TnGcGfxkHhlk!bd!G!cDti4-(i%o zn$_?Y-oiIx*oYFjS7Xu#aNYfK^qbh-eNu70&aQ27-APUQV|kc>bThAQd^DfndH1CYQG}q zgpOLF>flLF%ZJ(p-EzB7+hm)&hZBCX&BSO6b@U?Y%xW~ZCyQT9MZI|5<2eP^NfRQ# z%LX;6VVShf6Qh3^<}amZiw)zQwc*KoZv|ie!C4z{sakR6ByZ%a0o8lW@Z|uZgG8lP z@ymbN<(AYya#h6L7BOG_E&fr6X3M%6I)1gpLy8vNAlBs@L+LS;+S`6Q^t8&D(^UB^ z_n7^|BORwyyB;w{3gQAy7P`;#HM4egZd)BB4j<@ifyvRpV4d^Euc~wU_ z6_8Ln6|pbgL}R(=2{(eA+__}l(hl|)X?|a;7VA_+@uQRI+Q4+3-Ck*N@`q|z)MB+w zqw7#ca!url34Wre_FT}nDUVknZBN8HLt(o_3kif7OY~m|qNVG#-`%hOR8i+l9a zp>3Rqr`qft?a2kl0|;&{f8kbHaAg0;4hA}9*8j}mnVG{Y74K}){;~REZqtmcZ?fJ*YB_nojd%cqy81~j)IO^D^%tDaH%OubY7X!Xu3m$c7@Tyt2HzCX z3me^@;YV=RY+alq5H?_!y8Pabe3n1ASxrxweY= zp@@?+-W?v-_inwkf5B{ae^JB(Ug>BWIp-2m#miIb?Wx+tv5rs7E>q9V8h_H@*SR$X zNig;*$BjscX(}c>{OTfye7d70Vk$_if>V^~j&tK9Lb7%r?%)3?Ol9N*CSXHfp!B6- z#$!SQO{Rilv!+aD`<^?53$qa-7rz0=vxX4RCr!{b(h{jwf4Bi)Kw z#t}H3Jr_SyztpmBddh3_Lt4YEiM$+f!qO(cGjcRGIT=!M&!*}tYcYK)I1Fm*;vMPu z+;-U|b;Y^%P35YwBJ6%=A*V{8Q>q?Wbldt;V`(JWRq3@?FP8vI? zD}nTYROc{t9PBW80C|8cV%p}*ja1s@$h}l^R9n;&aqehc%t@R2^h+x%(Ss}a#vMcS zmeK(g^cAO^ec8ho&HHAf#?bdooWlzrW6~OSEAGwKeLq3suH}}1Wf}ddt9ENL`SBxW zcs)*DGJaM&SQuBhKoSerSI}fbv?}&I%J&x9XaTC{A%-|=)!lgPecAdxb&5JU!wYnvC7rI6+7HVDl$P z%a>$5^FEGu{^G-X;~Ku`{%vI16k4AEonIL~hDKYk( zhhulxvzK~T?@6h>3}Uc;!!%6-_a6(fr1bWipBsNuqIeiNr|W@sF-^VFfLjCPyIPg`R`hvti_zF4Kdg^r4I&G~$V#{9W~PfcmB7 z&(&r4P!E%Yr+kE`Y3o`#5j|H4TcEMjRgX6Vrr(0+jbVO%Jdc5iX(FE{`%h{mGChq( zn>FKh%ozJ1iMDkadC8JHX$!}2gU~LJTW40_14UUxTP2mLxv*2zr3eKr)<_F0=QE2p zyMc?Jozv&%1V0Qxm*N2K;ICOjtu!ESIJ;_N_YdXxRKVP_!X z7nh-I!qvX&_*%7B8NkCci#=w%$up4s8-kTC-$qD=?q$Nwd#!er8i`}u%^@o}c?vY8 zPxys~G7r(pA(Bea!`z*It%^qnm+JLm-A^8B9DL_r_5mtee;(Lgcz0l!vR1qJSY3*L z)%?{D8i_;0cm5Uz8X^NK*Y6m^+fbTrs>aD3{EmrH_8Q~O3`f?~g%GDs9jTd`;c_Xy zI;%pAOCB>_z_mb9{E*t^W-CpBq5Rm@Uiio@`X@$3H5V51%Eg&pO`dLLG|P+kf8yDe z_?J22zsmRfm)6w(Cr_97raqOxzOGl03l3DpZzt56XX5JHMqv4!s$0J+)jH&t2Jz|> z`bA_6358j9WCd@g{T%1oyHH6DTe)^ z3D)=1KYY#7&qleO_(c5t2c*>^TTU$l#FS;{n(;=H&U*(l6D^Etw@hoA_OM#`?Za5#nVx8mrz6VMOyE^2zfXdr+x|HRcrKzzu&!)PEwLE5QdZ@Up5#`W31qWnv2Ys1{T%?00sThw{LU+>n+KA$q< ztWVKCqxj&doeJ!f_M!yi!rKFF%v#&Rg9iH?IES6)_wU3kSz%%Il2r2eQu#-*YYaIO zw(J54IR@GtqUOu25x(t-GBY$v|zGk)MoY-^t89bEOokl71Qv-vkv zmW{BWFw&eT7Glz4P+^wkajmVPaU%|=`|b%6s3IpZ#b3A4m?xaiIDSI+$w00IcQPq# zKO6BdHmRJNjr4g?@>5s)+5jnf+nW83!3wG)aP7JPeKwuGK;J1+=(RETVn4p}jS>|p+4Oa8m8|Ltm^Vt?HN=0f*nPa} zq7S0nw8iByY;POz!6C(t1YRV0=ir(NW~EmXx8$t4od_{Fft~K>&J!JVc*|b0!%k{P zknQ(gO8WN$(w#Xo`z`{uy85ixp{(?>+4F%7d!B*D8SC3N3slXXYT#D z(D02Yd{!wHUaEvp0!AgD%ENtXU4-gCDfR1g+BC1*g6Ty}x7?<9B!fmRd#H-WJn1ch zs&+49mEPC95u>w8pHoP?L8?gqu1BEw)C!nDOt_$eUsE=o6~Qj;6gE&7kVm=g(^a-| zyuTvrLU}&dTzC3GGwq)vK`L!``DDNS3$;`79};wqE<=@?0>vE0tMra4Vk;hcbBVV+ z**LTS!yK-UB45pA}nGMv*B1v84TN@ zjJ%P#yema}2hvN7mTedW>Pw>|U54F9mK+h4n~VW}n%DkKoB!xJUgyxwkL~WOY+PeeRysPa<_@~>e*cQ&XoZMR`p2u_3;mGNQC zB}84;?0Kl}Om9otWiqF=o?DD3Bs+I%fKgL2I-g%`^u#~l2;Zl?eAy=>%?}j zH)*#=3Sbj%)^)YJ?2^z`)~u}Kwx7%;m4Gsbyswf?WPI;;GNT$@=_(xL%IU&HqWA@w z`_9T!klR{E!dJWgVxHrSrH68xEWz>MHSD2<++eV~nF7pU9NzWC?)1_ z)|HSejiEr1_Bj%VwFt4FmF@;5BxYT+!Wj;9phhCZsXbsD`a#3@$(AcmFI#mbIxnwJ zAbvb|H;s>`em3l8)T}XAit2@%k;8C2P)Y++^D?y5*0xljeAQsAZgnLo;7qtiu`d8^ zm7Ob8y))d=3MuxLinKlOM32(^Lf{l7%GfvRhq1B?I|{OQASy?mYqfZY+BpdXoMGvA z3mq5fwcef{~^y$cAP3K?}@hNMBw0d)Ogq%BL>|_t1RcfI!`|OduBgG?qgyDFt z`qLRNbAhjgl2Zo32^oWdyUsJnmVDYVSFtlq7bj15g?6*^(CSzsq95T;QKMt|*2qT5 z-Bt0&=Xqj@aZ=GKLS0xw`Z5p*wRKR1IQI#;^Y>`(g*~N}ue!ut{e*VD@>XFAa`NPty?KHPt zlF;9W%2It__)-IJ6=G_`NA(WiXgSHhhO-H1c^ca6YC4&Cpm;ER1)h^Zi1%S9(*#)U zU+x-+N?6ch_(2m35R<6#e=^m!1xQQh$LLZhNe2vZFjZfL z1Z%v~wG97q>e(kdA1Pn`aa+>_?ZOSU$4KiMab%qJ7hA(#+*tf z@o%P|?4`>HB7(xNvWDv-d$UdaKD&F@^qHkt&MNDjyT$BZxrp3j@r&HlQXH|%|5mh- zoP9Etu9mVzI6SAUzL)k};iC~v&f(7%#ncwRHvL&R`SJgoO;hYTP$Qpyz1R=!rwV!o zwBEHP&S^(XHhRz2q}PUwzP9wS>Xn$!?83YiY-r7E!3@8(FYP~L5UJjk7? z3qGtobR(a`enQ=M79rV>`d zrFb<#Rl4QVyqEUE*O|tmn>r<`t>bLLU7vLm;f+d7Q`*kb)s}EWvuEZt_fFrM4<+eZ z_rSl>SFj0mSLV$TH88N)Uj#@0sjdD+CD)A`UVS>>hy(=2x*2H%SWn+k@!Z`6L~Av; zeJ)OH%8<*hbU_F9odp|PD>FzlcljxitL-kss8?r~nRU&6Wv#_ML zuL5m12Be^L4-G=Erq&>8&rN??n~*F^Zq{0_Cq}N+i|w#IFm0360=j$L&lk{*SO`TQ zzHC2c+}EEMmYI|S8c#s24>deRokoKdI|Z`+xO;0bn=w1bU*{K`H&$Gd65WB9A41-( zE50PeD)0HnHiI+03zTRPsw|(m<~U1%h$Rhidopv`>NN@ z75UursABC9ao!hYOEicXIK}^irs-o<{H%Y~StNa`&E5~1^{69v$Gj%p(^=wNK8xq| zg_Sb9L;MS)m87lrujebVhdvq4*F4O5|G>1}%^o{s?FKno$BL1&)#hlQn9FFS?}Lgy zyl8kV*1tX^zN-JQHr{?08`x9fW;%_$y~83bz27-fLLYsqq#5b}pw(U}*JxBwRNLo@ z8=)_@GZs91Ct<#d!?A;}_Aw_99bNdyI@@~v?MK|LZfG#W7MJWD)nB2qGcXEuY-$AtRxh&h%&sM5{R%nrQ>#+_BWkPSM!C94AB zNNPsy$gfP0y4pYe&J=Wyq-XnJ-FUjR;9^6QRF>#KON&637FD_6aP4*1{*CT9)vPoF z-w+g4U|nEDpj@=4uh+QhYA5tDO<=&EPy27!{>Rr%iNE8cJ2}y#>i3Uge_vohx80<= zglvblaQsw(u8ct!tu1r}K#X0Eik`F8jvU;ck+C0#s7Ev{|hDC z_RFlMsIcIn#WJ0#bf;N%`RbTA5dA~|4VnSvnAvga=#5hmlNpQ1mODy&&|2A2OGMZ( z7w2(Ufu^9s1=w}HK6mJ_S^O)hdC!0ADOb&FL?O zycp(7<$YD=8*v^^{IDn_)*3rhQ7b;)Mc=yR7;l>;(!WPlXwiGR$mTk_*!E(I6GRgH zyx6%yO|4CEOFQcv$YJtMWfj)vX*%sXUS#~)_- zk>)unrAu*%+VP3s`X%`IjG!$y1H4Z(~6e6`6y5%7@~)EXxcxdt73BvK(bn~rhir6jO%-qG5dr#r|$W#zvbZK}wz`rjPKUD<4B) z&OKi*lYE24DsdjWHx959?mDJYdkm_L8^NS_6D^%ma0?hZ!*=w*hl`s$@8b!8#zPrk zw%*#&p4k4ZQo#RW@4e%i%GSMcELdQ`f`E#^j7k@<&=If+^&l? zB@V6r%mJLddIeu)8@0^k%l)bl#wodP1W~bk>Z;le)s3{|Dk06!jF5N z=Uh{Owt76!SI3zurb`$oFaWYM!!P1fM5Um21sbxM&P+@|GeiMFeGoNMlHV_9PDn7PFYuWk-^rWOxB7y-yF!GD>{`-Uy#OGLo$u@fjve1 z+EK`obqDomI-}IFt@y?BsS38ovmX7xDy7n%^Lmd*8q#>4?V~ug^=LLbohUOx3yt}! zQ1`~hG1qbq)iUyYzhHEAP2eS_U{>#0s*>IlfDxK|R}!aZvG6++`tjootXq?TQer`G zjO?UO@tV$-il$%Qz7I>$bVWaY(%Ab-jeQ} z-cN$%dI!|KAx)81W%~HEiozD(7wu*|YHcHmLB5)h;AX|eU2&Z5Mdc~wqA!fc(|}pbxWpwEPjyb*E?ZV-ze5CExy*Dvi&3MH?rlkLAq_K z9mbP`4zCQc(oyw-aUI*xu_pt){WE8NCuwGg;FGM1A98sAj|AKW)4-AoKXZT`e&*<@ zq1y!3d(UbALY=saj-8g%>st=sdXv#HBA%cbqxo$z(AbNsk=Z_hoV+QXi&ta1YtLqR zdiTLqt>t=A-Zhw}y*UMz(2D2uq2t%qEUIWMoLi%q`0yG@Fwm`zokp(AyYi=s`mb*H z?De;ra8}cYK=ETVZorRAesJsHBa>;8W}UO}b^TqsG5lYlLqTh&F|4^IxA+z0hx0vl z-(z)7rIb{>17p^+ zVb6c&@H&@Ym^!~HKi5B_{YyXqCkQwaHB4Q=Pu*<3FLOvIz6Gv$42|whzL|hVxDy}i z)kew#V+S|}7&|U$<<3T^X{sep+&CER0x@lpmMdutWGJq7jY!ONU_UfXGn9mc*N&B; zmQT(DvA6!Xg`=B3mBmV=^gK1_HfqPzN zVQ*t}|GSZmsYce@9l!$!TWS%CS93;gEbR6Nor%atimlT#{ zum-BUq_+uEP2V>y4bLA-?Hb`$KDPNvygg78;%z$DUlqG@c{Cu~4T%;LXR7uHet}3( z2BEWKd*lnFB+@RaO;m{EpQXG@TA6^Wg3gj{eAX=+28$V~3meWCrjUA%*-_c3Lo`nr z0^XF6>6f1g_6v_wGZ;58V!Hy7AQeF}6{qA+A%BFUqY<}iu1zkWje!}|ETb}Qg;1EC zYHsg=g~r_|X*Vh(1-dmVR#n88)g|W^_cf0PnrO1lRD8&`-j|e+4xaD#HC}~jFb`&I z!4PLHxV|C!-seJ^R!#7bRP@4q-$U6#cwfcAq8-2Hw@>X3-0lGO?J|_wy6q`RO;Ezd z;8eey#!;IK`EMR9<*E3dok<|OAoY=Z4&e)PEbv}TXqKO$p{XAnPaO0cchz`GUsp@r zFWb(tXqk3br?P76riIL|YY#hTOvIvN?F6PnNkN)=m3~vJMt2qc&?Z1Y(gGsfu#X(W z-Pe>c(Z?alJTMVLR%s2NT7Y;%49iikW<0G0?pI}ZW3*G1dd9SZx-D5d(-!P4R_n+6 zWoNNM{k41L?*}@3*W|>^DXEfv9U|3$ug@uiyq33fvokF=c$v0-?G|Fr9Vd=z<=&b? zx)Rmi$9_?J?M}vTTLO4g)VJarrE;;l<gkrd(05BqwgfjcfKXW%4uU z3E~3V_1W{&CRxf?X-gd&o=_7zGNIXA`AdG@-%ws6&PdRFqYK-ZnS<%2Vk&ORoc`qY z#2Isl=0U}pThr1ZQjI(nVAzEP2*o|@4*I1cCjCdR+5GbW~9u@XM_ z@9Pxt`$FZDgPFT)cDx~^n53INp1M%xpF%jkV=S_S86GZCAUbyLz5VX5tH=}7BVrrg zquKVu({#B(p874jk*?y{7v_ ziH6~mB5~5WY_#P}4u+_zSTCCc@j(nl>|#bd&9b#ef#HGSE_Ztb4dGEnYH^uv+VohX zzNaOnS9{u!!6_T(tegaPs}dEG1aA+^#U@QwEQtXyBuEl?$dMc#Bh9cd`qcGGC%hjU zdJuf^$y{b5PJ6jqW~uK(7Y8S2!uWQU3og?Xj8SK23>}y8DXA+3Elc=@<~S6+TS^ku zg)f6IQYw~401ALa5^OK((b4mqSBsk8*cD7Qu#TM@v&V90&RWDG{a{5%ilyNGPjhSF zP54TF#(O!!xSz&s5~O==sypbRDRb?AKsfHlRz;UyFa-v1c%HWPebG#ev zqGgT#y)2)s8b5e*kGhtH;EoYo+&kF7eUZ`gQQoeHeiOAlqXHDh6S zykv9WH`IJcXh9h;U(5u~Gkng>^NZ5)!Ah4=AxHA0nPXM}Ir_r57s88V{DUAxh_#8& z_*lLWT)>TV_zCdBt&(i0rM{ZxxC$%`v*n6va8i;c;e`%vM=e(!7;e}$|CTAjoEfxB zp_#W1Sc!T>p(Tz{DdlXj_5{7k`zI=o@p3z=b>4Q&%D9LRpKAo04pbkB8WP&-4d_P2 z8^sTJ1(T1x7t%Q>N!-1`8tvE9i2ugq@b_ye$E2La@{+sKIxWZG00~INDGf1TF$w=W zGXpH>a0qU!2l;)|59c%lHOE(UDP^&S!*u__ennc#aF1(d+|0vvbBa=XNx1La#Nj7o za$#y&-;@z7l8I7eHuoN^&^HpYk;Y2D2D?c$%Yh-D;ML(ZZ;f|A(fPWzj<5qwuUP(r9r+jficBk z^h1tn7bz!ChOdYSJD%OVwl>{X?gkqOb<}d=SLCMW1|;!)p;oYiNzAzPHh2Ct_R;)< zMH*u0iEJK4y(6WRv%P^5b>=WjcV4*HM8R&q10C{)BPR_EQf>Iyjb>m|7~9f8JUbSNk{Qnm$x<2)iAK{ zyaGR`8T}znVKuF985HW1JZFjOYvh+!_;RfgK%>Z4|9ku7IHiQR{5HekG3`7w9h6z; z?b;o-9>SgW+~u-HNXk8n0nbxsXWhVXb(JRIn1qGJz7e*#Ry#>M8N-{~22J!s-Ly;bv$PrvZEOg+~8-gs#&I|n>fssC(RVyKwcnd1J4^>v?}9;x!gdKnwrl28D*+!!t1XK{@~?8+bH z{`cM-s7U+tc3*0<6cYzscw*^Z4e~TV31dj1B}}ZHOF4 zP@WIkX!P&5zy0nzfgrq>hE~*5>5g08kov7`a=+MTs;0)x7ONw!dhhaOL5bIjZj`UN zb>q|(Pq@34Q0iO5gXQbyFC;<%R;er>%sG9#-o639L@yF}#GZPO@6)y7SMd{Q(`4+a zb3JH#QbPQsr-5_lI0e4daH3Y-zK@v!5OtHg_H12=lHOPf=0i`j*G$gdF8-CK{Arg*co@su z3!j=bT}2pX;p8&_XlV?0FCWIx zEw(!)>US=oNqe#e3P9JqL%Z3Sq=kw_BzvaRE0^!iH6Iaf6I-9S8-wC()3K6XB-|}s zNU@1rI@;}KlRYO*EFNXA1_=+RyN3WOUg4XAk@@60-dGu#-+t9d4`RPk#fh=!J}2`I zPgXjRN7Zk7$AN-i4@{#soLNNw(y`fW&Pc9Jji%*vND2d^+SJqJ9pavrtT8jXI-ypK zil0Qgc|r!9#5hG}7zVJuiBOJncdo7^KVrU_sz|7At_L=Q=mrDshcqSn3agy6p%K`V ze9rkE{XvcDlI1h5HiT)y6#6W8S`X=#|$ zVUeuTPGpTfmALYSNMiT){#oI12-&K&t2)Uh~;yyStQr!mXrcGCZGP2z@eziBipo z8z!m9T|OP6E7!p;^j{sBwpi#lCE1b6%YhBS=|RqyTw|~i^4O6aI5-L<&7HUMV+O(x zR^vGx5-d`xC^)P37o$xNVk*?zzz$2 zx40*3Vruj-_NyL2%^(S%j<3U2HlW4{AVY&A^l0Kr$F!ctaNlAqDG1N%W7T>WeB{Mf zjrThj@g!^CKUSGs{vAu2X9UB9C z-S!!uB`hI9t0AqL)*k2Xz$oCRqKr!yQ|1Z~;rjXJVC6wVYch8d0SlO;`Yw>IczbNp zc3Nt6uj=C9-p_{#dXJeV#bE~w6K_Wt^*Mvo@$RdZ$@-DkIkSSyX56W+Zr)$5eQU<@ zVZvj+_vh8{2h%2Wk0?!$Wb20$^?cptt%Z_>a?M|%@3P>8s=l7*4lgCAr< zFe3?*`!lmHLqKY7@}8I|u8Z6oLg<*WXIi(9D-{G1y8$ouj7$ex@y4WKKNtvW^YU-I3$=2;>OlYJ|x z<9e55>t9AJOG8s~tXIa=v^cYJH^kFiAmrIF-&ygsoG}1NVg0M(1*drNl_PNi(yon7 zC1tTC7FA8JqPwFlay(SybR#lH1o=w?rDXbJIGY{{sK~H2P?KQwYj`yzvZj2e9t}hr zxUDKvW!9pz9W8Pk3Ym(A*37<`uQO0xC@=$+e9hyDa=8g5q)PK4Qb_93Mj*}US&!gN z;@LzC1XZw=oYw+> z%=A@)AHVi>nq|+*%!-pXb1Z==^Kj&oJ$T8f(KKmD8FU)bP?kv1_a+aMFYBr+RXLXl z=zZ&!&xg}{XL912qAC-2Q@TSf2y-gr3RdQDLW~`tIlCVv_asUT<#J!Mj3CWGzT1?r zgTm7TvAhS_{+G!Q^R8Mp$#x1CkT%bjPl2bon`RR<;_ zKzR)2bCES4gc&XF+U^xU6pdxvHY;fj`h77`b88%k z;ZW~hR+G+|w#T_Q1?F4;LHaIRtx*pg&1PJx98r-lOxHe+x7J5c@)s-i9TLt8DNVt? zsu44!L8aI>zCD(*L3kyUKJN=}yjuHj|L`OdmJ$&!unV&c~zPGr)MomWPPp6=N2a=A6ait0_E{ty%8vTC05&$3< z%qTtNxc=_nyBz+dG{EcY#;r6~@KMVAX&KffotYg$QH65D>*%c`Mjue^&DX;o0{opa zr;J=Y--{U@Twc&@iTo9;2l2Z+gAL#Fo;&Bo@%i@fuDWN@S59%CdZ;3x%HetE?|!@e z_YZUPKJ0m9_tXB}Uvu()%a9k*zTPJUwKtK`i{+KsRoyI1$~N$(c`#PR?{9jpf4%n69P=coMw`@f0*%Wq%x`2FGSZ+iTap%T#T z(T6<|Uo(8w&Gzra|EgOhpwh!%_5GUeYaQf$__fKdtgz<8wB(Xm+P|ZL3mwTSToE)H zZiyTb6JV6xPxfM1wxbl=wHGvy!=-YbTzLdD8t*u$y0gI2b??Rl8tQ2`t&-#D@3sDG z(f^J=|MEWaA75@tykE*MtMdhn9&(OOOe~&&m-J(%!}jkrI|5HNycC%zgmJO7ypuDe zw&;?c*+_DrZ=gUP_b*vq9+F)t0)?4D{CRdPHP;sAEs_t*BKF6*YRP^qDyYi3WN3ur z8FOVZ3|$w?lO{>K16-a*{>-twUmjJvZE~e>xt`&nc1h8A>{@D>F)}Kf*)#4E_o`U% zh(7=GNF*5n(JW-P!8U)msA5-`=91%+-q4t65LOOwxw0JfZA)jD>Yqp2RHSVR&7Vcv zyk*tD5jb(ElXa!3@OlRJAgE=$)eTFR>931_R}AiuNMu7-Ot4shKY;x2HyYD)(s$3GFIl^CEoM-?e=}*HQ84 zFh-#QUc0a<8U}=VeMHUeQ6(;~?7zMuyJ3W~?<)pE-rV|DT^3))xcA1CXYI`P-Q7mo zEF_PG^+@#W|Gshlt|b}l!P2?iS`S${Z)!~nam*Lb(mH2=1*`2{`|cL2 zq&sn7lVgJ$u=Y-ctwXPbTUH#>yo1&Cd5&woGMYJ?Y{Pq~?ra;_cqguFryK;L@is3> z7qZ4JNhl#-#rC!5EkE6S0)Ol0k9T-x?3ga7z$%{g=m}9Cy6gIapnv2lNRvmr>}Dq+ zVawlSVRAcTU!90$tMK)6O=E`8wx^cqgW(hnZP?V~)()%mGuoOH1>0Q;{g@fN4}1B% zMvtCLsS>u;PTLT70=$`HT10p26&gjJvorA~XhnYN3g(qo?n9gC26clJ#RS{Y=BDNB zwEH`-Xn=z&7bn=;BCS`;tXzKh%?v#%Bd7U{jpz@|eZ!hCIH3Nr_A)Cn9(D96-hUSg+fNnI{~yV6$DVNJ4H zx-Hxwd_wQ>oBB81)?$E>Bha>bQs9I7%T#Cj>R&E6rDQA33J%GU4w0E$p`U1OU?~X zf+d95)h-D;G6V&hmQL8l6I9b+XcpEK*?eJBs*|2-Qqdrus`+$qds=&tDx}q7X|Aqt zb4Y_)jA0O7!E+!gO!%YRzfox8H(TdGrQ|}@c zg^s%9^mx39lK;vzF5^DauK&J;mu|Uk;kWHi1QtX(vK9N=Fp4)^bx zeBFBFEwi?1!E`wIVjK|GF7x2#itD2Br@HYe4Hf+zS-iU9tZZ+8CCl$iSukfr!R05t zGZE>K%}MuGTC(8i=ZZTQEPU3A@k7{^K_DiY`Z8=(Z%u?v4z@6vXKZ?2wr=RJ)mjwW z;)vW7r;cc~{LE2kE~K_~y?QHFlLmyO*mrJW*ZWr@t~=Sh3;>|+yRX~`x9gU8mbOrK zd73Q+mZ@d#6h@%mZNCm&KBr@=$D0s@LT~X}o-XsFxKuwZSTPC@XvPdA8!LzG8``0# zbQ2qVZRDsu4*eG)%?lO8C$ z5L2wb!IC!|-=EdnL^L1sV35rTZg1*l*6aa4_Z~&cwE@N2yi2vVw&F7bQay(v3KIa= zKji0!rZLvTCHS?Z1uFrN z&k^|G3R}et=Q-AUYco=$>B8-4p~5$GydV!iUO6U6TNo0od7&=mR{WFD`EOQv#jg%u z-Zbu)0x-Ma1FjPQchrw1^4vep5N;TzvHbKR05;jhXHk$GS8Y;gm3~wSE3aH?0#$r-P_BrcI#9fgr_qweua+vF}(ce(3e%b82@;|%T z`M18!6Ve^i@cay|Hl*NA%rkI8+*C}i%rx;!z(iS0DECGWyd*5pr(ul5iyd;1&$~wx zy7eA_6ERXsn>2>wysYo2b1k}=2{kV%q{WSZvLfne*7u5c?p9w-`*bq7V5{Irwr?Dl zcTn4781)#AF+>@3WsND-345kP8~e{UR~79j}n_P($iMlw%;Tt?OWGfh8+^7wK}PW%G^u5 zmL!)o^ZhDfl(>G_8lWX!Gn`a9@V7SqSN}U8zzUTg8)0nT!z%p_-7@qaE!d|aYYgT2 zKLw|@2g_%!a{qGL`dQX`N!*xQ$c3(13assCj^3m((Y?kC4$GHrC4UCs*STJ+S|>*6 zq~JiBMc%o;)5A;uHTAjl+|{Bh)KBQ|(ThwQoDc)Nu6`JsNuF5KGcp}pbG?9_Fhxn` zaOdbhg|$o%YU}2mdU7OFk@IGVwS3@=_mFDP*kHJy6?em7S;)bROt^+J7r1lGfJjRm z{h8zO+}Rsf{+by7RjS?li(7v?EJ5#R4wO+>K6IZNlByiLEeK6hPU&&g+C$#XI3xe* zfs<@~ecJhPu-v@!mM9t*070i!?l%6f3jbH1%{z(yS1hoCF(xP0de8k1=_*>8I!pu* zUGMz)eP_%`7Q+cTY5Z`Id&bZ`(yVK!^`$12Fx~ku43o?sDkTHJu)P!e zOk7&&EX)XuxYdVIRlc|KBO*-AOK%nd_Q0%ksF#<&^E}*R=;OhkPHW^T zEcwb8p+`<8vf^XSVZm|ho&#C#elAgF2VCy`uJ6?Vz^Da9Z$bC;a+pwKN~HnaVKPWh2xjy*zw z60qy}7WtuaFBX;d7VZ|xQbMIKb^K86Ynf46&mBnIGmM;5+WUdO#h^Iqhqrf+Sz@8v zil-5ab&c{#bx&Rpb@^TxMj|5-o^ZATfH+I zQtqn;(F@QHe+aR~-+#njyBwgeeqmddkZh@kJp&&uEPELe%LyuG??QfFTC6opKk3FR z?JV6EYrhKfDiKe#G#p}lMohg>h|caF9Wf~Dj_K*v}lzF3kr*K3`QTMK}aP+awuS@P#K)N`rAoH!uB`xQL;1 zZ5#el^->JLeE}sNmwXh(3gf}uT969Eow2==r8Jb1{&u9VCxe=uMxl7e3JR?awAO{D z>3Z|zf-9Phjj4+lq3C2oy(lP1;_RA5EYVH7h)^VV2s}Q{#3XZHlV2F_Jg~KWSw*Je z$!(9W?>gmE^+mhRbZyGgNuB|e9`<2ihg+PSBbZTFs9%~Lli#A%!2}hqS6^|Uu+xK~^{KUU|4pkc(l*N$}k@=Ap1ESn>OBOF7J~NLPX8T~r{SJUwlvsopk|ad>w1MLM z6z+8Ytq9)@XPU`9F;20H*gzeRVDky+&5Ug{OXg@L4xE!Ma_;v| zu&TbelTR2D@{SCXwszyL-r8?ks}NJVt=#P(WK4joFGgyK1;~et4S^PRS62;We$IXU z)b3uLWStyH%)YE4;NLP9*cMs8uYm`@?LgZPqD@j=R-sLadWn;#yLme<2G6Qu!3xe; zW&7dx4pkLOi6O$ej;s;aJ@Q|7)-(!yLqeV}^mn~B^hgZ6zs3>wZxQG(6e1Z05Z0Ht zi(@;F#7+?41?cbK?bYOV1T+@u+*U*fPTaveE8xl0XD{+?=6aQpMw736?Z+=CCBGKM z?IQL(QST5kC31OGmoUnw8~lFez(^G$L7E5UDtLN2lcE^yKE;VxZS|R)@f{z4Y96yN zp##@G^M%XFZ9`tEIj?a3*t;=`aAW4?a7u6L*-~mSgCuuK^)2KeaaqS3T^)Cy`!Rgh zn2k1}PdL+NtIc%XvR#->(>cMUWOWPr{dO_ezIGl#fyBOYX?7sD$UQWrg) zKOVezC~F`x71Aino@V-=xH?C+G!IsldlBRzf8>sorBctOk3Pz_@v8c$gCP96I!dpq zeDNp3yK#a`I1ySD0QQxZ@;@Abx*O^0(94!#&&_$d_I8gPPfgqvK|nE@6|o>c|h}T~kQK%<`-WI!fKf&!Q|X!I142Y-D4Qw_jlQ%T~YiWap1bK9VDJXo?4t zf{86rKtvjj!`w%I?6gx;K9RO?LM@yrtuP`IlqFNDpR#4)H%YpBCo?SsXm>SF#c|u* z8N;~LW4(#Ph{67k{`>;{iuSo?_dRk$*w%`&&DY-t0IR*bA$PPdK2R*qD>`T#x5=Y~ zNtXRkF?u;n<7+h68CwMCh;aP4Vp?JL-Ogpi%mkHtm}=rN&t5|Q1>^cFNn{9Q6N5I8 zZ~LoCsifs^T~@Okkx#fl+HB+RiWctZL~ZqWqcHVa zGfto#%h`>nXB?;4f8|yGY?|FAofgm}REseG^F9O^=EaUm5AuR6U2KpNt{9n#cfCd^ z0X$tpj{~I4p!KMR!V$W60gp{2bKSuVdciOX(w-XRR#Jy&Qz*s6Dv#k;IPLqQmsTo( zB|kNc=gZpXwXJErRkH_?SaIct9Csi7K|GPmelXN1k30P-q3MBC&o)fse59?A*x;P# zoj;`G*u}KYzIV$j7~#6r3uV+`8g<39eZ&#!DS>Nvq4qJH^D7b&CgKBQ$QFW~dO zPymx+qa1Ey0*TIc&N-t;`}7jL2TDaO-wSC*OKaTEJfo_ud3jBXYph`mv{g5fyX}8X z0^pR zc}$naL`itkkAl<4MRs6Bc+pAU0Y{`2#7g-1cO*`fl$%kcKW2oCpcLgyGwfeml(eaN z<2(Lmb|!Z1VQJo;-#F<7o=J?-{tW8TAB}LYeKo*{sA1IEItY}26t*C@bxgCz&vn~B zNQ$z4+2|g4amHa3LVeSdHscI+M;X%!)MC?Hj(X0gH*QU|jw#7Fbk8mU9}KLr_uH0R zp5yDx1PVOHEOo~KeuQ?PK0IzDGQoYt@IJ>=4q;H(7oeb(I(T}{ug zGTL@1+gfnl`J}mN`(+4oTb8+TfXA+FeV+vyh%hJAF;e$#62{on zy)b#qoR$9TTs0?+?iN(jrtmq4WDcr45VTypMjdFgQ~pySycNGzg+}gDSSWodhWcyi zL67MgBf6ii1pC#rZExz)h5sKt5Awq30|CLZ`g~A#97uU|)Pc6D)h&|nQHliRje>fj zh4NjH?DEoOvo$Ct**UnMC#iK$xIdorfzeKfrh^|!f{xS=cYOvN=>9LfcE_=l5Dj+~ zzmuk?ORZ>^it+oG3w2XZ1z62ll&v9zUyq;Kyt$%lYG9CH!_2v8Xqo6zI^sj8iaFIp z{RHrlzG7f3|3TEQsfD-6g}2rxAIHTx&pj=(*;Mvwma)p+I)eg$o?di3%*!;uzIc}s zhcMN{oyq-A>xC0mtYXx3sZ2o31MQrtR9d1bt30NM4_uCaZ|b`E`J;SO!^9o)JO63^ z|L%d>Bolir!WDo^4*_{AQiTdNt}M%f+b0`!3}^1H)%bcu4Q_M_x~P-t1TG66l!BX~z8N?s zj`IY>-ORcpI9UW{>DuOf*_V!mMx90E15ss1Eu*W2|2r}NFP`fv`sLGnGEOdw1!@Hy zn2)V+d$v7hZ(HnMczmJ<`E>BTfxkhxKU8GjddWg?Db~T;#}W6LqyFI^UdicFEc(<7 zz<4^p(dNLP(-nH83BZA>nY?0lCIOUeBKlDCv&}vJ_$b* z<-;BDd9-qMktbugG&Fk!AGx|Z#GOWRSGR~-_rH-SB_8F>>=$(z`4Yz3<50!0Iih<= zhz61?rf;;IU&}fk*W0Ycr$4c$?D-@l)jmP}#BpOoFGP6?GQT`$rt`(en+AtEdiLl~ zY*~t#&@RU89Xd-w8Igl8n+69y5^uV&cAp{t?1p;1r(wiM6`m;gayjW{0q653uQj8` zg|__1bNoPNp5yaNn3?Qj1fTtP)Fzl%9YI&Ijy2Ka!KR{ovrdAYj>+x&zka>!`f7E( zU#-p|!z3|yO__UPe&db6n?KanU%vsic?(V{8N|1I+vXkIV13ag#5%HyG=i<2a~oCa z1!9T5-9nZpC!Yql=xt_ppK#ufyOgc-t_9e59MR-nyjwmM>ab}MBG#FQky{9kC;<8U zW_McJuFKTVmTkd(_>-EJ65H1oylj|8&L&!A>-Ab1Ia?xzjtzBZ!)a@gPQ-V+cpv_C zf&a;r2gI2QPUQReuZIYstDE!VG4s}??+WBsV_ zlD6`UQwIqAgRA4gY`V3kfB8_bj#2y19M_#c$`Kl4I}gh7p9#GX<~^DWdp`UVQ}|Nm z1+uqdl({?2iZAA-&0NSK)C&iNoU}^kV4U%jIj~mTL7xojCV4QjNnptA-L60E{Bm{P z*E%o)kdoOhr_la>c*_4;Cp@7&jwLO=;+X-LFDHbf?1IXKL0GuwR$bG3OuzslK^SsTF(5ta*V=a!^wJ!lt9L1uDkr6(8jiP7dsYW?s>dC&N#-u%oF^=#hc zXO3d)^k42`zTI(R3tDS$%70*NN_IA<52eSuJ|O&D}HsN3hy-sV9~KwP!%-<_9}v$W&325ClkL<&BV7$_(`-eE3(x%>G= ze%zqQq+jGvZ=+TSMXXMNK8JPkDb>K77BZ_gtOw9)lYjI_MawE~-V%wzOoaI+nDU#0 ztyfYkV^drD_BRL6u-!6q6`~)KJ0h`pM1CoUt~WkX*YxmYLz^ib`QM&BEGfKJ98tX$ zC7;J#z5q<5mV4Ud{RsdvV-5_fZG0E@QV;HDf=1+H)j*;X9Fp z##Hbrl+t=*h3lHxSIE)-xog(~{oV$L0mZ1hH@^o_YY^iBfIDwRmx`?u6K#-#BL~~M zTilUOuj9CBhR*6}rZ&)BT@1I9{gVD(MG5KaM#}Clw_H=VDjLw?xgTM92N8AI z`PLlNQ*RlzTlamj{L>GMoTf$Eh6Y2&)9*=i%=E23d8}7cUF9%LuWP5l$w}LF?%ihY z7BRfh8v9DezpGA4V*NN#@I2ZTrGt@QH(xY)8dlEF0zJYE;`>zmmLA=Jw;q50Y>(gnxzuiyGE+A-)vtJlJ-M}8Ki?Yx9jCDS-yL*(xuy5Zz;X0RN3y8a*ryF#;-Q~D zOT-QP;FSHjJf659J;Pz^Q!uj`+I(r!r|n{Voo>ES(lll>s!wpF!H3%zrvYH8_NHIQr)Ka^>5$?nl+xN+qm5M;L1_Gm zH8B|e1H^No)Ym4V`-Iyokv6}m=Zsjknk$$DQ-p3`oLkmS9M9JXQt32SbmaCkjk=dJ zqrBR*nEH9r>d*_L2g^^Y?=GHmjyxmPqW{>^L4wW-DsNelHXh=@>uCiL)2{ z|2Ph@zGNC!Mdb1WTy$1G{o)^|4zSRe2S-;Yd{`ED&0o(}k~sgWBe2JsF@p{o0W*&$}@`~ z|7B6AR={dt8f7B}D_VB&`~MFm4Q!<}i}kmMsBC|}K_lMx#=aeUbME4hs)-TY`LpGi6R#U4y<}aYj8}{TRXAxqYAws#V;I zvw7^cWsz(4&)Omf$KJ5p6TKAB9pfx_jmb>OjNSxT(Yic_xa<4#<_vv$d(Q)(WF_MM z^?b{_TU(-!2QOyI+U68sU^1W?0H#%7Y71KvHfA~}q9K*;?P?Oi#)^!nL*4tXHvj=L zJ$M)Dfh1#qL7eT7{*E#Os52e;0> zGCXKNmNSqVraYVec}{l|{~`)u_XL^v?w?ofU&yg1;t0g7VC#iZ#t2XMPVl+(#W`oa z`Z80o7Xo{%i$rL?yoRjDCbn&qYL^DVosQY|m*--?+;y$L7^BNCN$p}NbJRPT?K4U; zC^Em09OY>((H#7-G`Um#MV)P)<>D>J?(`ebGA#5kls-5sac@;W=K5&2#~WFhv;~3; z*g4kRD5zyN5HE8!amZo7J@Ha4^8<%&@sseCq*t|iS!lwLbG_pC`%-PazYFC$2t;Ep zyR{hD0kHqv+oBaswrH|G`44t-=p&OzUUow8hhpF-8-}VVBM@6?k8xiPO^7;b&_F@0&fp zB-_?mcQDGdi0;9 zd4c~XS2tunSA*yi7HGJn>%tc8GcM=rOu4|HXRh-!U3_@Fu6Ui~cE93K7vgD|`Hp8X>&mS(-tKBMI?G zu+|<>Oix6<+f=ur1Nod>Ct0fu_u$Mo$zKuser<|hMWS^ZZ42$`S=C*QlmGEou@yJj zerdDh3gpENU1O)UDA=`&we9fsaCl>P#`=ZaJ~HK7RY!j-`Yifzq1~^=H@HlHt-YhS zTt?}N;(?DlO$JUuZq%uCODV*SS8cl%*gKHrdhw9qw+gfRhoFam$dUE+w??KKozBor z{P=}m+b%9YHEam}yWPCK&@KMH#ar0*c9C5;f{0tCKHlQ z{DyVLAmdC#pC1U4R#xAbeI6n0x-C+}N-o z=UrJl-zxd7$CU1+hS9e?FvIK&riFxrU;MCe+4Nh_wgOpG6^^)vk<(}RUsireScpX1 zr&i_gcv?n<-R{>mQFZ|ETkG3T{jM3Ruj#@Ybm%Lr zR5MNg!8C=$z9grnv`x*{na#oZHP;L5<$+qJh1qjc3h|+jPBSpPDT>2nZ+|~6{UWqUKpR8nzRG{EQc6$ zm$7S&%EyJQc6DL+wF!T=UpXl-E{x3n#nO%I1jQLo{#;FoTa(H=O6ZfNFPP-0LA#$h z67elsu^VlL_t1H5MzWRwU~JQRuC!4d{+l987upLvgvpy&bwR+O>dR(OJcwT5iz4dbhJ@ z{pCK3v}FOu%}80@0Z&isKr^j3@eVQ_-dIRL@5m@fzHdE>A>e|PEszaEk8U)0 z6&Vky6!F_5V(&A{Tpp}a#Xgk2-p#O({*Zs!-dPpu9Z_cC2U<&Y*;X?AMX!zAp5_GuAj|9asTbH)hvwT?S)3>85mW}TLbLcE;!WkmXEBSzNuS>ou=r}$iDNX_hdqa!2S9G{{m zc=)l)A{U?p2`Mph($-K-vSg!t*NuN%+P-_weizdJv!4raf_vZTl)7PrVZ>CNaKcDc zGj=h>a@RW!o&!1V1#DULs3G3)iFX_+Y5PWBqgeGtJ%rw4v430@dlarwjK5T@gz3RL zHP!{5Fdl@RpNh?#p*2fTRXu{&z>v`Dn_r9exSl6;zO4)WW7^!zf}8#ZEqt0%!R0sB zeiM4-Z{n%{R+jj?OFzo|)^y&oSf9}A8F-p8+^zI%+A3Bi9l*cW>T;^Q#{mSLNy5?? z!_4GYKQZ$$$55YJ$E7!hyLSD<7WzBG7!8Atf5by1FXtsg6iknR#&9hHGui%HPcC zzcCGPMX6lX0PYUpIb-t8+z|PLGyS~rEpsP1II5gA#1PJK?G9`?C8TfS|EzH9+xlva z`_nz!;R$zE!&$60FuUkU>s*{LtL|EjWqoA$b+v>iCJc-{$Dc^rrFI3+C&MivD~1J^7{Wb0eJ!Ky8r+H literal 0 HcmV?d00001 From d988b05d30e5b81581e350bfb07ae85b8db160ea Mon Sep 17 00:00:00 2001 From: Chris Abraham Date: Wed, 30 Oct 2024 14:34:00 -0400 Subject: [PATCH 2/2] Update publish date Signed-off-by: Chris Abraham --- ...n-stages.md => 2024-10-30-triton-kernel-compilation-stages.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename _posts/{2024-10-29-triton-kernel-compilation-stages.md => 2024-10-30-triton-kernel-compilation-stages.md} (100%) diff --git a/_posts/2024-10-29-triton-kernel-compilation-stages.md b/_posts/2024-10-30-triton-kernel-compilation-stages.md similarity index 100% rename from _posts/2024-10-29-triton-kernel-compilation-stages.md rename to _posts/2024-10-30-triton-kernel-compilation-stages.md